Make DeviceMgr a pure virtual interface with StaticDeviceMgr as its only implementation.
This is a precursor to haoyuzhang@'s work on supporting dynamic device managers. PiperOrigin-RevId: 266514057
This commit is contained in:
parent
f624f59710
commit
73e53d6343
tensorflow
c/eager
compiler
jit
encapsulate_subgraphs_pass.ccencapsulate_subgraphs_pass_test.ccextract_outside_compilation_pass_test.ccxla_kernel_creator_test.cc
tests
tf2xla
core
common_runtime
buf_rendezvous_test.cccollective_executor_mgr_test.cccollective_param_resolver_local_test.cccollective_rma_local_test.cc
data
device_mgr.ccdevice_mgr.hdevice_resolver_local_test.ccdirect_session.cceager
function_test.ccfunction_threadpool_test.cchierarchical_tree_broadcaster_test.ccpartitioning_utils_test.ccprocess_function_library_runtime_test.ccring_gatherer_test.ccring_reducer_test.ccdistributed_runtime
collective_param_resolver_distributed_test.cccollective_rma_distributed_test.ccdevice_resolver_distributed_test.cc
eager
rpc
rpc_collective_executor_mgr_test.ccsession_mgr.ccsession_mgr_test.ccgrappler
kernels
lite
@ -122,7 +122,7 @@ tensorflow::Status GetAllRemoteDevices(
|
||||
n.WaitForNotification();
|
||||
}
|
||||
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
|
||||
new tensorflow::DeviceMgr(std::move(remote_devices)));
|
||||
new tensorflow::StaticDeviceMgr(std::move(remote_devices)));
|
||||
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
@ -385,7 +385,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
&devices);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
|
||||
new tensorflow::DeviceMgr(std::move(devices)));
|
||||
new tensorflow::StaticDeviceMgr(std::move(devices)));
|
||||
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
@ -1193,7 +1193,7 @@ Status EncapsulateSubgraphsPass::Run(
|
||||
}
|
||||
|
||||
std::unique_ptr<DeviceMgr> device_mgr =
|
||||
absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(device_mgr.get(),
|
||||
|
@ -510,7 +510,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(
|
||||
session_options, "/job:localhost/replica:0/task:0", &devices));
|
||||
OptimizerOptions opts;
|
||||
auto device_mgr = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(),
|
||||
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||
|
@ -232,7 +232,7 @@ class ExtractOutsideCompilationForFunctionTest : public ::testing::Test {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(
|
||||
session_options, "/job:localhost/replica:0/task:0", &devices));
|
||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
}
|
||||
|
||||
Status ExtractOutsideCompilationTest(
|
||||
|
@ -71,7 +71,7 @@ class XlaKernelCreatorTest : public ::testing::Test {
|
||||
lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
|
||||
OpRegistry::Global(), proto);
|
||||
OptimizerOptions opts;
|
||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||
|
@ -3456,7 +3456,7 @@ int main(int argc, char** argv) {
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||
TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
|
||||
tensorflow::SessionOptions(), "", &devices));
|
||||
tensorflow::DeviceMgr device_mgr(std::move(devices));
|
||||
tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
|
||||
|
||||
tensorflow::Device* ignored;
|
||||
TF_QCHECK_OK(
|
||||
|
@ -474,7 +474,7 @@ class XlaCompiler {
|
||||
int64 next_step_id_;
|
||||
|
||||
XlaCompilationDevice* device_; // Owned by device_mgr_
|
||||
DeviceMgr device_mgr_;
|
||||
StaticDeviceMgr device_mgr_;
|
||||
|
||||
// To avoid copying the client's function library, use a local function
|
||||
// library and runtime for functions created as part of the functionalize
|
||||
|
@ -47,7 +47,7 @@ class BufRendezvousTest : public ::testing::Test {
|
||||
const uint64 incarnation) {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
devices.push_back(NewDevice(device, type, incarnation));
|
||||
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
br_ = absl::make_unique<BufRendezvous>(123, dev_mgr_.get());
|
||||
}
|
||||
|
||||
|
@ -40,7 +40,7 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
|
||||
device_count->insert({"CPU", NUM_DEVS});
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
||||
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
std::unique_ptr<DeviceResolverInterface> drl(
|
||||
new DeviceResolverLocal(device_mgr_.get()));
|
||||
std::unique_ptr<ParamResolverInterface> prl(
|
||||
|
@ -39,7 +39,7 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
|
||||
device_count->insert({"CPU", NUM_DEVS});
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
||||
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
|
||||
prl_.reset(new CollectiveParamResolverLocal(cp, device_mgr_.get(),
|
||||
drl_.get(), task_name));
|
||||
|
@ -46,7 +46,7 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
|
||||
device_count->insert({"CPU", NUM_DEVS});
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices));
|
||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
drl_ = absl::make_unique<DeviceResolverLocal>(device_mgr_.get());
|
||||
prl_ = absl::make_unique<CollectiveParamResolverLocal>(
|
||||
cp, device_mgr_.get(), drl_.get(), kTaskName);
|
||||
|
@ -47,7 +47,7 @@ Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
|
||||
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
|
||||
|
||||
// Instantiate enough of the TF runtime to run `graph` on a single CPU device.
|
||||
auto device_mgr = absl::make_unique<DeviceMgr>(DeviceFactory::NewDevice(
|
||||
auto device_mgr = absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
||||
"CPU", params.session_options, "/job:localhost/replica:0/task:0"));
|
||||
Device* device = device_mgr->ListDevices()[0];
|
||||
// Clone the `FunctionLibraryDefinition` to extend its lifetime extends beyond
|
||||
|
@ -25,7 +25,9 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DeviceMgr::DeviceMgr(std::vector<std::unique_ptr<Device>> devices)
|
||||
DeviceMgr::~DeviceMgr() {}
|
||||
|
||||
StaticDeviceMgr::StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices)
|
||||
: devices_(std::move(devices)), name_backing_store_(128) {
|
||||
for (auto& d : devices_) {
|
||||
// Register under the (1) full name and (2) canonical name.
|
||||
@ -42,14 +44,14 @@ DeviceMgr::DeviceMgr(std::vector<std::unique_ptr<Device>> devices)
|
||||
}
|
||||
}
|
||||
|
||||
DeviceMgr::DeviceMgr(std::unique_ptr<Device> device)
|
||||
: DeviceMgr([&device] {
|
||||
StaticDeviceMgr::StaticDeviceMgr(std::unique_ptr<Device> device)
|
||||
: StaticDeviceMgr([&device] {
|
||||
std::vector<std::unique_ptr<Device>> vector;
|
||||
vector.push_back(std::move(device));
|
||||
return vector;
|
||||
}()) {}
|
||||
|
||||
DeviceMgr::~DeviceMgr() {
|
||||
StaticDeviceMgr::~StaticDeviceMgr() {
|
||||
// Release resources ahead of destroying the device manager as the resource
|
||||
// destructors (e.g. ~IteratorResource) assume devices still exist.
|
||||
for (auto& device : devices_) {
|
||||
@ -57,14 +59,14 @@ DeviceMgr::~DeviceMgr() {
|
||||
}
|
||||
}
|
||||
|
||||
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
|
||||
StringPiece StaticDeviceMgr::CopyToBackingStore(StringPiece s) {
|
||||
size_t n = s.size();
|
||||
char* space = name_backing_store_.Alloc(n);
|
||||
memcpy(space, s.data(), n);
|
||||
return StringPiece(space, n);
|
||||
}
|
||||
|
||||
void DeviceMgr::ListDeviceAttributes(
|
||||
void StaticDeviceMgr::ListDeviceAttributes(
|
||||
std::vector<DeviceAttributes>* devices) const {
|
||||
devices->reserve(devices_.size());
|
||||
for (const auto& dev : devices_) {
|
||||
@ -72,7 +74,7 @@ void DeviceMgr::ListDeviceAttributes(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Device*> DeviceMgr::ListDevices() const {
|
||||
std::vector<Device*> StaticDeviceMgr::ListDevices() const {
|
||||
std::vector<Device*> devices(devices_.size());
|
||||
for (size_t i = 0; i < devices_.size(); ++i) {
|
||||
devices[i] = devices_[i].get();
|
||||
@ -80,7 +82,7 @@ std::vector<Device*> DeviceMgr::ListDevices() const {
|
||||
return devices;
|
||||
}
|
||||
|
||||
string DeviceMgr::DebugString() const {
|
||||
string StaticDeviceMgr::DebugString() const {
|
||||
string out;
|
||||
for (const auto& dev : devices_) {
|
||||
strings::StrAppend(&out, dev->name(), "\n");
|
||||
@ -88,7 +90,7 @@ string DeviceMgr::DebugString() const {
|
||||
return out;
|
||||
}
|
||||
|
||||
string DeviceMgr::DeviceMappingString() const {
|
||||
string StaticDeviceMgr::DeviceMappingString() const {
|
||||
string out;
|
||||
for (const auto& dev : devices_) {
|
||||
if (!dev->attributes().physical_device_desc().empty()) {
|
||||
@ -99,7 +101,7 @@ string DeviceMgr::DeviceMappingString() const {
|
||||
return out;
|
||||
}
|
||||
|
||||
Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
|
||||
Status StaticDeviceMgr::LookupDevice(StringPiece name, Device** device) const {
|
||||
auto iter = device_map_.find(name);
|
||||
if (iter == device_map_.end()) {
|
||||
std::vector<StringPiece> device_names;
|
||||
@ -114,7 +116,8 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
|
||||
void StaticDeviceMgr::ClearContainers(
|
||||
gtl::ArraySlice<string> containers) const {
|
||||
Status s;
|
||||
for (const auto& dev : devices_) {
|
||||
if (containers.empty()) {
|
||||
@ -131,7 +134,7 @@ void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
|
||||
}
|
||||
}
|
||||
|
||||
int DeviceMgr::NumDeviceType(const string& type) const {
|
||||
int StaticDeviceMgr::NumDeviceType(const string& type) const {
|
||||
auto iter = device_type_counts_.find(type);
|
||||
if (iter != device_type_counts_.end()) return iter->second;
|
||||
return 0;
|
||||
|
@ -33,38 +33,57 @@ namespace tensorflow {
|
||||
|
||||
class DeviceAttributes;
|
||||
|
||||
// Represents a set of devices.
|
||||
class DeviceMgr {
|
||||
public:
|
||||
// Constructs a DeviceMgr from a list of devices.
|
||||
// TODO(zhifengc): Other initialization information.
|
||||
explicit DeviceMgr(std::vector<std::unique_ptr<Device>> devices);
|
||||
|
||||
// Constructs a DeviceMgr managing a single device.
|
||||
explicit DeviceMgr(std::unique_ptr<Device> device);
|
||||
|
||||
~DeviceMgr();
|
||||
DeviceMgr() = default;
|
||||
virtual ~DeviceMgr();
|
||||
|
||||
// Returns attributes of all devices.
|
||||
void ListDeviceAttributes(std::vector<DeviceAttributes>* devices) const;
|
||||
virtual void ListDeviceAttributes(
|
||||
std::vector<DeviceAttributes>* devices) const = 0;
|
||||
|
||||
// Returns raw pointers to the underlying devices.
|
||||
std::vector<Device*> ListDevices() const;
|
||||
virtual std::vector<Device*> ListDevices() const = 0;
|
||||
|
||||
// Returns a string listing all devices.
|
||||
string DebugString() const;
|
||||
virtual string DebugString() const = 0;
|
||||
|
||||
// Returns a string of all the device mapping.
|
||||
string DeviceMappingString() const;
|
||||
virtual string DeviceMappingString() const = 0;
|
||||
|
||||
// Assigns *device with pointer to Device of the given name.
|
||||
// Accepts either a full device name, or just the replica-local suffix.
|
||||
Status LookupDevice(StringPiece name, Device** device) const;
|
||||
virtual Status LookupDevice(StringPiece name, Device** device) const = 0;
|
||||
|
||||
// Clears given containers of all devices if 'container' is
|
||||
// non-empty. Otherwise, clears default containers of all devices.
|
||||
void ClearContainers(gtl::ArraySlice<string> containers) const;
|
||||
virtual void ClearContainers(gtl::ArraySlice<string> containers) const = 0;
|
||||
|
||||
int NumDeviceType(const string& type) const;
|
||||
virtual int NumDeviceType(const string& type) const = 0;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);
|
||||
};
|
||||
|
||||
// Represents a static set of devices.
|
||||
class StaticDeviceMgr : public DeviceMgr {
|
||||
public:
|
||||
// Constructs a StaticDeviceMgr from a list of devices.
|
||||
explicit StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices);
|
||||
|
||||
// Constructs a StaticDeviceMgr managing a single device.
|
||||
explicit StaticDeviceMgr(std::unique_ptr<Device> device);
|
||||
|
||||
~StaticDeviceMgr() override;
|
||||
|
||||
void ListDeviceAttributes(
|
||||
std::vector<DeviceAttributes>* devices) const override;
|
||||
std::vector<Device*> ListDevices() const override;
|
||||
string DebugString() const override;
|
||||
string DeviceMappingString() const override;
|
||||
Status LookupDevice(StringPiece name, Device** device) const override;
|
||||
void ClearContainers(gtl::ArraySlice<string> containers) const override;
|
||||
int NumDeviceType(const string& type) const override;
|
||||
|
||||
private:
|
||||
const std::vector<std::unique_ptr<Device>> devices_;
|
||||
@ -75,9 +94,8 @@ class DeviceMgr {
|
||||
core::Arena name_backing_store_; // Storage for keys in device_map_
|
||||
std::unordered_map<string, int> device_type_counts_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StaticDeviceMgr);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
|
||||
|
@ -37,7 +37,7 @@ class DeviceResolverLocalTest : public ::testing::Test {
|
||||
device_count->insert({"CPU", NUM_DEVS});
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
||||
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
|
||||
}
|
||||
|
||||
|
@ -188,8 +188,8 @@ class DirectSessionFactory : public SessionFactory {
|
||||
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
|
||||
options, "/job:localhost/replica:0/task:0", &devices));
|
||||
|
||||
DirectSession* session =
|
||||
new DirectSession(options, new DeviceMgr(std::move(devices)), this);
|
||||
DirectSession* session = new DirectSession(
|
||||
options, new StaticDeviceMgr(std::move(devices)), this);
|
||||
{
|
||||
mutex_lock l(sessions_lock_);
|
||||
sessions_.push_back(session);
|
||||
|
@ -45,7 +45,7 @@ class TestEnv {
|
||||
devices.push_back(
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
||||
cpu_device_ = devices.back().get();
|
||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
OptimizerOptions opts;
|
||||
pflr_ = tensorflow::MakeUnique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, &flib_def_,
|
||||
|
@ -162,7 +162,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
||||
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
||||
OptimizerOptions opts;
|
||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||
opts));
|
||||
|
@ -62,7 +62,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
||||
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
||||
OptimizerOptions opts;
|
||||
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||
opts, default_thread_pool));
|
||||
|
@ -247,7 +247,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
if (!dev_mgr_ || device_type == DEVICE_CPU) {
|
||||
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices));
|
||||
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
|
||||
}
|
||||
if (!gpu_ring_order_) {
|
||||
gpu_ring_order_ = absl::make_unique<string>();
|
||||
|
@ -47,7 +47,7 @@ class PartitioningUtilsTest : public ::testing::Test {
|
||||
&devices));
|
||||
device0_ = devices[0].get();
|
||||
device1_ = devices[1].get();
|
||||
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
|
||||
for (auto d : device_mgr_->ListDevices()) {
|
||||
device_set_.AddDevice(d);
|
||||
|
@ -94,7 +94,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
|
||||
&devices));
|
||||
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
TF_CHECK_OK(device_mgr_->LookupDevice(
|
||||
"/job:a/replica:0/task:0/device:CPU:0", &device0_));
|
||||
TF_CHECK_OK(device_mgr_->LookupDevice(
|
||||
|
@ -167,7 +167,7 @@ class RingGathererTest : public ::testing::Test {
|
||||
if (!dev_mgr_ || device_type == DEVICE_CPU) {
|
||||
LOG(ERROR) << "resetting dev_mgr for " << local_devices.size()
|
||||
<< " devices: ";
|
||||
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices));
|
||||
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
|
||||
}
|
||||
if (!gpu_ring_order_) {
|
||||
gpu_ring_order_ = absl::make_unique<string>();
|
||||
|
@ -189,7 +189,7 @@ class RingReducerTest : public ::testing::Test {
|
||||
if (!dev_mgr_ || device_type == DEVICE_CPU) {
|
||||
LOG(INFO) << "resetting dev_mgr for " << local_devices.size()
|
||||
<< " devices: ";
|
||||
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices));
|
||||
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
|
||||
}
|
||||
if (!gpu_ring_order_) {
|
||||
gpu_ring_order_ = absl::make_unique<string>();
|
||||
|
@ -165,7 +165,7 @@ class DeviceResDistTest : public ::testing::Test {
|
||||
device_type,
|
||||
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
|
||||
}
|
||||
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
|
||||
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
|
||||
device_mgrs_.push_back(dev_mgr);
|
||||
std::vector<string>* dv = &dev_by_task_[worker_name];
|
||||
for (auto* d : dev_mgr->ListDevices()) {
|
||||
|
@ -222,7 +222,7 @@ class CollRMADistTest : public ::testing::Test {
|
||||
device_type,
|
||||
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
|
||||
}
|
||||
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
|
||||
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
|
||||
device_mgrs_.push_back(dev_mgr);
|
||||
std::vector<string>* dv = &dev_by_task_[worker_name];
|
||||
dv->clear();
|
||||
|
@ -163,7 +163,7 @@ class DeviceResDistTest : public ::testing::Test {
|
||||
strings::StrCat(worker_name, "/device:", device_type, ":", i), i,
|
||||
device_incarnation_base + i));
|
||||
}
|
||||
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
|
||||
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
|
||||
TestableDeviceResolverDistributed* dev_res =
|
||||
new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name);
|
||||
resolvers_[worker_name] = dev_res;
|
||||
|
@ -85,7 +85,7 @@ class EagerServiceImplTest : public ::testing::Test {
|
||||
worker_env_.rendezvous_mgr = &rendezvous_mgr_;
|
||||
worker_env_.session_mgr = session_mgr_.get();
|
||||
|
||||
device_mgr_ = absl::make_unique<DeviceMgr>(DeviceFactory::NewDevice(
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
||||
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||
worker_env_.local_devices = device_mgr_->ListDevices();
|
||||
worker_env_.device_mgr = device_mgr_.get();
|
||||
|
@ -48,7 +48,7 @@ class RemoteMgrTest : public ::testing::Test {
|
||||
devices.push_back(
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:worker/replica:0/task:0"));
|
||||
remote_device_ = devices.back().get();
|
||||
auto device_mgr = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
context_id_ = random::New64();
|
||||
tensorflow::Rendezvous* rendezvous =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
@ -163,7 +163,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_RETURN_IF_ERROR(
|
||||
DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
|
||||
worker_env_.device_mgr = new DeviceMgr(std::move(devices));
|
||||
worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices));
|
||||
master_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
||||
worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
||||
worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr
|
||||
|
@ -44,7 +44,7 @@ class RpcCollectiveExecutorMgrTest : public ::testing::Test {
|
||||
device_count->insert({"CPU", NUM_DEVS});
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
||||
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed(
|
||||
device_mgr_.get(), worker_cache, task_name));
|
||||
std::unique_ptr<CollectiveParamResolverDistributed> cpr(
|
||||
|
@ -101,7 +101,7 @@ Status SessionMgr::CreateSession(
|
||||
renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
|
||||
worker_name, d, false, isolate_session_state));
|
||||
}
|
||||
auto device_mgr = MakeUnique<DeviceMgr>(std::move(renamed_devices));
|
||||
auto device_mgr = MakeUnique<StaticDeviceMgr>(std::move(renamed_devices));
|
||||
LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) {
|
||||
return device_mgr->LookupDevice(name, device);
|
||||
};
|
||||
@ -109,7 +109,7 @@ Status SessionMgr::CreateSession(
|
||||
&cluster_devices);
|
||||
std::unique_ptr<DeviceMgr> remote_devices;
|
||||
if (!cluster_device_attributes.empty())
|
||||
remote_devices = MakeUnique<DeviceMgr>(std::move(cluster_devices));
|
||||
remote_devices = MakeUnique<StaticDeviceMgr>(std::move(cluster_devices));
|
||||
|
||||
auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
|
||||
worker_session.reset(
|
||||
@ -122,7 +122,7 @@ Status SessionMgr::CreateSession(
|
||||
&cluster_devices);
|
||||
std::unique_ptr<DeviceMgr> remote_devices;
|
||||
if (!cluster_device_attributes.empty())
|
||||
remote_devices = MakeUnique<DeviceMgr>(std::move(cluster_devices));
|
||||
remote_devices = MakeUnique<StaticDeviceMgr>(std::move(cluster_devices));
|
||||
// Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so
|
||||
// that resources using it can use its devices after the
|
||||
// WorkerSession has been deleted.
|
||||
|
@ -46,7 +46,7 @@ class SessionMgrTest : public ::testing::Test {
|
||||
SessionMgrTest()
|
||||
: mgr_(&env_, "/job:mnist/replica:0/task:0",
|
||||
std::unique_ptr<WorkerCacheInterface>(), factory_) {
|
||||
device_mgr_ = absl::make_unique<DeviceMgr>(
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(
|
||||
FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0"));
|
||||
env_.local_devices = device_mgr_->ListDevices();
|
||||
env_.device_mgr = device_mgr_.get();
|
||||
|
@ -239,7 +239,7 @@ Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
|
||||
TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
|
||||
options, "/job:localhost/replica:0/task:0", &devices));
|
||||
Device* cpu_device = devices[0].get();
|
||||
std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(std::move(devices)));
|
||||
auto dvc_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
FunctionLibraryDefinition function_library(OpRegistry::Global(),
|
||||
graph_def.library());
|
||||
Env* env = Env::Default();
|
||||
|
@ -117,7 +117,8 @@ class NcclTestBase : public ::testing::Test {
|
||||
device_names.push_back(device->name());
|
||||
VLOG(2) << device->name();
|
||||
}
|
||||
if (!dev_mgr_) dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
|
||||
if (!dev_mgr_)
|
||||
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
|
||||
col_exec_ = new BaseCollectiveExecutor(
|
||||
&col_exec_mgr_, /*remote_access=*/nullptr, kStepId, dev_mgr_.get(),
|
||||
/*gpu_ring_order=*/nullptr);
|
||||
|
@ -415,7 +415,7 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
|
||||
options, "/job:localhost/replica:0/task:0", &devices));
|
||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
resource_mgr_ = absl::make_unique<ResourceMgr>("default_container");
|
||||
|
||||
FunctionDefLibrary proto;
|
||||
|
@ -355,9 +355,10 @@ FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
|
||||
// in its resource manager. The existing device will outlive the
|
||||
// IteratorResource, because we are storing the IteratorResource
|
||||
// in that device's resource manager.
|
||||
*device_mgr = absl::make_unique<DeviceMgr>(RenamedDevice::NewRenamedDevice(
|
||||
ctx->device()->name(), down_cast<Device*>(ctx->device()),
|
||||
false /* owns_underlying */, false /* isolate_session_state */));
|
||||
*device_mgr =
|
||||
absl::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
|
||||
ctx->device()->name(), down_cast<Device*>(ctx->device()),
|
||||
false /* owns_underlying */, false /* isolate_session_state */));
|
||||
*flib_def = absl::make_unique<FunctionLibraryDefinition>(
|
||||
*ctx->function_library()->GetFunctionLibraryDefinition());
|
||||
*pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
|
@ -27,7 +27,7 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
|
||||
CHECK(device_) << "No device provided";
|
||||
|
||||
device_ = device.get();
|
||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(device));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
|
||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, flib_def_.get(),
|
||||
OptimizerOptions());
|
||||
|
@ -78,7 +78,7 @@ class OpsTestBase : public ::testing::Test {
|
||||
CHECK(device) << "Could not create CPU device";
|
||||
|
||||
device_ = device.get();
|
||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(device));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
|
||||
|
||||
allocator_ = device_->GetAllocator(AllocatorAttributes());
|
||||
|
||||
|
@ -39,7 +39,7 @@ tensorflow::Status DelegateData::Prepare(
|
||||
session_options, "/job:localhost/replica:0/task:0", &devices));
|
||||
|
||||
auto device_mgr =
|
||||
absl::make_unique<tensorflow::DeviceMgr>(std::move(devices));
|
||||
absl::make_unique<tensorflow::StaticDeviceMgr>(std::move(devices));
|
||||
// Note that Rendezvous is ref-counted so it will be automatically deleted.
|
||||
tensorflow::Rendezvous* rendezvous =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
@ -2234,7 +2234,7 @@ bool InlineAllFunctions(GraphDef* graphdef) {
|
||||
|
||||
tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
|
||||
graphdef_copy.library());
|
||||
tensorflow::DeviceMgr device_mgr(std::move(devices));
|
||||
tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
|
||||
tensorflow::OptimizerOptions o_opts;
|
||||
tensorflow::ProcessFunctionLibraryRuntime pflr(
|
||||
&device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
|
||||
|
Loading…
Reference in New Issue
Block a user