Make DeviceMgr a pure virtual interface with StaticDeviceMgr as its only implementation.

This is a precursor to haoyuzhang@'s work on supporting dynamic device managers.

PiperOrigin-RevId: 266514057
This commit is contained in:
Derek Murray 2019-08-30 22:49:43 -07:00 committed by TensorFlower Gardener
parent f624f59710
commit 73e53d6343
41 changed files with 97 additions and 74 deletions

View File

@ -122,7 +122,7 @@ tensorflow::Status GetAllRemoteDevices(
n.WaitForNotification();
}
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
new tensorflow::DeviceMgr(std::move(remote_devices)));
new tensorflow::StaticDeviceMgr(std::move(remote_devices)));
TF_RETURN_IF_ERROR(status);
@ -385,7 +385,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
&devices);
if (!status->status.ok()) return nullptr;
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(std::move(devices)));
new tensorflow::StaticDeviceMgr(std::move(devices)));
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());

View File

@ -1193,7 +1193,7 @@ Status EncapsulateSubgraphsPass::Run(
}
std::unique_ptr<DeviceMgr> device_mgr =
absl::make_unique<DeviceMgr>(std::move(devices));
absl::make_unique<StaticDeviceMgr>(std::move(devices));
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(device_mgr.get(),

View File

@ -510,7 +510,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
TF_CHECK_OK(DeviceFactory::AddDevices(
session_options, "/job:localhost/replica:0/task:0", &devices));
OptimizerOptions opts;
auto device_mgr = absl::make_unique<DeviceMgr>(std::move(devices));
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);

View File

@ -232,7 +232,7 @@ class ExtractOutsideCompilationForFunctionTest : public ::testing::Test {
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
session_options, "/job:localhost/replica:0/task:0", &devices));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
}
Status ExtractOutsideCompilationTest(

View File

@ -71,7 +71,7 @@ class XlaKernelCreatorTest : public ::testing::Test {
lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
OpRegistry::Global(), proto);
OptimizerOptions opts;
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);

View File

@ -3456,7 +3456,7 @@ int main(int argc, char** argv) {
std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
tensorflow::SessionOptions(), "", &devices));
tensorflow::DeviceMgr device_mgr(std::move(devices));
tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
tensorflow::Device* ignored;
TF_QCHECK_OK(

View File

@ -474,7 +474,7 @@ class XlaCompiler {
int64 next_step_id_;
XlaCompilationDevice* device_; // Owned by device_mgr_
DeviceMgr device_mgr_;
StaticDeviceMgr device_mgr_;
// To avoid copying the client's function library, use a local function
// library and runtime for functions created as part of the functionalize

View File

@ -47,7 +47,7 @@ class BufRendezvousTest : public ::testing::Test {
const uint64 incarnation) {
std::vector<std::unique_ptr<Device>> devices;
devices.push_back(NewDevice(device, type, incarnation));
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
br_ = absl::make_unique<BufRendezvous>(123, dev_mgr_.get());
}

View File

@ -40,7 +40,7 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
device_count->insert({"CPU", NUM_DEVS});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(device_mgr_.get()));
std::unique_ptr<ParamResolverInterface> prl(

View File

@ -39,7 +39,7 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
device_count->insert({"CPU", NUM_DEVS});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
prl_.reset(new CollectiveParamResolverLocal(cp, device_mgr_.get(),
drl_.get(), task_name));

View File

@ -46,7 +46,7 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
device_count->insert({"CPU", NUM_DEVS});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
drl_ = absl::make_unique<DeviceResolverLocal>(device_mgr_.get());
prl_ = absl::make_unique<CollectiveParamResolverLocal>(
cp, device_mgr_.get(), drl_.get(), kTaskName);

View File

@ -47,7 +47,7 @@ Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
// Instantiate enough of the TF runtime to run `graph` on a single CPU device.
auto device_mgr = absl::make_unique<DeviceMgr>(DeviceFactory::NewDevice(
auto device_mgr = absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
"CPU", params.session_options, "/job:localhost/replica:0/task:0"));
Device* device = device_mgr->ListDevices()[0];
// Clone the `FunctionLibraryDefinition` to extend its lifetime extends beyond

View File

@ -25,7 +25,9 @@ limitations under the License.
namespace tensorflow {
DeviceMgr::DeviceMgr(std::vector<std::unique_ptr<Device>> devices)
DeviceMgr::~DeviceMgr() {}
StaticDeviceMgr::StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices)
: devices_(std::move(devices)), name_backing_store_(128) {
for (auto& d : devices_) {
// Register under the (1) full name and (2) canonical name.
@ -42,14 +44,14 @@ DeviceMgr::DeviceMgr(std::vector<std::unique_ptr<Device>> devices)
}
}
DeviceMgr::DeviceMgr(std::unique_ptr<Device> device)
: DeviceMgr([&device] {
StaticDeviceMgr::StaticDeviceMgr(std::unique_ptr<Device> device)
: StaticDeviceMgr([&device] {
std::vector<std::unique_ptr<Device>> vector;
vector.push_back(std::move(device));
return vector;
}()) {}
DeviceMgr::~DeviceMgr() {
StaticDeviceMgr::~StaticDeviceMgr() {
// Release resources ahead of destroying the device manager as the resource
// destructors (e.g. ~IteratorResource) assume devices still exist.
for (auto& device : devices_) {
@ -57,14 +59,14 @@ DeviceMgr::~DeviceMgr() {
}
}
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
StringPiece StaticDeviceMgr::CopyToBackingStore(StringPiece s) {
size_t n = s.size();
char* space = name_backing_store_.Alloc(n);
memcpy(space, s.data(), n);
return StringPiece(space, n);
}
void DeviceMgr::ListDeviceAttributes(
void StaticDeviceMgr::ListDeviceAttributes(
std::vector<DeviceAttributes>* devices) const {
devices->reserve(devices_.size());
for (const auto& dev : devices_) {
@ -72,7 +74,7 @@ void DeviceMgr::ListDeviceAttributes(
}
}
std::vector<Device*> DeviceMgr::ListDevices() const {
std::vector<Device*> StaticDeviceMgr::ListDevices() const {
std::vector<Device*> devices(devices_.size());
for (size_t i = 0; i < devices_.size(); ++i) {
devices[i] = devices_[i].get();
@ -80,7 +82,7 @@ std::vector<Device*> DeviceMgr::ListDevices() const {
return devices;
}
string DeviceMgr::DebugString() const {
string StaticDeviceMgr::DebugString() const {
string out;
for (const auto& dev : devices_) {
strings::StrAppend(&out, dev->name(), "\n");
@ -88,7 +90,7 @@ string DeviceMgr::DebugString() const {
return out;
}
string DeviceMgr::DeviceMappingString() const {
string StaticDeviceMgr::DeviceMappingString() const {
string out;
for (const auto& dev : devices_) {
if (!dev->attributes().physical_device_desc().empty()) {
@ -99,7 +101,7 @@ string DeviceMgr::DeviceMappingString() const {
return out;
}
Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
Status StaticDeviceMgr::LookupDevice(StringPiece name, Device** device) const {
auto iter = device_map_.find(name);
if (iter == device_map_.end()) {
std::vector<StringPiece> device_names;
@ -114,7 +116,8 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
return Status::OK();
}
void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
void StaticDeviceMgr::ClearContainers(
gtl::ArraySlice<string> containers) const {
Status s;
for (const auto& dev : devices_) {
if (containers.empty()) {
@ -131,7 +134,7 @@ void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
}
}
int DeviceMgr::NumDeviceType(const string& type) const {
int StaticDeviceMgr::NumDeviceType(const string& type) const {
auto iter = device_type_counts_.find(type);
if (iter != device_type_counts_.end()) return iter->second;
return 0;

View File

@ -33,38 +33,57 @@ namespace tensorflow {
class DeviceAttributes;
// Represents a set of devices.
class DeviceMgr {
public:
// Constructs a DeviceMgr from a list of devices.
// TODO(zhifengc): Other initialization information.
explicit DeviceMgr(std::vector<std::unique_ptr<Device>> devices);
// Constructs a DeviceMgr managing a single device.
explicit DeviceMgr(std::unique_ptr<Device> device);
~DeviceMgr();
DeviceMgr() = default;
virtual ~DeviceMgr();
// Returns attributes of all devices.
void ListDeviceAttributes(std::vector<DeviceAttributes>* devices) const;
virtual void ListDeviceAttributes(
std::vector<DeviceAttributes>* devices) const = 0;
// Returns raw pointers to the underlying devices.
std::vector<Device*> ListDevices() const;
virtual std::vector<Device*> ListDevices() const = 0;
// Returns a string listing all devices.
string DebugString() const;
virtual string DebugString() const = 0;
// Returns a string of all the device mapping.
string DeviceMappingString() const;
virtual string DeviceMappingString() const = 0;
// Assigns *device with pointer to Device of the given name.
// Accepts either a full device name, or just the replica-local suffix.
Status LookupDevice(StringPiece name, Device** device) const;
virtual Status LookupDevice(StringPiece name, Device** device) const = 0;
// Clears given containers of all devices if 'container' is
// non-empty. Otherwise, clears default containers of all devices.
void ClearContainers(gtl::ArraySlice<string> containers) const;
virtual void ClearContainers(gtl::ArraySlice<string> containers) const = 0;
int NumDeviceType(const string& type) const;
virtual int NumDeviceType(const string& type) const = 0;
TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);
};
// Represents a static set of devices.
class StaticDeviceMgr : public DeviceMgr {
public:
// Constructs a StaticDeviceMgr from a list of devices.
explicit StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices);
// Constructs a StaticDeviceMgr managing a single device.
explicit StaticDeviceMgr(std::unique_ptr<Device> device);
~StaticDeviceMgr() override;
void ListDeviceAttributes(
std::vector<DeviceAttributes>* devices) const override;
std::vector<Device*> ListDevices() const override;
string DebugString() const override;
string DeviceMappingString() const override;
Status LookupDevice(StringPiece name, Device** device) const override;
void ClearContainers(gtl::ArraySlice<string> containers) const override;
int NumDeviceType(const string& type) const override;
private:
const std::vector<std::unique_ptr<Device>> devices_;
@ -75,9 +94,8 @@ class DeviceMgr {
core::Arena name_backing_store_; // Storage for keys in device_map_
std::unordered_map<string, int> device_type_counts_;
TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);
TF_DISALLOW_COPY_AND_ASSIGN(StaticDeviceMgr);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_

View File

@ -37,7 +37,7 @@ class DeviceResolverLocalTest : public ::testing::Test {
device_count->insert({"CPU", NUM_DEVS});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
}

View File

@ -188,8 +188,8 @@ class DirectSessionFactory : public SessionFactory {
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices));
DirectSession* session =
new DirectSession(options, new DeviceMgr(std::move(devices)), this);
DirectSession* session = new DirectSession(
options, new StaticDeviceMgr(std::move(devices)), this);
{
mutex_lock l(sessions_lock_);
sessions_.push_back(session);

View File

@ -45,7 +45,7 @@ class TestEnv {
devices.push_back(
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
cpu_device_ = devices.back().get();
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
OptimizerOptions opts;
pflr_ = tensorflow::MakeUnique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, &flib_def_,

View File

@ -162,7 +162,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
OptimizerOptions opts;
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts));

View File

@ -62,7 +62,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
OptimizerOptions opts;
device_mgr_.reset(new DeviceMgr(std::move(devices)));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, default_thread_pool));

View File

@ -247,7 +247,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
}
}
if (!dev_mgr_ || device_type == DEVICE_CPU) {
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices));
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
}
if (!gpu_ring_order_) {
gpu_ring_order_ = absl::make_unique<string>();

View File

@ -47,7 +47,7 @@ class PartitioningUtilsTest : public ::testing::Test {
&devices));
device0_ = devices[0].get();
device1_ = devices[1].get();
device_mgr_.reset(new DeviceMgr(std::move(devices)));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
for (auto d : device_mgr_->ListDevices()) {
device_set_.AddDevice(d);

View File

@ -94,7 +94,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
&devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
TF_CHECK_OK(device_mgr_->LookupDevice(
"/job:a/replica:0/task:0/device:CPU:0", &device0_));
TF_CHECK_OK(device_mgr_->LookupDevice(

View File

@ -167,7 +167,7 @@ class RingGathererTest : public ::testing::Test {
if (!dev_mgr_ || device_type == DEVICE_CPU) {
LOG(ERROR) << "resetting dev_mgr for " << local_devices.size()
<< " devices: ";
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices));
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
}
if (!gpu_ring_order_) {
gpu_ring_order_ = absl::make_unique<string>();

View File

@ -189,7 +189,7 @@ class RingReducerTest : public ::testing::Test {
if (!dev_mgr_ || device_type == DEVICE_CPU) {
LOG(INFO) << "resetting dev_mgr for " << local_devices.size()
<< " devices: ";
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices));
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
}
if (!gpu_ring_order_) {
gpu_ring_order_ = absl::make_unique<string>();

View File

@ -165,7 +165,7 @@ class DeviceResDistTest : public ::testing::Test {
device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
}
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name];
for (auto* d : dev_mgr->ListDevices()) {

View File

@ -222,7 +222,7 @@ class CollRMADistTest : public ::testing::Test {
device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
}
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name];
dv->clear();

View File

@ -163,7 +163,7 @@ class DeviceResDistTest : public ::testing::Test {
strings::StrCat(worker_name, "/device:", device_type, ":", i), i,
device_incarnation_base + i));
}
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
TestableDeviceResolverDistributed* dev_res =
new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name);
resolvers_[worker_name] = dev_res;

View File

@ -85,7 +85,7 @@ class EagerServiceImplTest : public ::testing::Test {
worker_env_.rendezvous_mgr = &rendezvous_mgr_;
worker_env_.session_mgr = session_mgr_.get();
device_mgr_ = absl::make_unique<DeviceMgr>(DeviceFactory::NewDevice(
device_mgr_ = absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
worker_env_.local_devices = device_mgr_->ListDevices();
worker_env_.device_mgr = device_mgr_.get();

View File

@ -48,7 +48,7 @@ class RemoteMgrTest : public ::testing::Test {
devices.push_back(
DeviceFactory::NewDevice("CPU", {}, "/job:worker/replica:0/task:0"));
remote_device_ = devices.back().get();
auto device_mgr = absl::make_unique<DeviceMgr>(std::move(devices));
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
context_id_ = random::New64();
tensorflow::Rendezvous* rendezvous =
new tensorflow::IntraProcessRendezvous(device_mgr.get());

View File

@ -163,7 +163,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(
DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
worker_env_.device_mgr = new DeviceMgr(std::move(devices));
worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices));
master_env_.local_devices = worker_env_.device_mgr->ListDevices();
worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr

View File

@ -44,7 +44,7 @@ class RpcCollectiveExecutorMgrTest : public ::testing::Test {
device_count->insert({"CPU", NUM_DEVS});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed(
device_mgr_.get(), worker_cache, task_name));
std::unique_ptr<CollectiveParamResolverDistributed> cpr(

View File

@ -101,7 +101,7 @@ Status SessionMgr::CreateSession(
renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
worker_name, d, false, isolate_session_state));
}
auto device_mgr = MakeUnique<DeviceMgr>(std::move(renamed_devices));
auto device_mgr = MakeUnique<StaticDeviceMgr>(std::move(renamed_devices));
LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) {
return device_mgr->LookupDevice(name, device);
};
@ -109,7 +109,7 @@ Status SessionMgr::CreateSession(
&cluster_devices);
std::unique_ptr<DeviceMgr> remote_devices;
if (!cluster_device_attributes.empty())
remote_devices = MakeUnique<DeviceMgr>(std::move(cluster_devices));
remote_devices = MakeUnique<StaticDeviceMgr>(std::move(cluster_devices));
auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
worker_session.reset(
@ -122,7 +122,7 @@ Status SessionMgr::CreateSession(
&cluster_devices);
std::unique_ptr<DeviceMgr> remote_devices;
if (!cluster_device_attributes.empty())
remote_devices = MakeUnique<DeviceMgr>(std::move(cluster_devices));
remote_devices = MakeUnique<StaticDeviceMgr>(std::move(cluster_devices));
// Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so
// that resources using it can use its devices after the
// WorkerSession has been deleted.

View File

@ -46,7 +46,7 @@ class SessionMgrTest : public ::testing::Test {
SessionMgrTest()
: mgr_(&env_, "/job:mnist/replica:0/task:0",
std::unique_ptr<WorkerCacheInterface>(), factory_) {
device_mgr_ = absl::make_unique<DeviceMgr>(
device_mgr_ = absl::make_unique<StaticDeviceMgr>(
FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0"));
env_.local_devices = device_mgr_->ListDevices();
env_.device_mgr = device_mgr_.get();

View File

@ -239,7 +239,7 @@ Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
options, "/job:localhost/replica:0/task:0", &devices));
Device* cpu_device = devices[0].get();
std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(std::move(devices)));
auto dvc_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
FunctionLibraryDefinition function_library(OpRegistry::Global(),
graph_def.library());
Env* env = Env::Default();

View File

@ -117,7 +117,8 @@ class NcclTestBase : public ::testing::Test {
device_names.push_back(device->name());
VLOG(2) << device->name();
}
if (!dev_mgr_) dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
if (!dev_mgr_)
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
col_exec_ = new BaseCollectiveExecutor(
&col_exec_mgr_, /*remote_access=*/nullptr, kStepId, dev_mgr_.get(),
/*gpu_ring_order=*/nullptr);

View File

@ -415,7 +415,7 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
resource_mgr_ = absl::make_unique<ResourceMgr>("default_container");
FunctionDefLibrary proto;

View File

@ -355,9 +355,10 @@ FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
// in its resource manager. The existing device will outlive the
// IteratorResource, because we are storing the IteratorResource
// in that device's resource manager.
*device_mgr = absl::make_unique<DeviceMgr>(RenamedDevice::NewRenamedDevice(
ctx->device()->name(), down_cast<Device*>(ctx->device()),
false /* owns_underlying */, false /* isolate_session_state */));
*device_mgr =
absl::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
ctx->device()->name(), down_cast<Device*>(ctx->device()),
false /* owns_underlying */, false /* isolate_session_state */));
*flib_def = absl::make_unique<FunctionLibraryDefinition>(
*ctx->function_library()->GetFunctionLibraryDefinition());
*pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(

View File

@ -27,7 +27,7 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
CHECK(device_) << "No device provided";
device_ = device.get();
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(device));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, flib_def_.get(),
OptimizerOptions());

View File

@ -78,7 +78,7 @@ class OpsTestBase : public ::testing::Test {
CHECK(device) << "Could not create CPU device";
device_ = device.get();
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(device));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
allocator_ = device_->GetAllocator(AllocatorAttributes());

View File

@ -39,7 +39,7 @@ tensorflow::Status DelegateData::Prepare(
session_options, "/job:localhost/replica:0/task:0", &devices));
auto device_mgr =
absl::make_unique<tensorflow::DeviceMgr>(std::move(devices));
absl::make_unique<tensorflow::StaticDeviceMgr>(std::move(devices));
// Note that Rendezvous is ref-counted so it will be automatically deleted.
tensorflow::Rendezvous* rendezvous =
new tensorflow::IntraProcessRendezvous(device_mgr.get());

View File

@ -2234,7 +2234,7 @@ bool InlineAllFunctions(GraphDef* graphdef) {
tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
graphdef_copy.library());
tensorflow::DeviceMgr device_mgr(std::move(devices));
tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
tensorflow::OptimizerOptions o_opts;
tensorflow::ProcessFunctionLibraryRuntime pflr(
&device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,