Re-initialize device and type priority list in ProcessFLR when updating cluster.

PiperOrigin-RevId: 303420309
Change-Id: Ia9afe83873043f15e34312a93979726933c880d5
This commit is contained in:
Haoyu Zhang 2020-03-27 15:33:49 -07:00 committed by TensorFlower Gardener
parent 966bbe7dc2
commit 006fd832b1
3 changed files with 26 additions and 17 deletions

View File

@ -1102,6 +1102,7 @@ Status EagerContext::UpdateRemoteMaster(
if (rendezvous_ != nullptr) rendezvous_->Unref();
rendezvous_ = r;
remote_eager_workers_ = std::move(remote_eager_workers);
pflr_->InitializeDeviceSet();
InitPrioritizedDeviceTypeList();
default_executor_.ClearError();

View File

@ -110,14 +110,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
session_metadata_, this);
}
DeviceMgr const* all_devices = device_mgr_;
if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
all_devices = parent_->remote_device_mgr();
}
for (auto d : all_devices->ListDevices()) {
device_set_.AddDevice(d);
}
InitializeDeviceSet();
}
/* static */
@ -214,6 +207,18 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext(
"function executions");
}
void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
DeviceMgr const* all_devices = device_mgr_;
if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
all_devices = parent_->remote_device_mgr();
}
device_set_.reset(new DeviceSet);
for (auto d : all_devices->ListDevices()) {
device_set_->AddDevice(d);
}
}
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
const string& device_name) const {
Device* device = nullptr;
@ -678,7 +683,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
TF_RETURN_IF_ERROR(
SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
TF_RETURN_IF_ERROR(PinArgsAndRets(
options.input_devices, options.output_devices, device_set_, arg_nodes,
options.input_devices, options.output_devices, *device_set_, arg_nodes,
ret_nodes,
options.config_proto.allow_soft_placement() ? default_device : nullptr));
@ -691,7 +696,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
bool control_rets_updated = false;
TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
device_set_, options.config_proto, &graph, &data->lib_def_,
*device_set_, options.config_proto, &graph, &data->lib_def_,
&control_ret_node_names, &control_rets_updated));
if (control_rets_updated) {
@ -714,7 +719,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
optimization_options.session_options = &session_options;
optimization_options.graph = &graph;
optimization_options.flib_def = &data->lib_def_;
optimization_options.device_set = &device_set_;
optimization_options.device_set = device_set_.get();
optimization_options.is_function_graph = true;
DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
@ -725,7 +730,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
// exceptions/warnings in case where nested function call options are ignored.
DumpGraph("Before calling Placer", graph.get());
Placer placer(graph.get(), function_name, optimization_options.flib_def,
&device_set_, default_device,
device_set_.get(), default_device,
options.config_proto.allow_soft_placement(),
options.config_proto.log_device_placement());
TF_RETURN_IF_ERROR(placer.Run());
@ -741,7 +746,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
DumpGraph("Before running graph optimization fn", graph.get());
Status status = options.optimize_graph_fn(
std::move(ret_node_names), std::move(control_ret_node_names),
&data->lib_def_, device_set_, cpu_device, &graph);
&data->lib_def_, *device_set_, cpu_device, &graph);
if (!status.ok()) {
LOG(WARNING) << "Ignoring multi-device function optimization failure: "
<< status.ToString();
@ -765,7 +770,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
TF_RETURN_IF_ERROR(
PartitionFunctionGraph(device_set_, std::move(graph), &subgraphs));
PartitionFunctionGraph(*device_set_, std::move(graph), &subgraphs));
for (const auto& pair : subgraphs) {
DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (",
@ -841,7 +846,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
const string& target = pair.first;
const string& device_type =
device_set_.FindDeviceByName(target)->device_type();
device_set_->FindDeviceByName(target)->device_type();
Graph* subgraph = pair.second.get();
status->Update(UpdateArgAndRetvalMetadata(

View File

@ -191,7 +191,10 @@ class ProcessFunctionLibraryRuntime {
const DeviceMgr* device_mgr() { return device_mgr_; }
const DeviceSet* device_set() { return &device_set_; }
const DeviceSet* device_set() { return device_set_.get(); }
// Initialize the set of local and remote devices for op device selection.
void InitializeDeviceSet();
const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
@ -422,7 +425,7 @@ class ProcessFunctionLibraryRuntime {
Env* const env_;
const absl::optional<const ConfigProto> config_;
const DeviceMgr* const device_mgr_;
DeviceSet device_set_;
std::unique_ptr<DeviceSet> device_set_;
const FunctionLibraryDefinition* lib_def_;
thread::ThreadPool* default_thread_pool_;