Instantiate multi-device function in parallel.
PiperOrigin-RevId: 265769016
This commit is contained in:
parent
7301b746f2
commit
0e7680ef8d
@ -141,8 +141,8 @@ std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
|
||||
}
|
||||
|
||||
string FunctionNameGenerator::GetName() {
|
||||
for (;; ++counter_) {
|
||||
const string candidate = strings::StrCat(name_, "_", counter_);
|
||||
while (true) {
|
||||
const string candidate = strings::StrCat(name_, "_", counter_++);
|
||||
if (flib_def_->Find(candidate) == nullptr) {
|
||||
return candidate;
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/partitioning_utils.h"
|
||||
#include "tensorflow/core/common_runtime/placer.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -35,7 +36,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_partition.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
@ -764,42 +768,71 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
int i = 0;
|
||||
// Generate a random function_name to avoid one function reuse the partition
|
||||
// function instantiated by another function.
|
||||
FunctionLibraryDefinition* data_lib_def = &data->lib_def_;
|
||||
FunctionNameGenerator name_generator(
|
||||
&data->lib_def_, absl::StrCat(function_name, "_", random::New64()));
|
||||
data_lib_def, absl::StrCat(function_name, "_", random::New64()));
|
||||
auto subgraph_size = subgraphs.size();
|
||||
gtl::InlinedVector<Status, 4> instantiate_status(subgraph_size);
|
||||
BlockingCounter counter(static_cast<int>(subgraph_size));
|
||||
auto runner = [this, subgraph_size](std::function<void()> fn) {
|
||||
// NOTE: Only use thread pool to instantiate sub-function when there are
|
||||
// more than 8 sub-functions. We want to avoid cost of switching thread when
|
||||
// there are only a few sub-functions.
|
||||
if (default_thread_pool_ != nullptr && subgraph_size > 8) {
|
||||
default_thread_pool_->Schedule(fn);
|
||||
} else {
|
||||
fn();
|
||||
}
|
||||
};
|
||||
for (const auto& pair : subgraphs) {
|
||||
i += 1;
|
||||
const string& target = pair.first;
|
||||
|
||||
const string& device_type =
|
||||
device_set_.FindDeviceByName(target)->device_type();
|
||||
Graph* subgraph = pair.second.get();
|
||||
|
||||
ComponentFunctionData* comp_data = &data->glue_[target];
|
||||
TF_RETURN_IF_ERROR(UpdateArgAndRetvalMetadata(
|
||||
subgraph, device_type, &comp_data->arg_indices_,
|
||||
&comp_data->ret_indices_, &comp_data->arg_alloc_attrs_,
|
||||
&comp_data->ret_alloc_attrs_));
|
||||
FunctionDef shard;
|
||||
Status* status = &instantiate_status[i];
|
||||
string unique_name = name_generator.GetName();
|
||||
TF_RETURN_IF_ERROR(
|
||||
GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
|
||||
TF_RETURN_IF_ERROR(data->lib_def_.AddFunctionDef(shard));
|
||||
FunctionLibraryRuntime::InstantiateOptions opts;
|
||||
opts.executor_type = options.executor_type;
|
||||
opts.target = target;
|
||||
opts.lib_def = &data->lib_def_;
|
||||
opts.create_kernels_eagerly = options.create_kernels_eagerly;
|
||||
opts.state_handle = options.state_handle;
|
||||
FunctionLibraryRuntime::Handle component_handle;
|
||||
ComponentFunctionData* comp_data = &data->glue_[pair.first];
|
||||
runner([this, &pair, comp_data, unique_name, data_lib_def, &control_ret,
|
||||
&options, status, &counter] {
|
||||
auto cleanup = gtl::MakeCleanup([&counter] { counter.DecrementCount(); });
|
||||
const string& target = pair.first;
|
||||
|
||||
TF_RETURN_IF_ERROR(Instantiate(unique_name, AttrSlice(&shard.attr()), opts,
|
||||
&component_handle));
|
||||
VLOG(1) << "Instantiated component function " << unique_name
|
||||
<< " on device " << target << " with component handle "
|
||||
<< component_handle;
|
||||
VLOG(2) << DebugString(shard);
|
||||
comp_data->handle_ = component_handle;
|
||||
const string& device_type =
|
||||
device_set_.FindDeviceByName(target)->device_type();
|
||||
Graph* subgraph = pair.second.get();
|
||||
|
||||
status->Update(UpdateArgAndRetvalMetadata(
|
||||
subgraph, device_type, &comp_data->arg_indices_,
|
||||
&comp_data->ret_indices_, &comp_data->arg_alloc_attrs_,
|
||||
&comp_data->ret_alloc_attrs_));
|
||||
if (!status->ok()) return;
|
||||
FunctionDef shard;
|
||||
status->Update(
|
||||
GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
|
||||
if (!status->ok()) return;
|
||||
status->Update(data_lib_def->AddFunctionDef(shard));
|
||||
FunctionLibraryRuntime::InstantiateOptions opts;
|
||||
opts.executor_type = options.executor_type;
|
||||
opts.target = target;
|
||||
opts.lib_def = data_lib_def;
|
||||
opts.create_kernels_eagerly = options.create_kernels_eagerly;
|
||||
opts.state_handle = options.state_handle;
|
||||
FunctionLibraryRuntime::Handle component_handle;
|
||||
|
||||
// TODO(fishx): introduce an async version of this Instantiate method.
|
||||
status->Update(Instantiate(unique_name, AttrSlice(&shard.attr()), opts,
|
||||
&component_handle));
|
||||
if (!status->ok()) return;
|
||||
VLOG(1) << "Instantiated component function " << unique_name
|
||||
<< " on device " << target << " with component handle "
|
||||
<< component_handle;
|
||||
VLOG(2) << DebugString(shard);
|
||||
comp_data->handle_ = component_handle;
|
||||
});
|
||||
i += 1;
|
||||
}
|
||||
counter.Wait();
|
||||
StatusGroup group;
|
||||
for (auto& status : instantiate_status) {
|
||||
group.Update(status);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(group.as_summary_status());
|
||||
|
||||
*handle = AddMultiDeviceHandle(std::move(data), function_key);
|
||||
VLOG(2) << "Instantiated MultiDevice function \"" << function_name
|
||||
|
Loading…
Reference in New Issue
Block a user