Instantiate multi-device function in parallel.
PiperOrigin-RevId: 265769016
This commit is contained in:
parent
7301b746f2
commit
0e7680ef8d
@ -141,8 +141,8 @@ std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
|
|||||||
}
|
}
|
||||||
|
|
||||||
string FunctionNameGenerator::GetName() {
|
string FunctionNameGenerator::GetName() {
|
||||||
for (;; ++counter_) {
|
while (true) {
|
||||||
const string candidate = strings::StrCat(name_, "_", counter_);
|
const string candidate = strings::StrCat(name_, "_", counter_++);
|
||||||
if (flib_def_->Find(candidate) == nullptr) {
|
if (flib_def_->Find(candidate) == nullptr) {
|
||||||
return candidate;
|
return candidate;
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
#include "tensorflow/core/common_runtime/partitioning_utils.h"
|
#include "tensorflow/core/common_runtime/partitioning_utils.h"
|
||||||
#include "tensorflow/core/common_runtime/placer.h"
|
#include "tensorflow/core/common_runtime/placer.h"
|
||||||
|
#include "tensorflow/core/common_runtime/process_util.h"
|
||||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||||
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
@ -35,7 +36,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/graph/graph_partition.h"
|
#include "tensorflow/core/graph/graph_partition.h"
|
||||||
|
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
@ -764,42 +768,71 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
|||||||
int i = 0;
|
int i = 0;
|
||||||
// Generate a random function_name to avoid one function reuse the partition
|
// Generate a random function_name to avoid one function reuse the partition
|
||||||
// function instantiated by another function.
|
// function instantiated by another function.
|
||||||
|
FunctionLibraryDefinition* data_lib_def = &data->lib_def_;
|
||||||
FunctionNameGenerator name_generator(
|
FunctionNameGenerator name_generator(
|
||||||
&data->lib_def_, absl::StrCat(function_name, "_", random::New64()));
|
data_lib_def, absl::StrCat(function_name, "_", random::New64()));
|
||||||
|
auto subgraph_size = subgraphs.size();
|
||||||
|
gtl::InlinedVector<Status, 4> instantiate_status(subgraph_size);
|
||||||
|
BlockingCounter counter(static_cast<int>(subgraph_size));
|
||||||
|
auto runner = [this, subgraph_size](std::function<void()> fn) {
|
||||||
|
// NOTE: Only use thread pool to instantiate sub-function when there are
|
||||||
|
// more than 8 sub-functions. We want to avoid cost of switching thread when
|
||||||
|
// there are only a few sub-functions.
|
||||||
|
if (default_thread_pool_ != nullptr && subgraph_size > 8) {
|
||||||
|
default_thread_pool_->Schedule(fn);
|
||||||
|
} else {
|
||||||
|
fn();
|
||||||
|
}
|
||||||
|
};
|
||||||
for (const auto& pair : subgraphs) {
|
for (const auto& pair : subgraphs) {
|
||||||
i += 1;
|
Status* status = &instantiate_status[i];
|
||||||
const string& target = pair.first;
|
|
||||||
|
|
||||||
const string& device_type =
|
|
||||||
device_set_.FindDeviceByName(target)->device_type();
|
|
||||||
Graph* subgraph = pair.second.get();
|
|
||||||
|
|
||||||
ComponentFunctionData* comp_data = &data->glue_[target];
|
|
||||||
TF_RETURN_IF_ERROR(UpdateArgAndRetvalMetadata(
|
|
||||||
subgraph, device_type, &comp_data->arg_indices_,
|
|
||||||
&comp_data->ret_indices_, &comp_data->arg_alloc_attrs_,
|
|
||||||
&comp_data->ret_alloc_attrs_));
|
|
||||||
FunctionDef shard;
|
|
||||||
string unique_name = name_generator.GetName();
|
string unique_name = name_generator.GetName();
|
||||||
TF_RETURN_IF_ERROR(
|
ComponentFunctionData* comp_data = &data->glue_[pair.first];
|
||||||
GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
|
runner([this, &pair, comp_data, unique_name, data_lib_def, &control_ret,
|
||||||
TF_RETURN_IF_ERROR(data->lib_def_.AddFunctionDef(shard));
|
&options, status, &counter] {
|
||||||
FunctionLibraryRuntime::InstantiateOptions opts;
|
auto cleanup = gtl::MakeCleanup([&counter] { counter.DecrementCount(); });
|
||||||
opts.executor_type = options.executor_type;
|
const string& target = pair.first;
|
||||||
opts.target = target;
|
|
||||||
opts.lib_def = &data->lib_def_;
|
|
||||||
opts.create_kernels_eagerly = options.create_kernels_eagerly;
|
|
||||||
opts.state_handle = options.state_handle;
|
|
||||||
FunctionLibraryRuntime::Handle component_handle;
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(Instantiate(unique_name, AttrSlice(&shard.attr()), opts,
|
const string& device_type =
|
||||||
&component_handle));
|
device_set_.FindDeviceByName(target)->device_type();
|
||||||
VLOG(1) << "Instantiated component function " << unique_name
|
Graph* subgraph = pair.second.get();
|
||||||
<< " on device " << target << " with component handle "
|
|
||||||
<< component_handle;
|
status->Update(UpdateArgAndRetvalMetadata(
|
||||||
VLOG(2) << DebugString(shard);
|
subgraph, device_type, &comp_data->arg_indices_,
|
||||||
comp_data->handle_ = component_handle;
|
&comp_data->ret_indices_, &comp_data->arg_alloc_attrs_,
|
||||||
|
&comp_data->ret_alloc_attrs_));
|
||||||
|
if (!status->ok()) return;
|
||||||
|
FunctionDef shard;
|
||||||
|
status->Update(
|
||||||
|
GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
|
||||||
|
if (!status->ok()) return;
|
||||||
|
status->Update(data_lib_def->AddFunctionDef(shard));
|
||||||
|
FunctionLibraryRuntime::InstantiateOptions opts;
|
||||||
|
opts.executor_type = options.executor_type;
|
||||||
|
opts.target = target;
|
||||||
|
opts.lib_def = data_lib_def;
|
||||||
|
opts.create_kernels_eagerly = options.create_kernels_eagerly;
|
||||||
|
opts.state_handle = options.state_handle;
|
||||||
|
FunctionLibraryRuntime::Handle component_handle;
|
||||||
|
|
||||||
|
// TODO(fishx): introduce an async version of this Instantiate method.
|
||||||
|
status->Update(Instantiate(unique_name, AttrSlice(&shard.attr()), opts,
|
||||||
|
&component_handle));
|
||||||
|
if (!status->ok()) return;
|
||||||
|
VLOG(1) << "Instantiated component function " << unique_name
|
||||||
|
<< " on device " << target << " with component handle "
|
||||||
|
<< component_handle;
|
||||||
|
VLOG(2) << DebugString(shard);
|
||||||
|
comp_data->handle_ = component_handle;
|
||||||
|
});
|
||||||
|
i += 1;
|
||||||
}
|
}
|
||||||
|
counter.Wait();
|
||||||
|
StatusGroup group;
|
||||||
|
for (auto& status : instantiate_status) {
|
||||||
|
group.Update(status);
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(group.as_summary_status());
|
||||||
|
|
||||||
*handle = AddMultiDeviceHandle(std::move(data), function_key);
|
*handle = AddMultiDeviceHandle(std::move(data), function_key);
|
||||||
VLOG(2) << "Instantiated MultiDevice function \"" << function_name
|
VLOG(2) << "Instantiated MultiDevice function \"" << function_name
|
||||||
|
Loading…
Reference in New Issue
Block a user