Make DeviceMgr a pure virtual interface with StaticDeviceMgr as its only implementation.
This is a precursor to haoyuzhang@'s work on supporting dynamic device managers. PiperOrigin-RevId: 266514057
This commit is contained in:
parent
f624f59710
commit
73e53d6343
@ -122,7 +122,7 @@ tensorflow::Status GetAllRemoteDevices(
|
|||||||
n.WaitForNotification();
|
n.WaitForNotification();
|
||||||
}
|
}
|
||||||
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
|
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
|
||||||
new tensorflow::DeviceMgr(std::move(remote_devices)));
|
new tensorflow::StaticDeviceMgr(std::move(remote_devices)));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(status);
|
TF_RETURN_IF_ERROR(status);
|
||||||
|
|
||||||
@ -385,7 +385,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
|||||||
&devices);
|
&devices);
|
||||||
if (!status->status.ok()) return nullptr;
|
if (!status->status.ok()) return nullptr;
|
||||||
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
|
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
|
||||||
new tensorflow::DeviceMgr(std::move(devices)));
|
new tensorflow::StaticDeviceMgr(std::move(devices)));
|
||||||
|
|
||||||
tensorflow::Rendezvous* r =
|
tensorflow::Rendezvous* r =
|
||||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||||
|
@ -1193,7 +1193,7 @@ Status EncapsulateSubgraphsPass::Run(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<DeviceMgr> device_mgr =
|
std::unique_ptr<DeviceMgr> device_mgr =
|
||||||
absl::make_unique<DeviceMgr>(std::move(devices));
|
absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
OptimizerOptions opts;
|
OptimizerOptions opts;
|
||||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||||
new ProcessFunctionLibraryRuntime(device_mgr.get(),
|
new ProcessFunctionLibraryRuntime(device_mgr.get(),
|
||||||
|
@ -510,7 +510,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
|
|||||||
TF_CHECK_OK(DeviceFactory::AddDevices(
|
TF_CHECK_OK(DeviceFactory::AddDevices(
|
||||||
session_options, "/job:localhost/replica:0/task:0", &devices));
|
session_options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
OptimizerOptions opts;
|
OptimizerOptions opts;
|
||||||
auto device_mgr = absl::make_unique<DeviceMgr>(std::move(devices));
|
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||||
device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(),
|
device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(),
|
||||||
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||||
|
@ -232,7 +232,7 @@ class ExtractOutsideCompilationForFunctionTest : public ::testing::Test {
|
|||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(
|
TF_CHECK_OK(DeviceFactory::AddDevices(
|
||||||
session_options, "/job:localhost/replica:0/task:0", &devices));
|
session_options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ExtractOutsideCompilationTest(
|
Status ExtractOutsideCompilationTest(
|
||||||
|
@ -71,7 +71,7 @@ class XlaKernelCreatorTest : public ::testing::Test {
|
|||||||
lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
|
lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
|
||||||
OpRegistry::Global(), proto);
|
OpRegistry::Global(), proto);
|
||||||
OptimizerOptions opts;
|
OptimizerOptions opts;
|
||||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||||
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||||
|
@ -3456,7 +3456,7 @@ int main(int argc, char** argv) {
|
|||||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||||
TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
|
TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
|
||||||
tensorflow::SessionOptions(), "", &devices));
|
tensorflow::SessionOptions(), "", &devices));
|
||||||
tensorflow::DeviceMgr device_mgr(std::move(devices));
|
tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
|
||||||
|
|
||||||
tensorflow::Device* ignored;
|
tensorflow::Device* ignored;
|
||||||
TF_QCHECK_OK(
|
TF_QCHECK_OK(
|
||||||
|
@ -474,7 +474,7 @@ class XlaCompiler {
|
|||||||
int64 next_step_id_;
|
int64 next_step_id_;
|
||||||
|
|
||||||
XlaCompilationDevice* device_; // Owned by device_mgr_
|
XlaCompilationDevice* device_; // Owned by device_mgr_
|
||||||
DeviceMgr device_mgr_;
|
StaticDeviceMgr device_mgr_;
|
||||||
|
|
||||||
// To avoid copying the client's function library, use a local function
|
// To avoid copying the client's function library, use a local function
|
||||||
// library and runtime for functions created as part of the functionalize
|
// library and runtime for functions created as part of the functionalize
|
||||||
|
@ -47,7 +47,7 @@ class BufRendezvousTest : public ::testing::Test {
|
|||||||
const uint64 incarnation) {
|
const uint64 incarnation) {
|
||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
devices.push_back(NewDevice(device, type, incarnation));
|
devices.push_back(NewDevice(device, type, incarnation));
|
||||||
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
br_ = absl::make_unique<BufRendezvous>(123, dev_mgr_.get());
|
br_ = absl::make_unique<BufRendezvous>(123, dev_mgr_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
|
|||||||
device_count->insert({"CPU", NUM_DEVS});
|
device_count->insert({"CPU", NUM_DEVS});
|
||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
||||||
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
std::unique_ptr<DeviceResolverInterface> drl(
|
std::unique_ptr<DeviceResolverInterface> drl(
|
||||||
new DeviceResolverLocal(device_mgr_.get()));
|
new DeviceResolverLocal(device_mgr_.get()));
|
||||||
std::unique_ptr<ParamResolverInterface> prl(
|
std::unique_ptr<ParamResolverInterface> prl(
|
||||||
|
@ -39,7 +39,7 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
|
|||||||
device_count->insert({"CPU", NUM_DEVS});
|
device_count->insert({"CPU", NUM_DEVS});
|
||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
||||||
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
|
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
|
||||||
prl_.reset(new CollectiveParamResolverLocal(cp, device_mgr_.get(),
|
prl_.reset(new CollectiveParamResolverLocal(cp, device_mgr_.get(),
|
||||||
drl_.get(), task_name));
|
drl_.get(), task_name));
|
||||||
|
@ -46,7 +46,7 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
|
|||||||
device_count->insert({"CPU", NUM_DEVS});
|
device_count->insert({"CPU", NUM_DEVS});
|
||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices));
|
TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices));
|
||||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
drl_ = absl::make_unique<DeviceResolverLocal>(device_mgr_.get());
|
drl_ = absl::make_unique<DeviceResolverLocal>(device_mgr_.get());
|
||||||
prl_ = absl::make_unique<CollectiveParamResolverLocal>(
|
prl_ = absl::make_unique<CollectiveParamResolverLocal>(
|
||||||
cp, device_mgr_.get(), drl_.get(), kTaskName);
|
cp, device_mgr_.get(), drl_.get(), kTaskName);
|
||||||
|
@ -47,7 +47,7 @@ Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
|
|||||||
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
|
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
|
||||||
|
|
||||||
// Instantiate enough of the TF runtime to run `graph` on a single CPU device.
|
// Instantiate enough of the TF runtime to run `graph` on a single CPU device.
|
||||||
auto device_mgr = absl::make_unique<DeviceMgr>(DeviceFactory::NewDevice(
|
auto device_mgr = absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
||||||
"CPU", params.session_options, "/job:localhost/replica:0/task:0"));
|
"CPU", params.session_options, "/job:localhost/replica:0/task:0"));
|
||||||
Device* device = device_mgr->ListDevices()[0];
|
Device* device = device_mgr->ListDevices()[0];
|
||||||
// Clone the `FunctionLibraryDefinition` to extend its lifetime extends beyond
|
// Clone the `FunctionLibraryDefinition` to extend its lifetime extends beyond
|
||||||
|
@ -25,7 +25,9 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
DeviceMgr::DeviceMgr(std::vector<std::unique_ptr<Device>> devices)
|
DeviceMgr::~DeviceMgr() {}
|
||||||
|
|
||||||
|
StaticDeviceMgr::StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices)
|
||||||
: devices_(std::move(devices)), name_backing_store_(128) {
|
: devices_(std::move(devices)), name_backing_store_(128) {
|
||||||
for (auto& d : devices_) {
|
for (auto& d : devices_) {
|
||||||
// Register under the (1) full name and (2) canonical name.
|
// Register under the (1) full name and (2) canonical name.
|
||||||
@ -42,14 +44,14 @@ DeviceMgr::DeviceMgr(std::vector<std::unique_ptr<Device>> devices)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceMgr::DeviceMgr(std::unique_ptr<Device> device)
|
StaticDeviceMgr::StaticDeviceMgr(std::unique_ptr<Device> device)
|
||||||
: DeviceMgr([&device] {
|
: StaticDeviceMgr([&device] {
|
||||||
std::vector<std::unique_ptr<Device>> vector;
|
std::vector<std::unique_ptr<Device>> vector;
|
||||||
vector.push_back(std::move(device));
|
vector.push_back(std::move(device));
|
||||||
return vector;
|
return vector;
|
||||||
}()) {}
|
}()) {}
|
||||||
|
|
||||||
DeviceMgr::~DeviceMgr() {
|
StaticDeviceMgr::~StaticDeviceMgr() {
|
||||||
// Release resources ahead of destroying the device manager as the resource
|
// Release resources ahead of destroying the device manager as the resource
|
||||||
// destructors (e.g. ~IteratorResource) assume devices still exist.
|
// destructors (e.g. ~IteratorResource) assume devices still exist.
|
||||||
for (auto& device : devices_) {
|
for (auto& device : devices_) {
|
||||||
@ -57,14 +59,14 @@ DeviceMgr::~DeviceMgr() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
|
StringPiece StaticDeviceMgr::CopyToBackingStore(StringPiece s) {
|
||||||
size_t n = s.size();
|
size_t n = s.size();
|
||||||
char* space = name_backing_store_.Alloc(n);
|
char* space = name_backing_store_.Alloc(n);
|
||||||
memcpy(space, s.data(), n);
|
memcpy(space, s.data(), n);
|
||||||
return StringPiece(space, n);
|
return StringPiece(space, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DeviceMgr::ListDeviceAttributes(
|
void StaticDeviceMgr::ListDeviceAttributes(
|
||||||
std::vector<DeviceAttributes>* devices) const {
|
std::vector<DeviceAttributes>* devices) const {
|
||||||
devices->reserve(devices_.size());
|
devices->reserve(devices_.size());
|
||||||
for (const auto& dev : devices_) {
|
for (const auto& dev : devices_) {
|
||||||
@ -72,7 +74,7 @@ void DeviceMgr::ListDeviceAttributes(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Device*> DeviceMgr::ListDevices() const {
|
std::vector<Device*> StaticDeviceMgr::ListDevices() const {
|
||||||
std::vector<Device*> devices(devices_.size());
|
std::vector<Device*> devices(devices_.size());
|
||||||
for (size_t i = 0; i < devices_.size(); ++i) {
|
for (size_t i = 0; i < devices_.size(); ++i) {
|
||||||
devices[i] = devices_[i].get();
|
devices[i] = devices_[i].get();
|
||||||
@ -80,7 +82,7 @@ std::vector<Device*> DeviceMgr::ListDevices() const {
|
|||||||
return devices;
|
return devices;
|
||||||
}
|
}
|
||||||
|
|
||||||
string DeviceMgr::DebugString() const {
|
string StaticDeviceMgr::DebugString() const {
|
||||||
string out;
|
string out;
|
||||||
for (const auto& dev : devices_) {
|
for (const auto& dev : devices_) {
|
||||||
strings::StrAppend(&out, dev->name(), "\n");
|
strings::StrAppend(&out, dev->name(), "\n");
|
||||||
@ -88,7 +90,7 @@ string DeviceMgr::DebugString() const {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
string DeviceMgr::DeviceMappingString() const {
|
string StaticDeviceMgr::DeviceMappingString() const {
|
||||||
string out;
|
string out;
|
||||||
for (const auto& dev : devices_) {
|
for (const auto& dev : devices_) {
|
||||||
if (!dev->attributes().physical_device_desc().empty()) {
|
if (!dev->attributes().physical_device_desc().empty()) {
|
||||||
@ -99,7 +101,7 @@ string DeviceMgr::DeviceMappingString() const {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
|
Status StaticDeviceMgr::LookupDevice(StringPiece name, Device** device) const {
|
||||||
auto iter = device_map_.find(name);
|
auto iter = device_map_.find(name);
|
||||||
if (iter == device_map_.end()) {
|
if (iter == device_map_.end()) {
|
||||||
std::vector<StringPiece> device_names;
|
std::vector<StringPiece> device_names;
|
||||||
@ -114,7 +116,8 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
|
void StaticDeviceMgr::ClearContainers(
|
||||||
|
gtl::ArraySlice<string> containers) const {
|
||||||
Status s;
|
Status s;
|
||||||
for (const auto& dev : devices_) {
|
for (const auto& dev : devices_) {
|
||||||
if (containers.empty()) {
|
if (containers.empty()) {
|
||||||
@ -131,7 +134,7 @@ void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int DeviceMgr::NumDeviceType(const string& type) const {
|
int StaticDeviceMgr::NumDeviceType(const string& type) const {
|
||||||
auto iter = device_type_counts_.find(type);
|
auto iter = device_type_counts_.find(type);
|
||||||
if (iter != device_type_counts_.end()) return iter->second;
|
if (iter != device_type_counts_.end()) return iter->second;
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -33,38 +33,57 @@ namespace tensorflow {
|
|||||||
|
|
||||||
class DeviceAttributes;
|
class DeviceAttributes;
|
||||||
|
|
||||||
|
// Represents a set of devices.
|
||||||
class DeviceMgr {
|
class DeviceMgr {
|
||||||
public:
|
public:
|
||||||
// Constructs a DeviceMgr from a list of devices.
|
DeviceMgr() = default;
|
||||||
// TODO(zhifengc): Other initialization information.
|
virtual ~DeviceMgr();
|
||||||
explicit DeviceMgr(std::vector<std::unique_ptr<Device>> devices);
|
|
||||||
|
|
||||||
// Constructs a DeviceMgr managing a single device.
|
|
||||||
explicit DeviceMgr(std::unique_ptr<Device> device);
|
|
||||||
|
|
||||||
~DeviceMgr();
|
|
||||||
|
|
||||||
// Returns attributes of all devices.
|
// Returns attributes of all devices.
|
||||||
void ListDeviceAttributes(std::vector<DeviceAttributes>* devices) const;
|
virtual void ListDeviceAttributes(
|
||||||
|
std::vector<DeviceAttributes>* devices) const = 0;
|
||||||
|
|
||||||
// Returns raw pointers to the underlying devices.
|
// Returns raw pointers to the underlying devices.
|
||||||
std::vector<Device*> ListDevices() const;
|
virtual std::vector<Device*> ListDevices() const = 0;
|
||||||
|
|
||||||
// Returns a string listing all devices.
|
// Returns a string listing all devices.
|
||||||
string DebugString() const;
|
virtual string DebugString() const = 0;
|
||||||
|
|
||||||
// Returns a string of all the device mapping.
|
// Returns a string of all the device mapping.
|
||||||
string DeviceMappingString() const;
|
virtual string DeviceMappingString() const = 0;
|
||||||
|
|
||||||
// Assigns *device with pointer to Device of the given name.
|
// Assigns *device with pointer to Device of the given name.
|
||||||
// Accepts either a full device name, or just the replica-local suffix.
|
// Accepts either a full device name, or just the replica-local suffix.
|
||||||
Status LookupDevice(StringPiece name, Device** device) const;
|
virtual Status LookupDevice(StringPiece name, Device** device) const = 0;
|
||||||
|
|
||||||
// Clears given containers of all devices if 'container' is
|
// Clears given containers of all devices if 'container' is
|
||||||
// non-empty. Otherwise, clears default containers of all devices.
|
// non-empty. Otherwise, clears default containers of all devices.
|
||||||
void ClearContainers(gtl::ArraySlice<string> containers) const;
|
virtual void ClearContainers(gtl::ArraySlice<string> containers) const = 0;
|
||||||
|
|
||||||
int NumDeviceType(const string& type) const;
|
virtual int NumDeviceType(const string& type) const = 0;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Represents a static set of devices.
|
||||||
|
class StaticDeviceMgr : public DeviceMgr {
|
||||||
|
public:
|
||||||
|
// Constructs a StaticDeviceMgr from a list of devices.
|
||||||
|
explicit StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices);
|
||||||
|
|
||||||
|
// Constructs a StaticDeviceMgr managing a single device.
|
||||||
|
explicit StaticDeviceMgr(std::unique_ptr<Device> device);
|
||||||
|
|
||||||
|
~StaticDeviceMgr() override;
|
||||||
|
|
||||||
|
void ListDeviceAttributes(
|
||||||
|
std::vector<DeviceAttributes>* devices) const override;
|
||||||
|
std::vector<Device*> ListDevices() const override;
|
||||||
|
string DebugString() const override;
|
||||||
|
string DeviceMappingString() const override;
|
||||||
|
Status LookupDevice(StringPiece name, Device** device) const override;
|
||||||
|
void ClearContainers(gtl::ArraySlice<string> containers) const override;
|
||||||
|
int NumDeviceType(const string& type) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const std::vector<std::unique_ptr<Device>> devices_;
|
const std::vector<std::unique_ptr<Device>> devices_;
|
||||||
@ -75,9 +94,8 @@ class DeviceMgr {
|
|||||||
core::Arena name_backing_store_; // Storage for keys in device_map_
|
core::Arena name_backing_store_; // Storage for keys in device_map_
|
||||||
std::unordered_map<string, int> device_type_counts_;
|
std::unordered_map<string, int> device_type_counts_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);
|
TF_DISALLOW_COPY_AND_ASSIGN(StaticDeviceMgr);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
|
||||||
|
@ -37,7 +37,7 @@ class DeviceResolverLocalTest : public ::testing::Test {
|
|||||||
device_count->insert({"CPU", NUM_DEVS});
|
device_count->insert({"CPU", NUM_DEVS});
|
||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
||||||
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
|
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -188,8 +188,8 @@ class DirectSessionFactory : public SessionFactory {
|
|||||||
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
|
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
|
||||||
options, "/job:localhost/replica:0/task:0", &devices));
|
options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
|
|
||||||
DirectSession* session =
|
DirectSession* session = new DirectSession(
|
||||||
new DirectSession(options, new DeviceMgr(std::move(devices)), this);
|
options, new StaticDeviceMgr(std::move(devices)), this);
|
||||||
{
|
{
|
||||||
mutex_lock l(sessions_lock_);
|
mutex_lock l(sessions_lock_);
|
||||||
sessions_.push_back(session);
|
sessions_.push_back(session);
|
||||||
|
@ -45,7 +45,7 @@ class TestEnv {
|
|||||||
devices.push_back(
|
devices.push_back(
|
||||||
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
||||||
cpu_device_ = devices.back().get();
|
cpu_device_ = devices.back().get();
|
||||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
OptimizerOptions opts;
|
OptimizerOptions opts;
|
||||||
pflr_ = tensorflow::MakeUnique<ProcessFunctionLibraryRuntime>(
|
pflr_ = tensorflow::MakeUnique<ProcessFunctionLibraryRuntime>(
|
||||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, &flib_def_,
|
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, &flib_def_,
|
||||||
|
@ -162,7 +162,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
||||||
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
||||||
OptimizerOptions opts;
|
OptimizerOptions opts;
|
||||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||||
opts));
|
opts));
|
||||||
|
@ -62,7 +62,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
||||||
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
||||||
OptimizerOptions opts;
|
OptimizerOptions opts;
|
||||||
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||||
opts, default_thread_pool));
|
opts, default_thread_pool));
|
||||||
|
@ -247,7 +247,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!dev_mgr_ || device_type == DEVICE_CPU) {
|
if (!dev_mgr_ || device_type == DEVICE_CPU) {
|
||||||
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices));
|
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
|
||||||
}
|
}
|
||||||
if (!gpu_ring_order_) {
|
if (!gpu_ring_order_) {
|
||||||
gpu_ring_order_ = absl::make_unique<string>();
|
gpu_ring_order_ = absl::make_unique<string>();
|
||||||
|
@ -47,7 +47,7 @@ class PartitioningUtilsTest : public ::testing::Test {
|
|||||||
&devices));
|
&devices));
|
||||||
device0_ = devices[0].get();
|
device0_ = devices[0].get();
|
||||||
device1_ = devices[1].get();
|
device1_ = devices[1].get();
|
||||||
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
|
|
||||||
for (auto d : device_mgr_->ListDevices()) {
|
for (auto d : device_mgr_->ListDevices()) {
|
||||||
device_set_.AddDevice(d);
|
device_set_.AddDevice(d);
|
||||||
|
@ -94,7 +94,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
|
TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
|
||||||
&devices));
|
&devices));
|
||||||
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
TF_CHECK_OK(device_mgr_->LookupDevice(
|
TF_CHECK_OK(device_mgr_->LookupDevice(
|
||||||
"/job:a/replica:0/task:0/device:CPU:0", &device0_));
|
"/job:a/replica:0/task:0/device:CPU:0", &device0_));
|
||||||
TF_CHECK_OK(device_mgr_->LookupDevice(
|
TF_CHECK_OK(device_mgr_->LookupDevice(
|
||||||
|
@ -167,7 +167,7 @@ class RingGathererTest : public ::testing::Test {
|
|||||||
if (!dev_mgr_ || device_type == DEVICE_CPU) {
|
if (!dev_mgr_ || device_type == DEVICE_CPU) {
|
||||||
LOG(ERROR) << "resetting dev_mgr for " << local_devices.size()
|
LOG(ERROR) << "resetting dev_mgr for " << local_devices.size()
|
||||||
<< " devices: ";
|
<< " devices: ";
|
||||||
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices));
|
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
|
||||||
}
|
}
|
||||||
if (!gpu_ring_order_) {
|
if (!gpu_ring_order_) {
|
||||||
gpu_ring_order_ = absl::make_unique<string>();
|
gpu_ring_order_ = absl::make_unique<string>();
|
||||||
|
@ -189,7 +189,7 @@ class RingReducerTest : public ::testing::Test {
|
|||||||
if (!dev_mgr_ || device_type == DEVICE_CPU) {
|
if (!dev_mgr_ || device_type == DEVICE_CPU) {
|
||||||
LOG(INFO) << "resetting dev_mgr for " << local_devices.size()
|
LOG(INFO) << "resetting dev_mgr for " << local_devices.size()
|
||||||
<< " devices: ";
|
<< " devices: ";
|
||||||
dev_mgr_ = absl::make_unique<DeviceMgr>(std::move(local_devices));
|
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
|
||||||
}
|
}
|
||||||
if (!gpu_ring_order_) {
|
if (!gpu_ring_order_) {
|
||||||
gpu_ring_order_ = absl::make_unique<string>();
|
gpu_ring_order_ = absl::make_unique<string>();
|
||||||
|
@ -165,7 +165,7 @@ class DeviceResDistTest : public ::testing::Test {
|
|||||||
device_type,
|
device_type,
|
||||||
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
|
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
|
||||||
}
|
}
|
||||||
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
|
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
|
||||||
device_mgrs_.push_back(dev_mgr);
|
device_mgrs_.push_back(dev_mgr);
|
||||||
std::vector<string>* dv = &dev_by_task_[worker_name];
|
std::vector<string>* dv = &dev_by_task_[worker_name];
|
||||||
for (auto* d : dev_mgr->ListDevices()) {
|
for (auto* d : dev_mgr->ListDevices()) {
|
||||||
|
@ -222,7 +222,7 @@ class CollRMADistTest : public ::testing::Test {
|
|||||||
device_type,
|
device_type,
|
||||||
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
|
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
|
||||||
}
|
}
|
||||||
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
|
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
|
||||||
device_mgrs_.push_back(dev_mgr);
|
device_mgrs_.push_back(dev_mgr);
|
||||||
std::vector<string>* dv = &dev_by_task_[worker_name];
|
std::vector<string>* dv = &dev_by_task_[worker_name];
|
||||||
dv->clear();
|
dv->clear();
|
||||||
|
@ -163,7 +163,7 @@ class DeviceResDistTest : public ::testing::Test {
|
|||||||
strings::StrCat(worker_name, "/device:", device_type, ":", i), i,
|
strings::StrCat(worker_name, "/device:", device_type, ":", i), i,
|
||||||
device_incarnation_base + i));
|
device_incarnation_base + i));
|
||||||
}
|
}
|
||||||
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
|
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
|
||||||
TestableDeviceResolverDistributed* dev_res =
|
TestableDeviceResolverDistributed* dev_res =
|
||||||
new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name);
|
new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name);
|
||||||
resolvers_[worker_name] = dev_res;
|
resolvers_[worker_name] = dev_res;
|
||||||
|
@ -85,7 +85,7 @@ class EagerServiceImplTest : public ::testing::Test {
|
|||||||
worker_env_.rendezvous_mgr = &rendezvous_mgr_;
|
worker_env_.rendezvous_mgr = &rendezvous_mgr_;
|
||||||
worker_env_.session_mgr = session_mgr_.get();
|
worker_env_.session_mgr = session_mgr_.get();
|
||||||
|
|
||||||
device_mgr_ = absl::make_unique<DeviceMgr>(DeviceFactory::NewDevice(
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
||||||
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
|
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||||
worker_env_.local_devices = device_mgr_->ListDevices();
|
worker_env_.local_devices = device_mgr_->ListDevices();
|
||||||
worker_env_.device_mgr = device_mgr_.get();
|
worker_env_.device_mgr = device_mgr_.get();
|
||||||
|
@ -48,7 +48,7 @@ class RemoteMgrTest : public ::testing::Test {
|
|||||||
devices.push_back(
|
devices.push_back(
|
||||||
DeviceFactory::NewDevice("CPU", {}, "/job:worker/replica:0/task:0"));
|
DeviceFactory::NewDevice("CPU", {}, "/job:worker/replica:0/task:0"));
|
||||||
remote_device_ = devices.back().get();
|
remote_device_ = devices.back().get();
|
||||||
auto device_mgr = absl::make_unique<DeviceMgr>(std::move(devices));
|
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
context_id_ = random::New64();
|
context_id_ = random::New64();
|
||||||
tensorflow::Rendezvous* rendezvous =
|
tensorflow::Rendezvous* rendezvous =
|
||||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||||
|
@ -163,7 +163,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
|||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
|
DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
|
||||||
worker_env_.device_mgr = new DeviceMgr(std::move(devices));
|
worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices));
|
||||||
master_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
master_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
||||||
worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
||||||
worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr
|
worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr
|
||||||
|
@ -44,7 +44,7 @@ class RpcCollectiveExecutorMgrTest : public ::testing::Test {
|
|||||||
device_count->insert({"CPU", NUM_DEVS});
|
device_count->insert({"CPU", NUM_DEVS});
|
||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
||||||
device_mgr_.reset(new DeviceMgr(std::move(devices)));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed(
|
std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed(
|
||||||
device_mgr_.get(), worker_cache, task_name));
|
device_mgr_.get(), worker_cache, task_name));
|
||||||
std::unique_ptr<CollectiveParamResolverDistributed> cpr(
|
std::unique_ptr<CollectiveParamResolverDistributed> cpr(
|
||||||
|
@ -101,7 +101,7 @@ Status SessionMgr::CreateSession(
|
|||||||
renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
|
renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
|
||||||
worker_name, d, false, isolate_session_state));
|
worker_name, d, false, isolate_session_state));
|
||||||
}
|
}
|
||||||
auto device_mgr = MakeUnique<DeviceMgr>(std::move(renamed_devices));
|
auto device_mgr = MakeUnique<StaticDeviceMgr>(std::move(renamed_devices));
|
||||||
LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) {
|
LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) {
|
||||||
return device_mgr->LookupDevice(name, device);
|
return device_mgr->LookupDevice(name, device);
|
||||||
};
|
};
|
||||||
@ -109,7 +109,7 @@ Status SessionMgr::CreateSession(
|
|||||||
&cluster_devices);
|
&cluster_devices);
|
||||||
std::unique_ptr<DeviceMgr> remote_devices;
|
std::unique_ptr<DeviceMgr> remote_devices;
|
||||||
if (!cluster_device_attributes.empty())
|
if (!cluster_device_attributes.empty())
|
||||||
remote_devices = MakeUnique<DeviceMgr>(std::move(cluster_devices));
|
remote_devices = MakeUnique<StaticDeviceMgr>(std::move(cluster_devices));
|
||||||
|
|
||||||
auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
|
auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
|
||||||
worker_session.reset(
|
worker_session.reset(
|
||||||
@ -122,7 +122,7 @@ Status SessionMgr::CreateSession(
|
|||||||
&cluster_devices);
|
&cluster_devices);
|
||||||
std::unique_ptr<DeviceMgr> remote_devices;
|
std::unique_ptr<DeviceMgr> remote_devices;
|
||||||
if (!cluster_device_attributes.empty())
|
if (!cluster_device_attributes.empty())
|
||||||
remote_devices = MakeUnique<DeviceMgr>(std::move(cluster_devices));
|
remote_devices = MakeUnique<StaticDeviceMgr>(std::move(cluster_devices));
|
||||||
// Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so
|
// Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so
|
||||||
// that resources using it can use its devices after the
|
// that resources using it can use its devices after the
|
||||||
// WorkerSession has been deleted.
|
// WorkerSession has been deleted.
|
||||||
|
@ -46,7 +46,7 @@ class SessionMgrTest : public ::testing::Test {
|
|||||||
SessionMgrTest()
|
SessionMgrTest()
|
||||||
: mgr_(&env_, "/job:mnist/replica:0/task:0",
|
: mgr_(&env_, "/job:mnist/replica:0/task:0",
|
||||||
std::unique_ptr<WorkerCacheInterface>(), factory_) {
|
std::unique_ptr<WorkerCacheInterface>(), factory_) {
|
||||||
device_mgr_ = absl::make_unique<DeviceMgr>(
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(
|
||||||
FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0"));
|
FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0"));
|
||||||
env_.local_devices = device_mgr_->ListDevices();
|
env_.local_devices = device_mgr_->ListDevices();
|
||||||
env_.device_mgr = device_mgr_.get();
|
env_.device_mgr = device_mgr_.get();
|
||||||
|
@ -239,7 +239,7 @@ Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
|
|||||||
TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
|
TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
|
||||||
options, "/job:localhost/replica:0/task:0", &devices));
|
options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
Device* cpu_device = devices[0].get();
|
Device* cpu_device = devices[0].get();
|
||||||
std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(std::move(devices)));
|
auto dvc_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
FunctionLibraryDefinition function_library(OpRegistry::Global(),
|
FunctionLibraryDefinition function_library(OpRegistry::Global(),
|
||||||
graph_def.library());
|
graph_def.library());
|
||||||
Env* env = Env::Default();
|
Env* env = Env::Default();
|
||||||
|
@ -117,7 +117,8 @@ class NcclTestBase : public ::testing::Test {
|
|||||||
device_names.push_back(device->name());
|
device_names.push_back(device->name());
|
||||||
VLOG(2) << device->name();
|
VLOG(2) << device->name();
|
||||||
}
|
}
|
||||||
if (!dev_mgr_) dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
|
if (!dev_mgr_)
|
||||||
|
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
|
||||||
col_exec_ = new BaseCollectiveExecutor(
|
col_exec_ = new BaseCollectiveExecutor(
|
||||||
&col_exec_mgr_, /*remote_access=*/nullptr, kStepId, dev_mgr_.get(),
|
&col_exec_mgr_, /*remote_access=*/nullptr, kStepId, dev_mgr_.get(),
|
||||||
/*gpu_ring_order=*/nullptr);
|
/*gpu_ring_order=*/nullptr);
|
||||||
|
@ -415,7 +415,7 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
|
|||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
|
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
|
||||||
options, "/job:localhost/replica:0/task:0", &devices));
|
options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||||
resource_mgr_ = absl::make_unique<ResourceMgr>("default_container");
|
resource_mgr_ = absl::make_unique<ResourceMgr>("default_container");
|
||||||
|
|
||||||
FunctionDefLibrary proto;
|
FunctionDefLibrary proto;
|
||||||
|
@ -355,9 +355,10 @@ FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
|
|||||||
// in its resource manager. The existing device will outlive the
|
// in its resource manager. The existing device will outlive the
|
||||||
// IteratorResource, because we are storing the IteratorResource
|
// IteratorResource, because we are storing the IteratorResource
|
||||||
// in that device's resource manager.
|
// in that device's resource manager.
|
||||||
*device_mgr = absl::make_unique<DeviceMgr>(RenamedDevice::NewRenamedDevice(
|
*device_mgr =
|
||||||
ctx->device()->name(), down_cast<Device*>(ctx->device()),
|
absl::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
|
||||||
false /* owns_underlying */, false /* isolate_session_state */));
|
ctx->device()->name(), down_cast<Device*>(ctx->device()),
|
||||||
|
false /* owns_underlying */, false /* isolate_session_state */));
|
||||||
*flib_def = absl::make_unique<FunctionLibraryDefinition>(
|
*flib_def = absl::make_unique<FunctionLibraryDefinition>(
|
||||||
*ctx->function_library()->GetFunctionLibraryDefinition());
|
*ctx->function_library()->GetFunctionLibraryDefinition());
|
||||||
*pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
*pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||||
|
@ -27,7 +27,7 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
|
|||||||
CHECK(device_) << "No device provided";
|
CHECK(device_) << "No device provided";
|
||||||
|
|
||||||
device_ = device.get();
|
device_ = device.get();
|
||||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(device));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
|
||||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, flib_def_.get(),
|
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, flib_def_.get(),
|
||||||
OptimizerOptions());
|
OptimizerOptions());
|
||||||
|
@ -78,7 +78,7 @@ class OpsTestBase : public ::testing::Test {
|
|||||||
CHECK(device) << "Could not create CPU device";
|
CHECK(device) << "Could not create CPU device";
|
||||||
|
|
||||||
device_ = device.get();
|
device_ = device.get();
|
||||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(device));
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
|
||||||
|
|
||||||
allocator_ = device_->GetAllocator(AllocatorAttributes());
|
allocator_ = device_->GetAllocator(AllocatorAttributes());
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ tensorflow::Status DelegateData::Prepare(
|
|||||||
session_options, "/job:localhost/replica:0/task:0", &devices));
|
session_options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
|
|
||||||
auto device_mgr =
|
auto device_mgr =
|
||||||
absl::make_unique<tensorflow::DeviceMgr>(std::move(devices));
|
absl::make_unique<tensorflow::StaticDeviceMgr>(std::move(devices));
|
||||||
// Note that Rendezvous is ref-counted so it will be automatically deleted.
|
// Note that Rendezvous is ref-counted so it will be automatically deleted.
|
||||||
tensorflow::Rendezvous* rendezvous =
|
tensorflow::Rendezvous* rendezvous =
|
||||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||||
|
@ -2234,7 +2234,7 @@ bool InlineAllFunctions(GraphDef* graphdef) {
|
|||||||
|
|
||||||
tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
|
tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
|
||||||
graphdef_copy.library());
|
graphdef_copy.library());
|
||||||
tensorflow::DeviceMgr device_mgr(std::move(devices));
|
tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
|
||||||
tensorflow::OptimizerOptions o_opts;
|
tensorflow::OptimizerOptions o_opts;
|
||||||
tensorflow::ProcessFunctionLibraryRuntime pflr(
|
tensorflow::ProcessFunctionLibraryRuntime pflr(
|
||||||
&device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
|
&device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
|
||||||
|
Loading…
Reference in New Issue
Block a user