Instantiate multi-device function in parallel.

PiperOrigin-RevId: 265769016
This commit is contained in:
Xiao Yu 2019-08-27 14:35:43 -07:00 committed by TensorFlower Gardener
parent 7301b746f2
commit 0e7680ef8d
2 changed files with 66 additions and 33 deletions

View File

@ -141,8 +141,8 @@ std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
} }
string FunctionNameGenerator::GetName() { string FunctionNameGenerator::GetName() {
for (;; ++counter_) { while (true) {
const string candidate = strings::StrCat(name_, "_", counter_); const string candidate = strings::StrCat(name_, "_", counter_++);
if (flib_def_->Find(candidate) == nullptr) { if (flib_def_->Find(candidate) == nullptr) {
return candidate; return candidate;
} }

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/partitioning_utils.h" #include "tensorflow/core/common_runtime/partitioning_utils.h"
#include "tensorflow/core/common_runtime/placer.h" #include "tensorflow/core/common_runtime/placer.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/rendezvous_util.h" #include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
@ -35,7 +36,10 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_partition.h" #include "tensorflow/core/graph/graph_partition.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/device_name_utils.h"
@ -764,42 +768,71 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
int i = 0; int i = 0;
// Generate a random function_name to avoid one function reuse the partition // Generate a random function_name to avoid one function reuse the partition
// function instantiated by another function. // function instantiated by another function.
FunctionLibraryDefinition* data_lib_def = &data->lib_def_;
FunctionNameGenerator name_generator( FunctionNameGenerator name_generator(
&data->lib_def_, absl::StrCat(function_name, "_", random::New64())); data_lib_def, absl::StrCat(function_name, "_", random::New64()));
auto subgraph_size = subgraphs.size();
gtl::InlinedVector<Status, 4> instantiate_status(subgraph_size);
BlockingCounter counter(static_cast<int>(subgraph_size));
auto runner = [this, subgraph_size](std::function<void()> fn) {
// NOTE: Only use thread pool to instantiate sub-function when there are
// more than 8 sub-functions. We want to avoid cost of switching thread when
// there are only a few sub-functions.
if (default_thread_pool_ != nullptr && subgraph_size > 8) {
default_thread_pool_->Schedule(fn);
} else {
fn();
}
};
for (const auto& pair : subgraphs) { for (const auto& pair : subgraphs) {
i += 1; Status* status = &instantiate_status[i];
const string& target = pair.first;
const string& device_type =
device_set_.FindDeviceByName(target)->device_type();
Graph* subgraph = pair.second.get();
ComponentFunctionData* comp_data = &data->glue_[target];
TF_RETURN_IF_ERROR(UpdateArgAndRetvalMetadata(
subgraph, device_type, &comp_data->arg_indices_,
&comp_data->ret_indices_, &comp_data->arg_alloc_attrs_,
&comp_data->ret_alloc_attrs_));
FunctionDef shard;
string unique_name = name_generator.GetName(); string unique_name = name_generator.GetName();
TF_RETURN_IF_ERROR( ComponentFunctionData* comp_data = &data->glue_[pair.first];
GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard)); runner([this, &pair, comp_data, unique_name, data_lib_def, &control_ret,
TF_RETURN_IF_ERROR(data->lib_def_.AddFunctionDef(shard)); &options, status, &counter] {
FunctionLibraryRuntime::InstantiateOptions opts; auto cleanup = gtl::MakeCleanup([&counter] { counter.DecrementCount(); });
opts.executor_type = options.executor_type; const string& target = pair.first;
opts.target = target;
opts.lib_def = &data->lib_def_;
opts.create_kernels_eagerly = options.create_kernels_eagerly;
opts.state_handle = options.state_handle;
FunctionLibraryRuntime::Handle component_handle;
TF_RETURN_IF_ERROR(Instantiate(unique_name, AttrSlice(&shard.attr()), opts, const string& device_type =
&component_handle)); device_set_.FindDeviceByName(target)->device_type();
VLOG(1) << "Instantiated component function " << unique_name Graph* subgraph = pair.second.get();
<< " on device " << target << " with component handle "
<< component_handle; status->Update(UpdateArgAndRetvalMetadata(
VLOG(2) << DebugString(shard); subgraph, device_type, &comp_data->arg_indices_,
comp_data->handle_ = component_handle; &comp_data->ret_indices_, &comp_data->arg_alloc_attrs_,
&comp_data->ret_alloc_attrs_));
if (!status->ok()) return;
FunctionDef shard;
status->Update(
GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
if (!status->ok()) return;
status->Update(data_lib_def->AddFunctionDef(shard));
FunctionLibraryRuntime::InstantiateOptions opts;
opts.executor_type = options.executor_type;
opts.target = target;
opts.lib_def = data_lib_def;
opts.create_kernels_eagerly = options.create_kernels_eagerly;
opts.state_handle = options.state_handle;
FunctionLibraryRuntime::Handle component_handle;
// TODO(fishx): introduce an async version of this Instantiate method.
status->Update(Instantiate(unique_name, AttrSlice(&shard.attr()), opts,
&component_handle));
if (!status->ok()) return;
VLOG(1) << "Instantiated component function " << unique_name
<< " on device " << target << " with component handle "
<< component_handle;
VLOG(2) << DebugString(shard);
comp_data->handle_ = component_handle;
});
i += 1;
} }
counter.Wait();
StatusGroup group;
for (auto& status : instantiate_status) {
group.Update(status);
}
TF_RETURN_IF_ERROR(group.as_summary_status());
*handle = AddMultiDeviceHandle(std::move(data), function_key); *handle = AddMultiDeviceHandle(std::move(data), function_key);
VLOG(2) << "Instantiated MultiDevice function \"" << function_name VLOG(2) << "Instantiated MultiDevice function \"" << function_name