[Core]: Use unique_ptr in DeviceMgr
In order to take advantage of the type system to help enforce ownership, this change refactors DeviceMgr to use std::unique_ptr<Device> instead of Device*'s. It also updates all callers to use the new types. PiperOrigin-RevId: 222645861
This commit is contained in:
parent
a1532717be
commit
809ed3c835
@ -50,6 +50,7 @@ tf_cuda_library(
|
|||||||
],
|
],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
}) + [
|
}) + [
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
@ -80,7 +81,7 @@ tensorflow::Status GetAllRemoteDevices(
|
|||||||
const std::vector<string>& remote_workers,
|
const std::vector<string>& remote_workers,
|
||||||
tensorflow::WorkerCacheInterface* worker_cache,
|
tensorflow::WorkerCacheInterface* worker_cache,
|
||||||
std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
|
std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
|
||||||
std::vector<tensorflow::Device*> remote_devices;
|
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
|
||||||
tensorflow::Status status;
|
tensorflow::Status status;
|
||||||
// TODO(nareshmodi) do this in parallel instead of serially.
|
// TODO(nareshmodi) do this in parallel instead of serially.
|
||||||
for (const string& remote_worker : remote_workers) {
|
for (const string& remote_worker : remote_workers) {
|
||||||
@ -93,7 +94,7 @@ tensorflow::Status GetAllRemoteDevices(
|
|||||||
status = s;
|
status = s;
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
for (tensorflow::Device* d : *devices) {
|
for (tensorflow::Device* d : *devices) {
|
||||||
remote_devices.push_back(d);
|
remote_devices.emplace_back(d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
n.Notify();
|
n.Notify();
|
||||||
@ -101,7 +102,7 @@ tensorflow::Status GetAllRemoteDevices(
|
|||||||
n.WaitForNotification();
|
n.WaitForNotification();
|
||||||
}
|
}
|
||||||
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
|
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
|
||||||
new tensorflow::DeviceMgr(remote_devices));
|
new tensorflow::DeviceMgr(std::move(remote_devices)));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(status);
|
TF_RETURN_IF_ERROR(status);
|
||||||
|
|
||||||
@ -262,13 +263,13 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
|
|||||||
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||||
|
|
||||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||||
status->status = tensorflow::DeviceFactory::AddDevices(
|
status->status = tensorflow::DeviceFactory::AddDevices(
|
||||||
opts->session_options.options, "/job:localhost/replica:0/task:0",
|
opts->session_options.options, "/job:localhost/replica:0/task:0",
|
||||||
&devices);
|
&devices);
|
||||||
if (!status->status.ok()) return nullptr;
|
if (!status->status.ok()) return nullptr;
|
||||||
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
|
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
|
||||||
new tensorflow::DeviceMgr(devices));
|
new tensorflow::DeviceMgr(std::move(devices)));
|
||||||
|
|
||||||
tensorflow::Rendezvous* r =
|
tensorflow::Rendezvous* r =
|
||||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||||
|
@ -42,14 +42,8 @@ class BuildXlaOpsTest : public ::testing::Test {
|
|||||||
.ok());
|
.ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
void TearDown() override {
|
|
||||||
for (Device* device : devices_) {
|
|
||||||
delete device;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<Device*> devices_;
|
std::vector<std::unique_ptr<Device>> devices_;
|
||||||
};
|
};
|
||||||
|
|
||||||
using ::tensorflow::testing::FindNodeByName;
|
using ::tensorflow::testing::FindNodeByName;
|
||||||
|
@ -59,8 +59,9 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
|
|||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
auto* device_count = options.config.mutable_device_count();
|
auto* device_count = options.config.mutable_device_count();
|
||||||
device_count->insert({"CPU", 1});
|
device_count->insert({"CPU", 1});
|
||||||
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(
|
TF_CHECK_OK(DeviceFactory::AddDevices(
|
||||||
options, "/job:localhost/replica:0/task:0", &devices_));
|
options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
|
|
||||||
FunctionDefLibrary proto;
|
FunctionDefLibrary proto;
|
||||||
for (const auto& fdef : flib) {
|
for (const auto& fdef : flib) {
|
||||||
@ -69,7 +70,7 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
|
|||||||
lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
|
lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
|
||||||
OpRegistry::Global(), proto);
|
OpRegistry::Global(), proto);
|
||||||
OptimizerOptions opts;
|
OptimizerOptions opts;
|
||||||
device_mgr_ = absl::make_unique<DeviceMgr>(devices_);
|
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||||
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||||
@ -77,7 +78,6 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
FunctionLibraryRuntime* flr_;
|
FunctionLibraryRuntime* flr_;
|
||||||
std::vector<Device*> devices_;
|
|
||||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
||||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
||||||
|
@ -34,15 +34,9 @@ namespace tensorflow {
|
|||||||
//
|
//
|
||||||
// It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
|
// It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
|
||||||
// make this more direct, but probably not worth it solely for this test.
|
// make this more direct, but probably not worth it solely for this test.
|
||||||
std::vector<Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices));
|
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices));
|
||||||
|
|
||||||
auto delete_devices = gtl::MakeCleanup([&] {
|
|
||||||
for (Device* d : devices) {
|
|
||||||
delete d;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
GraphOptimizationPassOptions opt_options;
|
GraphOptimizationPassOptions opt_options;
|
||||||
opt_options.graph = graph;
|
opt_options.graph = graph;
|
||||||
opt_options.session_options = session_options;
|
opt_options.session_options = session_options;
|
||||||
|
@ -386,7 +386,7 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
|
|||||||
TF_ASSERT_OK(s.ToGraph(graph.get()));
|
TF_ASSERT_OK(s.ToGraph(graph.get()));
|
||||||
|
|
||||||
// This is needed to register the XLA_GPU device.
|
// This is needed to register the XLA_GPU device.
|
||||||
std::vector<Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_ASSERT_OK(DeviceFactory::AddDevices(
|
TF_ASSERT_OK(DeviceFactory::AddDevices(
|
||||||
SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
|
SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
|
||||||
|
|
||||||
@ -400,10 +400,6 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
|
|||||||
TF_ASSERT_OK(PartiallyDecluster(&graph));
|
TF_ASSERT_OK(PartiallyDecluster(&graph));
|
||||||
|
|
||||||
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
|
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
|
||||||
|
|
||||||
for (Device* d : devices) {
|
|
||||||
delete d;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {
|
TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {
|
||||||
|
@ -31,12 +31,12 @@ namespace tensorflow {
|
|||||||
class XlaCpuDeviceFactory : public DeviceFactory {
|
class XlaCpuDeviceFactory : public DeviceFactory {
|
||||||
public:
|
public:
|
||||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) override;
|
std::vector<std::unique_ptr<Device>>* devices) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
Status XlaCpuDeviceFactory::CreateDevices(
|
||||||
const string& name_prefix,
|
const SessionOptions& session_options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) {
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||||
bool compile_on_demand = flags->tf_xla_compile_on_demand;
|
bool compile_on_demand = flags->tf_xla_compile_on_demand;
|
||||||
|
|
||||||
@ -63,8 +63,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
|||||||
options.device_ordinal = 0;
|
options.device_ordinal = 0;
|
||||||
options.compilation_device_name = DEVICE_CPU_XLA_JIT;
|
options.compilation_device_name = DEVICE_CPU_XLA_JIT;
|
||||||
options.use_multiple_streams = false;
|
options.use_multiple_streams = false;
|
||||||
auto device = absl::make_unique<XlaDevice>(session_options, options);
|
devices->push_back(absl::make_unique<XlaDevice>(session_options, options));
|
||||||
devices->push_back(device.release());
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,12 +29,12 @@ namespace tensorflow {
|
|||||||
class XlaGpuDeviceFactory : public DeviceFactory {
|
class XlaGpuDeviceFactory : public DeviceFactory {
|
||||||
public:
|
public:
|
||||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) override;
|
std::vector<std::unique_ptr<Device>>* devices) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
Status XlaGpuDeviceFactory::CreateDevices(
|
||||||
const string& name_prefix,
|
const SessionOptions& session_options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) {
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
XlaOpRegistry::DeviceRegistration registration;
|
XlaOpRegistry::DeviceRegistration registration;
|
||||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||||
registration.autoclustering_policy =
|
registration.autoclustering_policy =
|
||||||
@ -70,7 +70,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
devices->push_back(device.release());
|
devices->push_back(std::move(device));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -33,12 +33,12 @@ constexpr std::array<DataType, 9> kExecAllTypes = {
|
|||||||
class XlaInterpreterDeviceFactory : public DeviceFactory {
|
class XlaInterpreterDeviceFactory : public DeviceFactory {
|
||||||
public:
|
public:
|
||||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) override;
|
std::vector<std::unique_ptr<Device>>* devices) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status XlaInterpreterDeviceFactory::CreateDevices(
|
Status XlaInterpreterDeviceFactory::CreateDevices(
|
||||||
const SessionOptions& session_options, const string& name_prefix,
|
const SessionOptions& session_options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) {
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(
|
static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(
|
||||||
DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT);
|
DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT);
|
||||||
(void)registrations;
|
(void)registrations;
|
||||||
@ -61,8 +61,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
|
|||||||
options.device_ordinal = 0;
|
options.device_ordinal = 0;
|
||||||
options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
|
options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
|
||||||
options.use_multiple_streams = false;
|
options.use_multiple_streams = false;
|
||||||
auto device = absl::make_unique<XlaDevice>(session_options, options);
|
devices->push_back(absl::make_unique<XlaDevice>(session_options, options));
|
||||||
devices->push_back(device.release());
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -380,7 +380,7 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
|||||||
initialization_status_(Status::OK()),
|
initialization_status_(Status::OK()),
|
||||||
next_step_id_(1),
|
next_step_id_(1),
|
||||||
device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
|
device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
|
||||||
device_mgr_({device_}) {
|
device_mgr_(absl::WrapUnique(device_)) {
|
||||||
CHECK(!options_.device_type.type_string().empty());
|
CHECK(!options_.device_type.type_string().empty());
|
||||||
if (options_.populate_resource_manager) {
|
if (options_.populate_resource_manager) {
|
||||||
initialization_status_ =
|
initialization_status_ =
|
||||||
|
@ -2963,6 +2963,7 @@ tf_cuda_library(
|
|||||||
":lib_internal",
|
":lib_internal",
|
||||||
":proto_text",
|
":proto_text",
|
||||||
":protos_all_cc",
|
":protos_all_cc",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
] + mkl_deps(),
|
] + mkl_deps(),
|
||||||
@ -3816,6 +3817,7 @@ tf_cc_tests_gpu(
|
|||||||
":test",
|
":test",
|
||||||
":test_main",
|
":test_main",
|
||||||
":testlib",
|
":testlib",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3844,6 +3846,7 @@ tf_cc_tests_gpu(
|
|||||||
":test",
|
":test",
|
||||||
":test_main",
|
":test_main",
|
||||||
":testlib",
|
":testlib",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -4411,6 +4414,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core/kernels:random_ops",
|
"//tensorflow/core/kernels:random_ops",
|
||||||
"//tensorflow/core/kernels:shape_ops",
|
"//tensorflow/core/kernels:shape_ops",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -38,8 +38,9 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
|
|||||||
auto* device_count = options.config.mutable_device_count();
|
auto* device_count = options.config.mutable_device_count();
|
||||||
string task_name = "/job:localhost/replica:0/task:0";
|
string task_name = "/job:localhost/replica:0/task:0";
|
||||||
device_count->insert({"CPU", NUM_DEVS});
|
device_count->insert({"CPU", NUM_DEVS});
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
device_mgr_.reset(new DeviceMgr(devices_));
|
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
||||||
|
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
||||||
std::unique_ptr<DeviceResolverInterface> drl(
|
std::unique_ptr<DeviceResolverInterface> drl(
|
||||||
new DeviceResolverLocal(device_mgr_.get()));
|
new DeviceResolverLocal(device_mgr_.get()));
|
||||||
std::unique_ptr<ParamResolverInterface> prl(
|
std::unique_ptr<ParamResolverInterface> prl(
|
||||||
@ -50,7 +51,6 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<CollectiveExecutorMgr> cme_;
|
std::unique_ptr<CollectiveExecutorMgr> cme_;
|
||||||
std::vector<Device*> devices_;
|
|
||||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -37,8 +37,9 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
|
|||||||
string task_name = "/job:localhost/replica:0/task:0";
|
string task_name = "/job:localhost/replica:0/task:0";
|
||||||
auto* device_count = options.config.mutable_device_count();
|
auto* device_count = options.config.mutable_device_count();
|
||||||
device_count->insert({"CPU", NUM_DEVS});
|
device_count->insert({"CPU", NUM_DEVS});
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
device_mgr_.reset(new DeviceMgr(devices_));
|
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
||||||
|
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
||||||
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
|
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
|
||||||
prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
|
prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
|
||||||
task_name));
|
task_name));
|
||||||
@ -73,7 +74,6 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Device*> devices_;
|
|
||||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
std::unique_ptr<DeviceResolverLocal> drl_;
|
std::unique_ptr<DeviceResolverLocal> drl_;
|
||||||
std::unique_ptr<CollectiveParamResolverLocal> prl_;
|
std::unique_ptr<CollectiveParamResolverLocal> prl_;
|
||||||
|
@ -42,8 +42,9 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
|
|||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
auto* device_count = options.config.mutable_device_count();
|
auto* device_count = options.config.mutable_device_count();
|
||||||
device_count->insert({"CPU", NUM_DEVS});
|
device_count->insert({"CPU", NUM_DEVS});
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices_));
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
device_mgr_.reset(new DeviceMgr(devices_));
|
TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices));
|
||||||
|
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
||||||
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
|
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
|
||||||
prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
|
prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
|
||||||
kTaskName));
|
kTaskName));
|
||||||
@ -51,7 +52,6 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
|
|||||||
kStepId));
|
kStepId));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Device*> devices_;
|
|
||||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
std::unique_ptr<DeviceResolverLocal> drl_;
|
std::unique_ptr<DeviceResolverLocal> drl_;
|
||||||
std::unique_ptr<CollectiveParamResolverLocal> prl_;
|
std::unique_ptr<CollectiveParamResolverLocal> prl_;
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -89,9 +90,9 @@ DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
|
|||||||
return it->second.factory.get();
|
return it->second.factory.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DeviceFactory::AddDevices(const SessionOptions& options,
|
Status DeviceFactory::AddDevices(
|
||||||
const string& name_prefix,
|
const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) {
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
// CPU first. A CPU device is required.
|
// CPU first. A CPU device is required.
|
||||||
auto cpu_factory = GetFactory("CPU");
|
auto cpu_factory = GetFactory("CPU");
|
||||||
if (!cpu_factory) {
|
if (!cpu_factory) {
|
||||||
@ -116,16 +117,16 @@ Status DeviceFactory::AddDevices(const SessionOptions& options,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Device* DeviceFactory::NewDevice(const string& type,
|
std::unique_ptr<Device> DeviceFactory::NewDevice(const string& type,
|
||||||
const SessionOptions& options,
|
const SessionOptions& options,
|
||||||
const string& name_prefix) {
|
const string& name_prefix) {
|
||||||
auto device_factory = GetFactory(type);
|
auto device_factory = GetFactory(type);
|
||||||
if (!device_factory) {
|
if (!device_factory) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
SessionOptions opt = options;
|
SessionOptions opt = options;
|
||||||
(*opt.config.mutable_device_count())[type] = 1;
|
(*opt.config.mutable_device_count())[type] = 1;
|
||||||
std::vector<Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(device_factory->CreateDevices(opt, name_prefix, &devices));
|
TF_CHECK_OK(device_factory->CreateDevices(opt, name_prefix, &devices));
|
||||||
int expected_num_devices = 1;
|
int expected_num_devices = 1;
|
||||||
auto iter = options.config.device_count().find(type);
|
auto iter = options.config.device_count().find(type);
|
||||||
@ -133,7 +134,7 @@ Device* DeviceFactory::NewDevice(const string& type,
|
|||||||
expected_num_devices = iter->second;
|
expected_num_devices = iter->second;
|
||||||
}
|
}
|
||||||
DCHECK_EQ(devices.size(), static_cast<size_t>(expected_num_devices));
|
DCHECK_EQ(devices.size(), static_cast<size_t>(expected_num_devices));
|
||||||
return devices[0];
|
return std::move(devices[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -40,18 +40,19 @@ class DeviceFactory {
|
|||||||
// CPU devices are added first.
|
// CPU devices are added first.
|
||||||
static Status AddDevices(const SessionOptions& options,
|
static Status AddDevices(const SessionOptions& options,
|
||||||
const string& name_prefix,
|
const string& name_prefix,
|
||||||
std::vector<Device*>* devices);
|
std::vector<std::unique_ptr<Device>>* devices);
|
||||||
|
|
||||||
// Helper for tests. Create a single device of type "type". The
|
// Helper for tests. Create a single device of type "type". The
|
||||||
// returned device is always numbered zero, so if creating multiple
|
// returned device is always numbered zero, so if creating multiple
|
||||||
// devices of the same type, supply distinct name_prefix arguments.
|
// devices of the same type, supply distinct name_prefix arguments.
|
||||||
static Device* NewDevice(const string& type, const SessionOptions& options,
|
static std::unique_ptr<Device> NewDevice(const string& type,
|
||||||
const string& name_prefix);
|
const SessionOptions& options,
|
||||||
|
const string& name_prefix);
|
||||||
|
|
||||||
// Most clients should call AddDevices() instead.
|
// Most clients should call AddDevices() instead.
|
||||||
virtual Status CreateDevices(const SessionOptions& options,
|
virtual Status CreateDevices(
|
||||||
const string& name_prefix,
|
const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) = 0;
|
std::vector<std::unique_ptr<Device>>* devices) = 0;
|
||||||
|
|
||||||
// Return the device priority number for a "device_type" string.
|
// Return the device priority number for a "device_type" string.
|
||||||
//
|
//
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/common_runtime/local_device.h"
|
#include "tensorflow/core/common_runtime/local_device.h"
|
||||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||||
@ -24,32 +25,32 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
|
DeviceMgr::DeviceMgr(std::vector<std::unique_ptr<Device>> devices)
|
||||||
: name_backing_store_(128) {
|
: devices_(std::move(devices)), name_backing_store_(128) {
|
||||||
for (Device* d : devices) {
|
for (auto& d : devices_) {
|
||||||
CHECK(d->device_mgr_ == nullptr);
|
CHECK(d->device_mgr_ == nullptr);
|
||||||
d->device_mgr_ = this;
|
d->device_mgr_ = this;
|
||||||
|
|
||||||
devices_.push_back(d);
|
|
||||||
|
|
||||||
// Register under the (1) full name and (2) canonical name.
|
// Register under the (1) full name and (2) canonical name.
|
||||||
for (const string& name :
|
for (const string& name :
|
||||||
DeviceNameUtils::GetNamesForDeviceMappings(d->parsed_name())) {
|
DeviceNameUtils::GetNamesForDeviceMappings(d->parsed_name())) {
|
||||||
device_map_[CopyToBackingStore(name)] = d;
|
device_map_[CopyToBackingStore(name)] = d.get();
|
||||||
}
|
}
|
||||||
// Register under the (3) local name and (4) legacy local name.
|
// Register under the (3) local name and (4) legacy local name.
|
||||||
for (const string& name :
|
for (const string& name :
|
||||||
DeviceNameUtils::GetLocalNamesForDeviceMappings(d->parsed_name())) {
|
DeviceNameUtils::GetLocalNamesForDeviceMappings(d->parsed_name())) {
|
||||||
device_map_[CopyToBackingStore(name)] = d;
|
device_map_[CopyToBackingStore(name)] = d.get();
|
||||||
}
|
}
|
||||||
device_type_counts_[d->device_type()]++;
|
device_type_counts_[d->device_type()]++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceMgr::~DeviceMgr() {
|
DeviceMgr::DeviceMgr(std::unique_ptr<Device> device)
|
||||||
// TODO(b/37437134): Remove destructor after converting to std::unique_ptr.
|
: DeviceMgr([&device] {
|
||||||
for (Device* p : devices_) delete p;
|
std::vector<std::unique_ptr<Device>> vector;
|
||||||
}
|
vector.push_back(std::move(device));
|
||||||
|
return vector;
|
||||||
|
}()) {}
|
||||||
|
|
||||||
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
|
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
|
||||||
size_t n = s.size();
|
size_t n = s.size();
|
||||||
@ -61,18 +62,22 @@ StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
|
|||||||
void DeviceMgr::ListDeviceAttributes(
|
void DeviceMgr::ListDeviceAttributes(
|
||||||
std::vector<DeviceAttributes>* devices) const {
|
std::vector<DeviceAttributes>* devices) const {
|
||||||
devices->reserve(devices_.size());
|
devices->reserve(devices_.size());
|
||||||
for (Device* dev : devices_) {
|
for (const auto& dev : devices_) {
|
||||||
devices->emplace_back(dev->attributes());
|
devices->emplace_back(dev->attributes());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Device*> DeviceMgr::ListDevices() const {
|
std::vector<Device*> DeviceMgr::ListDevices() const {
|
||||||
return std::vector<Device*>(devices_.begin(), devices_.end());
|
std::vector<Device*> devices(devices_.size());
|
||||||
|
for (size_t i = 0; i < devices_.size(); ++i) {
|
||||||
|
devices[i] = devices_[i].get();
|
||||||
|
}
|
||||||
|
return devices;
|
||||||
}
|
}
|
||||||
|
|
||||||
string DeviceMgr::DebugString() const {
|
string DeviceMgr::DebugString() const {
|
||||||
string out;
|
string out;
|
||||||
for (Device* dev : devices_) {
|
for (const auto& dev : devices_) {
|
||||||
strings::StrAppend(&out, dev->name(), "\n");
|
strings::StrAppend(&out, dev->name(), "\n");
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
@ -80,7 +85,7 @@ string DeviceMgr::DebugString() const {
|
|||||||
|
|
||||||
string DeviceMgr::DeviceMappingString() const {
|
string DeviceMgr::DeviceMappingString() const {
|
||||||
string out;
|
string out;
|
||||||
for (Device* dev : devices_) {
|
for (const auto& dev : devices_) {
|
||||||
if (!dev->attributes().physical_device_desc().empty()) {
|
if (!dev->attributes().physical_device_desc().empty()) {
|
||||||
strings::StrAppend(&out, dev->name(), " -> ",
|
strings::StrAppend(&out, dev->name(), " -> ",
|
||||||
dev->attributes().physical_device_desc(), "\n");
|
dev->attributes().physical_device_desc(), "\n");
|
||||||
@ -107,7 +112,7 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
|
|||||||
|
|
||||||
void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
|
void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
|
||||||
Status s;
|
Status s;
|
||||||
for (Device* dev : devices_) {
|
for (const auto& dev : devices_) {
|
||||||
if (containers.empty()) {
|
if (containers.empty()) {
|
||||||
s.Update(dev->resource_manager()->Cleanup(
|
s.Update(dev->resource_manager()->Cleanup(
|
||||||
dev->resource_manager()->default_container()));
|
dev->resource_manager()->default_container()));
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
|
||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
@ -34,15 +35,17 @@ class DeviceAttributes;
|
|||||||
|
|
||||||
class DeviceMgr {
|
class DeviceMgr {
|
||||||
public:
|
public:
|
||||||
// Takes ownership of each device in 'devices'.
|
// Constructs a DeviceMgr from a list of devices.
|
||||||
// TODO(zhifengc): Other initialization information.
|
// TODO(zhifengc): Other initialization information.
|
||||||
// TODO(b/37437134): Use std::unique_ptr's to track ownership.
|
explicit DeviceMgr(std::vector<std::unique_ptr<Device>> devices);
|
||||||
explicit DeviceMgr(const std::vector<Device*>& devices);
|
|
||||||
~DeviceMgr();
|
// Constructs a DeviceMgr managing a single device.
|
||||||
|
explicit DeviceMgr(std::unique_ptr<Device> device);
|
||||||
|
|
||||||
// Returns attributes of all devices.
|
// Returns attributes of all devices.
|
||||||
void ListDeviceAttributes(std::vector<DeviceAttributes>* devices) const;
|
void ListDeviceAttributes(std::vector<DeviceAttributes>* devices) const;
|
||||||
|
|
||||||
|
// Returns raw pointers to the underlying devices.
|
||||||
std::vector<Device*> ListDevices() const;
|
std::vector<Device*> ListDevices() const;
|
||||||
|
|
||||||
// Returns a string listing all devices.
|
// Returns a string listing all devices.
|
||||||
@ -62,9 +65,7 @@ class DeviceMgr {
|
|||||||
int NumDeviceType(const string& type) const;
|
int NumDeviceType(const string& type) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// TODO(b/37437134): Use std::unique_ptr's to track ownership.
|
const std::vector<std::unique_ptr<Device>> devices_;
|
||||||
typedef gtl::InlinedVector<Device*, 8> DeviceVec;
|
|
||||||
DeviceVec devices_;
|
|
||||||
|
|
||||||
StringPiece CopyToBackingStore(StringPiece s);
|
StringPiece CopyToBackingStore(StringPiece s);
|
||||||
|
|
||||||
|
@ -36,12 +36,12 @@ class DeviceResolverLocalTest : public ::testing::Test {
|
|||||||
string task_name = "/job:localhost/replica:0/task:0";
|
string task_name = "/job:localhost/replica:0/task:0";
|
||||||
auto* device_count = options.config.mutable_device_count();
|
auto* device_count = options.config.mutable_device_count();
|
||||||
device_count->insert({"CPU", NUM_DEVS});
|
device_count->insert({"CPU", NUM_DEVS});
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
device_mgr_.reset(new DeviceMgr(devices_));
|
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
||||||
|
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
||||||
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
|
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Device*> devices_;
|
|
||||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
std::unique_ptr<DeviceResolverLocal> drl_;
|
std::unique_ptr<DeviceResolverLocal> drl_;
|
||||||
};
|
};
|
||||||
|
@ -57,7 +57,7 @@ class DeviceSetTest : public ::testing::Test {
|
|||||||
class DummyFactory : public DeviceFactory {
|
class DummyFactory : public DeviceFactory {
|
||||||
public:
|
public:
|
||||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) override {
|
std::vector<std::unique_ptr<Device>>* devices) override {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -155,12 +155,12 @@ class DirectSessionFactory : public SessionFactory {
|
|||||||
if (options.config.graph_options().build_cost_model() > 0) {
|
if (options.config.graph_options().build_cost_model() > 0) {
|
||||||
EnableCPUAllocatorFullStats(true);
|
EnableCPUAllocatorFullStats(true);
|
||||||
}
|
}
|
||||||
std::vector<Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
|
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
|
||||||
options, "/job:localhost/replica:0/task:0", &devices));
|
options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
|
|
||||||
DirectSession* session =
|
DirectSession* session =
|
||||||
new DirectSession(options, new DeviceMgr(devices), this);
|
new DirectSession(options, new DeviceMgr(std::move(devices)), this);
|
||||||
{
|
{
|
||||||
mutex_lock l(sessions_lock_);
|
mutex_lock l(sessions_lock_);
|
||||||
sessions_.push_back(session);
|
sessions_.push_back(session);
|
||||||
|
@ -181,6 +181,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/cc/client/client_session.h"
|
#include "tensorflow/cc/client/client_session.h"
|
||||||
#include "tensorflow/cc/framework/ops.h"
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
#include "tensorflow/cc/framework/scope.h"
|
#include "tensorflow/cc/framework/scope.h"
|
||||||
@ -37,12 +38,13 @@ namespace {
|
|||||||
class TestEnv {
|
class TestEnv {
|
||||||
public:
|
public:
|
||||||
TestEnv() : flib_def_(OpRegistry::Global(), {}) {
|
TestEnv() : flib_def_(OpRegistry::Global(), {}) {
|
||||||
Device* device =
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
|
devices.push_back(
|
||||||
device_mgr_.reset(new DeviceMgr({device}));
|
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
||||||
flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(),
|
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||||
device, TF_GRAPH_DEF_VERSION,
|
flib_runtime_ = NewFunctionLibraryRuntime(
|
||||||
&flib_def_, nullptr, {}, nullptr);
|
device_mgr_.get(), Env::Default(), device_mgr_->ListDevices()[0],
|
||||||
|
TF_GRAPH_DEF_VERSION, &flib_def_, nullptr, {}, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionLibraryRuntime* function_library_runtime() const {
|
FunctionLibraryRuntime* function_library_runtime() const {
|
||||||
|
@ -53,17 +53,17 @@ class ExecutorTest : public ::testing::Test {
|
|||||||
// when the test completes.
|
// when the test completes.
|
||||||
CHECK(rendez_->Unref());
|
CHECK(rendez_->Unref());
|
||||||
delete exec_;
|
delete exec_;
|
||||||
delete device_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resets executor_ with a new executor based on a graph 'gdef'.
|
// Resets executor_ with a new executor based on a graph 'gdef'.
|
||||||
void Create(std::unique_ptr<const Graph> graph) {
|
void Create(std::unique_ptr<const Graph> graph) {
|
||||||
const int version = graph->versions().producer();
|
const int version = graph->versions().producer();
|
||||||
LocalExecutorParams params;
|
LocalExecutorParams params;
|
||||||
params.device = device_;
|
params.device = device_.get();
|
||||||
params.create_kernel = [this, version](const NodeDef& ndef,
|
params.create_kernel = [this, version](const NodeDef& ndef,
|
||||||
OpKernel** kernel) {
|
OpKernel** kernel) {
|
||||||
return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
|
return CreateNonCachedKernel(device_.get(), nullptr, ndef, version,
|
||||||
|
kernel);
|
||||||
};
|
};
|
||||||
params.delete_kernel = [](OpKernel* kernel) {
|
params.delete_kernel = [](OpKernel* kernel) {
|
||||||
DeleteNonCachedKernel(kernel);
|
DeleteNonCachedKernel(kernel);
|
||||||
@ -83,7 +83,7 @@ class ExecutorTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
thread::ThreadPool* thread_pool_ = nullptr;
|
thread::ThreadPool* thread_pool_ = nullptr;
|
||||||
Device* device_ = nullptr;
|
std::unique_ptr<Device> device_;
|
||||||
Executor* exec_ = nullptr;
|
Executor* exec_ = nullptr;
|
||||||
StepStatsCollector step_stats_collector_;
|
StepStatsCollector step_stats_collector_;
|
||||||
StepStats step_stats_;
|
StepStats step_stats_;
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/numbers.h"
|
||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
#include "tensorflow/cc/ops/array_ops_internal.h"
|
#include "tensorflow/cc/ops/array_ops_internal.h"
|
||||||
@ -147,14 +148,15 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
auto* device_count = options.config.mutable_device_count();
|
auto* device_count = options.config.mutable_device_count();
|
||||||
device_count->insert({"CPU", 3});
|
device_count->insert({"CPU", 3});
|
||||||
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(
|
TF_CHECK_OK(DeviceFactory::AddDevices(
|
||||||
options, "/job:localhost/replica:0/task:0", &devices_));
|
options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
|
|
||||||
FunctionDefLibrary proto;
|
FunctionDefLibrary proto;
|
||||||
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
||||||
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
||||||
OptimizerOptions opts;
|
OptimizerOptions opts;
|
||||||
device_mgr_.reset(new DeviceMgr(devices_));
|
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||||
opts, default_thread_pool, nullptr /* cluster_flr */));
|
opts, default_thread_pool, nullptr /* cluster_flr */));
|
||||||
@ -358,7 +360,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
FunctionLibraryRuntime* flr0_;
|
FunctionLibraryRuntime* flr0_;
|
||||||
FunctionLibraryRuntime* flr1_;
|
FunctionLibraryRuntime* flr1_;
|
||||||
FunctionLibraryRuntime* flr2_;
|
FunctionLibraryRuntime* flr2_;
|
||||||
std::vector<Device*> devices_;
|
|
||||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
||||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
||||||
|
@ -54,14 +54,15 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
auto* device_count = options.config.mutable_device_count();
|
auto* device_count = options.config.mutable_device_count();
|
||||||
device_count->insert({"CPU", 3});
|
device_count->insert({"CPU", 3});
|
||||||
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(
|
TF_CHECK_OK(DeviceFactory::AddDevices(
|
||||||
options, "/job:localhost/replica:0/task:0", &devices_));
|
options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
|
|
||||||
FunctionDefLibrary proto;
|
FunctionDefLibrary proto;
|
||||||
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
||||||
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
||||||
OptimizerOptions opts;
|
OptimizerOptions opts;
|
||||||
device_mgr_.reset(new DeviceMgr(devices_));
|
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
||||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||||
opts, default_thread_pool, nullptr /* cluster_flr */));
|
opts, default_thread_pool, nullptr /* cluster_flr */));
|
||||||
@ -194,7 +195,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
FunctionLibraryRuntime* flr0_;
|
FunctionLibraryRuntime* flr0_;
|
||||||
FunctionLibraryRuntime* flr1_;
|
FunctionLibraryRuntime* flr1_;
|
||||||
FunctionLibraryRuntime* flr2_;
|
FunctionLibraryRuntime* flr2_;
|
||||||
std::vector<Device*> devices_;
|
|
||||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
||||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
||||||
|
@ -907,9 +907,9 @@ Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr,
|
|||||||
const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000;
|
const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000;
|
||||||
const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1;
|
const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1;
|
||||||
|
|
||||||
Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
|
Status BaseGPUDeviceFactory::CreateDevices(
|
||||||
const string& name_prefix,
|
const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) {
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
TF_RETURN_IF_ERROR(ValidateGPUMachineManager());
|
TF_RETURN_IF_ERROR(ValidateGPUMachineManager());
|
||||||
se::Platform* gpu_manager = GPUMachineManager();
|
se::Platform* gpu_manager = GPUMachineManager();
|
||||||
if (gpu_manager == nullptr) {
|
if (gpu_manager == nullptr) {
|
||||||
@ -1073,12 +1073,10 @@ static string GetShortDeviceDescription(PlatformGpuId platform_gpu_id,
|
|||||||
// LINT.ThenChange(//tensorflow/python/platform/test.py)
|
// LINT.ThenChange(//tensorflow/python/platform/test.py)
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
|
Status BaseGPUDeviceFactory::CreateGPUDevice(
|
||||||
const string& name_prefix,
|
const SessionOptions& options, const string& name_prefix, TfGpuId tf_gpu_id,
|
||||||
TfGpuId tf_gpu_id,
|
int64 memory_limit, const DeviceLocality& dev_locality,
|
||||||
int64 memory_limit,
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
const DeviceLocality& dev_locality,
|
|
||||||
std::vector<Device*>* devices) {
|
|
||||||
CHECK_GE(tf_gpu_id.value(), 0);
|
CHECK_GE(tf_gpu_id.value(), 0);
|
||||||
const string device_name =
|
const string device_name =
|
||||||
strings::StrCat(name_prefix, "/device:GPU:", tf_gpu_id.value());
|
strings::StrCat(name_prefix, "/device:GPU:", tf_gpu_id.value());
|
||||||
@ -1108,7 +1106,7 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
|
|||||||
// different (which should be an error).
|
// different (which should be an error).
|
||||||
//
|
//
|
||||||
// TODO(laigd): report error if memory_limit doesn't match stats.bytes_limit.
|
// TODO(laigd): report error if memory_limit doesn't match stats.bytes_limit.
|
||||||
BaseGPUDevice* gpu_device = CreateGPUDevice(
|
std::unique_ptr<BaseGPUDevice> gpu_device = CreateGPUDevice(
|
||||||
options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality,
|
options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality,
|
||||||
tf_gpu_id, GetShortDeviceDescription(platform_gpu_id, desc),
|
tf_gpu_id, GetShortDeviceDescription(platform_gpu_id, desc),
|
||||||
gpu_allocator, ProcessState::singleton()->GetCPUAllocator(numa_node));
|
gpu_allocator, ProcessState::singleton()->GetCPUAllocator(numa_node));
|
||||||
@ -1116,7 +1114,7 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
|
|||||||
<< (stats.bytes_limit >> 20) << " MB memory) -> physical GPU ("
|
<< (stats.bytes_limit >> 20) << " MB memory) -> physical GPU ("
|
||||||
<< GetShortDeviceDescription(platform_gpu_id, desc) << ")";
|
<< GetShortDeviceDescription(platform_gpu_id, desc) << ")";
|
||||||
TF_RETURN_IF_ERROR(gpu_device->Init(options));
|
TF_RETURN_IF_ERROR(gpu_device->Init(options));
|
||||||
devices->push_back(gpu_device);
|
devices->push_back(std::move(gpu_device));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -166,7 +166,7 @@ class BaseGPUDevice : public LocalDevice {
|
|||||||
class BaseGPUDeviceFactory : public DeviceFactory {
|
class BaseGPUDeviceFactory : public DeviceFactory {
|
||||||
public:
|
public:
|
||||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) override;
|
std::vector<std::unique_ptr<Device>>* devices) override;
|
||||||
|
|
||||||
struct InterconnectMap {
|
struct InterconnectMap {
|
||||||
// Name of interconnect technology, if known.
|
// Name of interconnect technology, if known.
|
||||||
@ -207,15 +207,13 @@ class BaseGPUDeviceFactory : public DeviceFactory {
|
|||||||
Status CreateGPUDevice(const SessionOptions& options,
|
Status CreateGPUDevice(const SessionOptions& options,
|
||||||
const string& name_prefix, TfGpuId tf_gpu_id,
|
const string& name_prefix, TfGpuId tf_gpu_id,
|
||||||
int64 memory_limit, const DeviceLocality& dev_locality,
|
int64 memory_limit, const DeviceLocality& dev_locality,
|
||||||
std::vector<Device*>* devices);
|
std::vector<std::unique_ptr<Device>>* devices);
|
||||||
|
|
||||||
virtual BaseGPUDevice* CreateGPUDevice(const SessionOptions& options,
|
virtual std::unique_ptr<BaseGPUDevice> CreateGPUDevice(
|
||||||
const string& name, Bytes memory_limit,
|
const SessionOptions& options, const string& name, Bytes memory_limit,
|
||||||
const DeviceLocality& dev_locality,
|
const DeviceLocality& dev_locality, TfGpuId tf_gpu_id,
|
||||||
TfGpuId tf_gpu_id,
|
const string& physical_device_desc, Allocator* gpu_allocator,
|
||||||
const string& physical_device_desc,
|
Allocator* cpu_allocator) = 0;
|
||||||
Allocator* gpu_allocator,
|
|
||||||
Allocator* cpu_allocator) = 0;
|
|
||||||
|
|
||||||
// Returns into 'ids' the list of valid platform GPU ids, in the order that
|
// Returns into 'ids' the list of valid platform GPU ids, in the order that
|
||||||
// they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc,
|
// they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc,
|
||||||
|
@ -59,15 +59,14 @@ class GPUDevice : public BaseGPUDevice {
|
|||||||
|
|
||||||
class GPUDeviceFactory : public BaseGPUDeviceFactory {
|
class GPUDeviceFactory : public BaseGPUDeviceFactory {
|
||||||
private:
|
private:
|
||||||
BaseGPUDevice* CreateGPUDevice(const SessionOptions& options,
|
std::unique_ptr<BaseGPUDevice> CreateGPUDevice(
|
||||||
const string& name, Bytes memory_limit,
|
const SessionOptions& options, const string& name, Bytes memory_limit,
|
||||||
const DeviceLocality& locality,
|
const DeviceLocality& locality, TfGpuId tf_gpu_id,
|
||||||
TfGpuId tf_gpu_id,
|
const string& physical_device_desc, Allocator* gpu_allocator,
|
||||||
const string& physical_device_desc,
|
Allocator* cpu_allocator) override {
|
||||||
Allocator* gpu_allocator,
|
return absl::make_unique<GPUDevice>(options, name, memory_limit, locality,
|
||||||
Allocator* cpu_allocator) override {
|
tf_gpu_id, physical_device_desc,
|
||||||
return new GPUDevice(options, name, memory_limit, locality, tf_gpu_id,
|
gpu_allocator, cpu_allocator);
|
||||||
physical_device_desc, gpu_allocator, cpu_allocator);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -108,7 +107,7 @@ class GPUCompatibleCPUDevice : public ThreadPoolDevice {
|
|||||||
class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
|
class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
|
||||||
public:
|
public:
|
||||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) override {
|
std::vector<std::unique_ptr<Device>>* devices) override {
|
||||||
int n = 1;
|
int n = 1;
|
||||||
auto iter = options.config.device_count().find("CPU");
|
auto iter = options.config.device_count().find("CPU");
|
||||||
if (iter != options.config.device_count().end()) {
|
if (iter != options.config.device_count().end()) {
|
||||||
@ -116,7 +115,7 @@ class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
|
|||||||
}
|
}
|
||||||
for (int i = 0; i < n; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
string name = strings::StrCat(name_prefix, "/device:CPU:", i);
|
string name = strings::StrCat(name_prefix, "/device:CPU:", i);
|
||||||
devices->push_back(new GPUCompatibleCPUDevice(
|
devices->push_back(absl::make_unique<GPUCompatibleCPUDevice>(
|
||||||
options, name, Bytes(256 << 20), DeviceLocality(), cpu_allocator()));
|
options, name, Bytes(256 << 20), DeviceLocality(), cpu_allocator()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ namespace {
|
|||||||
|
|
||||||
TEST(GPUDeviceOnNonGPUMachineTest, CreateGPUDevicesOnNonGPUMachine) {
|
TEST(GPUDeviceOnNonGPUMachineTest, CreateGPUDevicesOnNonGPUMachine) {
|
||||||
SessionOptions opts;
|
SessionOptions opts;
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||||
opts, "/job:localhost/replica:0/task:0", &devices));
|
opts, "/job:localhost/replica:0/task:0", &devices));
|
||||||
EXPECT_TRUE(devices.empty());
|
EXPECT_TRUE(devices.empty());
|
||||||
|
@ -88,7 +88,7 @@ class GPUDeviceTest : public ::testing::Test {
|
|||||||
|
|
||||||
TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) {
|
TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) {
|
||||||
SessionOptions opts = MakeSessionOptions("0,abc");
|
SessionOptions opts = MakeSessionOptions("0,abc");
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||||
opts, kDeviceNamePrefix, &devices);
|
opts, kDeviceNamePrefix, &devices);
|
||||||
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
||||||
@ -97,7 +97,7 @@ TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) {
|
|||||||
|
|
||||||
TEST_F(GPUDeviceTest, InvalidGpuId) {
|
TEST_F(GPUDeviceTest, InvalidGpuId) {
|
||||||
SessionOptions opts = MakeSessionOptions("100");
|
SessionOptions opts = MakeSessionOptions("100");
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||||
opts, kDeviceNamePrefix, &devices);
|
opts, kDeviceNamePrefix, &devices);
|
||||||
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
||||||
@ -107,7 +107,7 @@ TEST_F(GPUDeviceTest, InvalidGpuId) {
|
|||||||
|
|
||||||
TEST_F(GPUDeviceTest, DuplicateEntryInVisibleDeviceList) {
|
TEST_F(GPUDeviceTest, DuplicateEntryInVisibleDeviceList) {
|
||||||
SessionOptions opts = MakeSessionOptions("0,0");
|
SessionOptions opts = MakeSessionOptions("0,0");
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||||
opts, kDeviceNamePrefix, &devices);
|
opts, kDeviceNamePrefix, &devices);
|
||||||
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
||||||
@ -117,7 +117,7 @@ TEST_F(GPUDeviceTest, DuplicateEntryInVisibleDeviceList) {
|
|||||||
|
|
||||||
TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithMemoryFractionSettings) {
|
TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithMemoryFractionSettings) {
|
||||||
SessionOptions opts = MakeSessionOptions("0", 0.1, 1, {{}});
|
SessionOptions opts = MakeSessionOptions("0", 0.1, 1, {{}});
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||||
opts, kDeviceNamePrefix, &devices);
|
opts, kDeviceNamePrefix, &devices);
|
||||||
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
||||||
@ -129,7 +129,7 @@ TEST_F(GPUDeviceTest, GpuDeviceCountTooSmall) {
|
|||||||
// device_count is 0, but with one entry in visible_device_list and one
|
// device_count is 0, but with one entry in visible_device_list and one
|
||||||
// (empty) VirtualDevices messages.
|
// (empty) VirtualDevices messages.
|
||||||
SessionOptions opts = MakeSessionOptions("0", 0, 0, {{}});
|
SessionOptions opts = MakeSessionOptions("0", 0, 0, {{}});
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||||
opts, kDeviceNamePrefix, &devices);
|
opts, kDeviceNamePrefix, &devices);
|
||||||
EXPECT_EQ(status.code(), error::UNKNOWN);
|
EXPECT_EQ(status.code(), error::UNKNOWN);
|
||||||
@ -141,7 +141,7 @@ TEST_F(GPUDeviceTest, NotEnoughGpuInVisibleDeviceList) {
|
|||||||
// Single entry in visible_device_list with two (empty) VirtualDevices
|
// Single entry in visible_device_list with two (empty) VirtualDevices
|
||||||
// messages.
|
// messages.
|
||||||
SessionOptions opts = MakeSessionOptions("0", 0, 8, {{}, {}});
|
SessionOptions opts = MakeSessionOptions("0", 0, 8, {{}, {}});
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||||
opts, kDeviceNamePrefix, &devices);
|
opts, kDeviceNamePrefix, &devices);
|
||||||
EXPECT_EQ(status.code(), error::UNKNOWN);
|
EXPECT_EQ(status.code(), error::UNKNOWN);
|
||||||
@ -155,7 +155,7 @@ TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithVisibleDeviceList) {
|
|||||||
// Three entries in visible_device_list with two (empty) VirtualDevices
|
// Three entries in visible_device_list with two (empty) VirtualDevices
|
||||||
// messages.
|
// messages.
|
||||||
SessionOptions opts = MakeSessionOptions("0,1", 0, 8, {{}});
|
SessionOptions opts = MakeSessionOptions("0,1", 0, 8, {{}});
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||||
opts, kDeviceNamePrefix, &devices);
|
opts, kDeviceNamePrefix, &devices);
|
||||||
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
||||||
@ -169,39 +169,36 @@ TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithVisibleDeviceList) {
|
|||||||
TEST_F(GPUDeviceTest, EmptyVirtualDeviceConfig) {
|
TEST_F(GPUDeviceTest, EmptyVirtualDeviceConfig) {
|
||||||
// It'll create single virtual device when the virtual device config is empty.
|
// It'll create single virtual device when the virtual device config is empty.
|
||||||
SessionOptions opts = MakeSessionOptions("0");
|
SessionOptions opts = MakeSessionOptions("0");
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||||
opts, kDeviceNamePrefix, &devices));
|
opts, kDeviceNamePrefix, &devices));
|
||||||
EXPECT_EQ(1, devices.size());
|
EXPECT_EQ(1, devices.size());
|
||||||
EXPECT_GE(devices[0]->attributes().memory_limit(), 0);
|
EXPECT_GE(devices[0]->attributes().memory_limit(), 0);
|
||||||
gtl::STLDeleteElements(&devices);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithNoMemoryLimit) {
|
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithNoMemoryLimit) {
|
||||||
// It'll create single virtual device for the gpu in question when
|
// It'll create single virtual device for the gpu in question when
|
||||||
// memory_limit_mb is unset.
|
// memory_limit_mb is unset.
|
||||||
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{}});
|
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{}});
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||||
opts, kDeviceNamePrefix, &devices));
|
opts, kDeviceNamePrefix, &devices));
|
||||||
EXPECT_EQ(1, devices.size());
|
EXPECT_EQ(1, devices.size());
|
||||||
EXPECT_GE(devices[0]->attributes().memory_limit(), 0);
|
EXPECT_GE(devices[0]->attributes().memory_limit(), 0);
|
||||||
gtl::STLDeleteElements(&devices);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimit) {
|
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimit) {
|
||||||
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123}});
|
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123}});
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||||
opts, kDeviceNamePrefix, &devices));
|
opts, kDeviceNamePrefix, &devices));
|
||||||
EXPECT_EQ(1, devices.size());
|
EXPECT_EQ(1, devices.size());
|
||||||
EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit());
|
EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit());
|
||||||
gtl::STLDeleteElements(&devices);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
|
TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
|
||||||
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}});
|
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}});
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||||
opts, kDeviceNamePrefix, &devices));
|
opts, kDeviceNamePrefix, &devices));
|
||||||
EXPECT_EQ(2, devices.size());
|
EXPECT_EQ(2, devices.size());
|
||||||
@ -219,7 +216,6 @@ TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
|
|||||||
devices[1]->attributes().locality().links().link(0).type());
|
devices[1]->attributes().locality().links().link(0).type());
|
||||||
EXPECT_EQ(BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength,
|
EXPECT_EQ(BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength,
|
||||||
devices[1]->attributes().locality().links().link(0).strength());
|
devices[1]->attributes().locality().links().link(0).strength());
|
||||||
gtl::STLDeleteElements(&devices);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enabling unified memory on pre-Pascal GPUs results in an initialization
|
// Enabling unified memory on pre-Pascal GPUs results in an initialization
|
||||||
@ -236,7 +232,7 @@ TEST_F(GPUDeviceTest, UnifiedMemoryUnavailableOnPrePascalGpus) {
|
|||||||
opts.config.mutable_gpu_options()
|
opts.config.mutable_gpu_options()
|
||||||
->mutable_experimental()
|
->mutable_experimental()
|
||||||
->set_use_unified_memory(true);
|
->set_use_unified_memory(true);
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||||
opts, kDeviceNamePrefix, &devices);
|
opts, kDeviceNamePrefix, &devices);
|
||||||
EXPECT_EQ(status.code(), error::INTERNAL);
|
EXPECT_EQ(status.code(), error::INTERNAL);
|
||||||
@ -259,7 +255,7 @@ TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
SessionOptions opts = MakeSessionOptions("0", kGpuMemoryFraction);
|
SessionOptions opts = MakeSessionOptions("0", kGpuMemoryFraction);
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_ASSERT_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
TF_ASSERT_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||||
opts, kDeviceNamePrefix, &devices));
|
opts, kDeviceNamePrefix, &devices));
|
||||||
ASSERT_EQ(1, devices.size());
|
ASSERT_EQ(1, devices.size());
|
||||||
@ -278,8 +274,6 @@ TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
|
|||||||
(memory_limit >> 20) << 20);
|
(memory_limit >> 20) << 20);
|
||||||
EXPECT_NE(ptr, nullptr);
|
EXPECT_NE(ptr, nullptr);
|
||||||
allocator->DeallocateRaw(ptr);
|
allocator->DeallocateRaw(ptr);
|
||||||
|
|
||||||
gtl::STLDeleteElements(&devices);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
|
#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/core/common_runtime/base_collective_executor.h"
|
#include "tensorflow/core/common_runtime/base_collective_executor.h"
|
||||||
#include "tensorflow/core/common_runtime/collective_rma_local.h"
|
#include "tensorflow/core/common_runtime/collective_rma_local.h"
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
@ -217,7 +218,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
|||||||
<< " num_devices_per_worker=" << num_devices_per_worker;
|
<< " num_devices_per_worker=" << num_devices_per_worker;
|
||||||
int total_num_devices = num_workers * num_devices_per_worker;
|
int total_num_devices = num_workers * num_devices_per_worker;
|
||||||
device_type_ = device_type;
|
device_type_ = device_type;
|
||||||
std::vector<Device*> local_devices;
|
std::vector<std::unique_ptr<Device>> local_devices;
|
||||||
SessionOptions sess_opts;
|
SessionOptions sess_opts;
|
||||||
sess_opts.env = Env::Default();
|
sess_opts.env = Env::Default();
|
||||||
Bytes mem_limit(4 << 20);
|
Bytes mem_limit(4 << 20);
|
||||||
@ -227,7 +228,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
|||||||
if (device_type == DEVICE_CPU) {
|
if (device_type == DEVICE_CPU) {
|
||||||
string dev_name = strings::StrCat("/job:worker/replica:0/task:", wi,
|
string dev_name = strings::StrCat("/job:worker/replica:0/task:", wi,
|
||||||
"/device:CPU:", di);
|
"/device:CPU:", di);
|
||||||
local_devices.push_back(new ThreadPoolDevice(
|
local_devices.push_back(absl::make_unique<ThreadPoolDevice>(
|
||||||
sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
|
sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
|
||||||
} else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
|
} else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
|
||||||
int dev_idx = (wi * num_devices_per_worker) + di;
|
int dev_idx = (wi * num_devices_per_worker) + di;
|
||||||
@ -235,7 +236,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
|||||||
LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
|
LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
|
||||||
"than one ring node.";
|
"than one ring node.";
|
||||||
} else {
|
} else {
|
||||||
local_devices.push_back(gpu_devices_[dev_idx]);
|
local_devices.push_back(std::move(gpu_devices_[dev_idx]));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unsupported device_type " << device_type;
|
LOG(FATAL) << "Unsupported device_type " << device_type;
|
||||||
@ -243,7 +244,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!dev_mgr_ || device_type == DEVICE_CPU) {
|
if (!dev_mgr_ || device_type == DEVICE_CPU) {
|
||||||
dev_mgr_.reset(new DeviceMgr(local_devices));
|
dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
|
||||||
}
|
}
|
||||||
if (!gpu_ring_order_) gpu_ring_order_.reset(new string());
|
if (!gpu_ring_order_) gpu_ring_order_.reset(new string());
|
||||||
dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get()));
|
dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get()));
|
||||||
@ -714,7 +715,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
|||||||
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
|
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
|
||||||
std::vector<DeviceInstance*> instances_;
|
std::vector<DeviceInstance*> instances_;
|
||||||
CollectiveParams col_params_;
|
CollectiveParams col_params_;
|
||||||
std::vector<tensorflow::Device*> gpu_devices_;
|
std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
|
||||||
std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
|
std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
|
||||||
std::unique_ptr<string> gpu_ring_order_;
|
std::unique_ptr<string> gpu_ring_order_;
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
|
@ -75,12 +75,12 @@ Benchmark::Benchmark(const string& device, Graph* g,
|
|||||||
const int graph_def_version = g->versions().producer();
|
const int graph_def_version = g->versions().producer();
|
||||||
|
|
||||||
LocalExecutorParams params;
|
LocalExecutorParams params;
|
||||||
params.device = device_;
|
params.device = device_.get();
|
||||||
params.function_library = nullptr;
|
params.function_library = nullptr;
|
||||||
params.create_kernel = [this, graph_def_version](const NodeDef& ndef,
|
params.create_kernel = [this, graph_def_version](const NodeDef& ndef,
|
||||||
OpKernel** kernel) {
|
OpKernel** kernel) {
|
||||||
return CreateNonCachedKernel(device_, nullptr, ndef, graph_def_version,
|
return CreateNonCachedKernel(device_.get(), nullptr, ndef,
|
||||||
kernel);
|
graph_def_version, kernel);
|
||||||
};
|
};
|
||||||
params.delete_kernel = [](OpKernel* kernel) {
|
params.delete_kernel = [](OpKernel* kernel) {
|
||||||
DeleteNonCachedKernel(kernel);
|
DeleteNonCachedKernel(kernel);
|
||||||
@ -107,7 +107,7 @@ Benchmark::~Benchmark() {
|
|||||||
// run kernel destructors that may attempt to access state borrowed from
|
// run kernel destructors that may attempt to access state borrowed from
|
||||||
// `device_`, such as the resource manager.
|
// `device_`, such as the resource manager.
|
||||||
exec_.reset();
|
exec_.reset();
|
||||||
delete device_;
|
device_.reset();
|
||||||
delete pool_;
|
delete pool_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,7 @@ class Benchmark {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
thread::ThreadPool* pool_ = nullptr;
|
thread::ThreadPool* pool_ = nullptr;
|
||||||
Device* device_ = nullptr;
|
std::unique_ptr<Device> device_ = nullptr;
|
||||||
Rendezvous* rendez_ = nullptr;
|
Rendezvous* rendez_ = nullptr;
|
||||||
std::unique_ptr<Executor> exec_;
|
std::unique_ptr<Executor> exec_;
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ class FakeDevice : public Device {
|
|||||||
class DummyFactory : public DeviceFactory {
|
class DummyFactory : public DeviceFactory {
|
||||||
public:
|
public:
|
||||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) override {
|
std::vector<std::unique_ptr<Device>>* devices) override {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -62,9 +62,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
auto* device_count = options.config.mutable_device_count();
|
auto* device_count = options.config.mutable_device_count();
|
||||||
device_count->insert({"CPU", 2});
|
device_count->insert({"CPU", 2});
|
||||||
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
|
TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
|
||||||
&devices_));
|
&devices));
|
||||||
device_mgr_.reset(new DeviceMgr(devices_));
|
device0_ = devices[0].get();
|
||||||
|
device1_ = devices[1].get();
|
||||||
|
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
||||||
FunctionDefLibrary proto;
|
FunctionDefLibrary proto;
|
||||||
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
||||||
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
||||||
@ -138,8 +141,9 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Device*> devices_;
|
|
||||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
|
Device* device0_ = nullptr; // Not owned. (Owned by device_mgr_.)
|
||||||
|
Device* device1_ = nullptr; // Not owned. (Owned by device_mgr_.)
|
||||||
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
||||||
std::unique_ptr<TestClusterFLR> cluster_flr_;
|
std::unique_ptr<TestClusterFLR> cluster_flr_;
|
||||||
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
|
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
|
||||||
@ -165,16 +169,16 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
|
|||||||
FunctionLibraryRuntime* flr =
|
FunctionLibraryRuntime* flr =
|
||||||
proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0");
|
proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0");
|
||||||
EXPECT_NE(flr, nullptr);
|
EXPECT_NE(flr, nullptr);
|
||||||
EXPECT_EQ(flr->device(), devices_[0]);
|
EXPECT_EQ(flr->device(), device0_);
|
||||||
flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0");
|
flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0");
|
||||||
EXPECT_NE(flr, nullptr);
|
EXPECT_NE(flr, nullptr);
|
||||||
EXPECT_EQ(flr->device(), devices_[0]);
|
EXPECT_EQ(flr->device(), device0_);
|
||||||
flr = proc_flr_->GetFLR("/device:CPU:0");
|
flr = proc_flr_->GetFLR("/device:CPU:0");
|
||||||
EXPECT_NE(flr, nullptr);
|
EXPECT_NE(flr, nullptr);
|
||||||
EXPECT_EQ(flr->device(), devices_[0]);
|
EXPECT_EQ(flr->device(), device0_);
|
||||||
flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:1");
|
flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:1");
|
||||||
EXPECT_NE(flr, nullptr);
|
EXPECT_NE(flr, nullptr);
|
||||||
EXPECT_EQ(flr->device(), devices_[1]);
|
EXPECT_EQ(flr->device(), device1_);
|
||||||
flr = proc_flr_->GetFLR("abc");
|
flr = proc_flr_->GetFLR("abc");
|
||||||
EXPECT_EQ(flr, nullptr);
|
EXPECT_EQ(flr, nullptr);
|
||||||
rendezvous_->Unref();
|
rendezvous_->Unref();
|
||||||
|
@ -14,15 +14,14 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/renamed_device.h"
|
#include "tensorflow/core/common_runtime/renamed_device.h"
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// TODO(saeta): Convert to returning a std::unique_ptr?
|
|
||||||
/* static */
|
/* static */
|
||||||
Device* RenamedDevice::NewRenamedDevice(const string& new_base,
|
std::unique_ptr<Device> RenamedDevice::NewRenamedDevice(
|
||||||
Device* underlying,
|
const string& new_base, Device* underlying, bool owns_underlying,
|
||||||
bool owns_underlying,
|
bool isolate_session_state) {
|
||||||
bool isolate_session_state) {
|
|
||||||
DeviceNameUtils::ParsedName parsed_name;
|
DeviceNameUtils::ParsedName parsed_name;
|
||||||
CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name));
|
CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name));
|
||||||
DeviceNameUtils::ParsedName underlying_parsed_name =
|
DeviceNameUtils::ParsedName underlying_parsed_name =
|
||||||
@ -36,8 +35,9 @@ Device* RenamedDevice::NewRenamedDevice(const string& new_base,
|
|||||||
parsed_name.id);
|
parsed_name.id);
|
||||||
DeviceAttributes attributes(underlying->attributes());
|
DeviceAttributes attributes(underlying->attributes());
|
||||||
attributes.set_name(name);
|
attributes.set_name(name);
|
||||||
return new RenamedDevice(underlying, attributes, owns_underlying,
|
// Call absl::WrapUnique to access private constructor.
|
||||||
isolate_session_state);
|
return absl::WrapUnique(new RenamedDevice(
|
||||||
|
underlying, attributes, owns_underlying, isolate_session_state));
|
||||||
}
|
}
|
||||||
|
|
||||||
RenamedDevice::RenamedDevice(Device* underlying,
|
RenamedDevice::RenamedDevice(Device* underlying,
|
||||||
|
@ -28,9 +28,10 @@ namespace tensorflow {
|
|||||||
// session.
|
// session.
|
||||||
class RenamedDevice : public Device {
|
class RenamedDevice : public Device {
|
||||||
public:
|
public:
|
||||||
static Device* NewRenamedDevice(const string& new_base, Device* underlying,
|
static std::unique_ptr<Device> NewRenamedDevice(const string& new_base,
|
||||||
bool owns_underlying,
|
Device* underlying,
|
||||||
bool isolate_session_state);
|
bool owns_underlying,
|
||||||
|
bool isolate_session_state);
|
||||||
|
|
||||||
~RenamedDevice() override;
|
~RenamedDevice() override;
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/ring_reducer.h"
|
#include "tensorflow/core/common_runtime/ring_reducer.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/core/common_runtime/base_collective_executor.h"
|
#include "tensorflow/core/common_runtime/base_collective_executor.h"
|
||||||
#include "tensorflow/core/common_runtime/collective_rma_local.h"
|
#include "tensorflow/core/common_runtime/collective_rma_local.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
@ -157,7 +158,7 @@ class RingReducerTest : public ::testing::Test {
|
|||||||
InitGPUDevices();
|
InitGPUDevices();
|
||||||
#endif
|
#endif
|
||||||
device_type_ = device_type;
|
device_type_ = device_type;
|
||||||
std::vector<Device*> local_devices;
|
std::vector<std::unique_ptr<Device>> local_devices;
|
||||||
SessionOptions sess_opts;
|
SessionOptions sess_opts;
|
||||||
sess_opts.env = Env::Default();
|
sess_opts.env = Env::Default();
|
||||||
Bytes mem_limit(4 << 20);
|
Bytes mem_limit(4 << 20);
|
||||||
@ -167,7 +168,7 @@ class RingReducerTest : public ::testing::Test {
|
|||||||
if (device_type == DEVICE_CPU) {
|
if (device_type == DEVICE_CPU) {
|
||||||
string dev_name =
|
string dev_name =
|
||||||
strings::StrCat("/job:worker/replica:0/task:", wi, "/cpu:", di);
|
strings::StrCat("/job:worker/replica:0/task:", wi, "/cpu:", di);
|
||||||
local_devices.push_back(new ThreadPoolDevice(
|
local_devices.push_back(absl::make_unique<ThreadPoolDevice>(
|
||||||
sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
|
sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
|
||||||
} else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
|
} else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
|
||||||
int dev_idx = (wi * num_devices) + di;
|
int dev_idx = (wi * num_devices) + di;
|
||||||
@ -175,7 +176,7 @@ class RingReducerTest : public ::testing::Test {
|
|||||||
LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
|
LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
|
||||||
"than one ring node.";
|
"than one ring node.";
|
||||||
} else {
|
} else {
|
||||||
local_devices.push_back(gpu_devices_[dev_idx]);
|
local_devices.push_back(std::move(gpu_devices_[dev_idx]));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unsupported device_type " << device_type;
|
LOG(FATAL) << "Unsupported device_type " << device_type;
|
||||||
@ -185,7 +186,7 @@ class RingReducerTest : public ::testing::Test {
|
|||||||
if (!dev_mgr_ || device_type == DEVICE_CPU) {
|
if (!dev_mgr_ || device_type == DEVICE_CPU) {
|
||||||
LOG(ERROR) << "resetting dev_mgr for " << local_devices.size()
|
LOG(ERROR) << "resetting dev_mgr for " << local_devices.size()
|
||||||
<< " devices: ";
|
<< " devices: ";
|
||||||
dev_mgr_.reset(new DeviceMgr(local_devices));
|
dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
|
||||||
}
|
}
|
||||||
if (!gpu_ring_order_) gpu_ring_order_.reset(new string());
|
if (!gpu_ring_order_) gpu_ring_order_.reset(new string());
|
||||||
dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get()));
|
dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get()));
|
||||||
@ -544,7 +545,7 @@ class RingReducerTest : public ::testing::Test {
|
|||||||
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
|
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
|
||||||
std::vector<DeviceInstance*> instances_;
|
std::vector<DeviceInstance*> instances_;
|
||||||
CollectiveParams col_params_;
|
CollectiveParams col_params_;
|
||||||
std::vector<tensorflow::Device*> gpu_devices_;
|
std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
|
||||||
std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
|
std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
|
||||||
std::unique_ptr<string> gpu_ring_order_;
|
std::unique_ptr<string> gpu_ring_order_;
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
|
@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
// Register a factory that provides CPU devices.
|
|
||||||
#include "tensorflow/core/common_runtime/threadpool_device.h"
|
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
// Register a factory that provides CPU devices.
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/process_state.h"
|
#include "tensorflow/core/common_runtime/process_state.h"
|
||||||
|
#include "tensorflow/core/common_runtime/threadpool_device.h"
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
#include "tensorflow/core/platform/numa.h"
|
#include "tensorflow/core/platform/numa.h"
|
||||||
#include "tensorflow/core/public/session_options.h"
|
#include "tensorflow/core/public/session_options.h"
|
||||||
@ -29,7 +30,7 @@ namespace tensorflow {
|
|||||||
class ThreadPoolDeviceFactory : public DeviceFactory {
|
class ThreadPoolDeviceFactory : public DeviceFactory {
|
||||||
public:
|
public:
|
||||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) override {
|
std::vector<std::unique_ptr<Device>>* devices) override {
|
||||||
int num_numa_nodes = port::NUMANumNodes();
|
int num_numa_nodes = port::NUMANumNodes();
|
||||||
int n = 1;
|
int n = 1;
|
||||||
auto iter = options.config.device_count().find("CPU");
|
auto iter = options.config.device_count().find("CPU");
|
||||||
@ -38,7 +39,7 @@ class ThreadPoolDeviceFactory : public DeviceFactory {
|
|||||||
}
|
}
|
||||||
for (int i = 0; i < n; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
string name = strings::StrCat(name_prefix, "/device:CPU:", i);
|
string name = strings::StrCat(name_prefix, "/device:CPU:", i);
|
||||||
ThreadPoolDevice* tpd = nullptr;
|
std::unique_ptr<ThreadPoolDevice> tpd;
|
||||||
if (options.config.experimental().use_numa_affinity()) {
|
if (options.config.experimental().use_numa_affinity()) {
|
||||||
int numa_node = i % num_numa_nodes;
|
int numa_node = i % num_numa_nodes;
|
||||||
if (numa_node != i) {
|
if (numa_node != i) {
|
||||||
@ -49,15 +50,15 @@ class ThreadPoolDeviceFactory : public DeviceFactory {
|
|||||||
}
|
}
|
||||||
DeviceLocality dev_locality;
|
DeviceLocality dev_locality;
|
||||||
dev_locality.set_numa_node(numa_node);
|
dev_locality.set_numa_node(numa_node);
|
||||||
tpd = new ThreadPoolDevice(
|
tpd = absl::make_unique<ThreadPoolDevice>(
|
||||||
options, name, Bytes(256 << 20), dev_locality,
|
options, name, Bytes(256 << 20), dev_locality,
|
||||||
ProcessState::singleton()->GetCPUAllocator(numa_node));
|
ProcessState::singleton()->GetCPUAllocator(numa_node));
|
||||||
} else {
|
} else {
|
||||||
tpd = new ThreadPoolDevice(
|
tpd = absl::make_unique<ThreadPoolDevice>(
|
||||||
options, name, Bytes(256 << 20), DeviceLocality(),
|
options, name, Bytes(256 << 20), DeviceLocality(),
|
||||||
ProcessState::singleton()->GetCPUAllocator(port::kNUMANoAffinity));
|
ProcessState::singleton()->GetCPUAllocator(port::kNUMANoAffinity));
|
||||||
}
|
}
|
||||||
devices->push_back(tpd);
|
devices->push_back(std::move(tpd));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -624,6 +624,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -29,7 +29,8 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static Device* NewDevice(const string& type, const string& name) {
|
static std::unique_ptr<Device> NewDevice(const string& type,
|
||||||
|
const string& name) {
|
||||||
class FakeDevice : public Device {
|
class FakeDevice : public Device {
|
||||||
public:
|
public:
|
||||||
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
||||||
@ -40,7 +41,7 @@ static Device* NewDevice(const string& type, const string& name) {
|
|||||||
attr.set_name(name);
|
attr.set_name(name);
|
||||||
attr.set_device_type(type);
|
attr.set_device_type(type);
|
||||||
attr.mutable_locality()->set_numa_node(3); // a non-default value
|
attr.mutable_locality()->set_numa_node(3); // a non-default value
|
||||||
return new FakeDevice(attr);
|
return absl::make_unique<FakeDevice>(attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
class FakeWorker : public TestWorkerInterface {
|
class FakeWorker : public TestWorkerInterface {
|
||||||
@ -156,16 +157,16 @@ class DeviceResDistTest : public ::testing::Test {
|
|||||||
|
|
||||||
void DefineWorker(const ConfigProto& config, const string& worker_name,
|
void DefineWorker(const ConfigProto& config, const string& worker_name,
|
||||||
const string& device_type, int num_devices) {
|
const string& device_type, int num_devices) {
|
||||||
std::vector<Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
for (int i = 0; i < num_devices; ++i) {
|
for (int i = 0; i < num_devices; ++i) {
|
||||||
devices.push_back(NewDevice(
|
devices.push_back(NewDevice(
|
||||||
device_type,
|
device_type,
|
||||||
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
|
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
|
||||||
}
|
}
|
||||||
DeviceMgr* dev_mgr = new DeviceMgr(devices);
|
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
|
||||||
device_mgrs_.push_back(dev_mgr);
|
device_mgrs_.push_back(dev_mgr);
|
||||||
std::vector<string>* dv = &dev_by_task_[worker_name];
|
std::vector<string>* dv = &dev_by_task_[worker_name];
|
||||||
for (auto d : devices) {
|
for (auto* d : dev_mgr->ListDevices()) {
|
||||||
dv->push_back(d->name());
|
dv->push_back(d->name());
|
||||||
}
|
}
|
||||||
DeviceResolverDistributed* dev_res =
|
DeviceResolverDistributed* dev_res =
|
||||||
|
@ -41,7 +41,8 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static Device* NewDevice(const string& type, const string& name) {
|
static std::unique_ptr<Device> NewDevice(const string& type,
|
||||||
|
const string& name) {
|
||||||
class FakeDevice : public Device {
|
class FakeDevice : public Device {
|
||||||
public:
|
public:
|
||||||
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
||||||
@ -52,7 +53,7 @@ static Device* NewDevice(const string& type, const string& name) {
|
|||||||
attr.set_name(name);
|
attr.set_name(name);
|
||||||
attr.set_device_type(type);
|
attr.set_device_type(type);
|
||||||
attr.mutable_locality()->set_numa_node(3); // a non-default value
|
attr.mutable_locality()->set_numa_node(3); // a non-default value
|
||||||
return new FakeDevice(attr);
|
return absl::make_unique<FakeDevice>(attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
static int64 kStepId = 123;
|
static int64 kStepId = 123;
|
||||||
@ -211,16 +212,16 @@ class CollRMADistTest : public ::testing::Test {
|
|||||||
|
|
||||||
void DefineWorker(const ConfigProto& config, const string& worker_name,
|
void DefineWorker(const ConfigProto& config, const string& worker_name,
|
||||||
const string& device_type, int num_devices) {
|
const string& device_type, int num_devices) {
|
||||||
std::vector<Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
for (int i = 0; i < num_devices; ++i) {
|
for (int i = 0; i < num_devices; ++i) {
|
||||||
devices.push_back(NewDevice(
|
devices.push_back(NewDevice(
|
||||||
device_type,
|
device_type,
|
||||||
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
|
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
|
||||||
}
|
}
|
||||||
DeviceMgr* dev_mgr = new DeviceMgr(devices);
|
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
|
||||||
device_mgrs_.push_back(dev_mgr);
|
device_mgrs_.push_back(dev_mgr);
|
||||||
std::vector<string>* dv = &dev_by_task_[worker_name];
|
std::vector<string>* dv = &dev_by_task_[worker_name];
|
||||||
for (auto d : devices) {
|
for (auto d : dev_mgr->ListDevices()) {
|
||||||
dv->push_back(d->name());
|
dv->push_back(d->name());
|
||||||
}
|
}
|
||||||
DeviceResolverDistributed* dev_res =
|
DeviceResolverDistributed* dev_res =
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
|
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
#include "tensorflow/core/distributed_runtime/test_utils.h"
|
#include "tensorflow/core/distributed_runtime/test_utils.h"
|
||||||
#include "tensorflow/core/lib/core/notification.h"
|
#include "tensorflow/core/lib/core/notification.h"
|
||||||
@ -41,8 +42,8 @@ class TestableDeviceResolverDistributed : public DeviceResolverDistributed {
|
|||||||
|
|
||||||
// Create a fake 'Device' whose only interesting attribute is a non-default
|
// Create a fake 'Device' whose only interesting attribute is a non-default
|
||||||
// DeviceLocality.
|
// DeviceLocality.
|
||||||
static Device* NewDevice(const string& type, const string& name,
|
static std::unique_ptr<Device> NewDevice(const string& type, const string& name,
|
||||||
int numa_node) {
|
int numa_node) {
|
||||||
class FakeDevice : public Device {
|
class FakeDevice : public Device {
|
||||||
public:
|
public:
|
||||||
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
||||||
@ -53,7 +54,7 @@ static Device* NewDevice(const string& type, const string& name,
|
|||||||
attr.set_name(name);
|
attr.set_name(name);
|
||||||
attr.set_device_type(type);
|
attr.set_device_type(type);
|
||||||
attr.mutable_locality()->set_numa_node(numa_node);
|
attr.mutable_locality()->set_numa_node(numa_node);
|
||||||
return new FakeDevice(attr);
|
return absl::make_unique<FakeDevice>(attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a fake WorkerInterface that responds to requests without RPCs,
|
// Create a fake WorkerInterface that responds to requests without RPCs,
|
||||||
@ -151,19 +152,19 @@ class DeviceResDistTest : public ::testing::Test {
|
|||||||
|
|
||||||
void DefineWorker(const string& worker_name, const string& device_type,
|
void DefineWorker(const string& worker_name, const string& device_type,
|
||||||
int num_devices) {
|
int num_devices) {
|
||||||
std::vector<Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
for (int i = 0; i < num_devices; ++i) {
|
for (int i = 0; i < num_devices; ++i) {
|
||||||
devices.push_back(NewDevice(
|
devices.push_back(NewDevice(
|
||||||
device_type,
|
device_type,
|
||||||
strings::StrCat(worker_name, "/device:", device_type, ":", i), i));
|
strings::StrCat(worker_name, "/device:", device_type, ":", i), i));
|
||||||
}
|
}
|
||||||
DeviceMgr* dev_mgr = new DeviceMgr(devices);
|
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
|
||||||
TestableDeviceResolverDistributed* dev_res =
|
TestableDeviceResolverDistributed* dev_res =
|
||||||
new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name);
|
new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name);
|
||||||
resolvers_[worker_name] = dev_res;
|
resolvers_[worker_name] = dev_res;
|
||||||
device_mgrs_.push_back(dev_mgr);
|
device_mgrs_.push_back(dev_mgr);
|
||||||
std::vector<string>* dv = &dev_by_task_[worker_name];
|
std::vector<string>* dv = &dev_by_task_[worker_name];
|
||||||
for (auto d : devices) {
|
for (auto* d : dev_mgr->ListDevices()) {
|
||||||
dv->push_back(d->name());
|
dv->push_back(d->name());
|
||||||
}
|
}
|
||||||
FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);
|
FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);
|
||||||
|
@ -87,7 +87,7 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
|||||||
return tensorflow::errors::Internal(
|
return tensorflow::errors::Internal(
|
||||||
"invalid eager env_ or env_->rendezvous_mgr.");
|
"invalid eager env_ or env_->rendezvous_mgr.");
|
||||||
}
|
}
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
|
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
|
||||||
// TODO(nareshmodi): Correctly set the SessionOptions.
|
// TODO(nareshmodi): Correctly set the SessionOptions.
|
||||||
@ -97,12 +97,12 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
|||||||
request->server_def().task_index()),
|
request->server_def().task_index()),
|
||||||
&devices));
|
&devices));
|
||||||
response->mutable_device_attributes()->Reserve(devices.size());
|
response->mutable_device_attributes()->Reserve(devices.size());
|
||||||
for (auto& d : devices) {
|
for (const auto& d : devices) {
|
||||||
*response->add_device_attributes() = d->attributes();
|
*response->add_device_attributes() = d->attributes();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
|
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
|
||||||
new tensorflow::DeviceMgr(devices));
|
new tensorflow::DeviceMgr(std::move(devices)));
|
||||||
|
|
||||||
auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id());
|
auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id());
|
||||||
auto session_name = strings::StrCat("eager_", request->rendezvous_id());
|
auto session_name = strings::StrCat("eager_", request->rendezvous_id());
|
||||||
|
@ -68,12 +68,9 @@ class EagerServiceImplTest : public ::testing::Test {
|
|||||||
worker_env_.rendezvous_mgr = &rendezvous_mgr_;
|
worker_env_.rendezvous_mgr = &rendezvous_mgr_;
|
||||||
worker_env_.session_mgr = session_mgr_.get();
|
worker_env_.session_mgr = session_mgr_.get();
|
||||||
|
|
||||||
Device* device = DeviceFactory::NewDevice(
|
device_mgr_ = absl::make_unique<DeviceMgr>(DeviceFactory::NewDevice(
|
||||||
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0");
|
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||||
|
worker_env_.local_devices = device_mgr_->ListDevices();
|
||||||
worker_env_.local_devices = {device};
|
|
||||||
|
|
||||||
device_mgr_.reset(new DeviceMgr(worker_env_.local_devices));
|
|
||||||
worker_env_.device_mgr = device_mgr_.get();
|
worker_env_.device_mgr = device_mgr_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "grpc/support/alloc.h"
|
#include "grpc/support/alloc.h"
|
||||||
#include "grpcpp/grpcpp.h"
|
#include "grpcpp/grpcpp.h"
|
||||||
@ -156,10 +157,12 @@ Status GrpcServer::Init(
|
|||||||
string name_prefix =
|
string name_prefix =
|
||||||
strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
|
strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
|
||||||
"/task:", server_def_.task_index());
|
"/task:", server_def_.task_index());
|
||||||
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
&master_env_.local_devices));
|
TF_RETURN_IF_ERROR(
|
||||||
worker_env_.local_devices = master_env_.local_devices;
|
DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
|
||||||
worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices);
|
worker_env_.device_mgr = new DeviceMgr(std::move(devices));
|
||||||
|
master_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
||||||
|
worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
||||||
worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
|
worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
|
||||||
? new RpcRendezvousMgr(&worker_env_)
|
? new RpcRendezvousMgr(&worker_env_)
|
||||||
: rendezvous_mgr_func(&worker_env_);
|
: rendezvous_mgr_func(&worker_env_);
|
||||||
|
@ -42,8 +42,9 @@ class RpcCollectiveExecutorMgrTest : public ::testing::Test {
|
|||||||
WorkerCacheInterface* worker_cache = nullptr;
|
WorkerCacheInterface* worker_cache = nullptr;
|
||||||
auto* device_count = options.config.mutable_device_count();
|
auto* device_count = options.config.mutable_device_count();
|
||||||
device_count->insert({"CPU", NUM_DEVS});
|
device_count->insert({"CPU", NUM_DEVS});
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
device_mgr_.reset(new DeviceMgr(devices_));
|
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
||||||
|
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
||||||
std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed(
|
std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed(
|
||||||
device_mgr_.get(), worker_cache, task_name));
|
device_mgr_.get(), worker_cache, task_name));
|
||||||
std::unique_ptr<CollectiveParamResolverDistributed> cpr(
|
std::unique_ptr<CollectiveParamResolverDistributed> cpr(
|
||||||
@ -57,7 +58,6 @@ class RpcCollectiveExecutorMgrTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<RpcCollectiveExecutorMgr> cme_;
|
std::unique_ptr<RpcCollectiveExecutorMgr> cme_;
|
||||||
std::vector<Device*> devices_;
|
|
||||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -78,13 +78,13 @@ Status SessionMgr::CreateSession(const string& session,
|
|||||||
|
|
||||||
if (isolate_session_state) {
|
if (isolate_session_state) {
|
||||||
// Create a private copy of the DeviceMgr for the WorkerSession.
|
// Create a private copy of the DeviceMgr for the WorkerSession.
|
||||||
std::vector<Device*> renamed_devices;
|
std::vector<std::unique_ptr<Device>> renamed_devices;
|
||||||
for (Device* d : worker_env_->local_devices) {
|
for (Device* d : worker_env_->local_devices) {
|
||||||
renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
|
renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
|
||||||
worker_name, d, false, isolate_session_state));
|
worker_name, d, false, isolate_session_state));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto device_mgr = MakeUnique<DeviceMgr>(renamed_devices);
|
auto device_mgr = MakeUnique<DeviceMgr>(std::move(renamed_devices));
|
||||||
auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
|
auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
|
||||||
worker_session.reset(
|
worker_session.reset(
|
||||||
new WorkerSession(session, worker_name,
|
new WorkerSession(session, worker_name,
|
||||||
|
@ -46,11 +46,9 @@ class SessionMgrTest : public ::testing::Test {
|
|||||||
SessionMgrTest()
|
SessionMgrTest()
|
||||||
: mgr_(&env_, "/job:mnist/replica:0/task:0",
|
: mgr_(&env_, "/job:mnist/replica:0/task:0",
|
||||||
std::unique_ptr<WorkerCacheInterface>(), factory_) {
|
std::unique_ptr<WorkerCacheInterface>(), factory_) {
|
||||||
Device* device =
|
device_mgr_ = absl::make_unique<DeviceMgr>(
|
||||||
FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0")
|
FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0"));
|
||||||
.release();
|
env_.local_devices = device_mgr_->ListDevices();
|
||||||
env_.local_devices = {device};
|
|
||||||
device_mgr_.reset(new DeviceMgr(env_.local_devices));
|
|
||||||
env_.device_mgr = device_mgr_.get();
|
env_.device_mgr = device_mgr_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,10 +102,11 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Instantiate all variables for function library runtime creation.
|
// Instantiate all variables for function library runtime creation.
|
||||||
std::vector<Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
|
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
|
||||||
options, "/job:localhost/replica:0/task:0", &devices));
|
options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(devices));
|
Device* cpu_device = devices[0].get();
|
||||||
|
std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(std::move(devices)));
|
||||||
FunctionLibraryDefinition function_library(OpRegistry::Global(),
|
FunctionLibraryDefinition function_library(OpRegistry::Global(),
|
||||||
graph_def.library());
|
graph_def.library());
|
||||||
Env* env = Env::Default();
|
Env* env = Env::Default();
|
||||||
@ -124,7 +125,7 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
|
|||||||
new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env,
|
new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env,
|
||||||
graph_def.versions().producer(),
|
graph_def.versions().producer(),
|
||||||
&function_library, *optimizer_opts));
|
&function_library, *optimizer_opts));
|
||||||
FunctionLibraryRuntime* flr = pflr->GetFLR(devices[0]->name());
|
FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());
|
||||||
|
|
||||||
// Create the GraphOptimizer to optimize the graph def.
|
// Create the GraphOptimizer to optimize the graph def.
|
||||||
GraphConstructorOptions graph_ctor_opts;
|
GraphConstructorOptions graph_ctor_opts;
|
||||||
@ -137,7 +138,7 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
|
|||||||
|
|
||||||
// Optimize the graph.
|
// Optimize the graph.
|
||||||
::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
|
::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
|
||||||
optimizer.Optimize(flr, env, devices[0], &graphptr, /*shape_map=*/nullptr);
|
optimizer.Optimize(flr, env, cpu_device, &graphptr, /*shape_map=*/nullptr);
|
||||||
graphptr->ToGraphDef(output_graph_def);
|
graphptr->ToGraphDef(output_graph_def);
|
||||||
|
|
||||||
// The default values of attributes might have been stripped by the optimizer.
|
// The default values of attributes might have been stripped by the optimizer.
|
||||||
|
@ -142,7 +142,6 @@ cc_library(
|
|||||||
":graph_optimizer",
|
":graph_optimizer",
|
||||||
"//tensorflow/core:core_cpu_base",
|
"//tensorflow/core:core_cpu_base",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
@ -150,6 +149,7 @@ cc_library(
|
|||||||
"//tensorflow/core/grappler:op_types",
|
"//tensorflow/core/grappler:op_types",
|
||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
"//tensorflow/core/grappler/utils:functions",
|
"//tensorflow/core/grappler/utils:functions",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -16,7 +16,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_replace.h"
|
#include "absl/strings/str_replace.h"
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
@ -343,14 +345,15 @@ class FunctionOptimizerContext {
|
|||||||
DeviceAttributes attr;
|
DeviceAttributes attr;
|
||||||
attr.set_name("/device:CPU:0");
|
attr.set_name("/device:CPU:0");
|
||||||
attr.set_device_type("CPU");
|
attr.set_device_type("CPU");
|
||||||
Device* device = new FakeCPUDevice(env, attr);
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
device_mgr_.reset(new DeviceMgr({device}));
|
devices.push_back(absl::make_unique<FakeCPUDevice>(env, attr));
|
||||||
|
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||||
OptimizerOptions optimizer_opts;
|
OptimizerOptions optimizer_opts;
|
||||||
optimizer_opts.set_do_function_inlining(true);
|
optimizer_opts.set_do_function_inlining(true);
|
||||||
process_flr_.reset(new ProcessFunctionLibraryRuntime(
|
process_flr_.reset(new ProcessFunctionLibraryRuntime(
|
||||||
device_mgr_.get(), env, graph_version_, &function_library_,
|
device_mgr_.get(), env, graph_version_, &function_library_,
|
||||||
optimizer_opts));
|
optimizer_opts));
|
||||||
flr_ = process_flr_->GetFLR(device->name());
|
flr_ = process_flr_->GetFLR(device_mgr_->ListDevices()[0]->name());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -600,6 +600,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:session_options",
|
"//tensorflow/core:session_options",
|
||||||
"//tensorflow/core/kernels:ops_util",
|
"//tensorflow/core/kernels:ops_util",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/core/kernels/data/iterator_ops.h"
|
#include "tensorflow/core/kernels/data/iterator_ops.h"
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||||
#include "tensorflow/core/common_runtime/renamed_device.h"
|
#include "tensorflow/core/common_runtime/renamed_device.h"
|
||||||
#include "tensorflow/core/common_runtime/threadpool_device.h"
|
#include "tensorflow/core/common_runtime/threadpool_device.h"
|
||||||
@ -545,10 +546,9 @@ FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
|
|||||||
// in its resource manager. The existing device will outlive the
|
// in its resource manager. The existing device will outlive the
|
||||||
// IteratorResource, because we are storing the IteratorResource
|
// IteratorResource, because we are storing the IteratorResource
|
||||||
// in that device's resource manager.
|
// in that device's resource manager.
|
||||||
Device* wrapped_device = RenamedDevice::NewRenamedDevice(
|
*device_mgr = absl::make_unique<DeviceMgr>(RenamedDevice::NewRenamedDevice(
|
||||||
ctx->device()->name(), down_cast<Device*>(ctx->device()),
|
ctx->device()->name(), down_cast<Device*>(ctx->device()),
|
||||||
false /* owns_underlying */, false /* isolate_session_state */);
|
false /* owns_underlying */, false /* isolate_session_state */));
|
||||||
device_mgr->reset(new DeviceMgr({wrapped_device}));
|
|
||||||
flib_def->reset(new FunctionLibraryDefinition(
|
flib_def->reset(new FunctionLibraryDefinition(
|
||||||
*ctx->function_library()->GetFunctionLibraryDefinition()));
|
*ctx->function_library()->GetFunctionLibraryDefinition()));
|
||||||
pflr->reset(new ProcessFunctionLibraryRuntime(
|
pflr->reset(new ProcessFunctionLibraryRuntime(
|
||||||
|
@ -51,17 +51,17 @@ class ExecutorTest : public ::testing::Test {
|
|||||||
// when the test completes.
|
// when the test completes.
|
||||||
CHECK(rendez_->Unref());
|
CHECK(rendez_->Unref());
|
||||||
delete exec_;
|
delete exec_;
|
||||||
delete device_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resets executor_ with a new executor based on a graph 'gdef'.
|
// Resets executor_ with a new executor based on a graph 'gdef'.
|
||||||
void Create(std::unique_ptr<const Graph> graph) {
|
void Create(std::unique_ptr<const Graph> graph) {
|
||||||
const int version = graph->versions().producer();
|
const int version = graph->versions().producer();
|
||||||
LocalExecutorParams params;
|
LocalExecutorParams params;
|
||||||
params.device = device_;
|
params.device = device_.get();
|
||||||
params.create_kernel = [this, version](const NodeDef& ndef,
|
params.create_kernel = [this, version](const NodeDef& ndef,
|
||||||
OpKernel** kernel) {
|
OpKernel** kernel) {
|
||||||
return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
|
return CreateNonCachedKernel(device_.get(), nullptr, ndef, version,
|
||||||
|
kernel);
|
||||||
};
|
};
|
||||||
params.delete_kernel = [](OpKernel* kernel) {
|
params.delete_kernel = [](OpKernel* kernel) {
|
||||||
DeleteNonCachedKernel(kernel);
|
DeleteNonCachedKernel(kernel);
|
||||||
@ -86,7 +86,7 @@ class ExecutorTest : public ::testing::Test {
|
|||||||
return exec_->Run(args);
|
return exec_->Run(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
Device* device_ = nullptr;
|
std::unique_ptr<Device> device_;
|
||||||
Executor* exec_ = nullptr;
|
Executor* exec_ = nullptr;
|
||||||
Executor::Args::Runner runner_;
|
Executor::Args::Runner runner_;
|
||||||
Rendezvous* rendez_ = nullptr;
|
Rendezvous* rendez_ = nullptr;
|
||||||
|
@ -116,6 +116,7 @@ cc_library(
|
|||||||
hdrs = ["delegate_data.h"],
|
hdrs = ["delegate_data.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":buffer_map",
|
":buffer_map",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
"//tensorflow/core/common_runtime/eager:context",
|
"//tensorflow/core/common_runtime/eager:context",
|
||||||
] + select({
|
] + select({
|
||||||
"//tensorflow:android": [
|
"//tensorflow:android": [
|
||||||
|
@ -14,20 +14,21 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/lite/delegates/flex/delegate_data.h"
|
#include "tensorflow/lite/delegates/flex/delegate_data.h"
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace flex {
|
namespace flex {
|
||||||
tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) {
|
tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) {
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
|
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
|
||||||
tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0",
|
tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0",
|
||||||
&devices));
|
&devices));
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
|
std::unique_ptr<tensorflow::DeviceMgr> device_mgr =
|
||||||
new tensorflow::DeviceMgr(devices));
|
absl::make_unique<tensorflow::DeviceMgr>(std::move(devices));
|
||||||
// Note that Rendezvous is ref-counted so it will be automatically deleted.
|
// Note that Rendezvous is ref-counted so it will be automatically deleted.
|
||||||
tensorflow::Rendezvous* rendezvous =
|
tensorflow::Rendezvous* rendezvous =
|
||||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||||
|
@ -2012,13 +2012,13 @@ bool InlineAllFunctions(GraphDef* graphdef) {
|
|||||||
tensorflow::SessionOptions options;
|
tensorflow::SessionOptions options;
|
||||||
auto* device_count = options.config.mutable_device_count();
|
auto* device_count = options.config.mutable_device_count();
|
||||||
device_count->insert({"CPU", 1});
|
device_count->insert({"CPU", 1});
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||||
TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
|
TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
|
||||||
options, "/job:localhost/replica:0/task:0", &devices));
|
options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
|
|
||||||
tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
|
tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
|
||||||
graphdef_copy.library());
|
graphdef_copy.library());
|
||||||
tensorflow::DeviceMgr device_mgr(devices);
|
tensorflow::DeviceMgr device_mgr(std::move(devices));
|
||||||
tensorflow::OptimizerOptions o_opts;
|
tensorflow::OptimizerOptions o_opts;
|
||||||
tensorflow::ProcessFunctionLibraryRuntime pflr(
|
tensorflow::ProcessFunctionLibraryRuntime pflr(
|
||||||
&device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
|
&device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
|
||||||
|
@ -48,17 +48,14 @@ static std::vector<string> ListDevicesWithSessionConfig(
|
|||||||
std::vector<string> output;
|
std::vector<string> output;
|
||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
options.config = config;
|
options.config = config;
|
||||||
std::vector<Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
Status status = DeviceFactory::AddDevices(
|
Status status = DeviceFactory::AddDevices(
|
||||||
options, "" /* name_prefix */, &devices);
|
options, "" /* name_prefix */, &devices);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
Set_TF_Status_from_Status(out_status, status);
|
Set_TF_Status_from_Status(out_status, status);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::unique_ptr<Device>> device_holder(devices.begin(),
|
for (const std::unique_ptr<Device>& device : devices) {
|
||||||
devices.end());
|
|
||||||
|
|
||||||
for (const Device* device : devices) {
|
|
||||||
const DeviceAttributes& attr = device->attributes();
|
const DeviceAttributes& attr = device->attributes();
|
||||||
string attr_serialized;
|
string attr_serialized;
|
||||||
if (!attr.SerializeToString(&attr_serialized)) {
|
if (!attr.SerializeToString(&attr_serialized)) {
|
||||||
|
@ -74,13 +74,13 @@ limitations under the License.
|
|||||||
|
|
||||||
void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* device_map) {
|
void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* device_map) {
|
||||||
tensorflow::SessionOptions options;
|
tensorflow::SessionOptions options;
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||||
tensorflow::Status status = tensorflow::DeviceFactory::AddDevices(options, "", &devices);
|
tensorflow::Status status = tensorflow::DeviceFactory::AddDevices(options, "", &devices);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const tensorflow::Device* device : devices) {
|
for (const std::unique_ptr<tensorflow::Device>& device : devices) {
|
||||||
tensorflow::DeviceProperties& prop = (*device_map)[device->name()];
|
tensorflow::DeviceProperties& prop = (*device_map)[device->name()];
|
||||||
prop = tensorflow::grappler::GetDeviceInfo(device->parsed_name());
|
prop = tensorflow::grappler::GetDeviceInfo(device->parsed_name());
|
||||||
|
|
||||||
@ -88,7 +88,6 @@ void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* dev
|
|||||||
// available device memory.
|
// available device memory.
|
||||||
const tensorflow::DeviceAttributes& attr = device->attributes();
|
const tensorflow::DeviceAttributes& attr = device->attributes();
|
||||||
prop.set_memory_size(attr.memory_limit());
|
prop.set_memory_size(attr.memory_limit());
|
||||||
delete device;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user