Make DeviceMgr a pure virtual interface with StaticDeviceMgr as its only implementation.

This is a precursor to haoyuzhang@'s work on supporting dynamic device managers.

PiperOrigin-RevId: 266514057
This commit is contained in:
Derek Murray 2019-08-30 22:49:43 -07:00 committed by TensorFlower Gardener
parent f624f59710
commit 73e53d6343
41 changed files with 97 additions and 74 deletions

View File

@ -122,7 +122,7 @@ tensorflow::Status GetAllRemoteDevices(
n.WaitForNotification(); n.WaitForNotification();
} }
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr( std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
new tensorflow::DeviceMgr(std::move(remote_devices))); new tensorflow::StaticDeviceMgr(std::move(remote_devices)));
TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(status);
@ -385,7 +385,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
&devices); &devices);
if (!status->status.ok()) return nullptr; if (!status->status.ok()) return nullptr;
std::unique_ptr<tensorflow::DeviceMgr> device_mgr( std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(std::move(devices))); new tensorflow::StaticDeviceMgr(std::move(devices)));
tensorflow::Rendezvous* r = tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get()); new tensorflow::IntraProcessRendezvous(device_mgr.get());

View File

@ -1193,7 +1193,7 @@ Status EncapsulateSubgraphsPass::Run(
} }
std::unique_ptr<DeviceMgr> device_mgr = std::unique_ptr<DeviceMgr> device_mgr =
absl::make_unique<DeviceMgr>(std::move(devices)); absl::make_unique<StaticDeviceMgr>(std::move(devices));
OptimizerOptions opts; OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(device_mgr.get(), new ProcessFunctionLibraryRuntime(device_mgr.get(),

View File

@ -510,7 +510,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
TF_CHECK_OK(DeviceFactory::AddDevices( TF_CHECK_OK(DeviceFactory::AddDevices(
session_options, "/job:localhost/replica:0/task:0", &devices)); session_options, "/job:localhost/replica:0/task:0", &devices));
OptimizerOptions opts; OptimizerOptions opts;
auto device_mgr = absl::make_unique<DeviceMgr>(std::move(devices)); auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>( auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(), device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);

View File

@ -232,7 +232,7 @@ class ExtractOutsideCompilationForFunctionTest : public ::testing::Test {
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices( TF_CHECK_OK(DeviceFactory::AddDevices(
session_options, "/job:localhost/replica:0/task:0", &devices)); session_options, "/job:localhost/replica:0/task:0", &devices));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices)); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
} }
Status ExtractOutsideCompilationTest( Status ExtractOutsideCompilationTest(

View File

@ -71,7 +71,7 @@ class XlaKernelCreatorTest : public ::testing::Test {
lib_def_ = absl::make_unique<FunctionLibraryDefinition>( lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
OpRegistry::Global(), proto); OpRegistry::Global(), proto);
OptimizerOptions opts; OptimizerOptions opts;
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices)); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);

View File

@ -3456,7 +3456,7 @@ int main(int argc, char** argv) {
std::vector<std::unique_ptr<tensorflow::Device>> devices; std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices( TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
tensorflow::SessionOptions(), "", &devices)); tensorflow::SessionOptions(), "", &devices));
tensorflow::DeviceMgr device_mgr(std::move(devices)); tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
tensorflow::Device* ignored; tensorflow::Device* ignored;
TF_QCHECK_OK( TF_QCHECK_OK(

View File

@ -474,7 +474,7 @@ class XlaCompiler {
int64 next_step_id_; int64 next_step_id_;
XlaCompilationDevice* device_; // Owned by device_mgr_ XlaCompilationDevice* device_; // Owned by device_mgr_
DeviceMgr device_mgr_; StaticDeviceMgr device_mgr_;
// To avoid copying the client's function library, use a local function // To avoid copying the client's function library, use a local function
// library and runtime for functions created as part of the functionalize // library and runtime for functions created as part of the functionalize

View File

@ -47,7 +47,7 @@ class BufRendezvousTest : public ::testing::Test {
const uint64 incarnation) { const uint64 incarnation) {
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
devices.push_back(NewDevice(device, type, incarnation)); devices.push_back(NewDevice(device, type, incarnation));
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices)); dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
br_ = absl::make_unique<BufRendezvous>(123, dev_mgr_.get()); br_ = absl::make_unique<BufRendezvous>(123, dev_mgr_.get());
} }

View File

@ -40,7 +40,7 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
device_count->insert({"CPU", NUM_DEVS}); device_count->insert({"CPU", NUM_DEVS});
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices)); TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices))); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
std::unique_ptr<DeviceResolverInterface> drl( std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(device_mgr_.get())); new DeviceResolverLocal(device_mgr_.get()));
std::unique_ptr<ParamResolverInterface> prl( std::unique_ptr<ParamResolverInterface> prl(

View File

@ -39,7 +39,7 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
device_count->insert({"CPU", NUM_DEVS}); device_count->insert({"CPU", NUM_DEVS});
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices)); TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices))); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
drl_.reset(new DeviceResolverLocal(device_mgr_.get())); drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
prl_.reset(new CollectiveParamResolverLocal(cp, device_mgr_.get(), prl_.reset(new CollectiveParamResolverLocal(cp, device_mgr_.get(),
drl_.get(), task_name)); drl_.get(), task_name));

View File

@ -46,7 +46,7 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
device_count->insert({"CPU", NUM_DEVS}); device_count->insert({"CPU", NUM_DEVS});
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices)); TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices)); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
drl_ = absl::make_unique<DeviceResolverLocal>(device_mgr_.get()); drl_ = absl::make_unique<DeviceResolverLocal>(device_mgr_.get());
prl_ = absl::make_unique<CollectiveParamResolverLocal>( prl_ = absl::make_unique<CollectiveParamResolverLocal>(
cp, device_mgr_.get(), drl_.get(), kTaskName); cp, device_mgr_.get(), drl_.get(), kTaskName);

View File

@ -47,7 +47,7 @@ Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
// Instantiate enough of the TF runtime to run `graph` on a single CPU device. // Instantiate enough of the TF runtime to run `graph` on a single CPU device.
auto device_mgr = absl::make_unique<DeviceMgr>(DeviceFactory::NewDevice( auto device_mgr = absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
"CPU", params.session_options, "/job:localhost/replica:0/task:0")); "CPU", params.session_options, "/job:localhost/replica:0/task:0"));
Device* device = device_mgr->ListDevices()[0]; Device* device = device_mgr->ListDevices()[0];
// Clone the `FunctionLibraryDefinition` to extend its lifetime extends beyond // Clone the `FunctionLibraryDefinition` to extend its lifetime extends beyond

View File

@ -25,7 +25,9 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
DeviceMgr::DeviceMgr(std::vector<std::unique_ptr<Device>> devices) DeviceMgr::~DeviceMgr() {}
StaticDeviceMgr::StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices)
: devices_(std::move(devices)), name_backing_store_(128) { : devices_(std::move(devices)), name_backing_store_(128) {
for (auto& d : devices_) { for (auto& d : devices_) {
// Register under the (1) full name and (2) canonical name. // Register under the (1) full name and (2) canonical name.
@ -42,14 +44,14 @@ DeviceMgr::DeviceMgr(std::vector<std::unique_ptr<Device>> devices)
} }
} }
DeviceMgr::DeviceMgr(std::unique_ptr<Device> device) StaticDeviceMgr::StaticDeviceMgr(std::unique_ptr<Device> device)
: DeviceMgr([&device] { : StaticDeviceMgr([&device] {
std::vector<std::unique_ptr<Device>> vector; std::vector<std::unique_ptr<Device>> vector;
vector.push_back(std::move(device)); vector.push_back(std::move(device));
return vector; return vector;
}()) {} }()) {}
DeviceMgr::~DeviceMgr() { StaticDeviceMgr::~StaticDeviceMgr() {
// Release resources ahead of destroying the device manager as the resource // Release resources ahead of destroying the device manager as the resource
// destructors (e.g. ~IteratorResource) assume devices still exist. // destructors (e.g. ~IteratorResource) assume devices still exist.
for (auto& device : devices_) { for (auto& device : devices_) {
@ -57,14 +59,14 @@ DeviceMgr::~DeviceMgr() {
} }
} }
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) { StringPiece StaticDeviceMgr::CopyToBackingStore(StringPiece s) {
size_t n = s.size(); size_t n = s.size();
char* space = name_backing_store_.Alloc(n); char* space = name_backing_store_.Alloc(n);
memcpy(space, s.data(), n); memcpy(space, s.data(), n);
return StringPiece(space, n); return StringPiece(space, n);
} }
void DeviceMgr::ListDeviceAttributes( void StaticDeviceMgr::ListDeviceAttributes(
std::vector<DeviceAttributes>* devices) const { std::vector<DeviceAttributes>* devices) const {
devices->reserve(devices_.size()); devices->reserve(devices_.size());
for (const auto& dev : devices_) { for (const auto& dev : devices_) {
@ -72,7 +74,7 @@ void DeviceMgr::ListDeviceAttributes(
} }
} }
std::vector<Device*> DeviceMgr::ListDevices() const { std::vector<Device*> StaticDeviceMgr::ListDevices() const {
std::vector<Device*> devices(devices_.size()); std::vector<Device*> devices(devices_.size());
for (size_t i = 0; i < devices_.size(); ++i) { for (size_t i = 0; i < devices_.size(); ++i) {
devices[i] = devices_[i].get(); devices[i] = devices_[i].get();
@ -80,7 +82,7 @@ std::vector<Device*> DeviceMgr::ListDevices() const {
return devices; return devices;
} }
string DeviceMgr::DebugString() const { string StaticDeviceMgr::DebugString() const {
string out; string out;
for (const auto& dev : devices_) { for (const auto& dev : devices_) {
strings::StrAppend(&out, dev->name(), "\n"); strings::StrAppend(&out, dev->name(), "\n");
@ -88,7 +90,7 @@ string DeviceMgr::DebugString() const {
return out; return out;
} }
string DeviceMgr::DeviceMappingString() const { string StaticDeviceMgr::DeviceMappingString() const {
string out; string out;
for (const auto& dev : devices_) { for (const auto& dev : devices_) {
if (!dev->attributes().physical_device_desc().empty()) { if (!dev->attributes().physical_device_desc().empty()) {
@ -99,7 +101,7 @@ string DeviceMgr::DeviceMappingString() const {
return out; return out;
} }
Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const { Status StaticDeviceMgr::LookupDevice(StringPiece name, Device** device) const {
auto iter = device_map_.find(name); auto iter = device_map_.find(name);
if (iter == device_map_.end()) { if (iter == device_map_.end()) {
std::vector<StringPiece> device_names; std::vector<StringPiece> device_names;
@ -114,7 +116,8 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
return Status::OK(); return Status::OK();
} }
void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const { void StaticDeviceMgr::ClearContainers(
gtl::ArraySlice<string> containers) const {
Status s; Status s;
for (const auto& dev : devices_) { for (const auto& dev : devices_) {
if (containers.empty()) { if (containers.empty()) {
@ -131,7 +134,7 @@ void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
} }
} }
int DeviceMgr::NumDeviceType(const string& type) const { int StaticDeviceMgr::NumDeviceType(const string& type) const {
auto iter = device_type_counts_.find(type); auto iter = device_type_counts_.find(type);
if (iter != device_type_counts_.end()) return iter->second; if (iter != device_type_counts_.end()) return iter->second;
return 0; return 0;

View File

@ -33,38 +33,57 @@ namespace tensorflow {
class DeviceAttributes; class DeviceAttributes;
// Represents a set of devices.
class DeviceMgr { class DeviceMgr {
public: public:
// Constructs a DeviceMgr from a list of devices. DeviceMgr() = default;
// TODO(zhifengc): Other initialization information. virtual ~DeviceMgr();
explicit DeviceMgr(std::vector<std::unique_ptr<Device>> devices);
// Constructs a DeviceMgr managing a single device.
explicit DeviceMgr(std::unique_ptr<Device> device);
~DeviceMgr();
// Returns attributes of all devices. // Returns attributes of all devices.
void ListDeviceAttributes(std::vector<DeviceAttributes>* devices) const; virtual void ListDeviceAttributes(
std::vector<DeviceAttributes>* devices) const = 0;
// Returns raw pointers to the underlying devices. // Returns raw pointers to the underlying devices.
std::vector<Device*> ListDevices() const; virtual std::vector<Device*> ListDevices() const = 0;
// Returns a string listing all devices. // Returns a string listing all devices.
string DebugString() const; virtual string DebugString() const = 0;
// Returns a string of all the device mapping. // Returns a string of all the device mapping.
string DeviceMappingString() const; virtual string DeviceMappingString() const = 0;
// Assigns *device with pointer to Device of the given name. // Assigns *device with pointer to Device of the given name.
// Accepts either a full device name, or just the replica-local suffix. // Accepts either a full device name, or just the replica-local suffix.
Status LookupDevice(StringPiece name, Device** device) const; virtual Status LookupDevice(StringPiece name, Device** device) const = 0;
// Clears given containers of all devices if 'container' is // Clears given containers of all devices if 'container' is
// non-empty. Otherwise, clears default containers of all devices. // non-empty. Otherwise, clears default containers of all devices.
void ClearContainers(gtl::ArraySlice<string> containers) const; virtual void ClearContainers(gtl::ArraySlice<string> containers) const = 0;
int NumDeviceType(const string& type) const; virtual int NumDeviceType(const string& type) const = 0;
TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);
};
// Represents a static set of devices.
class StaticDeviceMgr : public DeviceMgr {
public:
// Constructs a StaticDeviceMgr from a list of devices.
explicit StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices);
// Constructs a StaticDeviceMgr managing a single device.
explicit StaticDeviceMgr(std::unique_ptr<Device> device);
~StaticDeviceMgr() override;
void ListDeviceAttributes(
std::vector<DeviceAttributes>* devices) const override;
std::vector<Device*> ListDevices() const override;
string DebugString() const override;
string DeviceMappingString() const override;
Status LookupDevice(StringPiece name, Device** device) const override;
void ClearContainers(gtl::ArraySlice<string> containers) const override;
int NumDeviceType(const string& type) const override;
private: private:
const std::vector<std::unique_ptr<Device>> devices_; const std::vector<std::unique_ptr<Device>> devices_;
@ -75,9 +94,8 @@ class DeviceMgr {
core::Arena name_backing_store_; // Storage for keys in device_map_ core::Arena name_backing_store_; // Storage for keys in device_map_
std::unordered_map<string, int> device_type_counts_; std::unordered_map<string, int> device_type_counts_;
TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr); TF_DISALLOW_COPY_AND_ASSIGN(StaticDeviceMgr);
}; };
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ #endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_

View File

@ -37,7 +37,7 @@ class DeviceResolverLocalTest : public ::testing::Test {
device_count->insert({"CPU", NUM_DEVS}); device_count->insert({"CPU", NUM_DEVS});
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices)); TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices))); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
drl_.reset(new DeviceResolverLocal(device_mgr_.get())); drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
} }

View File

@ -188,8 +188,8 @@ class DirectSessionFactory : public SessionFactory {
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices( TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices)); options, "/job:localhost/replica:0/task:0", &devices));
DirectSession* session = DirectSession* session = new DirectSession(
new DirectSession(options, new DeviceMgr(std::move(devices)), this); options, new StaticDeviceMgr(std::move(devices)), this);
{ {
mutex_lock l(sessions_lock_); mutex_lock l(sessions_lock_);
sessions_.push_back(session); sessions_.push_back(session);

View File

@ -45,7 +45,7 @@ class TestEnv {
devices.push_back( devices.push_back(
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
cpu_device_ = devices.back().get(); cpu_device_ = devices.back().get();
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices)); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
OptimizerOptions opts; OptimizerOptions opts;
pflr_ = tensorflow::MakeUnique<ProcessFunctionLibraryRuntime>( pflr_ = tensorflow::MakeUnique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, &flib_def_, device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, &flib_def_,

View File

@ -162,7 +162,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
for (const auto& fdef : flib) *(proto.add_function()) = fdef; for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto)); lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
OptimizerOptions opts; OptimizerOptions opts;
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices)); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_.reset(new ProcessFunctionLibraryRuntime( pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts)); opts));

View File

@ -62,7 +62,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
for (const auto& fdef : flib) *(proto.add_function()) = fdef; for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto)); lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
OptimizerOptions opts; OptimizerOptions opts;
device_mgr_.reset(new DeviceMgr(std::move(devices))); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_.reset(new ProcessFunctionLibraryRuntime( pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, default_thread_pool)); opts, default_thread_pool));

View File

@ -247,7 +247,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
} }
} }
if (!dev_mgr_ || device_type == DEVICE_CPU) { if (!dev_mgr_ || device_type == DEVICE_CPU) {
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices)); dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
} }
if (!gpu_ring_order_) { if (!gpu_ring_order_) {
gpu_ring_order_ = absl::make_unique<string>(); gpu_ring_order_ = absl::make_unique<string>();

View File

@ -47,7 +47,7 @@ class PartitioningUtilsTest : public ::testing::Test {
&devices)); &devices));
device0_ = devices[0].get(); device0_ = devices[0].get();
device1_ = devices[1].get(); device1_ = devices[1].get();
device_mgr_.reset(new DeviceMgr(std::move(devices))); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
for (auto d : device_mgr_->ListDevices()) { for (auto d : device_mgr_->ListDevices()) {
device_set_.AddDevice(d); device_set_.AddDevice(d);

View File

@ -94,7 +94,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0", TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
&devices)); &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices))); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
TF_CHECK_OK(device_mgr_->LookupDevice( TF_CHECK_OK(device_mgr_->LookupDevice(
"/job:a/replica:0/task:0/device:CPU:0", &device0_)); "/job:a/replica:0/task:0/device:CPU:0", &device0_));
TF_CHECK_OK(device_mgr_->LookupDevice( TF_CHECK_OK(device_mgr_->LookupDevice(

View File

@ -167,7 +167,7 @@ class RingGathererTest : public ::testing::Test {
if (!dev_mgr_ || device_type == DEVICE_CPU) { if (!dev_mgr_ || device_type == DEVICE_CPU) {
LOG(ERROR) << "resetting dev_mgr for " << local_devices.size() LOG(ERROR) << "resetting dev_mgr for " << local_devices.size()
<< " devices: "; << " devices: ";
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices)); dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
} }
if (!gpu_ring_order_) { if (!gpu_ring_order_) {
gpu_ring_order_ = absl::make_unique<string>(); gpu_ring_order_ = absl::make_unique<string>();

View File

@ -189,7 +189,7 @@ class RingReducerTest : public ::testing::Test {
if (!dev_mgr_ || device_type == DEVICE_CPU) { if (!dev_mgr_ || device_type == DEVICE_CPU) {
LOG(INFO) << "resetting dev_mgr for " << local_devices.size() LOG(INFO) << "resetting dev_mgr for " << local_devices.size()
<< " devices: "; << " devices: ";
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices)); dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
} }
if (!gpu_ring_order_) { if (!gpu_ring_order_) {
gpu_ring_order_ = absl::make_unique<string>(); gpu_ring_order_ = absl::make_unique<string>();

View File

@ -165,7 +165,7 @@ class DeviceResDistTest : public ::testing::Test {
device_type, device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i))); strings::StrCat(worker_name, "/device:", device_type, ":", i)));
} }
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices)); DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
device_mgrs_.push_back(dev_mgr); device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name]; std::vector<string>* dv = &dev_by_task_[worker_name];
for (auto* d : dev_mgr->ListDevices()) { for (auto* d : dev_mgr->ListDevices()) {

View File

@ -222,7 +222,7 @@ class CollRMADistTest : public ::testing::Test {
device_type, device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i))); strings::StrCat(worker_name, "/device:", device_type, ":", i)));
} }
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices)); DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
device_mgrs_.push_back(dev_mgr); device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name]; std::vector<string>* dv = &dev_by_task_[worker_name];
dv->clear(); dv->clear();

View File

@ -163,7 +163,7 @@ class DeviceResDistTest : public ::testing::Test {
strings::StrCat(worker_name, "/device:", device_type, ":", i), i, strings::StrCat(worker_name, "/device:", device_type, ":", i), i,
device_incarnation_base + i)); device_incarnation_base + i));
} }
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices)); DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
TestableDeviceResolverDistributed* dev_res = TestableDeviceResolverDistributed* dev_res =
new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name); new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name);
resolvers_[worker_name] = dev_res; resolvers_[worker_name] = dev_res;

View File

@ -85,7 +85,7 @@ class EagerServiceImplTest : public ::testing::Test {
worker_env_.rendezvous_mgr = &rendezvous_mgr_; worker_env_.rendezvous_mgr = &rendezvous_mgr_;
worker_env_.session_mgr = session_mgr_.get(); worker_env_.session_mgr = session_mgr_.get();
device_mgr_ = absl::make_unique<DeviceMgr>(DeviceFactory::NewDevice( device_mgr_ = absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0")); "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
worker_env_.local_devices = device_mgr_->ListDevices(); worker_env_.local_devices = device_mgr_->ListDevices();
worker_env_.device_mgr = device_mgr_.get(); worker_env_.device_mgr = device_mgr_.get();

View File

@ -48,7 +48,7 @@ class RemoteMgrTest : public ::testing::Test {
devices.push_back( devices.push_back(
DeviceFactory::NewDevice("CPU", {}, "/job:worker/replica:0/task:0")); DeviceFactory::NewDevice("CPU", {}, "/job:worker/replica:0/task:0"));
remote_device_ = devices.back().get(); remote_device_ = devices.back().get();
auto device_mgr = absl::make_unique<DeviceMgr>(std::move(devices)); auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
context_id_ = random::New64(); context_id_ = random::New64();
tensorflow::Rendezvous* rendezvous = tensorflow::Rendezvous* rendezvous =
new tensorflow::IntraProcessRendezvous(device_mgr.get()); new tensorflow::IntraProcessRendezvous(device_mgr.get());

View File

@ -163,7 +163,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
DeviceFactory::AddDevices(sess_opts, name_prefix, &devices)); DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
worker_env_.device_mgr = new DeviceMgr(std::move(devices)); worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices));
master_env_.local_devices = worker_env_.device_mgr->ListDevices(); master_env_.local_devices = worker_env_.device_mgr->ListDevices();
worker_env_.local_devices = worker_env_.device_mgr->ListDevices(); worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr

View File

@ -44,7 +44,7 @@ class RpcCollectiveExecutorMgrTest : public ::testing::Test {
device_count->insert({"CPU", NUM_DEVS}); device_count->insert({"CPU", NUM_DEVS});
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices)); TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices))); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed( std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed(
device_mgr_.get(), worker_cache, task_name)); device_mgr_.get(), worker_cache, task_name));
std::unique_ptr<CollectiveParamResolverDistributed> cpr( std::unique_ptr<CollectiveParamResolverDistributed> cpr(

View File

@ -101,7 +101,7 @@ Status SessionMgr::CreateSession(
renamed_devices.push_back(RenamedDevice::NewRenamedDevice( renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
worker_name, d, false, isolate_session_state)); worker_name, d, false, isolate_session_state));
} }
auto device_mgr = MakeUnique<DeviceMgr>(std::move(renamed_devices)); auto device_mgr = MakeUnique<StaticDeviceMgr>(std::move(renamed_devices));
LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) { LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) {
return device_mgr->LookupDevice(name, device); return device_mgr->LookupDevice(name, device);
}; };
@ -109,7 +109,7 @@ Status SessionMgr::CreateSession(
&cluster_devices); &cluster_devices);
std::unique_ptr<DeviceMgr> remote_devices; std::unique_ptr<DeviceMgr> remote_devices;
if (!cluster_device_attributes.empty()) if (!cluster_device_attributes.empty())
remote_devices = MakeUnique<DeviceMgr>(std::move(cluster_devices)); remote_devices = MakeUnique<StaticDeviceMgr>(std::move(cluster_devices));
auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get()); auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
worker_session.reset( worker_session.reset(
@ -122,7 +122,7 @@ Status SessionMgr::CreateSession(
&cluster_devices); &cluster_devices);
std::unique_ptr<DeviceMgr> remote_devices; std::unique_ptr<DeviceMgr> remote_devices;
if (!cluster_device_attributes.empty()) if (!cluster_device_attributes.empty())
remote_devices = MakeUnique<DeviceMgr>(std::move(cluster_devices)); remote_devices = MakeUnique<StaticDeviceMgr>(std::move(cluster_devices));
// Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so // Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so
// that resources using it can use its devices after the // that resources using it can use its devices after the
// WorkerSession has been deleted. // WorkerSession has been deleted.

View File

@ -46,7 +46,7 @@ class SessionMgrTest : public ::testing::Test {
SessionMgrTest() SessionMgrTest()
: mgr_(&env_, "/job:mnist/replica:0/task:0", : mgr_(&env_, "/job:mnist/replica:0/task:0",
std::unique_ptr<WorkerCacheInterface>(), factory_) { std::unique_ptr<WorkerCacheInterface>(), factory_) {
device_mgr_ = absl::make_unique<DeviceMgr>( device_mgr_ = absl::make_unique<StaticDeviceMgr>(
FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0")); FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0"));
env_.local_devices = device_mgr_->ListDevices(); env_.local_devices = device_mgr_->ListDevices();
env_.device_mgr = device_mgr_.get(); env_.device_mgr = device_mgr_.get();

View File

@ -239,7 +239,7 @@ Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
TF_RETURN_IF_ERROR(cpu_factory->CreateDevices( TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
options, "/job:localhost/replica:0/task:0", &devices)); options, "/job:localhost/replica:0/task:0", &devices));
Device* cpu_device = devices[0].get(); Device* cpu_device = devices[0].get();
std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(std::move(devices))); auto dvc_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
FunctionLibraryDefinition function_library(OpRegistry::Global(), FunctionLibraryDefinition function_library(OpRegistry::Global(),
graph_def.library()); graph_def.library());
Env* env = Env::Default(); Env* env = Env::Default();

View File

@ -117,7 +117,8 @@ class NcclTestBase : public ::testing::Test {
device_names.push_back(device->name()); device_names.push_back(device->name());
VLOG(2) << device->name(); VLOG(2) << device->name();
} }
if (!dev_mgr_) dev_mgr_.reset(new DeviceMgr(std::move(local_devices))); if (!dev_mgr_)
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
col_exec_ = new BaseCollectiveExecutor( col_exec_ = new BaseCollectiveExecutor(
&col_exec_mgr_, /*remote_access=*/nullptr, kStepId, dev_mgr_.get(), &col_exec_mgr_, /*remote_access=*/nullptr, kStepId, dev_mgr_.get(),
/*gpu_ring_order=*/nullptr); /*gpu_ring_order=*/nullptr);

View File

@ -415,7 +415,7 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices( TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices)); options, "/job:localhost/replica:0/task:0", &devices));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices)); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
resource_mgr_ = absl::make_unique<ResourceMgr>("default_container"); resource_mgr_ = absl::make_unique<ResourceMgr>("default_container");
FunctionDefLibrary proto; FunctionDefLibrary proto;

View File

@ -355,9 +355,10 @@ FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
// in its resource manager. The existing device will outlive the // in its resource manager. The existing device will outlive the
// IteratorResource, because we are storing the IteratorResource // IteratorResource, because we are storing the IteratorResource
// in that device's resource manager. // in that device's resource manager.
*device_mgr = absl::make_unique<DeviceMgr>(RenamedDevice::NewRenamedDevice( *device_mgr =
ctx->device()->name(), down_cast<Device*>(ctx->device()), absl::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
false /* owns_underlying */, false /* isolate_session_state */)); ctx->device()->name(), down_cast<Device*>(ctx->device()),
false /* owns_underlying */, false /* isolate_session_state */));
*flib_def = absl::make_unique<FunctionLibraryDefinition>( *flib_def = absl::make_unique<FunctionLibraryDefinition>(
*ctx->function_library()->GetFunctionLibraryDefinition()); *ctx->function_library()->GetFunctionLibraryDefinition());
*pflr = absl::make_unique<ProcessFunctionLibraryRuntime>( *pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(

View File

@ -27,7 +27,7 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
CHECK(device_) << "No device provided"; CHECK(device_) << "No device provided";
device_ = device.get(); device_ = device.get();
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(device)); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, flib_def_.get(), device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, flib_def_.get(),
OptimizerOptions()); OptimizerOptions());

View File

@ -78,7 +78,7 @@ class OpsTestBase : public ::testing::Test {
CHECK(device) << "Could not create CPU device"; CHECK(device) << "Could not create CPU device";
device_ = device.get(); device_ = device.get();
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(device)); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
allocator_ = device_->GetAllocator(AllocatorAttributes()); allocator_ = device_->GetAllocator(AllocatorAttributes());

View File

@ -39,7 +39,7 @@ tensorflow::Status DelegateData::Prepare(
session_options, "/job:localhost/replica:0/task:0", &devices)); session_options, "/job:localhost/replica:0/task:0", &devices));
auto device_mgr = auto device_mgr =
absl::make_unique<tensorflow::DeviceMgr>(std::move(devices)); absl::make_unique<tensorflow::StaticDeviceMgr>(std::move(devices));
// Note that Rendezvous is ref-counted so it will be automatically deleted. // Note that Rendezvous is ref-counted so it will be automatically deleted.
tensorflow::Rendezvous* rendezvous = tensorflow::Rendezvous* rendezvous =
new tensorflow::IntraProcessRendezvous(device_mgr.get()); new tensorflow::IntraProcessRendezvous(device_mgr.get());

View File

@ -2234,7 +2234,7 @@ bool InlineAllFunctions(GraphDef* graphdef) {
tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(), tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
graphdef_copy.library()); graphdef_copy.library());
tensorflow::DeviceMgr device_mgr(std::move(devices)); tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
tensorflow::OptimizerOptions o_opts; tensorflow::OptimizerOptions o_opts;
tensorflow::ProcessFunctionLibraryRuntime pflr( tensorflow::ProcessFunctionLibraryRuntime pflr(
&device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld, &device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,