From 0e7680ef8d97598923e16f503141056a469642ac Mon Sep 17 00:00:00 2001 From: Xiao Yu Date: Tue, 27 Aug 2019 14:35:43 -0700 Subject: [PATCH] Instantiate multi-device function in parallel. PiperOrigin-RevId: 265769016 --- .../core/common_runtime/partitioning_utils.cc | 4 +- .../process_function_library_runtime.cc | 95 +++++++++++++------ 2 files changed, 66 insertions(+), 33 deletions(-) diff --git a/tensorflow/core/common_runtime/partitioning_utils.cc b/tensorflow/core/common_runtime/partitioning_utils.cc index 8f9583cd028..f8194e6c4ba 100644 --- a/tensorflow/core/common_runtime/partitioning_utils.cc +++ b/tensorflow/core/common_runtime/partitioning_utils.cc @@ -141,8 +141,8 @@ std::vector GetArgsForIndices(const std::vector& indices, } string FunctionNameGenerator::GetName() { - for (;; ++counter_) { - const string candidate = strings::StrCat(name_, "_", counter_); + while (true) { + const string candidate = strings::StrCat(name_, "_", counter_++); if (flib_def_->Find(candidate) == nullptr) { return candidate; } diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 2bfed3e02af..36ddfd568c8 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/partitioning_utils.h" #include "tensorflow/core/common_runtime/placer.h" +#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/rendezvous_util.h" #include "tensorflow/core/framework/function.h" @@ -35,7 +36,10 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_partition.h" +#include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/util/device_name_utils.h" @@ -764,42 +768,71 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( int i = 0; // Generate a random function_name to avoid one function reuse the partition // function instantiated by another function. + FunctionLibraryDefinition* data_lib_def = &data->lib_def_; FunctionNameGenerator name_generator( - &data->lib_def_, absl::StrCat(function_name, "_", random::New64())); + data_lib_def, absl::StrCat(function_name, "_", random::New64())); + auto subgraph_size = subgraphs.size(); + gtl::InlinedVector instantiate_status(subgraph_size); + BlockingCounter counter(static_cast(subgraph_size)); + auto runner = [this, subgraph_size](std::function fn) { + // NOTE: Only use thread pool to instantiate sub-function when there are + // more than 8 sub-functions. We want to avoid cost of switching thread when + // there are only a few sub-functions. + if (default_thread_pool_ != nullptr && subgraph_size > 8) { + default_thread_pool_->Schedule(fn); + } else { + fn(); + } + }; for (const auto& pair : subgraphs) { - i += 1; - const string& target = pair.first; - - const string& device_type = - device_set_.FindDeviceByName(target)->device_type(); - Graph* subgraph = pair.second.get(); - - ComponentFunctionData* comp_data = &data->glue_[target]; - TF_RETURN_IF_ERROR(UpdateArgAndRetvalMetadata( - subgraph, device_type, &comp_data->arg_indices_, - &comp_data->ret_indices_, &comp_data->arg_alloc_attrs_, - &comp_data->ret_alloc_attrs_)); - FunctionDef shard; + Status* status = &instantiate_status[i]; string unique_name = name_generator.GetName(); - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard)); - TF_RETURN_IF_ERROR(data->lib_def_.AddFunctionDef(shard)); - FunctionLibraryRuntime::InstantiateOptions opts; - opts.executor_type = options.executor_type; - opts.target = target; - opts.lib_def = &data->lib_def_; - opts.create_kernels_eagerly = options.create_kernels_eagerly; - opts.state_handle = options.state_handle; - FunctionLibraryRuntime::Handle component_handle; + ComponentFunctionData* comp_data = &data->glue_[pair.first]; + runner([this, &pair, comp_data, unique_name, data_lib_def, &control_ret, + &options, status, &counter] { + auto cleanup = gtl::MakeCleanup([&counter] { counter.DecrementCount(); }); + const string& target = pair.first; - TF_RETURN_IF_ERROR(Instantiate(unique_name, AttrSlice(&shard.attr()), opts, - &component_handle)); - VLOG(1) << "Instantiated component function " << unique_name - << " on device " << target << " with component handle " - << component_handle; - VLOG(2) << DebugString(shard); - comp_data->handle_ = component_handle; + const string& device_type = + device_set_.FindDeviceByName(target)->device_type(); + Graph* subgraph = pair.second.get(); + + status->Update(UpdateArgAndRetvalMetadata( + subgraph, device_type, &comp_data->arg_indices_, + &comp_data->ret_indices_, &comp_data->arg_alloc_attrs_, + &comp_data->ret_alloc_attrs_)); + if (!status->ok()) return; + FunctionDef shard; + status->Update( + GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard)); + if (!status->ok()) return; + status->Update(data_lib_def->AddFunctionDef(shard)); + FunctionLibraryRuntime::InstantiateOptions opts; + opts.executor_type = options.executor_type; + opts.target = target; + opts.lib_def = data_lib_def; + opts.create_kernels_eagerly = options.create_kernels_eagerly; + opts.state_handle = options.state_handle; + FunctionLibraryRuntime::Handle component_handle; + + // TODO(fishx): introduce an async version of this Instantiate method. + status->Update(Instantiate(unique_name, AttrSlice(&shard.attr()), opts, + &component_handle)); + if (!status->ok()) return; + VLOG(1) << "Instantiated component function " << unique_name + << " on device " << target << " with component handle " + << component_handle; + VLOG(2) << DebugString(shard); + comp_data->handle_ = component_handle; + }); + i += 1; } + counter.Wait(); + StatusGroup group; + for (auto& status : instantiate_status) { + group.Update(status); + } + TF_RETURN_IF_ERROR(group.as_summary_status()); *handle = AddMultiDeviceHandle(std::move(data), function_key); VLOG(2) << "Instantiated MultiDevice function \"" << function_name