[Core]: Use unique_ptr in DeviceMgr

In order to take advantage of the type system to help enforce ownership, this
change refactors DeviceMgr to use std::unique_ptr<Device> instead of Device*'s.
It also updates all callers to use the new types.

PiperOrigin-RevId: 222645861
This commit is contained in:
Brennan Saeta 2018-11-23 14:36:05 -08:00 committed by TensorFlower Gardener
parent a1532717be
commit 809ed3c835
61 changed files with 272 additions and 270 deletions

View File

@ -50,6 +50,7 @@ tf_cuda_library(
],
"//conditions:default": [],
}) + [
"@com_google_absl//absl/memory",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
@ -80,7 +81,7 @@ tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
std::vector<tensorflow::Device*> remote_devices;
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
tensorflow::Status status;
// TODO(nareshmodi) do this in parallel instead of serially.
for (const string& remote_worker : remote_workers) {
@ -93,7 +94,7 @@ tensorflow::Status GetAllRemoteDevices(
status = s;
if (s.ok()) {
for (tensorflow::Device* d : *devices) {
remote_devices.push_back(d);
remote_devices.emplace_back(d);
}
}
n.Notify();
@ -101,7 +102,7 @@ tensorflow::Status GetAllRemoteDevices(
n.WaitForNotification();
}
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
new tensorflow::DeviceMgr(remote_devices));
new tensorflow::DeviceMgr(std::move(remote_devices)));
TF_RETURN_IF_ERROR(status);
@ -262,13 +263,13 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<tensorflow::Device>> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
&devices);
if (!status->status.ok()) return nullptr;
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(devices));
new tensorflow::DeviceMgr(std::move(devices)));
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());

View File

@ -42,14 +42,8 @@ class BuildXlaOpsTest : public ::testing::Test {
.ok());
}
void TearDown() override {
for (Device* device : devices_) {
delete device;
}
}
private:
std::vector<Device*> devices_;
std::vector<std::unique_ptr<Device>> devices_;
};
using ::tensorflow::testing::FindNodeByName;

View File

@ -59,8 +59,9 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 1});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices_));
options, "/job:localhost/replica:0/task:0", &devices));
FunctionDefLibrary proto;
for (const auto& fdef : flib) {
@ -69,7 +70,7 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
OpRegistry::Global(), proto);
OptimizerOptions opts;
device_mgr_ = absl::make_unique<DeviceMgr>(devices_);
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
@ -77,7 +78,6 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
}
FunctionLibraryRuntime* flr_;
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;

View File

@ -34,15 +34,9 @@ namespace tensorflow {
//
// It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
// make this more direct, but probably not worth it solely for this test.
std::vector<Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices));
auto delete_devices = gtl::MakeCleanup([&] {
for (Device* d : devices) {
delete d;
}
});
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
opt_options.session_options = session_options;

View File

@ -386,7 +386,7 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
TF_ASSERT_OK(s.ToGraph(graph.get()));
// This is needed to register the XLA_GPU device.
std::vector<Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
TF_ASSERT_OK(DeviceFactory::AddDevices(
SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
@ -400,10 +400,6 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
TF_ASSERT_OK(PartiallyDecluster(&graph));
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
for (Device* d : devices) {
delete d;
}
}
TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {

View File

@ -31,12 +31,12 @@ namespace tensorflow {
class XlaCpuDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override;
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
const string& name_prefix,
std::vector<Device*>* devices) {
Status XlaCpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
bool compile_on_demand = flags->tf_xla_compile_on_demand;
@ -63,8 +63,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
options.device_ordinal = 0;
options.compilation_device_name = DEVICE_CPU_XLA_JIT;
options.use_multiple_streams = false;
auto device = absl::make_unique<XlaDevice>(session_options, options);
devices->push_back(device.release());
devices->push_back(absl::make_unique<XlaDevice>(session_options, options));
return Status::OK();
}

View File

@ -29,12 +29,12 @@ namespace tensorflow {
class XlaGpuDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override;
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
const string& name_prefix,
std::vector<Device*>* devices) {
Status XlaGpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.autoclustering_policy =
@ -70,7 +70,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
return status;
}
devices->push_back(device.release());
devices->push_back(std::move(device));
}
return Status::OK();
}

View File

@ -33,12 +33,12 @@ constexpr std::array<DataType, 9> kExecAllTypes = {
class XlaInterpreterDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override;
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaInterpreterDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<Device*>* devices) {
std::vector<std::unique_ptr<Device>>* devices) {
static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(
DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT);
(void)registrations;
@ -61,8 +61,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
options.device_ordinal = 0;
options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
options.use_multiple_streams = false;
auto device = absl::make_unique<XlaDevice>(session_options, options);
devices->push_back(device.release());
devices->push_back(absl::make_unique<XlaDevice>(session_options, options));
return Status::OK();
}

View File

@ -380,7 +380,7 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
initialization_status_(Status::OK()),
next_step_id_(1),
device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
device_mgr_({device_}) {
device_mgr_(absl::WrapUnique(device_)) {
CHECK(!options_.device_type.type_string().empty());
if (options_.populate_resource_manager) {
initialization_status_ =

View File

@ -2963,6 +2963,7 @@ tf_cuda_library(
":lib_internal",
":proto_text",
":protos_all_cc",
"@com_google_absl//absl/memory",
"//third_party/eigen3",
"//tensorflow/core/grappler:grappler_item",
] + mkl_deps(),
@ -3816,6 +3817,7 @@ tf_cc_tests_gpu(
":test",
":test_main",
":testlib",
"@com_google_absl//absl/memory",
],
)
@ -3844,6 +3846,7 @@ tf_cc_tests_gpu(
":test",
":test_main",
":testlib",
"@com_google_absl//absl/memory",
],
)
@ -4411,6 +4414,7 @@ tf_cc_test(
"//tensorflow/core/kernels:random_ops",
"//tensorflow/core/kernels:shape_ops",
"//third_party/eigen3",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)

View File

@ -38,8 +38,9 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
auto* device_count = options.config.mutable_device_count();
string task_name = "/job:localhost/replica:0/task:0";
device_count->insert({"CPU", NUM_DEVS});
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
device_mgr_.reset(new DeviceMgr(devices_));
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(device_mgr_.get()));
std::unique_ptr<ParamResolverInterface> prl(
@ -50,7 +51,6 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
}
std::unique_ptr<CollectiveExecutorMgr> cme_;
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
};

View File

@ -37,8 +37,9 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
string task_name = "/job:localhost/replica:0/task:0";
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", NUM_DEVS});
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
device_mgr_.reset(new DeviceMgr(devices_));
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
task_name));
@ -73,7 +74,6 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
}
}
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_;
std::unique_ptr<CollectiveParamResolverLocal> prl_;

View File

@ -42,8 +42,9 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", NUM_DEVS});
TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices_));
device_mgr_.reset(new DeviceMgr(devices_));
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
kTaskName));
@ -51,7 +52,6 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
kStepId));
}
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_;
std::unique_ptr<CollectiveParamResolverLocal> prl_;

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@ -89,9 +90,9 @@ DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
return it->second.factory.get();
}
Status DeviceFactory::AddDevices(const SessionOptions& options,
const string& name_prefix,
std::vector<Device*>* devices) {
Status DeviceFactory::AddDevices(
const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
// CPU first. A CPU device is required.
auto cpu_factory = GetFactory("CPU");
if (!cpu_factory) {
@ -116,16 +117,16 @@ Status DeviceFactory::AddDevices(const SessionOptions& options,
return Status::OK();
}
Device* DeviceFactory::NewDevice(const string& type,
const SessionOptions& options,
const string& name_prefix) {
std::unique_ptr<Device> DeviceFactory::NewDevice(const string& type,
const SessionOptions& options,
const string& name_prefix) {
auto device_factory = GetFactory(type);
if (!device_factory) {
return nullptr;
}
SessionOptions opt = options;
(*opt.config.mutable_device_count())[type] = 1;
std::vector<Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(device_factory->CreateDevices(opt, name_prefix, &devices));
int expected_num_devices = 1;
auto iter = options.config.device_count().find(type);
@ -133,7 +134,7 @@ Device* DeviceFactory::NewDevice(const string& type,
expected_num_devices = iter->second;
}
DCHECK_EQ(devices.size(), static_cast<size_t>(expected_num_devices));
return devices[0];
return std::move(devices[0]);
}
} // namespace tensorflow

View File

@ -40,18 +40,19 @@ class DeviceFactory {
// CPU devices are added first.
static Status AddDevices(const SessionOptions& options,
const string& name_prefix,
std::vector<Device*>* devices);
std::vector<std::unique_ptr<Device>>* devices);
// Helper for tests. Create a single device of type "type". The
// returned device is always numbered zero, so if creating multiple
// devices of the same type, supply distinct name_prefix arguments.
static Device* NewDevice(const string& type, const SessionOptions& options,
const string& name_prefix);
static std::unique_ptr<Device> NewDevice(const string& type,
const SessionOptions& options,
const string& name_prefix);
// Most clients should call AddDevices() instead.
virtual Status CreateDevices(const SessionOptions& options,
const string& name_prefix,
std::vector<Device*>* devices) = 0;
virtual Status CreateDevices(
const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) = 0;
// Return the device priority number for a "device_type" string.
//

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include <memory>
#include <vector>
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
@ -24,32 +25,32 @@ limitations under the License.
namespace tensorflow {
DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
: name_backing_store_(128) {
for (Device* d : devices) {
DeviceMgr::DeviceMgr(std::vector<std::unique_ptr<Device>> devices)
: devices_(std::move(devices)), name_backing_store_(128) {
for (auto& d : devices_) {
CHECK(d->device_mgr_ == nullptr);
d->device_mgr_ = this;
devices_.push_back(d);
// Register under the (1) full name and (2) canonical name.
for (const string& name :
DeviceNameUtils::GetNamesForDeviceMappings(d->parsed_name())) {
device_map_[CopyToBackingStore(name)] = d;
device_map_[CopyToBackingStore(name)] = d.get();
}
// Register under the (3) local name and (4) legacy local name.
for (const string& name :
DeviceNameUtils::GetLocalNamesForDeviceMappings(d->parsed_name())) {
device_map_[CopyToBackingStore(name)] = d;
device_map_[CopyToBackingStore(name)] = d.get();
}
device_type_counts_[d->device_type()]++;
}
}
DeviceMgr::~DeviceMgr() {
// TODO(b/37437134): Remove destructor after converting to std::unique_ptr.
for (Device* p : devices_) delete p;
}
DeviceMgr::DeviceMgr(std::unique_ptr<Device> device)
: DeviceMgr([&device] {
std::vector<std::unique_ptr<Device>> vector;
vector.push_back(std::move(device));
return vector;
}()) {}
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
size_t n = s.size();
@ -61,18 +62,22 @@ StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
void DeviceMgr::ListDeviceAttributes(
std::vector<DeviceAttributes>* devices) const {
devices->reserve(devices_.size());
for (Device* dev : devices_) {
for (const auto& dev : devices_) {
devices->emplace_back(dev->attributes());
}
}
std::vector<Device*> DeviceMgr::ListDevices() const {
return std::vector<Device*>(devices_.begin(), devices_.end());
std::vector<Device*> devices(devices_.size());
for (size_t i = 0; i < devices_.size(); ++i) {
devices[i] = devices_[i].get();
}
return devices;
}
string DeviceMgr::DebugString() const {
string out;
for (Device* dev : devices_) {
for (const auto& dev : devices_) {
strings::StrAppend(&out, dev->name(), "\n");
}
return out;
@ -80,7 +85,7 @@ string DeviceMgr::DebugString() const {
string DeviceMgr::DeviceMappingString() const {
string out;
for (Device* dev : devices_) {
for (const auto& dev : devices_) {
if (!dev->attributes().physical_device_desc().empty()) {
strings::StrAppend(&out, dev->name(), " -> ",
dev->attributes().physical_device_desc(), "\n");
@ -107,7 +112,7 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
Status s;
for (Device* dev : devices_) {
for (const auto& dev : devices_) {
if (containers.empty()) {
s.Update(dev->resource_manager()->Cleanup(
dev->resource_manager()->default_container()));

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
@ -34,15 +35,17 @@ class DeviceAttributes;
class DeviceMgr {
public:
// Takes ownership of each device in 'devices'.
// Constructs a DeviceMgr from a list of devices.
// TODO(zhifengc): Other initialization information.
// TODO(b/37437134): Use std::unique_ptr's to track ownership.
explicit DeviceMgr(const std::vector<Device*>& devices);
~DeviceMgr();
explicit DeviceMgr(std::vector<std::unique_ptr<Device>> devices);
// Constructs a DeviceMgr managing a single device.
explicit DeviceMgr(std::unique_ptr<Device> device);
// Returns attributes of all devices.
void ListDeviceAttributes(std::vector<DeviceAttributes>* devices) const;
// Returns raw pointers to the underlying devices.
std::vector<Device*> ListDevices() const;
// Returns a string listing all devices.
@ -62,9 +65,7 @@ class DeviceMgr {
int NumDeviceType(const string& type) const;
private:
// TODO(b/37437134): Use std::unique_ptr's to track ownership.
typedef gtl::InlinedVector<Device*, 8> DeviceVec;
DeviceVec devices_;
const std::vector<std::unique_ptr<Device>> devices_;
StringPiece CopyToBackingStore(StringPiece s);

View File

@ -36,12 +36,12 @@ class DeviceResolverLocalTest : public ::testing::Test {
string task_name = "/job:localhost/replica:0/task:0";
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", NUM_DEVS});
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
device_mgr_.reset(new DeviceMgr(devices_));
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
}
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_;
};

View File

@ -57,7 +57,7 @@ class DeviceSetTest : public ::testing::Test {
class DummyFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override {
std::vector<std::unique_ptr<Device>>* devices) override {
return Status::OK();
}
};

View File

@ -155,12 +155,12 @@ class DirectSessionFactory : public SessionFactory {
if (options.config.graph_options().build_cost_model() > 0) {
EnableCPUAllocatorFullStats(true);
}
std::vector<Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices));
DirectSession* session =
new DirectSession(options, new DeviceMgr(devices), this);
new DirectSession(options, new DeviceMgr(std::move(devices)), this);
{
mutex_lock l(sessions_lock_);
sessions_.push_back(session);

View File

@ -181,6 +181,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/memory",
],
)

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
@ -37,12 +38,13 @@ namespace {
class TestEnv {
public:
TestEnv() : flib_def_(OpRegistry::Global(), {}) {
Device* device =
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
device_mgr_.reset(new DeviceMgr({device}));
flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(),
device, TF_GRAPH_DEF_VERSION,
&flib_def_, nullptr, {}, nullptr);
std::vector<std::unique_ptr<Device>> devices;
devices.push_back(
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
flib_runtime_ = NewFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), device_mgr_->ListDevices()[0],
TF_GRAPH_DEF_VERSION, &flib_def_, nullptr, {}, nullptr);
}
FunctionLibraryRuntime* function_library_runtime() const {

View File

@ -53,17 +53,17 @@ class ExecutorTest : public ::testing::Test {
// when the test completes.
CHECK(rendez_->Unref());
delete exec_;
delete device_;
}
// Resets executor_ with a new executor based on a graph 'gdef'.
void Create(std::unique_ptr<const Graph> graph) {
const int version = graph->versions().producer();
LocalExecutorParams params;
params.device = device_;
params.device = device_.get();
params.create_kernel = [this, version](const NodeDef& ndef,
OpKernel** kernel) {
return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
return CreateNonCachedKernel(device_.get(), nullptr, ndef, version,
kernel);
};
params.delete_kernel = [](OpKernel* kernel) {
DeleteNonCachedKernel(kernel);
@ -83,7 +83,7 @@ class ExecutorTest : public ::testing::Test {
}
thread::ThreadPool* thread_pool_ = nullptr;
Device* device_ = nullptr;
std::unique_ptr<Device> device_;
Executor* exec_ = nullptr;
StepStatsCollector step_stats_collector_;
StepStats step_stats_;

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <atomic>
#include <utility>
#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "tensorflow/cc/ops/array_ops_internal.h"
@ -147,14 +148,15 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 3});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices_));
options, "/job:localhost/replica:0/task:0", &devices));
FunctionDefLibrary proto;
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
OptimizerOptions opts;
device_mgr_.reset(new DeviceMgr(devices_));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, default_thread_pool, nullptr /* cluster_flr */));
@ -358,7 +360,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
FunctionLibraryRuntime* flr0_;
FunctionLibraryRuntime* flr1_;
FunctionLibraryRuntime* flr2_;
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;

View File

@ -54,14 +54,15 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 3});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices_));
options, "/job:localhost/replica:0/task:0", &devices));
FunctionDefLibrary proto;
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
OptimizerOptions opts;
device_mgr_.reset(new DeviceMgr(devices_));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, default_thread_pool, nullptr /* cluster_flr */));
@ -194,7 +195,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
FunctionLibraryRuntime* flr0_;
FunctionLibraryRuntime* flr1_;
FunctionLibraryRuntime* flr2_;
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;

View File

@ -907,9 +907,9 @@ Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr,
const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000;
const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1;
Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
const string& name_prefix,
std::vector<Device*>* devices) {
Status BaseGPUDeviceFactory::CreateDevices(
const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
TF_RETURN_IF_ERROR(ValidateGPUMachineManager());
se::Platform* gpu_manager = GPUMachineManager();
if (gpu_manager == nullptr) {
@ -1073,12 +1073,10 @@ static string GetShortDeviceDescription(PlatformGpuId platform_gpu_id,
// LINT.ThenChange(//tensorflow/python/platform/test.py)
}
Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
const string& name_prefix,
TfGpuId tf_gpu_id,
int64 memory_limit,
const DeviceLocality& dev_locality,
std::vector<Device*>* devices) {
Status BaseGPUDeviceFactory::CreateGPUDevice(
const SessionOptions& options, const string& name_prefix, TfGpuId tf_gpu_id,
int64 memory_limit, const DeviceLocality& dev_locality,
std::vector<std::unique_ptr<Device>>* devices) {
CHECK_GE(tf_gpu_id.value(), 0);
const string device_name =
strings::StrCat(name_prefix, "/device:GPU:", tf_gpu_id.value());
@ -1108,7 +1106,7 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
// different (which should be an error).
//
// TODO(laigd): report error if memory_limit doesn't match stats.bytes_limit.
BaseGPUDevice* gpu_device = CreateGPUDevice(
std::unique_ptr<BaseGPUDevice> gpu_device = CreateGPUDevice(
options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality,
tf_gpu_id, GetShortDeviceDescription(platform_gpu_id, desc),
gpu_allocator, ProcessState::singleton()->GetCPUAllocator(numa_node));
@ -1116,7 +1114,7 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
<< (stats.bytes_limit >> 20) << " MB memory) -> physical GPU ("
<< GetShortDeviceDescription(platform_gpu_id, desc) << ")";
TF_RETURN_IF_ERROR(gpu_device->Init(options));
devices->push_back(gpu_device);
devices->push_back(std::move(gpu_device));
return Status::OK();
}

View File

@ -166,7 +166,7 @@ class BaseGPUDevice : public LocalDevice {
class BaseGPUDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override;
std::vector<std::unique_ptr<Device>>* devices) override;
struct InterconnectMap {
// Name of interconnect technology, if known.
@ -207,15 +207,13 @@ class BaseGPUDeviceFactory : public DeviceFactory {
Status CreateGPUDevice(const SessionOptions& options,
const string& name_prefix, TfGpuId tf_gpu_id,
int64 memory_limit, const DeviceLocality& dev_locality,
std::vector<Device*>* devices);
std::vector<std::unique_ptr<Device>>* devices);
virtual BaseGPUDevice* CreateGPUDevice(const SessionOptions& options,
const string& name, Bytes memory_limit,
const DeviceLocality& dev_locality,
TfGpuId tf_gpu_id,
const string& physical_device_desc,
Allocator* gpu_allocator,
Allocator* cpu_allocator) = 0;
virtual std::unique_ptr<BaseGPUDevice> CreateGPUDevice(
const SessionOptions& options, const string& name, Bytes memory_limit,
const DeviceLocality& dev_locality, TfGpuId tf_gpu_id,
const string& physical_device_desc, Allocator* gpu_allocator,
Allocator* cpu_allocator) = 0;
// Returns into 'ids' the list of valid platform GPU ids, in the order that
// they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc,

View File

@ -59,15 +59,14 @@ class GPUDevice : public BaseGPUDevice {
class GPUDeviceFactory : public BaseGPUDeviceFactory {
private:
BaseGPUDevice* CreateGPUDevice(const SessionOptions& options,
const string& name, Bytes memory_limit,
const DeviceLocality& locality,
TfGpuId tf_gpu_id,
const string& physical_device_desc,
Allocator* gpu_allocator,
Allocator* cpu_allocator) override {
return new GPUDevice(options, name, memory_limit, locality, tf_gpu_id,
physical_device_desc, gpu_allocator, cpu_allocator);
std::unique_ptr<BaseGPUDevice> CreateGPUDevice(
const SessionOptions& options, const string& name, Bytes memory_limit,
const DeviceLocality& locality, TfGpuId tf_gpu_id,
const string& physical_device_desc, Allocator* gpu_allocator,
Allocator* cpu_allocator) override {
return absl::make_unique<GPUDevice>(options, name, memory_limit, locality,
tf_gpu_id, physical_device_desc,
gpu_allocator, cpu_allocator);
}
};
@ -108,7 +107,7 @@ class GPUCompatibleCPUDevice : public ThreadPoolDevice {
class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override {
std::vector<std::unique_ptr<Device>>* devices) override {
int n = 1;
auto iter = options.config.device_count().find("CPU");
if (iter != options.config.device_count().end()) {
@ -116,7 +115,7 @@ class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
}
for (int i = 0; i < n; i++) {
string name = strings::StrCat(name_prefix, "/device:CPU:", i);
devices->push_back(new GPUCompatibleCPUDevice(
devices->push_back(absl::make_unique<GPUCompatibleCPUDevice>(
options, name, Bytes(256 << 20), DeviceLocality(), cpu_allocator()));
}

View File

@ -33,7 +33,7 @@ namespace {
TEST(GPUDeviceOnNonGPUMachineTest, CreateGPUDevicesOnNonGPUMachine) {
SessionOptions opts;
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, "/job:localhost/replica:0/task:0", &devices));
EXPECT_TRUE(devices.empty());

View File

@ -88,7 +88,7 @@ class GPUDeviceTest : public ::testing::Test {
TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) {
SessionOptions opts = MakeSessionOptions("0,abc");
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@ -97,7 +97,7 @@ TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) {
TEST_F(GPUDeviceTest, InvalidGpuId) {
SessionOptions opts = MakeSessionOptions("100");
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@ -107,7 +107,7 @@ TEST_F(GPUDeviceTest, InvalidGpuId) {
TEST_F(GPUDeviceTest, DuplicateEntryInVisibleDeviceList) {
SessionOptions opts = MakeSessionOptions("0,0");
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@ -117,7 +117,7 @@ TEST_F(GPUDeviceTest, DuplicateEntryInVisibleDeviceList) {
TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithMemoryFractionSettings) {
SessionOptions opts = MakeSessionOptions("0", 0.1, 1, {{}});
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@ -129,7 +129,7 @@ TEST_F(GPUDeviceTest, GpuDeviceCountTooSmall) {
// device_count is 0, but with one entry in visible_device_list and one
// (empty) VirtualDevices messages.
SessionOptions opts = MakeSessionOptions("0", 0, 0, {{}});
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::UNKNOWN);
@ -141,7 +141,7 @@ TEST_F(GPUDeviceTest, NotEnoughGpuInVisibleDeviceList) {
// Single entry in visible_device_list with two (empty) VirtualDevices
// messages.
SessionOptions opts = MakeSessionOptions("0", 0, 8, {{}, {}});
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::UNKNOWN);
@ -155,7 +155,7 @@ TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithVisibleDeviceList) {
// Three entries in visible_device_list with two (empty) VirtualDevices
// messages.
SessionOptions opts = MakeSessionOptions("0,1", 0, 8, {{}});
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@ -169,39 +169,36 @@ TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithVisibleDeviceList) {
TEST_F(GPUDeviceTest, EmptyVirtualDeviceConfig) {
// It'll create single virtual device when the virtual device config is empty.
SessionOptions opts = MakeSessionOptions("0");
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(1, devices.size());
EXPECT_GE(devices[0]->attributes().memory_limit(), 0);
gtl::STLDeleteElements(&devices);
}
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithNoMemoryLimit) {
// It'll create single virtual device for the gpu in question when
// memory_limit_mb is unset.
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{}});
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(1, devices.size());
EXPECT_GE(devices[0]->attributes().memory_limit(), 0);
gtl::STLDeleteElements(&devices);
}
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimit) {
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123}});
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(1, devices.size());
EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit());
gtl::STLDeleteElements(&devices);
}
TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}});
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(2, devices.size());
@ -219,7 +216,6 @@ TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
devices[1]->attributes().locality().links().link(0).type());
EXPECT_EQ(BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength,
devices[1]->attributes().locality().links().link(0).strength());
gtl::STLDeleteElements(&devices);
}
// Enabling unified memory on pre-Pascal GPUs results in an initialization
@ -236,7 +232,7 @@ TEST_F(GPUDeviceTest, UnifiedMemoryUnavailableOnPrePascalGpus) {
opts.config.mutable_gpu_options()
->mutable_experimental()
->set_use_unified_memory(true);
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INTERNAL);
@ -259,7 +255,7 @@ TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
}
SessionOptions opts = MakeSessionOptions("0", kGpuMemoryFraction);
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
TF_ASSERT_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices));
ASSERT_EQ(1, devices.size());
@ -278,8 +274,6 @@ TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
(memory_limit >> 20) << 20);
EXPECT_NE(ptr, nullptr);
allocator->DeallocateRaw(ptr);
gtl::STLDeleteElements(&devices);
}
} // namespace tensorflow

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
#include <algorithm>
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/common_runtime/collective_rma_local.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@ -217,7 +218,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
<< " num_devices_per_worker=" << num_devices_per_worker;
int total_num_devices = num_workers * num_devices_per_worker;
device_type_ = device_type;
std::vector<Device*> local_devices;
std::vector<std::unique_ptr<Device>> local_devices;
SessionOptions sess_opts;
sess_opts.env = Env::Default();
Bytes mem_limit(4 << 20);
@ -227,7 +228,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
if (device_type == DEVICE_CPU) {
string dev_name = strings::StrCat("/job:worker/replica:0/task:", wi,
"/device:CPU:", di);
local_devices.push_back(new ThreadPoolDevice(
local_devices.push_back(absl::make_unique<ThreadPoolDevice>(
sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
} else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
int dev_idx = (wi * num_devices_per_worker) + di;
@ -235,7 +236,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
"than one ring node.";
} else {
local_devices.push_back(gpu_devices_[dev_idx]);
local_devices.push_back(std::move(gpu_devices_[dev_idx]));
}
} else {
LOG(FATAL) << "Unsupported device_type " << device_type;
@ -243,7 +244,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
}
}
if (!dev_mgr_ || device_type == DEVICE_CPU) {
dev_mgr_.reset(new DeviceMgr(local_devices));
dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
}
if (!gpu_ring_order_) gpu_ring_order_.reset(new string());
dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get()));
@ -714,7 +715,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
std::vector<DeviceInstance*> instances_;
CollectiveParams col_params_;
std::vector<tensorflow::Device*> gpu_devices_;
std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
std::unique_ptr<string> gpu_ring_order_;
mutex mu_;

View File

@ -75,12 +75,12 @@ Benchmark::Benchmark(const string& device, Graph* g,
const int graph_def_version = g->versions().producer();
LocalExecutorParams params;
params.device = device_;
params.device = device_.get();
params.function_library = nullptr;
params.create_kernel = [this, graph_def_version](const NodeDef& ndef,
OpKernel** kernel) {
return CreateNonCachedKernel(device_, nullptr, ndef, graph_def_version,
kernel);
return CreateNonCachedKernel(device_.get(), nullptr, ndef,
graph_def_version, kernel);
};
params.delete_kernel = [](OpKernel* kernel) {
DeleteNonCachedKernel(kernel);
@ -107,7 +107,7 @@ Benchmark::~Benchmark() {
// run kernel destructors that may attempt to access state borrowed from
// `device_`, such as the resource manager.
exec_.reset();
delete device_;
device_.reset();
delete pool_;
}
}

View File

@ -55,7 +55,7 @@ class Benchmark {
private:
thread::ThreadPool* pool_ = nullptr;
Device* device_ = nullptr;
std::unique_ptr<Device> device_ = nullptr;
Rendezvous* rendez_ = nullptr;
std::unique_ptr<Executor> exec_;

View File

@ -92,7 +92,7 @@ class FakeDevice : public Device {
class DummyFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override {
std::vector<std::unique_ptr<Device>>* devices) override {
return Status::OK();
}
};

View File

@ -62,9 +62,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 2});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
&devices_));
device_mgr_.reset(new DeviceMgr(devices_));
&devices));
device0_ = devices[0].get();
device1_ = devices[1].get();
device_mgr_.reset(new DeviceMgr(std::move(devices)));
FunctionDefLibrary proto;
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
@ -138,8 +141,9 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
return Status::OK();
}
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
Device* device0_ = nullptr; // Not owned. (Owned by device_mgr_.)
Device* device1_ = nullptr; // Not owned. (Owned by device_mgr_.)
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<TestClusterFLR> cluster_flr_;
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
@ -165,16 +169,16 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
FunctionLibraryRuntime* flr =
proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0");
EXPECT_NE(flr, nullptr);
EXPECT_EQ(flr->device(), devices_[0]);
EXPECT_EQ(flr->device(), device0_);
flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0");
EXPECT_NE(flr, nullptr);
EXPECT_EQ(flr->device(), devices_[0]);
EXPECT_EQ(flr->device(), device0_);
flr = proc_flr_->GetFLR("/device:CPU:0");
EXPECT_NE(flr, nullptr);
EXPECT_EQ(flr->device(), devices_[0]);
EXPECT_EQ(flr->device(), device0_);
flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:1");
EXPECT_NE(flr, nullptr);
EXPECT_EQ(flr->device(), devices_[1]);
EXPECT_EQ(flr->device(), device1_);
flr = proc_flr_->GetFLR("abc");
EXPECT_EQ(flr, nullptr);
rendezvous_->Unref();

View File

@ -14,15 +14,14 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/renamed_device.h"
#include "absl/memory/memory.h"
namespace tensorflow {
// TODO(saeta): Convert to returning a std::unique_ptr?
/* static */
Device* RenamedDevice::NewRenamedDevice(const string& new_base,
Device* underlying,
bool owns_underlying,
bool isolate_session_state) {
std::unique_ptr<Device> RenamedDevice::NewRenamedDevice(
const string& new_base, Device* underlying, bool owns_underlying,
bool isolate_session_state) {
DeviceNameUtils::ParsedName parsed_name;
CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name));
DeviceNameUtils::ParsedName underlying_parsed_name =
@ -36,8 +35,9 @@ Device* RenamedDevice::NewRenamedDevice(const string& new_base,
parsed_name.id);
DeviceAttributes attributes(underlying->attributes());
attributes.set_name(name);
return new RenamedDevice(underlying, attributes, owns_underlying,
isolate_session_state);
// Call absl::WrapUnique to access private constructor.
return absl::WrapUnique(new RenamedDevice(
underlying, attributes, owns_underlying, isolate_session_state));
}
RenamedDevice::RenamedDevice(Device* underlying,

View File

@ -28,9 +28,10 @@ namespace tensorflow {
// session.
class RenamedDevice : public Device {
public:
static Device* NewRenamedDevice(const string& new_base, Device* underlying,
bool owns_underlying,
bool isolate_session_state);
static std::unique_ptr<Device> NewRenamedDevice(const string& new_base,
Device* underlying,
bool owns_underlying,
bool isolate_session_state);
~RenamedDevice() override;

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/ring_reducer.h"
#include <algorithm>
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/common_runtime/collective_rma_local.h"
#include "tensorflow/core/common_runtime/device.h"
@ -157,7 +158,7 @@ class RingReducerTest : public ::testing::Test {
InitGPUDevices();
#endif
device_type_ = device_type;
std::vector<Device*> local_devices;
std::vector<std::unique_ptr<Device>> local_devices;
SessionOptions sess_opts;
sess_opts.env = Env::Default();
Bytes mem_limit(4 << 20);
@ -167,7 +168,7 @@ class RingReducerTest : public ::testing::Test {
if (device_type == DEVICE_CPU) {
string dev_name =
strings::StrCat("/job:worker/replica:0/task:", wi, "/cpu:", di);
local_devices.push_back(new ThreadPoolDevice(
local_devices.push_back(absl::make_unique<ThreadPoolDevice>(
sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
} else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
int dev_idx = (wi * num_devices) + di;
@ -175,7 +176,7 @@ class RingReducerTest : public ::testing::Test {
LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
"than one ring node.";
} else {
local_devices.push_back(gpu_devices_[dev_idx]);
local_devices.push_back(std::move(gpu_devices_[dev_idx]));
}
} else {
LOG(FATAL) << "Unsupported device_type " << device_type;
@ -185,7 +186,7 @@ class RingReducerTest : public ::testing::Test {
if (!dev_mgr_ || device_type == DEVICE_CPU) {
LOG(ERROR) << "resetting dev_mgr for " << local_devices.size()
<< " devices: ";
dev_mgr_.reset(new DeviceMgr(local_devices));
dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
}
if (!gpu_ring_order_) gpu_ring_order_.reset(new string());
dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get()));
@ -544,7 +545,7 @@ class RingReducerTest : public ::testing::Test {
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
std::vector<DeviceInstance*> instances_;
CollectiveParams col_params_;
std::vector<tensorflow::Device*> gpu_devices_;
std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
std::unique_ptr<string> gpu_ring_order_;
mutex mu_;

View File

@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Register a factory that provides CPU devices.
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include <vector>
// Register a factory that provides CPU devices.
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/process_state.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/numa.h"
#include "tensorflow/core/public/session_options.h"
@ -29,7 +30,7 @@ namespace tensorflow {
class ThreadPoolDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override {
std::vector<std::unique_ptr<Device>>* devices) override {
int num_numa_nodes = port::NUMANumNodes();
int n = 1;
auto iter = options.config.device_count().find("CPU");
@ -38,7 +39,7 @@ class ThreadPoolDeviceFactory : public DeviceFactory {
}
for (int i = 0; i < n; i++) {
string name = strings::StrCat(name_prefix, "/device:CPU:", i);
ThreadPoolDevice* tpd = nullptr;
std::unique_ptr<ThreadPoolDevice> tpd;
if (options.config.experimental().use_numa_affinity()) {
int numa_node = i % num_numa_nodes;
if (numa_node != i) {
@ -49,15 +50,15 @@ class ThreadPoolDeviceFactory : public DeviceFactory {
}
DeviceLocality dev_locality;
dev_locality.set_numa_node(numa_node);
tpd = new ThreadPoolDevice(
tpd = absl::make_unique<ThreadPoolDevice>(
options, name, Bytes(256 << 20), dev_locality,
ProcessState::singleton()->GetCPUAllocator(numa_node));
} else {
tpd = new ThreadPoolDevice(
tpd = absl::make_unique<ThreadPoolDevice>(
options, name, Bytes(256 << 20), DeviceLocality(),
ProcessState::singleton()->GetCPUAllocator(port::kNUMANoAffinity));
}
devices->push_back(tpd);
devices->push_back(std::move(tpd));
}
return Status::OK();

View File

@ -624,6 +624,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"@com_google_absl//absl/memory",
],
)

View File

@ -29,7 +29,8 @@ limitations under the License.
namespace tensorflow {
namespace {
static Device* NewDevice(const string& type, const string& name) {
static std::unique_ptr<Device> NewDevice(const string& type,
const string& name) {
class FakeDevice : public Device {
public:
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
@ -40,7 +41,7 @@ static Device* NewDevice(const string& type, const string& name) {
attr.set_name(name);
attr.set_device_type(type);
attr.mutable_locality()->set_numa_node(3); // a non-default value
return new FakeDevice(attr);
return absl::make_unique<FakeDevice>(attr);
}
class FakeWorker : public TestWorkerInterface {
@ -156,16 +157,16 @@ class DeviceResDistTest : public ::testing::Test {
void DefineWorker(const ConfigProto& config, const string& worker_name,
const string& device_type, int num_devices) {
std::vector<Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < num_devices; ++i) {
devices.push_back(NewDevice(
device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
}
DeviceMgr* dev_mgr = new DeviceMgr(devices);
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name];
for (auto d : devices) {
for (auto* d : dev_mgr->ListDevices()) {
dv->push_back(d->name());
}
DeviceResolverDistributed* dev_res =

View File

@ -41,7 +41,8 @@ limitations under the License.
namespace tensorflow {
namespace {
static Device* NewDevice(const string& type, const string& name) {
static std::unique_ptr<Device> NewDevice(const string& type,
const string& name) {
class FakeDevice : public Device {
public:
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
@ -52,7 +53,7 @@ static Device* NewDevice(const string& type, const string& name) {
attr.set_name(name);
attr.set_device_type(type);
attr.mutable_locality()->set_numa_node(3); // a non-default value
return new FakeDevice(attr);
return absl::make_unique<FakeDevice>(attr);
}
static int64 kStepId = 123;
@ -211,16 +212,16 @@ class CollRMADistTest : public ::testing::Test {
void DefineWorker(const ConfigProto& config, const string& worker_name,
const string& device_type, int num_devices) {
std::vector<Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < num_devices; ++i) {
devices.push_back(NewDevice(
device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
}
DeviceMgr* dev_mgr = new DeviceMgr(devices);
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name];
for (auto d : devices) {
for (auto d : dev_mgr->ListDevices()) {
dv->push_back(d->name());
}
DeviceResolverDistributed* dev_res =

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/distributed_runtime/test_utils.h"
#include "tensorflow/core/lib/core/notification.h"
@ -41,8 +42,8 @@ class TestableDeviceResolverDistributed : public DeviceResolverDistributed {
// Create a fake 'Device' whose only interesting attribute is a non-default
// DeviceLocality.
static Device* NewDevice(const string& type, const string& name,
int numa_node) {
static std::unique_ptr<Device> NewDevice(const string& type, const string& name,
int numa_node) {
class FakeDevice : public Device {
public:
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
@ -53,7 +54,7 @@ static Device* NewDevice(const string& type, const string& name,
attr.set_name(name);
attr.set_device_type(type);
attr.mutable_locality()->set_numa_node(numa_node);
return new FakeDevice(attr);
return absl::make_unique<FakeDevice>(attr);
}
// Create a fake WorkerInterface that responds to requests without RPCs,
@ -151,19 +152,19 @@ class DeviceResDistTest : public ::testing::Test {
void DefineWorker(const string& worker_name, const string& device_type,
int num_devices) {
std::vector<Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < num_devices; ++i) {
devices.push_back(NewDevice(
device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i), i));
}
DeviceMgr* dev_mgr = new DeviceMgr(devices);
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
TestableDeviceResolverDistributed* dev_res =
new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name);
resolvers_[worker_name] = dev_res;
device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name];
for (auto d : devices) {
for (auto* d : dev_mgr->ListDevices()) {
dv->push_back(d->name());
}
FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);

View File

@ -87,7 +87,7 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
return tensorflow::errors::Internal(
"invalid eager env_ or env_->rendezvous_mgr.");
}
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
// TODO(nareshmodi): Correctly set the SessionOptions.
@ -97,12 +97,12 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
request->server_def().task_index()),
&devices));
response->mutable_device_attributes()->Reserve(devices.size());
for (auto& d : devices) {
for (const auto& d : devices) {
*response->add_device_attributes() = d->attributes();
}
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(devices));
new tensorflow::DeviceMgr(std::move(devices)));
auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id());
auto session_name = strings::StrCat("eager_", request->rendezvous_id());

View File

@ -68,12 +68,9 @@ class EagerServiceImplTest : public ::testing::Test {
worker_env_.rendezvous_mgr = &rendezvous_mgr_;
worker_env_.session_mgr = session_mgr_.get();
Device* device = DeviceFactory::NewDevice(
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0");
worker_env_.local_devices = {device};
device_mgr_.reset(new DeviceMgr(worker_env_.local_devices));
device_mgr_ = absl::make_unique<DeviceMgr>(DeviceFactory::NewDevice(
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
worker_env_.local_devices = device_mgr_->ListDevices();
worker_env_.device_mgr = device_mgr_.get();
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <cstring>
#include <limits>
#include <memory>
#include <vector>
#include "grpc/support/alloc.h"
#include "grpcpp/grpcpp.h"
@ -156,10 +157,12 @@ Status GrpcServer::Init(
string name_prefix =
strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
"/task:", server_def_.task_index());
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
&master_env_.local_devices));
worker_env_.local_devices = master_env_.local_devices;
worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices);
std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(
DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
worker_env_.device_mgr = new DeviceMgr(std::move(devices));
master_env_.local_devices = worker_env_.device_mgr->ListDevices();
worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
? new RpcRendezvousMgr(&worker_env_)
: rendezvous_mgr_func(&worker_env_);

View File

@ -42,8 +42,9 @@ class RpcCollectiveExecutorMgrTest : public ::testing::Test {
WorkerCacheInterface* worker_cache = nullptr;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", NUM_DEVS});
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
device_mgr_.reset(new DeviceMgr(devices_));
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed(
device_mgr_.get(), worker_cache, task_name));
std::unique_ptr<CollectiveParamResolverDistributed> cpr(
@ -57,7 +58,6 @@ class RpcCollectiveExecutorMgrTest : public ::testing::Test {
}
std::unique_ptr<RpcCollectiveExecutorMgr> cme_;
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
};

View File

@ -78,13 +78,13 @@ Status SessionMgr::CreateSession(const string& session,
if (isolate_session_state) {
// Create a private copy of the DeviceMgr for the WorkerSession.
std::vector<Device*> renamed_devices;
std::vector<std::unique_ptr<Device>> renamed_devices;
for (Device* d : worker_env_->local_devices) {
renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
worker_name, d, false, isolate_session_state));
}
auto device_mgr = MakeUnique<DeviceMgr>(renamed_devices);
auto device_mgr = MakeUnique<DeviceMgr>(std::move(renamed_devices));
auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
worker_session.reset(
new WorkerSession(session, worker_name,

View File

@ -46,11 +46,9 @@ class SessionMgrTest : public ::testing::Test {
SessionMgrTest()
: mgr_(&env_, "/job:mnist/replica:0/task:0",
std::unique_ptr<WorkerCacheInterface>(), factory_) {
Device* device =
FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0")
.release();
env_.local_devices = {device};
device_mgr_.reset(new DeviceMgr(env_.local_devices));
device_mgr_ = absl::make_unique<DeviceMgr>(
FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0"));
env_.local_devices = device_mgr_->ListDevices();
env_.device_mgr = device_mgr_.get();
}

View File

@ -102,10 +102,11 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
}
// Instantiate all variables for function library runtime creation.
std::vector<Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices));
std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(devices));
Device* cpu_device = devices[0].get();
std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(std::move(devices)));
FunctionLibraryDefinition function_library(OpRegistry::Global(),
graph_def.library());
Env* env = Env::Default();
@ -124,7 +125,7 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env,
graph_def.versions().producer(),
&function_library, *optimizer_opts));
FunctionLibraryRuntime* flr = pflr->GetFLR(devices[0]->name());
FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());
// Create the GraphOptimizer to optimize the graph def.
GraphConstructorOptions graph_ctor_opts;
@ -137,7 +138,7 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
// Optimize the graph.
::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
optimizer.Optimize(flr, env, devices[0], &graphptr, /*shape_map=*/nullptr);
optimizer.Optimize(flr, env, cpu_device, &graphptr, /*shape_map=*/nullptr);
graphptr->ToGraphDef(output_graph_def);
// The default values of attributes might have been stripped by the optimizer.

View File

@ -142,7 +142,6 @@ cc_library(
":graph_optimizer",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
@ -150,6 +149,7 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/utils:functions",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)

View File

@ -16,7 +16,9 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
#include <unordered_map>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/substitute.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@ -343,14 +345,15 @@ class FunctionOptimizerContext {
DeviceAttributes attr;
attr.set_name("/device:CPU:0");
attr.set_device_type("CPU");
Device* device = new FakeCPUDevice(env, attr);
device_mgr_.reset(new DeviceMgr({device}));
std::vector<std::unique_ptr<Device>> devices;
devices.push_back(absl::make_unique<FakeCPUDevice>(env, attr));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
OptimizerOptions optimizer_opts;
optimizer_opts.set_do_function_inlining(true);
process_flr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), env, graph_version_, &function_library_,
optimizer_opts));
flr_ = process_flr_->GetFLR(device->name());
flr_ = process_flr_->GetFLR(device_mgr_->ListDevices()[0]->name());
}
}

View File

@ -600,6 +600,7 @@ tf_kernel_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:session_options",
"//tensorflow/core/kernels:ops_util",
"@com_google_absl//absl/memory",
],
)

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/iterator_ops.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
@ -545,10 +546,9 @@ FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
// in its resource manager. The existing device will outlive the
// IteratorResource, because we are storing the IteratorResource
// in that device's resource manager.
Device* wrapped_device = RenamedDevice::NewRenamedDevice(
*device_mgr = absl::make_unique<DeviceMgr>(RenamedDevice::NewRenamedDevice(
ctx->device()->name(), down_cast<Device*>(ctx->device()),
false /* owns_underlying */, false /* isolate_session_state */);
device_mgr->reset(new DeviceMgr({wrapped_device}));
false /* owns_underlying */, false /* isolate_session_state */));
flib_def->reset(new FunctionLibraryDefinition(
*ctx->function_library()->GetFunctionLibraryDefinition()));
pflr->reset(new ProcessFunctionLibraryRuntime(

View File

@ -51,17 +51,17 @@ class ExecutorTest : public ::testing::Test {
// when the test completes.
CHECK(rendez_->Unref());
delete exec_;
delete device_;
}
// Resets executor_ with a new executor based on a graph 'gdef'.
void Create(std::unique_ptr<const Graph> graph) {
const int version = graph->versions().producer();
LocalExecutorParams params;
params.device = device_;
params.device = device_.get();
params.create_kernel = [this, version](const NodeDef& ndef,
OpKernel** kernel) {
return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
return CreateNonCachedKernel(device_.get(), nullptr, ndef, version,
kernel);
};
params.delete_kernel = [](OpKernel* kernel) {
DeleteNonCachedKernel(kernel);
@ -86,7 +86,7 @@ class ExecutorTest : public ::testing::Test {
return exec_->Run(args);
}
Device* device_ = nullptr;
std::unique_ptr<Device> device_;
Executor* exec_ = nullptr;
Executor::Args::Runner runner_;
Rendezvous* rendez_ = nullptr;

View File

@ -116,6 +116,7 @@ cc_library(
hdrs = ["delegate_data.h"],
deps = [
":buffer_map",
"@com_google_absl//absl/memory",
"//tensorflow/core/common_runtime/eager:context",
] + select({
"//tensorflow:android": [

View File

@ -14,20 +14,21 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/flex/delegate_data.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/lib/core/status.h"
namespace tflite {
namespace flex {
tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) {
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0",
&devices));
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(devices));
std::unique_ptr<tensorflow::DeviceMgr> device_mgr =
absl::make_unique<tensorflow::DeviceMgr>(std::move(devices));
// Note that Rendezvous is ref-counted so it will be automatically deleted.
tensorflow::Rendezvous* rendezvous =
new tensorflow::IntraProcessRendezvous(device_mgr.get());

View File

@ -2012,13 +2012,13 @@ bool InlineAllFunctions(GraphDef* graphdef) {
tensorflow::SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 1});
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices));
tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
graphdef_copy.library());
tensorflow::DeviceMgr device_mgr(devices);
tensorflow::DeviceMgr device_mgr(std::move(devices));
tensorflow::OptimizerOptions o_opts;
tensorflow::ProcessFunctionLibraryRuntime pflr(
&device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,

View File

@ -48,17 +48,14 @@ static std::vector<string> ListDevicesWithSessionConfig(
std::vector<string> output;
SessionOptions options;
options.config = config;
std::vector<Device*> devices;
std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::AddDevices(
options, "" /* name_prefix */, &devices);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
}
std::vector<std::unique_ptr<Device>> device_holder(devices.begin(),
devices.end());
for (const Device* device : devices) {
for (const std::unique_ptr<Device>& device : devices) {
const DeviceAttributes& attr = device->attributes();
string attr_serialized;
if (!attr.SerializeToString(&attr_serialized)) {

View File

@ -74,13 +74,13 @@ limitations under the License.
void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* device_map) {
tensorflow::SessionOptions options;
std::vector<tensorflow::Device*> devices;
std::vector<std::unique_ptr<tensorflow::Device>> devices;
tensorflow::Status status = tensorflow::DeviceFactory::AddDevices(options, "", &devices);
if (!status.ok()) {
return;
}
for (const tensorflow::Device* device : devices) {
for (const std::unique_ptr<tensorflow::Device>& device : devices) {
tensorflow::DeviceProperties& prop = (*device_map)[device->name()];
prop = tensorflow::grappler::GetDeviceInfo(device->parsed_name());
@ -88,7 +88,6 @@ void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* dev
// available device memory.
const tensorflow::DeviceAttributes& attr = device->attributes();
prop.set_memory_size(attr.memory_limit());
delete device;
}
}