[Core]: Use unique_ptr in DeviceMgr

In order to take advantage of the type system to help enforce ownership, this
change refactors DeviceMgr to use std::unique_ptr<Device> instead of Device*'s.
It also updates all callers to use the new types.

PiperOrigin-RevId: 222645861
This commit is contained in:
Brennan Saeta 2018-11-23 14:36:05 -08:00 committed by TensorFlower Gardener
parent a1532717be
commit 809ed3c835
61 changed files with 272 additions and 270 deletions

View File

@ -50,6 +50,7 @@ tf_cuda_library(
], ],
"//conditions:default": [], "//conditions:default": [],
}) + [ }) + [
"@com_google_absl//absl/memory",
"//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
@ -80,7 +81,7 @@ tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers, const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache, tensorflow::WorkerCacheInterface* worker_cache,
std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) { std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
std::vector<tensorflow::Device*> remote_devices; std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
tensorflow::Status status; tensorflow::Status status;
// TODO(nareshmodi) do this in parallel instead of serially. // TODO(nareshmodi) do this in parallel instead of serially.
for (const string& remote_worker : remote_workers) { for (const string& remote_worker : remote_workers) {
@ -93,7 +94,7 @@ tensorflow::Status GetAllRemoteDevices(
status = s; status = s;
if (s.ok()) { if (s.ok()) {
for (tensorflow::Device* d : *devices) { for (tensorflow::Device* d : *devices) {
remote_devices.push_back(d); remote_devices.emplace_back(d);
} }
} }
n.Notify(); n.Notify();
@ -101,7 +102,7 @@ tensorflow::Status GetAllRemoteDevices(
n.WaitForNotification(); n.WaitForNotification();
} }
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr( std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
new tensorflow::DeviceMgr(remote_devices)); new tensorflow::DeviceMgr(std::move(remote_devices)));
TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(status);
@ -262,13 +263,13 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<tensorflow::Device>> devices;
status->status = tensorflow::DeviceFactory::AddDevices( status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0", opts->session_options.options, "/job:localhost/replica:0/task:0",
&devices); &devices);
if (!status->status.ok()) return nullptr; if (!status->status.ok()) return nullptr;
std::unique_ptr<tensorflow::DeviceMgr> device_mgr( std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(devices)); new tensorflow::DeviceMgr(std::move(devices)));
tensorflow::Rendezvous* r = tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get()); new tensorflow::IntraProcessRendezvous(device_mgr.get());

View File

@ -42,14 +42,8 @@ class BuildXlaOpsTest : public ::testing::Test {
.ok()); .ok());
} }
void TearDown() override {
for (Device* device : devices_) {
delete device;
}
}
private: private:
std::vector<Device*> devices_; std::vector<std::unique_ptr<Device>> devices_;
}; };
using ::tensorflow::testing::FindNodeByName; using ::tensorflow::testing::FindNodeByName;

View File

@ -59,8 +59,9 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
SessionOptions options; SessionOptions options;
auto* device_count = options.config.mutable_device_count(); auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 1}); device_count->insert({"CPU", 1});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices( TF_CHECK_OK(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices_)); options, "/job:localhost/replica:0/task:0", &devices));
FunctionDefLibrary proto; FunctionDefLibrary proto;
for (const auto& fdef : flib) { for (const auto& fdef : flib) {
@ -69,7 +70,7 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
lib_def_ = absl::make_unique<FunctionLibraryDefinition>( lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
OpRegistry::Global(), proto); OpRegistry::Global(), proto);
OptimizerOptions opts; OptimizerOptions opts;
device_mgr_ = absl::make_unique<DeviceMgr>(devices_); device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
@ -77,7 +78,6 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
} }
FunctionLibraryRuntime* flr_; FunctionLibraryRuntime* flr_;
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_; std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;

View File

@ -34,15 +34,9 @@ namespace tensorflow {
// //
// It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
// make this more direct, but probably not worth it solely for this test. // make this more direct, but probably not worth it solely for this test.
std::vector<Device*> devices; std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices)); TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices));
auto delete_devices = gtl::MakeCleanup([&] {
for (Device* d : devices) {
delete d;
}
});
GraphOptimizationPassOptions opt_options; GraphOptimizationPassOptions opt_options;
opt_options.graph = graph; opt_options.graph = graph;
opt_options.session_options = session_options; opt_options.session_options = session_options;

View File

@ -386,7 +386,7 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
TF_ASSERT_OK(s.ToGraph(graph.get())); TF_ASSERT_OK(s.ToGraph(graph.get()));
// This is needed to register the XLA_GPU device. // This is needed to register the XLA_GPU device.
std::vector<Device*> devices; std::vector<std::unique_ptr<Device>> devices;
TF_ASSERT_OK(DeviceFactory::AddDevices( TF_ASSERT_OK(DeviceFactory::AddDevices(
SessionOptions(), "/job:localhost/replica:0/task:0", &devices)); SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
@ -400,10 +400,6 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
TF_ASSERT_OK(PartiallyDecluster(&graph)); TF_ASSERT_OK(PartiallyDecluster(&graph));
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
for (Device* d : devices) {
delete d;
}
} }
TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) { TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {

View File

@ -31,12 +31,12 @@ namespace tensorflow {
class XlaCpuDeviceFactory : public DeviceFactory { class XlaCpuDeviceFactory : public DeviceFactory {
public: public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix, Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override; std::vector<std::unique_ptr<Device>>* devices) override;
}; };
Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options, Status XlaCpuDeviceFactory::CreateDevices(
const string& name_prefix, const SessionOptions& session_options, const string& name_prefix,
std::vector<Device*>* devices) { std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags(); XlaDeviceFlags* flags = GetXlaDeviceFlags();
bool compile_on_demand = flags->tf_xla_compile_on_demand; bool compile_on_demand = flags->tf_xla_compile_on_demand;
@ -63,8 +63,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
options.device_ordinal = 0; options.device_ordinal = 0;
options.compilation_device_name = DEVICE_CPU_XLA_JIT; options.compilation_device_name = DEVICE_CPU_XLA_JIT;
options.use_multiple_streams = false; options.use_multiple_streams = false;
auto device = absl::make_unique<XlaDevice>(session_options, options); devices->push_back(absl::make_unique<XlaDevice>(session_options, options));
devices->push_back(device.release());
return Status::OK(); return Status::OK();
} }

View File

@ -29,12 +29,12 @@ namespace tensorflow {
class XlaGpuDeviceFactory : public DeviceFactory { class XlaGpuDeviceFactory : public DeviceFactory {
public: public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix, Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override; std::vector<std::unique_ptr<Device>>* devices) override;
}; };
Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, Status XlaGpuDeviceFactory::CreateDevices(
const string& name_prefix, const SessionOptions& session_options, const string& name_prefix,
std::vector<Device*>* devices) { std::vector<std::unique_ptr<Device>>* devices) {
XlaOpRegistry::DeviceRegistration registration; XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_GPU_XLA_JIT; registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.autoclustering_policy = registration.autoclustering_policy =
@ -70,7 +70,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
return status; return status;
} }
devices->push_back(device.release()); devices->push_back(std::move(device));
} }
return Status::OK(); return Status::OK();
} }

View File

@ -33,12 +33,12 @@ constexpr std::array<DataType, 9> kExecAllTypes = {
class XlaInterpreterDeviceFactory : public DeviceFactory { class XlaInterpreterDeviceFactory : public DeviceFactory {
public: public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix, Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override; std::vector<std::unique_ptr<Device>>* devices) override;
}; };
Status XlaInterpreterDeviceFactory::CreateDevices( Status XlaInterpreterDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix, const SessionOptions& session_options, const string& name_prefix,
std::vector<Device*>* devices) { std::vector<std::unique_ptr<Device>>* devices) {
static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels( static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(
DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT); DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT);
(void)registrations; (void)registrations;
@ -61,8 +61,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
options.device_ordinal = 0; options.device_ordinal = 0;
options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
options.use_multiple_streams = false; options.use_multiple_streams = false;
auto device = absl::make_unique<XlaDevice>(session_options, options); devices->push_back(absl::make_unique<XlaDevice>(session_options, options));
devices->push_back(device.release());
return Status::OK(); return Status::OK();
} }

View File

@ -380,7 +380,7 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
initialization_status_(Status::OK()), initialization_status_(Status::OK()),
next_step_id_(1), next_step_id_(1),
device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
device_mgr_({device_}) { device_mgr_(absl::WrapUnique(device_)) {
CHECK(!options_.device_type.type_string().empty()); CHECK(!options_.device_type.type_string().empty());
if (options_.populate_resource_manager) { if (options_.populate_resource_manager) {
initialization_status_ = initialization_status_ =

View File

@ -2963,6 +2963,7 @@ tf_cuda_library(
":lib_internal", ":lib_internal",
":proto_text", ":proto_text",
":protos_all_cc", ":protos_all_cc",
"@com_google_absl//absl/memory",
"//third_party/eigen3", "//third_party/eigen3",
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",
] + mkl_deps(), ] + mkl_deps(),
@ -3816,6 +3817,7 @@ tf_cc_tests_gpu(
":test", ":test",
":test_main", ":test_main",
":testlib", ":testlib",
"@com_google_absl//absl/memory",
], ],
) )
@ -3844,6 +3846,7 @@ tf_cc_tests_gpu(
":test", ":test",
":test_main", ":test_main",
":testlib", ":testlib",
"@com_google_absl//absl/memory",
], ],
) )
@ -4411,6 +4414,7 @@ tf_cc_test(
"//tensorflow/core/kernels:random_ops", "//tensorflow/core/kernels:random_ops",
"//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:shape_ops",
"//third_party/eigen3", "//third_party/eigen3",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@ -38,8 +38,9 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
auto* device_count = options.config.mutable_device_count(); auto* device_count = options.config.mutable_device_count();
string task_name = "/job:localhost/replica:0/task:0"; string task_name = "/job:localhost/replica:0/task:0";
device_count->insert({"CPU", NUM_DEVS}); device_count->insert({"CPU", NUM_DEVS});
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_)); std::vector<std::unique_ptr<Device>> devices;
device_mgr_.reset(new DeviceMgr(devices_)); TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
std::unique_ptr<DeviceResolverInterface> drl( std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(device_mgr_.get())); new DeviceResolverLocal(device_mgr_.get()));
std::unique_ptr<ParamResolverInterface> prl( std::unique_ptr<ParamResolverInterface> prl(
@ -50,7 +51,6 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
} }
std::unique_ptr<CollectiveExecutorMgr> cme_; std::unique_ptr<CollectiveExecutorMgr> cme_;
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
}; };

View File

@ -37,8 +37,9 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
string task_name = "/job:localhost/replica:0/task:0"; string task_name = "/job:localhost/replica:0/task:0";
auto* device_count = options.config.mutable_device_count(); auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", NUM_DEVS}); device_count->insert({"CPU", NUM_DEVS});
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_)); std::vector<std::unique_ptr<Device>> devices;
device_mgr_.reset(new DeviceMgr(devices_)); TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
drl_.reset(new DeviceResolverLocal(device_mgr_.get())); drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(), prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
task_name)); task_name));
@ -73,7 +74,6 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
} }
} }
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_; std::unique_ptr<DeviceResolverLocal> drl_;
std::unique_ptr<CollectiveParamResolverLocal> prl_; std::unique_ptr<CollectiveParamResolverLocal> prl_;

View File

@ -42,8 +42,9 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
SessionOptions options; SessionOptions options;
auto* device_count = options.config.mutable_device_count(); auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", NUM_DEVS}); device_count->insert({"CPU", NUM_DEVS});
TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices_)); std::vector<std::unique_ptr<Device>> devices;
device_mgr_.reset(new DeviceMgr(devices_)); TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
drl_.reset(new DeviceResolverLocal(device_mgr_.get())); drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(), prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
kTaskName)); kTaskName));
@ -51,7 +52,6 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
kStepId)); kStepId));
} }
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_; std::unique_ptr<DeviceResolverLocal> drl_;
std::unique_ptr<CollectiveParamResolverLocal> prl_; std::unique_ptr<CollectiveParamResolverLocal> prl_;

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -89,9 +90,9 @@ DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
return it->second.factory.get(); return it->second.factory.get();
} }
Status DeviceFactory::AddDevices(const SessionOptions& options, Status DeviceFactory::AddDevices(
const string& name_prefix, const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) { std::vector<std::unique_ptr<Device>>* devices) {
// CPU first. A CPU device is required. // CPU first. A CPU device is required.
auto cpu_factory = GetFactory("CPU"); auto cpu_factory = GetFactory("CPU");
if (!cpu_factory) { if (!cpu_factory) {
@ -116,16 +117,16 @@ Status DeviceFactory::AddDevices(const SessionOptions& options,
return Status::OK(); return Status::OK();
} }
Device* DeviceFactory::NewDevice(const string& type, std::unique_ptr<Device> DeviceFactory::NewDevice(const string& type,
const SessionOptions& options, const SessionOptions& options,
const string& name_prefix) { const string& name_prefix) {
auto device_factory = GetFactory(type); auto device_factory = GetFactory(type);
if (!device_factory) { if (!device_factory) {
return nullptr; return nullptr;
} }
SessionOptions opt = options; SessionOptions opt = options;
(*opt.config.mutable_device_count())[type] = 1; (*opt.config.mutable_device_count())[type] = 1;
std::vector<Device*> devices; std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(device_factory->CreateDevices(opt, name_prefix, &devices)); TF_CHECK_OK(device_factory->CreateDevices(opt, name_prefix, &devices));
int expected_num_devices = 1; int expected_num_devices = 1;
auto iter = options.config.device_count().find(type); auto iter = options.config.device_count().find(type);
@ -133,7 +134,7 @@ Device* DeviceFactory::NewDevice(const string& type,
expected_num_devices = iter->second; expected_num_devices = iter->second;
} }
DCHECK_EQ(devices.size(), static_cast<size_t>(expected_num_devices)); DCHECK_EQ(devices.size(), static_cast<size_t>(expected_num_devices));
return devices[0]; return std::move(devices[0]);
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -40,18 +40,19 @@ class DeviceFactory {
// CPU devices are added first. // CPU devices are added first.
static Status AddDevices(const SessionOptions& options, static Status AddDevices(const SessionOptions& options,
const string& name_prefix, const string& name_prefix,
std::vector<Device*>* devices); std::vector<std::unique_ptr<Device>>* devices);
// Helper for tests. Create a single device of type "type". The // Helper for tests. Create a single device of type "type". The
// returned device is always numbered zero, so if creating multiple // returned device is always numbered zero, so if creating multiple
// devices of the same type, supply distinct name_prefix arguments. // devices of the same type, supply distinct name_prefix arguments.
static Device* NewDevice(const string& type, const SessionOptions& options, static std::unique_ptr<Device> NewDevice(const string& type,
const string& name_prefix); const SessionOptions& options,
const string& name_prefix);
// Most clients should call AddDevices() instead. // Most clients should call AddDevices() instead.
virtual Status CreateDevices(const SessionOptions& options, virtual Status CreateDevices(
const string& name_prefix, const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) = 0; std::vector<std::unique_ptr<Device>>* devices) = 0;
// Return the device priority number for a "device_type" string. // Return the device priority number for a "device_type" string.
// //

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include <memory>
#include <vector> #include <vector>
#include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/device_attributes.pb.h"
@ -24,32 +25,32 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
DeviceMgr::DeviceMgr(const std::vector<Device*>& devices) DeviceMgr::DeviceMgr(std::vector<std::unique_ptr<Device>> devices)
: name_backing_store_(128) { : devices_(std::move(devices)), name_backing_store_(128) {
for (Device* d : devices) { for (auto& d : devices_) {
CHECK(d->device_mgr_ == nullptr); CHECK(d->device_mgr_ == nullptr);
d->device_mgr_ = this; d->device_mgr_ = this;
devices_.push_back(d);
// Register under the (1) full name and (2) canonical name. // Register under the (1) full name and (2) canonical name.
for (const string& name : for (const string& name :
DeviceNameUtils::GetNamesForDeviceMappings(d->parsed_name())) { DeviceNameUtils::GetNamesForDeviceMappings(d->parsed_name())) {
device_map_[CopyToBackingStore(name)] = d; device_map_[CopyToBackingStore(name)] = d.get();
} }
// Register under the (3) local name and (4) legacy local name. // Register under the (3) local name and (4) legacy local name.
for (const string& name : for (const string& name :
DeviceNameUtils::GetLocalNamesForDeviceMappings(d->parsed_name())) { DeviceNameUtils::GetLocalNamesForDeviceMappings(d->parsed_name())) {
device_map_[CopyToBackingStore(name)] = d; device_map_[CopyToBackingStore(name)] = d.get();
} }
device_type_counts_[d->device_type()]++; device_type_counts_[d->device_type()]++;
} }
} }
DeviceMgr::~DeviceMgr() { DeviceMgr::DeviceMgr(std::unique_ptr<Device> device)
// TODO(b/37437134): Remove destructor after converting to std::unique_ptr. : DeviceMgr([&device] {
for (Device* p : devices_) delete p; std::vector<std::unique_ptr<Device>> vector;
} vector.push_back(std::move(device));
return vector;
}()) {}
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) { StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
size_t n = s.size(); size_t n = s.size();
@ -61,18 +62,22 @@ StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
void DeviceMgr::ListDeviceAttributes( void DeviceMgr::ListDeviceAttributes(
std::vector<DeviceAttributes>* devices) const { std::vector<DeviceAttributes>* devices) const {
devices->reserve(devices_.size()); devices->reserve(devices_.size());
for (Device* dev : devices_) { for (const auto& dev : devices_) {
devices->emplace_back(dev->attributes()); devices->emplace_back(dev->attributes());
} }
} }
std::vector<Device*> DeviceMgr::ListDevices() const { std::vector<Device*> DeviceMgr::ListDevices() const {
return std::vector<Device*>(devices_.begin(), devices_.end()); std::vector<Device*> devices(devices_.size());
for (size_t i = 0; i < devices_.size(); ++i) {
devices[i] = devices_[i].get();
}
return devices;
} }
string DeviceMgr::DebugString() const { string DeviceMgr::DebugString() const {
string out; string out;
for (Device* dev : devices_) { for (const auto& dev : devices_) {
strings::StrAppend(&out, dev->name(), "\n"); strings::StrAppend(&out, dev->name(), "\n");
} }
return out; return out;
@ -80,7 +85,7 @@ string DeviceMgr::DebugString() const {
string DeviceMgr::DeviceMappingString() const { string DeviceMgr::DeviceMappingString() const {
string out; string out;
for (Device* dev : devices_) { for (const auto& dev : devices_) {
if (!dev->attributes().physical_device_desc().empty()) { if (!dev->attributes().physical_device_desc().empty()) {
strings::StrAppend(&out, dev->name(), " -> ", strings::StrAppend(&out, dev->name(), " -> ",
dev->attributes().physical_device_desc(), "\n"); dev->attributes().physical_device_desc(), "\n");
@ -107,7 +112,7 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const { void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
Status s; Status s;
for (Device* dev : devices_) { for (const auto& dev : devices_) {
if (containers.empty()) { if (containers.empty()) {
s.Update(dev->resource_manager()->Cleanup( s.Update(dev->resource_manager()->Cleanup(
dev->resource_manager()->default_container())); dev->resource_manager()->default_container()));

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
#include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
@ -34,15 +35,17 @@ class DeviceAttributes;
class DeviceMgr { class DeviceMgr {
public: public:
// Takes ownership of each device in 'devices'. // Constructs a DeviceMgr from a list of devices.
// TODO(zhifengc): Other initialization information. // TODO(zhifengc): Other initialization information.
// TODO(b/37437134): Use std::unique_ptr's to track ownership. explicit DeviceMgr(std::vector<std::unique_ptr<Device>> devices);
explicit DeviceMgr(const std::vector<Device*>& devices);
~DeviceMgr(); // Constructs a DeviceMgr managing a single device.
explicit DeviceMgr(std::unique_ptr<Device> device);
// Returns attributes of all devices. // Returns attributes of all devices.
void ListDeviceAttributes(std::vector<DeviceAttributes>* devices) const; void ListDeviceAttributes(std::vector<DeviceAttributes>* devices) const;
// Returns raw pointers to the underlying devices.
std::vector<Device*> ListDevices() const; std::vector<Device*> ListDevices() const;
// Returns a string listing all devices. // Returns a string listing all devices.
@ -62,9 +65,7 @@ class DeviceMgr {
int NumDeviceType(const string& type) const; int NumDeviceType(const string& type) const;
private: private:
// TODO(b/37437134): Use std::unique_ptr's to track ownership. const std::vector<std::unique_ptr<Device>> devices_;
typedef gtl::InlinedVector<Device*, 8> DeviceVec;
DeviceVec devices_;
StringPiece CopyToBackingStore(StringPiece s); StringPiece CopyToBackingStore(StringPiece s);

View File

@ -36,12 +36,12 @@ class DeviceResolverLocalTest : public ::testing::Test {
string task_name = "/job:localhost/replica:0/task:0"; string task_name = "/job:localhost/replica:0/task:0";
auto* device_count = options.config.mutable_device_count(); auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", NUM_DEVS}); device_count->insert({"CPU", NUM_DEVS});
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_)); std::vector<std::unique_ptr<Device>> devices;
device_mgr_.reset(new DeviceMgr(devices_)); TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
drl_.reset(new DeviceResolverLocal(device_mgr_.get())); drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
} }
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_; std::unique_ptr<DeviceResolverLocal> drl_;
}; };

View File

@ -57,7 +57,7 @@ class DeviceSetTest : public ::testing::Test {
class DummyFactory : public DeviceFactory { class DummyFactory : public DeviceFactory {
public: public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix, Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override { std::vector<std::unique_ptr<Device>>* devices) override {
return Status::OK(); return Status::OK();
} }
}; };

View File

@ -155,12 +155,12 @@ class DirectSessionFactory : public SessionFactory {
if (options.config.graph_options().build_cost_model() > 0) { if (options.config.graph_options().build_cost_model() > 0) {
EnableCPUAllocatorFullStats(true); EnableCPUAllocatorFullStats(true);
} }
std::vector<Device*> devices; std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices( TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices)); options, "/job:localhost/replica:0/task:0", &devices));
DirectSession* session = DirectSession* session =
new DirectSession(options, new DeviceMgr(devices), this); new DirectSession(options, new DeviceMgr(std::move(devices)), this);
{ {
mutex_lock l(sessions_lock_); mutex_lock l(sessions_lock_);
sessions_.push_back(session); sessions_.push_back(session);

View File

@ -181,6 +181,7 @@ tf_cc_test(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"@com_google_absl//absl/memory",
], ],
) )

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/framework/scope.h"
@ -37,12 +38,13 @@ namespace {
class TestEnv { class TestEnv {
public: public:
TestEnv() : flib_def_(OpRegistry::Global(), {}) { TestEnv() : flib_def_(OpRegistry::Global(), {}) {
Device* device = std::vector<std::unique_ptr<Device>> devices;
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); devices.push_back(
device_mgr_.reset(new DeviceMgr({device})); DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(), device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
device, TF_GRAPH_DEF_VERSION, flib_runtime_ = NewFunctionLibraryRuntime(
&flib_def_, nullptr, {}, nullptr); device_mgr_.get(), Env::Default(), device_mgr_->ListDevices()[0],
TF_GRAPH_DEF_VERSION, &flib_def_, nullptr, {}, nullptr);
} }
FunctionLibraryRuntime* function_library_runtime() const { FunctionLibraryRuntime* function_library_runtime() const {

View File

@ -53,17 +53,17 @@ class ExecutorTest : public ::testing::Test {
// when the test completes. // when the test completes.
CHECK(rendez_->Unref()); CHECK(rendez_->Unref());
delete exec_; delete exec_;
delete device_;
} }
// Resets executor_ with a new executor based on a graph 'gdef'. // Resets executor_ with a new executor based on a graph 'gdef'.
void Create(std::unique_ptr<const Graph> graph) { void Create(std::unique_ptr<const Graph> graph) {
const int version = graph->versions().producer(); const int version = graph->versions().producer();
LocalExecutorParams params; LocalExecutorParams params;
params.device = device_; params.device = device_.get();
params.create_kernel = [this, version](const NodeDef& ndef, params.create_kernel = [this, version](const NodeDef& ndef,
OpKernel** kernel) { OpKernel** kernel) {
return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel); return CreateNonCachedKernel(device_.get(), nullptr, ndef, version,
kernel);
}; };
params.delete_kernel = [](OpKernel* kernel) { params.delete_kernel = [](OpKernel* kernel) {
DeleteNonCachedKernel(kernel); DeleteNonCachedKernel(kernel);
@ -83,7 +83,7 @@ class ExecutorTest : public ::testing::Test {
} }
thread::ThreadPool* thread_pool_ = nullptr; thread::ThreadPool* thread_pool_ = nullptr;
Device* device_ = nullptr; std::unique_ptr<Device> device_;
Executor* exec_ = nullptr; Executor* exec_ = nullptr;
StepStatsCollector step_stats_collector_; StepStatsCollector step_stats_collector_;
StepStats step_stats_; StepStats step_stats_;

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <atomic> #include <atomic>
#include <utility> #include <utility>
#include "absl/memory/memory.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/array_ops_internal.h"
@ -147,14 +148,15 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
SessionOptions options; SessionOptions options;
auto* device_count = options.config.mutable_device_count(); auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 3}); device_count->insert({"CPU", 3});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices( TF_CHECK_OK(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices_)); options, "/job:localhost/replica:0/task:0", &devices));
FunctionDefLibrary proto; FunctionDefLibrary proto;
for (const auto& fdef : flib) *(proto.add_function()) = fdef; for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto)); lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
OptimizerOptions opts; OptimizerOptions opts;
device_mgr_.reset(new DeviceMgr(devices_)); device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
pflr_.reset(new ProcessFunctionLibraryRuntime( pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, default_thread_pool, nullptr /* cluster_flr */)); opts, default_thread_pool, nullptr /* cluster_flr */));
@ -358,7 +360,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
FunctionLibraryRuntime* flr0_; FunctionLibraryRuntime* flr0_;
FunctionLibraryRuntime* flr1_; FunctionLibraryRuntime* flr1_;
FunctionLibraryRuntime* flr2_; FunctionLibraryRuntime* flr2_;
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_; std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;

View File

@ -54,14 +54,15 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
SessionOptions options; SessionOptions options;
auto* device_count = options.config.mutable_device_count(); auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 3}); device_count->insert({"CPU", 3});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices( TF_CHECK_OK(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices_)); options, "/job:localhost/replica:0/task:0", &devices));
FunctionDefLibrary proto; FunctionDefLibrary proto;
for (const auto& fdef : flib) *(proto.add_function()) = fdef; for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto)); lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
OptimizerOptions opts; OptimizerOptions opts;
device_mgr_.reset(new DeviceMgr(devices_)); device_mgr_.reset(new DeviceMgr(std::move(devices)));
pflr_.reset(new ProcessFunctionLibraryRuntime( pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, default_thread_pool, nullptr /* cluster_flr */)); opts, default_thread_pool, nullptr /* cluster_flr */));
@ -194,7 +195,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
FunctionLibraryRuntime* flr0_; FunctionLibraryRuntime* flr0_;
FunctionLibraryRuntime* flr1_; FunctionLibraryRuntime* flr1_;
FunctionLibraryRuntime* flr2_; FunctionLibraryRuntime* flr2_;
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_; std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;

View File

@ -907,9 +907,9 @@ Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr,
const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000; const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000;
const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1; const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1;
Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options, Status BaseGPUDeviceFactory::CreateDevices(
const string& name_prefix, const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) { std::vector<std::unique_ptr<Device>>* devices) {
TF_RETURN_IF_ERROR(ValidateGPUMachineManager()); TF_RETURN_IF_ERROR(ValidateGPUMachineManager());
se::Platform* gpu_manager = GPUMachineManager(); se::Platform* gpu_manager = GPUMachineManager();
if (gpu_manager == nullptr) { if (gpu_manager == nullptr) {
@ -1073,12 +1073,10 @@ static string GetShortDeviceDescription(PlatformGpuId platform_gpu_id,
// LINT.ThenChange(//tensorflow/python/platform/test.py) // LINT.ThenChange(//tensorflow/python/platform/test.py)
} }
Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options, Status BaseGPUDeviceFactory::CreateGPUDevice(
const string& name_prefix, const SessionOptions& options, const string& name_prefix, TfGpuId tf_gpu_id,
TfGpuId tf_gpu_id, int64 memory_limit, const DeviceLocality& dev_locality,
int64 memory_limit, std::vector<std::unique_ptr<Device>>* devices) {
const DeviceLocality& dev_locality,
std::vector<Device*>* devices) {
CHECK_GE(tf_gpu_id.value(), 0); CHECK_GE(tf_gpu_id.value(), 0);
const string device_name = const string device_name =
strings::StrCat(name_prefix, "/device:GPU:", tf_gpu_id.value()); strings::StrCat(name_prefix, "/device:GPU:", tf_gpu_id.value());
@ -1108,7 +1106,7 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
// different (which should be an error). // different (which should be an error).
// //
// TODO(laigd): report error if memory_limit doesn't match stats.bytes_limit. // TODO(laigd): report error if memory_limit doesn't match stats.bytes_limit.
BaseGPUDevice* gpu_device = CreateGPUDevice( std::unique_ptr<BaseGPUDevice> gpu_device = CreateGPUDevice(
options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality, options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality,
tf_gpu_id, GetShortDeviceDescription(platform_gpu_id, desc), tf_gpu_id, GetShortDeviceDescription(platform_gpu_id, desc),
gpu_allocator, ProcessState::singleton()->GetCPUAllocator(numa_node)); gpu_allocator, ProcessState::singleton()->GetCPUAllocator(numa_node));
@ -1116,7 +1114,7 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
<< (stats.bytes_limit >> 20) << " MB memory) -> physical GPU (" << (stats.bytes_limit >> 20) << " MB memory) -> physical GPU ("
<< GetShortDeviceDescription(platform_gpu_id, desc) << ")"; << GetShortDeviceDescription(platform_gpu_id, desc) << ")";
TF_RETURN_IF_ERROR(gpu_device->Init(options)); TF_RETURN_IF_ERROR(gpu_device->Init(options));
devices->push_back(gpu_device); devices->push_back(std::move(gpu_device));
return Status::OK(); return Status::OK();
} }

View File

@ -166,7 +166,7 @@ class BaseGPUDevice : public LocalDevice {
class BaseGPUDeviceFactory : public DeviceFactory { class BaseGPUDeviceFactory : public DeviceFactory {
public: public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix, Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override; std::vector<std::unique_ptr<Device>>* devices) override;
struct InterconnectMap { struct InterconnectMap {
// Name of interconnect technology, if known. // Name of interconnect technology, if known.
@ -207,15 +207,13 @@ class BaseGPUDeviceFactory : public DeviceFactory {
Status CreateGPUDevice(const SessionOptions& options, Status CreateGPUDevice(const SessionOptions& options,
const string& name_prefix, TfGpuId tf_gpu_id, const string& name_prefix, TfGpuId tf_gpu_id,
int64 memory_limit, const DeviceLocality& dev_locality, int64 memory_limit, const DeviceLocality& dev_locality,
std::vector<Device*>* devices); std::vector<std::unique_ptr<Device>>* devices);
virtual BaseGPUDevice* CreateGPUDevice(const SessionOptions& options, virtual std::unique_ptr<BaseGPUDevice> CreateGPUDevice(
const string& name, Bytes memory_limit, const SessionOptions& options, const string& name, Bytes memory_limit,
const DeviceLocality& dev_locality, const DeviceLocality& dev_locality, TfGpuId tf_gpu_id,
TfGpuId tf_gpu_id, const string& physical_device_desc, Allocator* gpu_allocator,
const string& physical_device_desc, Allocator* cpu_allocator) = 0;
Allocator* gpu_allocator,
Allocator* cpu_allocator) = 0;
// Returns into 'ids' the list of valid platform GPU ids, in the order that // Returns into 'ids' the list of valid platform GPU ids, in the order that
// they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc, // they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc,

View File

@ -59,15 +59,14 @@ class GPUDevice : public BaseGPUDevice {
class GPUDeviceFactory : public BaseGPUDeviceFactory { class GPUDeviceFactory : public BaseGPUDeviceFactory {
private: private:
BaseGPUDevice* CreateGPUDevice(const SessionOptions& options, std::unique_ptr<BaseGPUDevice> CreateGPUDevice(
const string& name, Bytes memory_limit, const SessionOptions& options, const string& name, Bytes memory_limit,
const DeviceLocality& locality, const DeviceLocality& locality, TfGpuId tf_gpu_id,
TfGpuId tf_gpu_id, const string& physical_device_desc, Allocator* gpu_allocator,
const string& physical_device_desc, Allocator* cpu_allocator) override {
Allocator* gpu_allocator, return absl::make_unique<GPUDevice>(options, name, memory_limit, locality,
Allocator* cpu_allocator) override { tf_gpu_id, physical_device_desc,
return new GPUDevice(options, name, memory_limit, locality, tf_gpu_id, gpu_allocator, cpu_allocator);
physical_device_desc, gpu_allocator, cpu_allocator);
} }
}; };
@ -108,7 +107,7 @@ class GPUCompatibleCPUDevice : public ThreadPoolDevice {
class GPUCompatibleCPUDeviceFactory : public DeviceFactory { class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
public: public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix, Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override { std::vector<std::unique_ptr<Device>>* devices) override {
int n = 1; int n = 1;
auto iter = options.config.device_count().find("CPU"); auto iter = options.config.device_count().find("CPU");
if (iter != options.config.device_count().end()) { if (iter != options.config.device_count().end()) {
@ -116,7 +115,7 @@ class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
} }
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
string name = strings::StrCat(name_prefix, "/device:CPU:", i); string name = strings::StrCat(name_prefix, "/device:CPU:", i);
devices->push_back(new GPUCompatibleCPUDevice( devices->push_back(absl::make_unique<GPUCompatibleCPUDevice>(
options, name, Bytes(256 << 20), DeviceLocality(), cpu_allocator())); options, name, Bytes(256 << 20), DeviceLocality(), cpu_allocator()));
} }

View File

@ -33,7 +33,7 @@ namespace {
TEST(GPUDeviceOnNonGPUMachineTest, CreateGPUDevicesOnNonGPUMachine) { TEST(GPUDeviceOnNonGPUMachineTest, CreateGPUDevicesOnNonGPUMachine) {
SessionOptions opts; SessionOptions opts;
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices( TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, "/job:localhost/replica:0/task:0", &devices)); opts, "/job:localhost/replica:0/task:0", &devices));
EXPECT_TRUE(devices.empty()); EXPECT_TRUE(devices.empty());

View File

@ -88,7 +88,7 @@ class GPUDeviceTest : public ::testing::Test {
TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) { TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) {
SessionOptions opts = MakeSessionOptions("0,abc"); SessionOptions opts = MakeSessionOptions("0,abc");
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices( Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices); opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@ -97,7 +97,7 @@ TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) {
TEST_F(GPUDeviceTest, InvalidGpuId) { TEST_F(GPUDeviceTest, InvalidGpuId) {
SessionOptions opts = MakeSessionOptions("100"); SessionOptions opts = MakeSessionOptions("100");
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices( Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices); opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@ -107,7 +107,7 @@ TEST_F(GPUDeviceTest, InvalidGpuId) {
TEST_F(GPUDeviceTest, DuplicateEntryInVisibleDeviceList) { TEST_F(GPUDeviceTest, DuplicateEntryInVisibleDeviceList) {
SessionOptions opts = MakeSessionOptions("0,0"); SessionOptions opts = MakeSessionOptions("0,0");
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices( Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices); opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@ -117,7 +117,7 @@ TEST_F(GPUDeviceTest, DuplicateEntryInVisibleDeviceList) {
TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithMemoryFractionSettings) { TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithMemoryFractionSettings) {
SessionOptions opts = MakeSessionOptions("0", 0.1, 1, {{}}); SessionOptions opts = MakeSessionOptions("0", 0.1, 1, {{}});
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices( Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices); opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@ -129,7 +129,7 @@ TEST_F(GPUDeviceTest, GpuDeviceCountTooSmall) {
// device_count is 0, but with one entry in visible_device_list and one // device_count is 0, but with one entry in visible_device_list and one
// (empty) VirtualDevices messages. // (empty) VirtualDevices messages.
SessionOptions opts = MakeSessionOptions("0", 0, 0, {{}}); SessionOptions opts = MakeSessionOptions("0", 0, 0, {{}});
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices( Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices); opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::UNKNOWN); EXPECT_EQ(status.code(), error::UNKNOWN);
@ -141,7 +141,7 @@ TEST_F(GPUDeviceTest, NotEnoughGpuInVisibleDeviceList) {
// Single entry in visible_device_list with two (empty) VirtualDevices // Single entry in visible_device_list with two (empty) VirtualDevices
// messages. // messages.
SessionOptions opts = MakeSessionOptions("0", 0, 8, {{}, {}}); SessionOptions opts = MakeSessionOptions("0", 0, 8, {{}, {}});
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices( Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices); opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::UNKNOWN); EXPECT_EQ(status.code(), error::UNKNOWN);
@ -155,7 +155,7 @@ TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithVisibleDeviceList) {
// Three entries in visible_device_list with two (empty) VirtualDevices // Three entries in visible_device_list with two (empty) VirtualDevices
// messages. // messages.
SessionOptions opts = MakeSessionOptions("0,1", 0, 8, {{}}); SessionOptions opts = MakeSessionOptions("0,1", 0, 8, {{}});
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices( Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices); opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@ -169,39 +169,36 @@ TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithVisibleDeviceList) {
TEST_F(GPUDeviceTest, EmptyVirtualDeviceConfig) { TEST_F(GPUDeviceTest, EmptyVirtualDeviceConfig) {
// It'll create single virtual device when the virtual device config is empty. // It'll create single virtual device when the virtual device config is empty.
SessionOptions opts = MakeSessionOptions("0"); SessionOptions opts = MakeSessionOptions("0");
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices( TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices)); opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(1, devices.size()); EXPECT_EQ(1, devices.size());
EXPECT_GE(devices[0]->attributes().memory_limit(), 0); EXPECT_GE(devices[0]->attributes().memory_limit(), 0);
gtl::STLDeleteElements(&devices);
} }
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithNoMemoryLimit) { TEST_F(GPUDeviceTest, SingleVirtualDeviceWithNoMemoryLimit) {
// It'll create single virtual device for the gpu in question when // It'll create single virtual device for the gpu in question when
// memory_limit_mb is unset. // memory_limit_mb is unset.
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{}}); SessionOptions opts = MakeSessionOptions("0", 0, 1, {{}});
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices( TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices)); opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(1, devices.size()); EXPECT_EQ(1, devices.size());
EXPECT_GE(devices[0]->attributes().memory_limit(), 0); EXPECT_GE(devices[0]->attributes().memory_limit(), 0);
gtl::STLDeleteElements(&devices);
} }
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimit) { TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimit) {
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123}}); SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123}});
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices( TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices)); opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(1, devices.size()); EXPECT_EQ(1, devices.size());
EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit()); EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit());
gtl::STLDeleteElements(&devices);
} }
TEST_F(GPUDeviceTest, MultipleVirtualDevices) { TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}); SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}});
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices( TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices)); opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(2, devices.size()); EXPECT_EQ(2, devices.size());
@ -219,7 +216,6 @@ TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
devices[1]->attributes().locality().links().link(0).type()); devices[1]->attributes().locality().links().link(0).type());
EXPECT_EQ(BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength, EXPECT_EQ(BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength,
devices[1]->attributes().locality().links().link(0).strength()); devices[1]->attributes().locality().links().link(0).strength());
gtl::STLDeleteElements(&devices);
} }
// Enabling unified memory on pre-Pascal GPUs results in an initialization // Enabling unified memory on pre-Pascal GPUs results in an initialization
@ -236,7 +232,7 @@ TEST_F(GPUDeviceTest, UnifiedMemoryUnavailableOnPrePascalGpus) {
opts.config.mutable_gpu_options() opts.config.mutable_gpu_options()
->mutable_experimental() ->mutable_experimental()
->set_use_unified_memory(true); ->set_use_unified_memory(true);
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices( Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices); opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INTERNAL); EXPECT_EQ(status.code(), error::INTERNAL);
@ -259,7 +255,7 @@ TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
} }
SessionOptions opts = MakeSessionOptions("0", kGpuMemoryFraction); SessionOptions opts = MakeSessionOptions("0", kGpuMemoryFraction);
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<Device>> devices;
TF_ASSERT_OK(DeviceFactory::GetFactory("GPU")->CreateDevices( TF_ASSERT_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices)); opts, kDeviceNamePrefix, &devices));
ASSERT_EQ(1, devices.size()); ASSERT_EQ(1, devices.size());
@ -278,8 +274,6 @@ TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
(memory_limit >> 20) << 20); (memory_limit >> 20) << 20);
EXPECT_NE(ptr, nullptr); EXPECT_NE(ptr, nullptr);
allocator->DeallocateRaw(ptr); allocator->DeallocateRaw(ptr);
gtl::STLDeleteElements(&devices);
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h" #include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
#include <algorithm> #include <algorithm>
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/base_collective_executor.h" #include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/common_runtime/collective_rma_local.h" #include "tensorflow/core/common_runtime/collective_rma_local.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
@ -217,7 +218,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
<< " num_devices_per_worker=" << num_devices_per_worker; << " num_devices_per_worker=" << num_devices_per_worker;
int total_num_devices = num_workers * num_devices_per_worker; int total_num_devices = num_workers * num_devices_per_worker;
device_type_ = device_type; device_type_ = device_type;
std::vector<Device*> local_devices; std::vector<std::unique_ptr<Device>> local_devices;
SessionOptions sess_opts; SessionOptions sess_opts;
sess_opts.env = Env::Default(); sess_opts.env = Env::Default();
Bytes mem_limit(4 << 20); Bytes mem_limit(4 << 20);
@ -227,7 +228,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
if (device_type == DEVICE_CPU) { if (device_type == DEVICE_CPU) {
string dev_name = strings::StrCat("/job:worker/replica:0/task:", wi, string dev_name = strings::StrCat("/job:worker/replica:0/task:", wi,
"/device:CPU:", di); "/device:CPU:", di);
local_devices.push_back(new ThreadPoolDevice( local_devices.push_back(absl::make_unique<ThreadPoolDevice>(
sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator())); sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
} else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) { } else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
int dev_idx = (wi * num_devices_per_worker) + di; int dev_idx = (wi * num_devices_per_worker) + di;
@ -235,7 +236,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more " LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
"than one ring node."; "than one ring node.";
} else { } else {
local_devices.push_back(gpu_devices_[dev_idx]); local_devices.push_back(std::move(gpu_devices_[dev_idx]));
} }
} else { } else {
LOG(FATAL) << "Unsupported device_type " << device_type; LOG(FATAL) << "Unsupported device_type " << device_type;
@ -243,7 +244,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
} }
} }
if (!dev_mgr_ || device_type == DEVICE_CPU) { if (!dev_mgr_ || device_type == DEVICE_CPU) {
dev_mgr_.reset(new DeviceMgr(local_devices)); dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
} }
if (!gpu_ring_order_) gpu_ring_order_.reset(new string()); if (!gpu_ring_order_) gpu_ring_order_.reset(new string());
dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get())); dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get()));
@ -714,7 +715,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
std::unique_ptr<DeviceResolverLocal> dev_resolver_; std::unique_ptr<DeviceResolverLocal> dev_resolver_;
std::vector<DeviceInstance*> instances_; std::vector<DeviceInstance*> instances_;
CollectiveParams col_params_; CollectiveParams col_params_;
std::vector<tensorflow::Device*> gpu_devices_; std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_; std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
std::unique_ptr<string> gpu_ring_order_; std::unique_ptr<string> gpu_ring_order_;
mutex mu_; mutex mu_;

View File

@ -75,12 +75,12 @@ Benchmark::Benchmark(const string& device, Graph* g,
const int graph_def_version = g->versions().producer(); const int graph_def_version = g->versions().producer();
LocalExecutorParams params; LocalExecutorParams params;
params.device = device_; params.device = device_.get();
params.function_library = nullptr; params.function_library = nullptr;
params.create_kernel = [this, graph_def_version](const NodeDef& ndef, params.create_kernel = [this, graph_def_version](const NodeDef& ndef,
OpKernel** kernel) { OpKernel** kernel) {
return CreateNonCachedKernel(device_, nullptr, ndef, graph_def_version, return CreateNonCachedKernel(device_.get(), nullptr, ndef,
kernel); graph_def_version, kernel);
}; };
params.delete_kernel = [](OpKernel* kernel) { params.delete_kernel = [](OpKernel* kernel) {
DeleteNonCachedKernel(kernel); DeleteNonCachedKernel(kernel);
@ -107,7 +107,7 @@ Benchmark::~Benchmark() {
// run kernel destructors that may attempt to access state borrowed from // run kernel destructors that may attempt to access state borrowed from
// `device_`, such as the resource manager. // `device_`, such as the resource manager.
exec_.reset(); exec_.reset();
delete device_; device_.reset();
delete pool_; delete pool_;
} }
} }

View File

@ -55,7 +55,7 @@ class Benchmark {
private: private:
thread::ThreadPool* pool_ = nullptr; thread::ThreadPool* pool_ = nullptr;
Device* device_ = nullptr; std::unique_ptr<Device> device_ = nullptr;
Rendezvous* rendez_ = nullptr; Rendezvous* rendez_ = nullptr;
std::unique_ptr<Executor> exec_; std::unique_ptr<Executor> exec_;

View File

@ -92,7 +92,7 @@ class FakeDevice : public Device {
class DummyFactory : public DeviceFactory { class DummyFactory : public DeviceFactory {
public: public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix, Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override { std::vector<std::unique_ptr<Device>>* devices) override {
return Status::OK(); return Status::OK();
} }
}; };

View File

@ -62,9 +62,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
SessionOptions options; SessionOptions options;
auto* device_count = options.config.mutable_device_count(); auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 2}); device_count->insert({"CPU", 2});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0", TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
&devices_)); &devices));
device_mgr_.reset(new DeviceMgr(devices_)); device0_ = devices[0].get();
device1_ = devices[1].get();
device_mgr_.reset(new DeviceMgr(std::move(devices)));
FunctionDefLibrary proto; FunctionDefLibrary proto;
for (const auto& fdef : flib) *(proto.add_function()) = fdef; for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto)); lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
@ -138,8 +141,9 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
return Status::OK(); return Status::OK();
} }
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
Device* device0_ = nullptr; // Not owned. (Owned by device_mgr_.)
Device* device1_ = nullptr; // Not owned. (Owned by device_mgr_.)
std::unique_ptr<FunctionLibraryDefinition> lib_def_; std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<TestClusterFLR> cluster_flr_; std::unique_ptr<TestClusterFLR> cluster_flr_;
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_; std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
@ -165,16 +169,16 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
FunctionLibraryRuntime* flr = FunctionLibraryRuntime* flr =
proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0"); proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0");
EXPECT_NE(flr, nullptr); EXPECT_NE(flr, nullptr);
EXPECT_EQ(flr->device(), devices_[0]); EXPECT_EQ(flr->device(), device0_);
flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0"); flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0");
EXPECT_NE(flr, nullptr); EXPECT_NE(flr, nullptr);
EXPECT_EQ(flr->device(), devices_[0]); EXPECT_EQ(flr->device(), device0_);
flr = proc_flr_->GetFLR("/device:CPU:0"); flr = proc_flr_->GetFLR("/device:CPU:0");
EXPECT_NE(flr, nullptr); EXPECT_NE(flr, nullptr);
EXPECT_EQ(flr->device(), devices_[0]); EXPECT_EQ(flr->device(), device0_);
flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:1"); flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:1");
EXPECT_NE(flr, nullptr); EXPECT_NE(flr, nullptr);
EXPECT_EQ(flr->device(), devices_[1]); EXPECT_EQ(flr->device(), device1_);
flr = proc_flr_->GetFLR("abc"); flr = proc_flr_->GetFLR("abc");
EXPECT_EQ(flr, nullptr); EXPECT_EQ(flr, nullptr);
rendezvous_->Unref(); rendezvous_->Unref();

View File

@ -14,15 +14,14 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/common_runtime/renamed_device.h"
#include "absl/memory/memory.h"
namespace tensorflow { namespace tensorflow {
// TODO(saeta): Convert to returning a std::unique_ptr?
/* static */ /* static */
Device* RenamedDevice::NewRenamedDevice(const string& new_base, std::unique_ptr<Device> RenamedDevice::NewRenamedDevice(
Device* underlying, const string& new_base, Device* underlying, bool owns_underlying,
bool owns_underlying, bool isolate_session_state) {
bool isolate_session_state) {
DeviceNameUtils::ParsedName parsed_name; DeviceNameUtils::ParsedName parsed_name;
CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name)); CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name));
DeviceNameUtils::ParsedName underlying_parsed_name = DeviceNameUtils::ParsedName underlying_parsed_name =
@ -36,8 +35,9 @@ Device* RenamedDevice::NewRenamedDevice(const string& new_base,
parsed_name.id); parsed_name.id);
DeviceAttributes attributes(underlying->attributes()); DeviceAttributes attributes(underlying->attributes());
attributes.set_name(name); attributes.set_name(name);
return new RenamedDevice(underlying, attributes, owns_underlying, // Call absl::WrapUnique to access private constructor.
isolate_session_state); return absl::WrapUnique(new RenamedDevice(
underlying, attributes, owns_underlying, isolate_session_state));
} }
RenamedDevice::RenamedDevice(Device* underlying, RenamedDevice::RenamedDevice(Device* underlying,

View File

@ -28,9 +28,10 @@ namespace tensorflow {
// session. // session.
class RenamedDevice : public Device { class RenamedDevice : public Device {
public: public:
static Device* NewRenamedDevice(const string& new_base, Device* underlying, static std::unique_ptr<Device> NewRenamedDevice(const string& new_base,
bool owns_underlying, Device* underlying,
bool isolate_session_state); bool owns_underlying,
bool isolate_session_state);
~RenamedDevice() override; ~RenamedDevice() override;

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/ring_reducer.h" #include "tensorflow/core/common_runtime/ring_reducer.h"
#include <algorithm> #include <algorithm>
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/base_collective_executor.h" #include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/common_runtime/collective_rma_local.h" #include "tensorflow/core/common_runtime/collective_rma_local.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
@ -157,7 +158,7 @@ class RingReducerTest : public ::testing::Test {
InitGPUDevices(); InitGPUDevices();
#endif #endif
device_type_ = device_type; device_type_ = device_type;
std::vector<Device*> local_devices; std::vector<std::unique_ptr<Device>> local_devices;
SessionOptions sess_opts; SessionOptions sess_opts;
sess_opts.env = Env::Default(); sess_opts.env = Env::Default();
Bytes mem_limit(4 << 20); Bytes mem_limit(4 << 20);
@ -167,7 +168,7 @@ class RingReducerTest : public ::testing::Test {
if (device_type == DEVICE_CPU) { if (device_type == DEVICE_CPU) {
string dev_name = string dev_name =
strings::StrCat("/job:worker/replica:0/task:", wi, "/cpu:", di); strings::StrCat("/job:worker/replica:0/task:", wi, "/cpu:", di);
local_devices.push_back(new ThreadPoolDevice( local_devices.push_back(absl::make_unique<ThreadPoolDevice>(
sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator())); sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
} else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) { } else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
int dev_idx = (wi * num_devices) + di; int dev_idx = (wi * num_devices) + di;
@ -175,7 +176,7 @@ class RingReducerTest : public ::testing::Test {
LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more " LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
"than one ring node."; "than one ring node.";
} else { } else {
local_devices.push_back(gpu_devices_[dev_idx]); local_devices.push_back(std::move(gpu_devices_[dev_idx]));
} }
} else { } else {
LOG(FATAL) << "Unsupported device_type " << device_type; LOG(FATAL) << "Unsupported device_type " << device_type;
@ -185,7 +186,7 @@ class RingReducerTest : public ::testing::Test {
if (!dev_mgr_ || device_type == DEVICE_CPU) { if (!dev_mgr_ || device_type == DEVICE_CPU) {
LOG(ERROR) << "resetting dev_mgr for " << local_devices.size() LOG(ERROR) << "resetting dev_mgr for " << local_devices.size()
<< " devices: "; << " devices: ";
dev_mgr_.reset(new DeviceMgr(local_devices)); dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
} }
if (!gpu_ring_order_) gpu_ring_order_.reset(new string()); if (!gpu_ring_order_) gpu_ring_order_.reset(new string());
dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get())); dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get()));
@ -544,7 +545,7 @@ class RingReducerTest : public ::testing::Test {
std::unique_ptr<DeviceResolverLocal> dev_resolver_; std::unique_ptr<DeviceResolverLocal> dev_resolver_;
std::vector<DeviceInstance*> instances_; std::vector<DeviceInstance*> instances_;
CollectiveParams col_params_; CollectiveParams col_params_;
std::vector<tensorflow::Device*> gpu_devices_; std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_; std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
std::unique_ptr<string> gpu_ring_order_; std::unique_ptr<string> gpu_ring_order_;
mutex mu_; mutex mu_;

View File

@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// Register a factory that provides CPU devices.
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include <vector> #include <vector>
// Register a factory that provides CPU devices.
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/process_state.h" #include "tensorflow/core/common_runtime/process_state.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/numa.h" #include "tensorflow/core/platform/numa.h"
#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/session_options.h"
@ -29,7 +30,7 @@ namespace tensorflow {
class ThreadPoolDeviceFactory : public DeviceFactory { class ThreadPoolDeviceFactory : public DeviceFactory {
public: public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix, Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override { std::vector<std::unique_ptr<Device>>* devices) override {
int num_numa_nodes = port::NUMANumNodes(); int num_numa_nodes = port::NUMANumNodes();
int n = 1; int n = 1;
auto iter = options.config.device_count().find("CPU"); auto iter = options.config.device_count().find("CPU");
@ -38,7 +39,7 @@ class ThreadPoolDeviceFactory : public DeviceFactory {
} }
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
string name = strings::StrCat(name_prefix, "/device:CPU:", i); string name = strings::StrCat(name_prefix, "/device:CPU:", i);
ThreadPoolDevice* tpd = nullptr; std::unique_ptr<ThreadPoolDevice> tpd;
if (options.config.experimental().use_numa_affinity()) { if (options.config.experimental().use_numa_affinity()) {
int numa_node = i % num_numa_nodes; int numa_node = i % num_numa_nodes;
if (numa_node != i) { if (numa_node != i) {
@ -49,15 +50,15 @@ class ThreadPoolDeviceFactory : public DeviceFactory {
} }
DeviceLocality dev_locality; DeviceLocality dev_locality;
dev_locality.set_numa_node(numa_node); dev_locality.set_numa_node(numa_node);
tpd = new ThreadPoolDevice( tpd = absl::make_unique<ThreadPoolDevice>(
options, name, Bytes(256 << 20), dev_locality, options, name, Bytes(256 << 20), dev_locality,
ProcessState::singleton()->GetCPUAllocator(numa_node)); ProcessState::singleton()->GetCPUAllocator(numa_node));
} else { } else {
tpd = new ThreadPoolDevice( tpd = absl::make_unique<ThreadPoolDevice>(
options, name, Bytes(256 << 20), DeviceLocality(), options, name, Bytes(256 << 20), DeviceLocality(),
ProcessState::singleton()->GetCPUAllocator(port::kNUMANoAffinity)); ProcessState::singleton()->GetCPUAllocator(port::kNUMANoAffinity));
} }
devices->push_back(tpd); devices->push_back(std::move(tpd));
} }
return Status::OK(); return Status::OK();

View File

@ -624,6 +624,7 @@ tf_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
"@com_google_absl//absl/memory",
], ],
) )

View File

@ -29,7 +29,8 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace { namespace {
static Device* NewDevice(const string& type, const string& name) { static std::unique_ptr<Device> NewDevice(const string& type,
const string& name) {
class FakeDevice : public Device { class FakeDevice : public Device {
public: public:
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
@ -40,7 +41,7 @@ static Device* NewDevice(const string& type, const string& name) {
attr.set_name(name); attr.set_name(name);
attr.set_device_type(type); attr.set_device_type(type);
attr.mutable_locality()->set_numa_node(3); // a non-default value attr.mutable_locality()->set_numa_node(3); // a non-default value
return new FakeDevice(attr); return absl::make_unique<FakeDevice>(attr);
} }
class FakeWorker : public TestWorkerInterface { class FakeWorker : public TestWorkerInterface {
@ -156,16 +157,16 @@ class DeviceResDistTest : public ::testing::Test {
void DefineWorker(const ConfigProto& config, const string& worker_name, void DefineWorker(const ConfigProto& config, const string& worker_name,
const string& device_type, int num_devices) { const string& device_type, int num_devices) {
std::vector<Device*> devices; std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < num_devices; ++i) { for (int i = 0; i < num_devices; ++i) {
devices.push_back(NewDevice( devices.push_back(NewDevice(
device_type, device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i))); strings::StrCat(worker_name, "/device:", device_type, ":", i)));
} }
DeviceMgr* dev_mgr = new DeviceMgr(devices); DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
device_mgrs_.push_back(dev_mgr); device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name]; std::vector<string>* dv = &dev_by_task_[worker_name];
for (auto d : devices) { for (auto* d : dev_mgr->ListDevices()) {
dv->push_back(d->name()); dv->push_back(d->name());
} }
DeviceResolverDistributed* dev_res = DeviceResolverDistributed* dev_res =

View File

@ -41,7 +41,8 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace { namespace {
static Device* NewDevice(const string& type, const string& name) { static std::unique_ptr<Device> NewDevice(const string& type,
const string& name) {
class FakeDevice : public Device { class FakeDevice : public Device {
public: public:
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
@ -52,7 +53,7 @@ static Device* NewDevice(const string& type, const string& name) {
attr.set_name(name); attr.set_name(name);
attr.set_device_type(type); attr.set_device_type(type);
attr.mutable_locality()->set_numa_node(3); // a non-default value attr.mutable_locality()->set_numa_node(3); // a non-default value
return new FakeDevice(attr); return absl::make_unique<FakeDevice>(attr);
} }
static int64 kStepId = 123; static int64 kStepId = 123;
@ -211,16 +212,16 @@ class CollRMADistTest : public ::testing::Test {
void DefineWorker(const ConfigProto& config, const string& worker_name, void DefineWorker(const ConfigProto& config, const string& worker_name,
const string& device_type, int num_devices) { const string& device_type, int num_devices) {
std::vector<Device*> devices; std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < num_devices; ++i) { for (int i = 0; i < num_devices; ++i) {
devices.push_back(NewDevice( devices.push_back(NewDevice(
device_type, device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i))); strings::StrCat(worker_name, "/device:", device_type, ":", i)));
} }
DeviceMgr* dev_mgr = new DeviceMgr(devices); DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
device_mgrs_.push_back(dev_mgr); device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name]; std::vector<string>* dv = &dev_by_task_[worker_name];
for (auto d : devices) { for (auto d : dev_mgr->ListDevices()) {
dv->push_back(d->name()); dv->push_back(d->name());
} }
DeviceResolverDistributed* dev_res = DeviceResolverDistributed* dev_res =

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/distributed_runtime/test_utils.h" #include "tensorflow/core/distributed_runtime/test_utils.h"
#include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/notification.h"
@ -41,8 +42,8 @@ class TestableDeviceResolverDistributed : public DeviceResolverDistributed {
// Create a fake 'Device' whose only interesting attribute is a non-default // Create a fake 'Device' whose only interesting attribute is a non-default
// DeviceLocality. // DeviceLocality.
static Device* NewDevice(const string& type, const string& name, static std::unique_ptr<Device> NewDevice(const string& type, const string& name,
int numa_node) { int numa_node) {
class FakeDevice : public Device { class FakeDevice : public Device {
public: public:
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
@ -53,7 +54,7 @@ static Device* NewDevice(const string& type, const string& name,
attr.set_name(name); attr.set_name(name);
attr.set_device_type(type); attr.set_device_type(type);
attr.mutable_locality()->set_numa_node(numa_node); attr.mutable_locality()->set_numa_node(numa_node);
return new FakeDevice(attr); return absl::make_unique<FakeDevice>(attr);
} }
// Create a fake WorkerInterface that responds to requests without RPCs, // Create a fake WorkerInterface that responds to requests without RPCs,
@ -151,19 +152,19 @@ class DeviceResDistTest : public ::testing::Test {
void DefineWorker(const string& worker_name, const string& device_type, void DefineWorker(const string& worker_name, const string& device_type,
int num_devices) { int num_devices) {
std::vector<Device*> devices; std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < num_devices; ++i) { for (int i = 0; i < num_devices; ++i) {
devices.push_back(NewDevice( devices.push_back(NewDevice(
device_type, device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i), i)); strings::StrCat(worker_name, "/device:", device_type, ":", i), i));
} }
DeviceMgr* dev_mgr = new DeviceMgr(devices); DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
TestableDeviceResolverDistributed* dev_res = TestableDeviceResolverDistributed* dev_res =
new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name); new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name);
resolvers_[worker_name] = dev_res; resolvers_[worker_name] = dev_res;
device_mgrs_.push_back(dev_mgr); device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name]; std::vector<string>* dv = &dev_by_task_[worker_name];
for (auto d : devices) { for (auto* d : dev_mgr->ListDevices()) {
dv->push_back(d->name()); dv->push_back(d->name());
} }
FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res); FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);

View File

@ -87,7 +87,7 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
return tensorflow::errors::Internal( return tensorflow::errors::Internal(
"invalid eager env_ or env_->rendezvous_mgr."); "invalid eager env_ or env_->rendezvous_mgr.");
} }
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
// TODO(nareshmodi): Correctly set the SessionOptions. // TODO(nareshmodi): Correctly set the SessionOptions.
@ -97,12 +97,12 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
request->server_def().task_index()), request->server_def().task_index()),
&devices)); &devices));
response->mutable_device_attributes()->Reserve(devices.size()); response->mutable_device_attributes()->Reserve(devices.size());
for (auto& d : devices) { for (const auto& d : devices) {
*response->add_device_attributes() = d->attributes(); *response->add_device_attributes() = d->attributes();
} }
std::unique_ptr<tensorflow::DeviceMgr> device_mgr( std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(devices)); new tensorflow::DeviceMgr(std::move(devices)));
auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id()); auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id());
auto session_name = strings::StrCat("eager_", request->rendezvous_id()); auto session_name = strings::StrCat("eager_", request->rendezvous_id());

View File

@ -68,12 +68,9 @@ class EagerServiceImplTest : public ::testing::Test {
worker_env_.rendezvous_mgr = &rendezvous_mgr_; worker_env_.rendezvous_mgr = &rendezvous_mgr_;
worker_env_.session_mgr = session_mgr_.get(); worker_env_.session_mgr = session_mgr_.get();
Device* device = DeviceFactory::NewDevice( device_mgr_ = absl::make_unique<DeviceMgr>(DeviceFactory::NewDevice(
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"); "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
worker_env_.local_devices = device_mgr_->ListDevices();
worker_env_.local_devices = {device};
device_mgr_.reset(new DeviceMgr(worker_env_.local_devices));
worker_env_.device_mgr = device_mgr_.get(); worker_env_.device_mgr = device_mgr_.get();
} }

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <cstring> #include <cstring>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <vector>
#include "grpc/support/alloc.h" #include "grpc/support/alloc.h"
#include "grpcpp/grpcpp.h" #include "grpcpp/grpcpp.h"
@ -156,10 +157,12 @@ Status GrpcServer::Init(
string name_prefix = string name_prefix =
strings::StrCat("/job:", server_def_.job_name(), "/replica:0", strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
"/task:", server_def_.task_index()); "/task:", server_def_.task_index());
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix, std::vector<std::unique_ptr<Device>> devices;
&master_env_.local_devices)); TF_RETURN_IF_ERROR(
worker_env_.local_devices = master_env_.local_devices; DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices); worker_env_.device_mgr = new DeviceMgr(std::move(devices));
master_env_.local_devices = worker_env_.device_mgr->ListDevices();
worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
? new RpcRendezvousMgr(&worker_env_) ? new RpcRendezvousMgr(&worker_env_)
: rendezvous_mgr_func(&worker_env_); : rendezvous_mgr_func(&worker_env_);

View File

@ -42,8 +42,9 @@ class RpcCollectiveExecutorMgrTest : public ::testing::Test {
WorkerCacheInterface* worker_cache = nullptr; WorkerCacheInterface* worker_cache = nullptr;
auto* device_count = options.config.mutable_device_count(); auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", NUM_DEVS}); device_count->insert({"CPU", NUM_DEVS});
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_)); std::vector<std::unique_ptr<Device>> devices;
device_mgr_.reset(new DeviceMgr(devices_)); TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed( std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed(
device_mgr_.get(), worker_cache, task_name)); device_mgr_.get(), worker_cache, task_name));
std::unique_ptr<CollectiveParamResolverDistributed> cpr( std::unique_ptr<CollectiveParamResolverDistributed> cpr(
@ -57,7 +58,6 @@ class RpcCollectiveExecutorMgrTest : public ::testing::Test {
} }
std::unique_ptr<RpcCollectiveExecutorMgr> cme_; std::unique_ptr<RpcCollectiveExecutorMgr> cme_;
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
}; };

View File

@ -78,13 +78,13 @@ Status SessionMgr::CreateSession(const string& session,
if (isolate_session_state) { if (isolate_session_state) {
// Create a private copy of the DeviceMgr for the WorkerSession. // Create a private copy of the DeviceMgr for the WorkerSession.
std::vector<Device*> renamed_devices; std::vector<std::unique_ptr<Device>> renamed_devices;
for (Device* d : worker_env_->local_devices) { for (Device* d : worker_env_->local_devices) {
renamed_devices.push_back(RenamedDevice::NewRenamedDevice( renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
worker_name, d, false, isolate_session_state)); worker_name, d, false, isolate_session_state));
} }
auto device_mgr = MakeUnique<DeviceMgr>(renamed_devices); auto device_mgr = MakeUnique<DeviceMgr>(std::move(renamed_devices));
auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get()); auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
worker_session.reset( worker_session.reset(
new WorkerSession(session, worker_name, new WorkerSession(session, worker_name,

View File

@ -46,11 +46,9 @@ class SessionMgrTest : public ::testing::Test {
SessionMgrTest() SessionMgrTest()
: mgr_(&env_, "/job:mnist/replica:0/task:0", : mgr_(&env_, "/job:mnist/replica:0/task:0",
std::unique_ptr<WorkerCacheInterface>(), factory_) { std::unique_ptr<WorkerCacheInterface>(), factory_) {
Device* device = device_mgr_ = absl::make_unique<DeviceMgr>(
FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0") FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0"));
.release(); env_.local_devices = device_mgr_->ListDevices();
env_.local_devices = {device};
device_mgr_.reset(new DeviceMgr(env_.local_devices));
env_.device_mgr = device_mgr_.get(); env_.device_mgr = device_mgr_.get();
} }

View File

@ -102,10 +102,11 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
} }
// Instantiate all variables for function library runtime creation. // Instantiate all variables for function library runtime creation.
std::vector<Device*> devices; std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices( TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices)); options, "/job:localhost/replica:0/task:0", &devices));
std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(devices)); Device* cpu_device = devices[0].get();
std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(std::move(devices)));
FunctionLibraryDefinition function_library(OpRegistry::Global(), FunctionLibraryDefinition function_library(OpRegistry::Global(),
graph_def.library()); graph_def.library());
Env* env = Env::Default(); Env* env = Env::Default();
@ -124,7 +125,7 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env,
graph_def.versions().producer(), graph_def.versions().producer(),
&function_library, *optimizer_opts)); &function_library, *optimizer_opts));
FunctionLibraryRuntime* flr = pflr->GetFLR(devices[0]->name()); FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());
// Create the GraphOptimizer to optimize the graph def. // Create the GraphOptimizer to optimize the graph def.
GraphConstructorOptions graph_ctor_opts; GraphConstructorOptions graph_ctor_opts;
@ -137,7 +138,7 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
// Optimize the graph. // Optimize the graph.
::tensorflow::GraphOptimizer optimizer(*optimizer_opts); ::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
optimizer.Optimize(flr, env, devices[0], &graphptr, /*shape_map=*/nullptr); optimizer.Optimize(flr, env, cpu_device, &graphptr, /*shape_map=*/nullptr);
graphptr->ToGraphDef(output_graph_def); graphptr->ToGraphDef(output_graph_def);
// The default values of attributes might have been stripped by the optimizer. // The default values of attributes might have been stripped by the optimizer.

View File

@ -142,7 +142,6 @@ cc_library(
":graph_optimizer", ":graph_optimizer",
"//tensorflow/core:core_cpu_base", "//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",
@ -150,6 +149,7 @@ cc_library(
"//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/utils:functions", "//tensorflow/core/grappler/utils:functions",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@ -16,7 +16,9 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/function_optimizer.h" #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/str_replace.h" #include "absl/strings/str_replace.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
@ -343,14 +345,15 @@ class FunctionOptimizerContext {
DeviceAttributes attr; DeviceAttributes attr;
attr.set_name("/device:CPU:0"); attr.set_name("/device:CPU:0");
attr.set_device_type("CPU"); attr.set_device_type("CPU");
Device* device = new FakeCPUDevice(env, attr); std::vector<std::unique_ptr<Device>> devices;
device_mgr_.reset(new DeviceMgr({device})); devices.push_back(absl::make_unique<FakeCPUDevice>(env, attr));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
OptimizerOptions optimizer_opts; OptimizerOptions optimizer_opts;
optimizer_opts.set_do_function_inlining(true); optimizer_opts.set_do_function_inlining(true);
process_flr_.reset(new ProcessFunctionLibraryRuntime( process_flr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), env, graph_version_, &function_library_, device_mgr_.get(), env, graph_version_, &function_library_,
optimizer_opts)); optimizer_opts));
flr_ = process_flr_->GetFLR(device->name()); flr_ = process_flr_->GetFLR(device_mgr_->ListDevices()[0]->name());
} }
} }

View File

@ -600,6 +600,7 @@ tf_kernel_library(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:session_options", "//tensorflow/core:session_options",
"//tensorflow/core/kernels:ops_util", "//tensorflow/core/kernels:ops_util",
"@com_google_absl//absl/memory",
], ],
) )

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/data/iterator_ops.h" #include "tensorflow/core/kernels/data/iterator_ops.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/common_runtime/threadpool_device.h" #include "tensorflow/core/common_runtime/threadpool_device.h"
@ -545,10 +546,9 @@ FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
// in its resource manager. The existing device will outlive the // in its resource manager. The existing device will outlive the
// IteratorResource, because we are storing the IteratorResource // IteratorResource, because we are storing the IteratorResource
// in that device's resource manager. // in that device's resource manager.
Device* wrapped_device = RenamedDevice::NewRenamedDevice( *device_mgr = absl::make_unique<DeviceMgr>(RenamedDevice::NewRenamedDevice(
ctx->device()->name(), down_cast<Device*>(ctx->device()), ctx->device()->name(), down_cast<Device*>(ctx->device()),
false /* owns_underlying */, false /* isolate_session_state */); false /* owns_underlying */, false /* isolate_session_state */));
device_mgr->reset(new DeviceMgr({wrapped_device}));
flib_def->reset(new FunctionLibraryDefinition( flib_def->reset(new FunctionLibraryDefinition(
*ctx->function_library()->GetFunctionLibraryDefinition())); *ctx->function_library()->GetFunctionLibraryDefinition()));
pflr->reset(new ProcessFunctionLibraryRuntime( pflr->reset(new ProcessFunctionLibraryRuntime(

View File

@ -51,17 +51,17 @@ class ExecutorTest : public ::testing::Test {
// when the test completes. // when the test completes.
CHECK(rendez_->Unref()); CHECK(rendez_->Unref());
delete exec_; delete exec_;
delete device_;
} }
// Resets executor_ with a new executor based on a graph 'gdef'. // Resets executor_ with a new executor based on a graph 'gdef'.
void Create(std::unique_ptr<const Graph> graph) { void Create(std::unique_ptr<const Graph> graph) {
const int version = graph->versions().producer(); const int version = graph->versions().producer();
LocalExecutorParams params; LocalExecutorParams params;
params.device = device_; params.device = device_.get();
params.create_kernel = [this, version](const NodeDef& ndef, params.create_kernel = [this, version](const NodeDef& ndef,
OpKernel** kernel) { OpKernel** kernel) {
return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel); return CreateNonCachedKernel(device_.get(), nullptr, ndef, version,
kernel);
}; };
params.delete_kernel = [](OpKernel* kernel) { params.delete_kernel = [](OpKernel* kernel) {
DeleteNonCachedKernel(kernel); DeleteNonCachedKernel(kernel);
@ -86,7 +86,7 @@ class ExecutorTest : public ::testing::Test {
return exec_->Run(args); return exec_->Run(args);
} }
Device* device_ = nullptr; std::unique_ptr<Device> device_;
Executor* exec_ = nullptr; Executor* exec_ = nullptr;
Executor::Args::Runner runner_; Executor::Args::Runner runner_;
Rendezvous* rendez_ = nullptr; Rendezvous* rendez_ = nullptr;

View File

@ -116,6 +116,7 @@ cc_library(
hdrs = ["delegate_data.h"], hdrs = ["delegate_data.h"],
deps = [ deps = [
":buffer_map", ":buffer_map",
"@com_google_absl//absl/memory",
"//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:context",
] + select({ ] + select({
"//tensorflow:android": [ "//tensorflow:android": [

View File

@ -14,20 +14,21 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/delegates/flex/delegate_data.h" #include "tensorflow/lite/delegates/flex/delegate_data.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
namespace tflite { namespace tflite {
namespace flex { namespace flex {
tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) { tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) {
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0", tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0",
&devices)); &devices));
std::unique_ptr<tensorflow::DeviceMgr> device_mgr( std::unique_ptr<tensorflow::DeviceMgr> device_mgr =
new tensorflow::DeviceMgr(devices)); absl::make_unique<tensorflow::DeviceMgr>(std::move(devices));
// Note that Rendezvous is ref-counted so it will be automatically deleted. // Note that Rendezvous is ref-counted so it will be automatically deleted.
tensorflow::Rendezvous* rendezvous = tensorflow::Rendezvous* rendezvous =
new tensorflow::IntraProcessRendezvous(device_mgr.get()); new tensorflow::IntraProcessRendezvous(device_mgr.get());

View File

@ -2012,13 +2012,13 @@ bool InlineAllFunctions(GraphDef* graphdef) {
tensorflow::SessionOptions options; tensorflow::SessionOptions options;
auto* device_count = options.config.mutable_device_count(); auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 1}); device_count->insert({"CPU", 1});
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices( TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices)); options, "/job:localhost/replica:0/task:0", &devices));
tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(), tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
graphdef_copy.library()); graphdef_copy.library());
tensorflow::DeviceMgr device_mgr(devices); tensorflow::DeviceMgr device_mgr(std::move(devices));
tensorflow::OptimizerOptions o_opts; tensorflow::OptimizerOptions o_opts;
tensorflow::ProcessFunctionLibraryRuntime pflr( tensorflow::ProcessFunctionLibraryRuntime pflr(
&device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld, &device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,

View File

@ -48,17 +48,14 @@ static std::vector<string> ListDevicesWithSessionConfig(
std::vector<string> output; std::vector<string> output;
SessionOptions options; SessionOptions options;
options.config = config; options.config = config;
std::vector<Device*> devices; std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::AddDevices( Status status = DeviceFactory::AddDevices(
options, "" /* name_prefix */, &devices); options, "" /* name_prefix */, &devices);
if (!status.ok()) { if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status); Set_TF_Status_from_Status(out_status, status);
} }
std::vector<std::unique_ptr<Device>> device_holder(devices.begin(), for (const std::unique_ptr<Device>& device : devices) {
devices.end());
for (const Device* device : devices) {
const DeviceAttributes& attr = device->attributes(); const DeviceAttributes& attr = device->attributes();
string attr_serialized; string attr_serialized;
if (!attr.SerializeToString(&attr_serialized)) { if (!attr.SerializeToString(&attr_serialized)) {

View File

@ -74,13 +74,13 @@ limitations under the License.
void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* device_map) { void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* device_map) {
tensorflow::SessionOptions options; tensorflow::SessionOptions options;
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<tensorflow::Device>> devices;
tensorflow::Status status = tensorflow::DeviceFactory::AddDevices(options, "", &devices); tensorflow::Status status = tensorflow::DeviceFactory::AddDevices(options, "", &devices);
if (!status.ok()) { if (!status.ok()) {
return; return;
} }
for (const tensorflow::Device* device : devices) { for (const std::unique_ptr<tensorflow::Device>& device : devices) {
tensorflow::DeviceProperties& prop = (*device_map)[device->name()]; tensorflow::DeviceProperties& prop = (*device_map)[device->name()];
prop = tensorflow::grappler::GetDeviceInfo(device->parsed_name()); prop = tensorflow::grappler::GetDeviceInfo(device->parsed_name());
@ -88,7 +88,6 @@ void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* dev
// available device memory. // available device memory.
const tensorflow::DeviceAttributes& attr = device->attributes(); const tensorflow::DeviceAttributes& attr = device->attributes();
prop.set_memory_size(attr.memory_limit()); prop.set_memory_size(attr.memory_limit());
delete device;
} }
} }