From 006fd832b1e0d8bf48396f32bdd3396cff3b99fc Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Fri, 27 Mar 2020 15:33:49 -0700 Subject: [PATCH] Re-initialize device and type priority list in ProcessFLR when updating cluster. PiperOrigin-RevId: 303420309 Change-Id: Ia9afe83873043f15e34312a93979726933c880d5 --- .../core/common_runtime/eager/context.cc | 1 + .../process_function_library_runtime.cc | 35 +++++++++++-------- .../process_function_library_runtime.h | 7 ++-- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index a8513b9e613..33221e51218 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -1102,6 +1102,7 @@ Status EagerContext::UpdateRemoteMaster( if (rendezvous_ != nullptr) rendezvous_->Unref(); rendezvous_ = r; remote_eager_workers_ = std::move(remote_eager_workers); + pflr_->InitializeDeviceSet(); InitPrioritizedDeviceTypeList(); default_executor_.ClearError(); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 76c5f34f3ac..1543a341c2a 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -110,14 +110,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( session_metadata_, this); } - DeviceMgr const* all_devices = device_mgr_; - if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) { - all_devices = parent_->remote_device_mgr(); - } - - for (auto d : all_devices->ListDevices()) { - device_set_.AddDevice(d); - } + InitializeDeviceSet(); } /* static */ @@ -214,6 +207,18 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext( "function executions"); } +void ProcessFunctionLibraryRuntime::InitializeDeviceSet() { + DeviceMgr const* all_devices = device_mgr_; + if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) { + all_devices = parent_->remote_device_mgr(); + } + + device_set_.reset(new DeviceSet); + for (auto d : all_devices->ListDevices()) { + device_set_->AddDevice(d); + } +} + FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( const string& device_name) const { Device* device = nullptr; @@ -678,7 +683,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( TF_RETURN_IF_ERROR( SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes)); TF_RETURN_IF_ERROR(PinArgsAndRets( - options.input_devices, options.output_devices, device_set_, arg_nodes, + options.input_devices, options.output_devices, *device_set_, arg_nodes, ret_nodes, options.config_proto.allow_soft_placement() ? default_device : nullptr)); @@ -691,7 +696,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( bool control_rets_updated = false; TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run( - device_set_, options.config_proto, &graph, &data->lib_def_, + *device_set_, options.config_proto, &graph, &data->lib_def_, &control_ret_node_names, &control_rets_updated)); if (control_rets_updated) { @@ -714,7 +719,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( optimization_options.session_options = &session_options; optimization_options.graph = &graph; optimization_options.flib_def = &data->lib_def_; - optimization_options.device_set = &device_set_; + optimization_options.device_set = device_set_.get(); optimization_options.is_function_graph = true; DumpGraph("Before running PRE_PLACEMENT passes", graph.get()); @@ -725,7 +730,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( // exceptions/warnings in case where nested function call options are ignored. DumpGraph("Before calling Placer", graph.get()); Placer placer(graph.get(), function_name, optimization_options.flib_def, - &device_set_, default_device, + device_set_.get(), default_device, options.config_proto.allow_soft_placement(), options.config_proto.log_device_placement()); TF_RETURN_IF_ERROR(placer.Run()); @@ -741,7 +746,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( DumpGraph("Before running graph optimization fn", graph.get()); Status status = options.optimize_graph_fn( std::move(ret_node_names), std::move(control_ret_node_names), - &data->lib_def_, device_set_, cpu_device, &graph); + &data->lib_def_, *device_set_, cpu_device, &graph); if (!status.ok()) { LOG(WARNING) << "Ignoring multi-device function optimization failure: " << status.ToString(); @@ -765,7 +770,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( std::unordered_map> subgraphs; TF_RETURN_IF_ERROR( - PartitionFunctionGraph(device_set_, std::move(graph), &subgraphs)); + PartitionFunctionGraph(*device_set_, std::move(graph), &subgraphs)); for (const auto& pair : subgraphs) { DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (", @@ -841,7 +846,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( const string& target = pair.first; const string& device_type = - device_set_.FindDeviceByName(target)->device_type(); + device_set_->FindDeviceByName(target)->device_type(); Graph* subgraph = pair.second.get(); status->Update(UpdateArgAndRetvalMetadata( diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 1d7708fbdc1..f8550fd8bea 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -191,7 +191,10 @@ class ProcessFunctionLibraryRuntime { const DeviceMgr* device_mgr() { return device_mgr_; } - const DeviceSet* device_set() { return &device_set_; } + const DeviceSet* device_set() { return device_set_.get(); } + + // Initialize the set of local and remote devices for op device selection. + void InitializeDeviceSet(); const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; } @@ -422,7 +425,7 @@ class ProcessFunctionLibraryRuntime { Env* const env_; const absl::optional config_; const DeviceMgr* const device_mgr_; - DeviceSet device_set_; + std::unique_ptr device_set_; const FunctionLibraryDefinition* lib_def_; thread::ThreadPool* default_thread_pool_;