Re-initialize device and type priority list in ProcessFLR when updating cluster.
PiperOrigin-RevId: 303420309 Change-Id: Ia9afe83873043f15e34312a93979726933c880d5
This commit is contained in:
parent
966bbe7dc2
commit
006fd832b1
@ -1102,6 +1102,7 @@ Status EagerContext::UpdateRemoteMaster(
|
||||
if (rendezvous_ != nullptr) rendezvous_->Unref();
|
||||
rendezvous_ = r;
|
||||
remote_eager_workers_ = std::move(remote_eager_workers);
|
||||
pflr_->InitializeDeviceSet();
|
||||
InitPrioritizedDeviceTypeList();
|
||||
|
||||
default_executor_.ClearError();
|
||||
|
@ -110,14 +110,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
|
||||
session_metadata_, this);
|
||||
}
|
||||
|
||||
DeviceMgr const* all_devices = device_mgr_;
|
||||
if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
|
||||
all_devices = parent_->remote_device_mgr();
|
||||
}
|
||||
|
||||
for (auto d : all_devices->ListDevices()) {
|
||||
device_set_.AddDevice(d);
|
||||
}
|
||||
InitializeDeviceSet();
|
||||
}
|
||||
|
||||
/* static */
|
||||
@ -214,6 +207,18 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext(
|
||||
"function executions");
|
||||
}
|
||||
|
||||
void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
|
||||
DeviceMgr const* all_devices = device_mgr_;
|
||||
if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
|
||||
all_devices = parent_->remote_device_mgr();
|
||||
}
|
||||
|
||||
device_set_.reset(new DeviceSet);
|
||||
for (auto d : all_devices->ListDevices()) {
|
||||
device_set_->AddDevice(d);
|
||||
}
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
|
||||
const string& device_name) const {
|
||||
Device* device = nullptr;
|
||||
@ -678,7 +683,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
|
||||
TF_RETURN_IF_ERROR(PinArgsAndRets(
|
||||
options.input_devices, options.output_devices, device_set_, arg_nodes,
|
||||
options.input_devices, options.output_devices, *device_set_, arg_nodes,
|
||||
ret_nodes,
|
||||
options.config_proto.allow_soft_placement() ? default_device : nullptr));
|
||||
|
||||
@ -691,7 +696,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
|
||||
bool control_rets_updated = false;
|
||||
TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
|
||||
device_set_, options.config_proto, &graph, &data->lib_def_,
|
||||
*device_set_, options.config_proto, &graph, &data->lib_def_,
|
||||
&control_ret_node_names, &control_rets_updated));
|
||||
|
||||
if (control_rets_updated) {
|
||||
@ -714,7 +719,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
optimization_options.session_options = &session_options;
|
||||
optimization_options.graph = &graph;
|
||||
optimization_options.flib_def = &data->lib_def_;
|
||||
optimization_options.device_set = &device_set_;
|
||||
optimization_options.device_set = device_set_.get();
|
||||
optimization_options.is_function_graph = true;
|
||||
|
||||
DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
|
||||
@ -725,7 +730,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
// exceptions/warnings in case where nested function call options are ignored.
|
||||
DumpGraph("Before calling Placer", graph.get());
|
||||
Placer placer(graph.get(), function_name, optimization_options.flib_def,
|
||||
&device_set_, default_device,
|
||||
device_set_.get(), default_device,
|
||||
options.config_proto.allow_soft_placement(),
|
||||
options.config_proto.log_device_placement());
|
||||
TF_RETURN_IF_ERROR(placer.Run());
|
||||
@ -741,7 +746,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
DumpGraph("Before running graph optimization fn", graph.get());
|
||||
Status status = options.optimize_graph_fn(
|
||||
std::move(ret_node_names), std::move(control_ret_node_names),
|
||||
&data->lib_def_, device_set_, cpu_device, &graph);
|
||||
&data->lib_def_, *device_set_, cpu_device, &graph);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "Ignoring multi-device function optimization failure: "
|
||||
<< status.ToString();
|
||||
@ -765,7 +770,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
|
||||
std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
PartitionFunctionGraph(device_set_, std::move(graph), &subgraphs));
|
||||
PartitionFunctionGraph(*device_set_, std::move(graph), &subgraphs));
|
||||
|
||||
for (const auto& pair : subgraphs) {
|
||||
DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (",
|
||||
@ -841,7 +846,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
const string& target = pair.first;
|
||||
|
||||
const string& device_type =
|
||||
device_set_.FindDeviceByName(target)->device_type();
|
||||
device_set_->FindDeviceByName(target)->device_type();
|
||||
Graph* subgraph = pair.second.get();
|
||||
|
||||
status->Update(UpdateArgAndRetvalMetadata(
|
||||
|
@ -191,7 +191,10 @@ class ProcessFunctionLibraryRuntime {
|
||||
|
||||
const DeviceMgr* device_mgr() { return device_mgr_; }
|
||||
|
||||
const DeviceSet* device_set() { return &device_set_; }
|
||||
const DeviceSet* device_set() { return device_set_.get(); }
|
||||
|
||||
// Initialize the set of local and remote devices for op device selection.
|
||||
void InitializeDeviceSet();
|
||||
|
||||
const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
|
||||
|
||||
@ -422,7 +425,7 @@ class ProcessFunctionLibraryRuntime {
|
||||
Env* const env_;
|
||||
const absl::optional<const ConfigProto> config_;
|
||||
const DeviceMgr* const device_mgr_;
|
||||
DeviceSet device_set_;
|
||||
std::unique_ptr<DeviceSet> device_set_;
|
||||
const FunctionLibraryDefinition* lib_def_;
|
||||
thread::ThreadPool* default_thread_pool_;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user