Support packed tensor inputs in ProcessFunctionLibraryRuntime.

- Expand the packed _Arg nodes when the graph is ready for graph partition.
- Introduce an optional sub-index to function Arg nodes, in order to distinguish between two arguments with the same "index". It happens after replacing a packed _Arg node which is assigned to a CompositeDevice with multiple replica nodes (one per device).

The "index" of an _Arg node is unique before expanding it. It's also unique within each subgraph after graph partition.

PiperOrigin-RevId: 309781835
Change-Id: Ic6e351f45b7523288b5dae30997ddf0dae86660b
This commit is contained in:
Yujing Zhang 2020-05-04 11:15:05 -07:00 committed by TensorFlower Gardener
parent 4158c029ef
commit 8fcb130e92
16 changed files with 356 additions and 125 deletions

View File

@ -631,6 +631,7 @@ cc_library(
], ],
copts = tf_copts(), copts = tf_copts(),
deps = [ deps = [
":composite_device",
":device", ":device",
":device_mgr", ":device_mgr",
":device_set", ":device_set",
@ -651,6 +652,7 @@ cc_library(
":process_util", ":process_util",
":rendezvous_mgr", ":rendezvous_mgr",
":rendezvous_util", ":rendezvous_util",
":replicate_per_replica_nodes",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:graph", "//tensorflow/core:graph",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -658,6 +660,7 @@ cc_library(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],

View File

@ -52,9 +52,14 @@ Status ExecuteNodeArgs::Init(
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
if (has_remote_inputs_) { if (has_remote_inputs_) {
serialize_remote_handle_ = serialize_remote_handle_ =
[ctx, &op_inputs](const int i, [ctx, &op_inputs](const FunctionArgIndex& index,
eager::RemoteTensorHandle* handle) -> Status { eager::RemoteTensorHandle* handle) -> Status {
VariantDevice variant_device = op_inputs[i]->device(); if (index.sub_index >= 0) {
return errors::InvalidArgument("Got unexpected sub_index ",
index.sub_index, " for argument ",
index.index);
}
VariantDevice variant_device = op_inputs[index.index]->device();
if (VariantDeviceIsCustom(variant_device)) { if (VariantDeviceIsCustom(variant_device)) {
return errors::Internal( return errors::Internal(
"Custom devices and remote execution are currently not supported " "Custom devices and remote execution are currently not supported "
@ -62,7 +67,7 @@ Status ExecuteNodeArgs::Init(
} }
Device* device = absl::get<Device*>(variant_device); Device* device = absl::get<Device*>(variant_device);
return ctx->RemoteMgr()->SerializeRemoteTensorHandle( return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
op_inputs[i], handle, device, device->name()); op_inputs[index.index], handle, device, device->name());
}; };
} }
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM

View File

@ -54,10 +54,12 @@ class ExecuteNodeArgs : public EagerKernelArgs {
const absl::InlinedVector<TensorHandle*, 4>& op_inputs, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
const core::RefCountPtr<KernelAndDevice>& kernel); const core::RefCountPtr<KernelAndDevice>& kernel);
bool HasRemoteInputs() const override { return has_remote_inputs_; }; bool HasRemoteOrPackedInputs() const override {
return has_remote_inputs_ || has_packed_inputs_;
};
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
Status GetRemoteArg(const int index, Status GetRemoteArg(const FunctionArgIndex& index,
eager::RemoteTensorHandle* val) const override { eager::RemoteTensorHandle* val) const override {
return serialize_remote_handle_(index, val); return serialize_remote_handle_(index, val);
} }
@ -65,8 +67,9 @@ class ExecuteNodeArgs : public EagerKernelArgs {
private: private:
bool has_remote_inputs_ = false; bool has_remote_inputs_ = false;
bool has_packed_inputs_ = false;
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
std::function<Status(const int, eager::RemoteTensorHandle*)> std::function<Status(const FunctionArgIndex&, eager::RemoteTensorHandle*)>
serialize_remote_handle_; serialize_remote_handle_;
#endif // IS_MOBILE_PLATFORM #endif // IS_MOBILE_PLATFORM
}; };

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/profiler/lib/annotated_traceme.h" #include "tensorflow/core/profiler/lib/annotated_traceme.h"
#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme.h"
@ -49,13 +50,18 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
Status EagerKernelArgs::GetLocalArg(const int index, Tensor* val) const { Status EagerKernelArgs::GetLocalArg(const FunctionArgIndex& index,
Tensor* arg = tensor_args_.at(index).tensor; Tensor* val) const {
if (index.sub_index >= 0) {
return errors::InvalidArgument("Got unexpected sub_index ", index.sub_index,
" for argument ", index.index);
}
Tensor* arg = tensor_args_.at(index.index).tensor;
if (arg) { if (arg) {
*val = *arg; *val = *arg;
return Status::OK(); return Status::OK();
} else { } else {
return errors::NotFound("Argument ", index, " has no local tensor."); return errors::NotFound("Argument ", index.index, " has no local tensor.");
} }
} }

View File

@ -66,16 +66,16 @@ class EagerKernelArgs : public FunctionArgsInterface {
~EagerKernelArgs() override{}; ~EagerKernelArgs() override{};
bool HasRemoteInputs() const override { return false; }; bool HasRemoteOrPackedInputs() const override { return false; };
TensorValue* MutableInput(int i) { return &tensor_args_[i]; } TensorValue* MutableInput(int i) { return &tensor_args_[i]; }
Status GetLocalArg(const int index, Tensor* val) const override; Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override;
std::vector<Tensor> GetLocalTensors() const override; std::vector<Tensor> GetLocalTensors() const override;
const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const override { const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const {
return &tensor_args_; return &tensor_args_;
}; }
protected: protected:
gtl::InlinedVector<TensorValue, 4> tensor_args_; gtl::InlinedVector<TensorValue, 4> tensor_args_;

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/partitioning_utils.h" #include "tensorflow/core/common_runtime/partitioning_utils.h"
#include <algorithm> #include <algorithm>
#include <utility>
#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
@ -73,11 +74,11 @@ Status PartitionFunctionGraph(
} }
Status UpdateArgAndRetvalMetadata( Status UpdateArgAndRetvalMetadata(
Graph* subgraph, const string& device_type, std::vector<int>* arg_indices, Graph* subgraph, const string& device_type,
std::vector<int>* ret_indices, std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
std::vector<AllocatorAttributes>* arg_alloc_attrs, std::vector<AllocatorAttributes>* arg_alloc_attrs,
std::vector<AllocatorAttributes>* ret_alloc_attrs) { std::vector<AllocatorAttributes>* ret_alloc_attrs) {
std::vector<std::pair<Node*, int>> arg_nodes; std::vector<std::pair<Node*, FunctionArgIndex>> arg_nodes;
std::vector<std::pair<Node*, int>> ret_nodes; std::vector<std::pair<Node*, int>> ret_nodes;
const AttrValue* attr_value; const AttrValue* attr_value;
@ -87,7 +88,11 @@ Status UpdateArgAndRetvalMetadata(
if (node->IsArg()) { if (node->IsArg()) {
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value)); TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
int index = static_cast<int>(attr_value->i()); int index = static_cast<int>(attr_value->i());
arg_nodes.emplace_back(node, index); int sub_index = -1;
if (node->attrs().Find("sub_index", &attr_value).ok()) {
sub_index = static_cast<int>(attr_value->i());
}
arg_nodes.emplace_back(node, FunctionArgIndex(index, sub_index));
} else if (node->IsRetval()) { } else if (node->IsRetval()) {
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value)); TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
int index = static_cast<int>(attr_value->i()); int index = static_cast<int>(attr_value->i());
@ -99,11 +104,16 @@ Status UpdateArgAndRetvalMetadata(
// //
// In particular, this enables calling a single-partition function with // In particular, this enables calling a single-partition function with
// the same signature as the original unpartitioned function. // the same signature as the original unpartitioned function.
auto comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) { auto arg_comparator = [](std::pair<Node*, FunctionArgIndex> a,
std::pair<Node*, FunctionArgIndex> b) {
return std::tie(a.second.index, a.second.sub_index) <
std::tie(b.second.index, b.second.sub_index);
};
std::sort(arg_nodes.begin(), arg_nodes.end(), arg_comparator);
auto ret_comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
return a.second < b.second; return a.second < b.second;
}; };
std::sort(arg_nodes.begin(), arg_nodes.end(), comparator); std::sort(ret_nodes.begin(), ret_nodes.end(), ret_comparator);
std::sort(ret_nodes.begin(), ret_nodes.end(), comparator);
arg_indices->reserve(arg_nodes.size()); arg_indices->reserve(arg_nodes.size());
for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second); for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
@ -144,16 +154,6 @@ Status UpdateArgAndRetvalMetadata(
return Status::OK(); return Status::OK();
} }
std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
gtl::ArraySlice<Tensor> arguments) {
std::vector<Tensor> args;
args.reserve(indices.size());
for (int i : indices) {
args.push_back(arguments[i]);
}
return args;
}
string FunctionNameGenerator::GetName() { string FunctionNameGenerator::GetName() {
while (true) { while (true) {
const string candidate = strings::StrCat(name_, "_", counter_++); const string candidate = strings::StrCat(name_, "_", counter_++);

View File

@ -58,15 +58,11 @@ Status PartitionFunctionGraph(
// (3) records which `Arg` and `Retval` nodes live in host memory in // (3) records which `Arg` and `Retval` nodes live in host memory in
// `*_alloc_attrs`. // `*_alloc_attrs`.
Status UpdateArgAndRetvalMetadata( Status UpdateArgAndRetvalMetadata(
Graph* subgraph, const string& device_type, std::vector<int>* arg_indices, Graph* subgraph, const string& device_type,
std::vector<int>* ret_indices, std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
std::vector<AllocatorAttributes>* arg_alloc_attrs, std::vector<AllocatorAttributes>* arg_alloc_attrs,
std::vector<AllocatorAttributes>* ret_alloc_attrs); std::vector<AllocatorAttributes>* ret_alloc_attrs);
// Extracts tensors at `indices` from `arguments`.
std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
gtl::ArraySlice<Tensor> arguments);
// Utility for generating function names not present in `flib_def`, using // Utility for generating function names not present in `flib_def`, using
// given `name` as the base for the name. // given `name` as the base for the name.
class FunctionNameGenerator { class FunctionNameGenerator {

View File

@ -158,14 +158,23 @@ TEST_F(PartitioningUtilsTest, TwoDevices) {
ASSERT_EQ(3, part2->num_op_nodes()); ASSERT_EQ(3, part2->num_op_nodes());
} }
void CheckIndices(const std::vector<int>& expected, void CheckRetIndices(const std::vector<int>& expected,
const std::vector<int>& actual) { const std::vector<int>& actual) {
ASSERT_EQ(expected.size(), actual.size()); ASSERT_EQ(expected.size(), actual.size());
for (int i = 0; i < expected.size(); ++i) { for (int i = 0; i < expected.size(); ++i) {
ASSERT_EQ(expected[i], actual[i]) << " at index " << i; ASSERT_EQ(expected[i], actual[i]) << " at index " << i;
} }
} }
void CheckArgIndices(const std::vector<FunctionArgIndex>& expected,
const std::vector<FunctionArgIndex>& actual) {
ASSERT_EQ(expected.size(), actual.size());
for (int i = 0; i < expected.size(); ++i) {
ASSERT_EQ(expected[i].index, actual[i].index) << " at index " << i;
ASSERT_EQ(expected[i].sub_index, actual[i].sub_index) << " at index " << i;
}
}
void CheckAlloc(const std::vector<bool>& expected, void CheckAlloc(const std::vector<bool>& expected,
const std::vector<AllocatorAttributes>& actual) { const std::vector<AllocatorAttributes>& actual) {
ASSERT_EQ(expected.size(), actual.size()); ASSERT_EQ(expected.size(), actual.size());
@ -185,7 +194,7 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
auto graph = absl::make_unique<Graph>(OpRegistry::Global()); auto graph = absl::make_unique<Graph>(OpRegistry::Global());
SubGraph(graph.get(), DT_FLOAT, {3}, {5}); SubGraph(graph.get(), DT_FLOAT, {3}, {5});
std::vector<int> arg_indices; std::vector<FunctionArgIndex> arg_indices;
std::vector<int> ret_indices; std::vector<int> ret_indices;
std::vector<AllocatorAttributes> arg_alloc_attrs; std::vector<AllocatorAttributes> arg_alloc_attrs;
std::vector<AllocatorAttributes> ret_alloc_attrs; std::vector<AllocatorAttributes> ret_alloc_attrs;
@ -197,8 +206,8 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
&ret_alloc_attrs); &ret_alloc_attrs);
ASSERT_TRUE(status.ok()) << status.ToString(); ASSERT_TRUE(status.ok()) << status.ToString();
CheckIndices({3}, arg_indices); CheckArgIndices({{3, -1}}, arg_indices);
CheckIndices({5}, ret_indices); CheckRetIndices({5}, ret_indices);
CheckAlloc({false}, arg_alloc_attrs); CheckAlloc({false}, arg_alloc_attrs);
CheckAlloc({false}, ret_alloc_attrs); CheckAlloc({false}, ret_alloc_attrs);
@ -213,7 +222,18 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets_Order) {
auto graph = absl::make_unique<Graph>(OpRegistry::Global()); auto graph = absl::make_unique<Graph>(OpRegistry::Global());
SubGraph(graph.get(), DT_FLOAT, {9, 7, 5, 3, 1}, {2, 4, 6, 8, 10}); SubGraph(graph.get(), DT_FLOAT, {9, 7, 5, 3, 1}, {2, 4, 6, 8, 10});
std::vector<int> arg_indices; const std::map<int, int> sub_indices = {
{7, 2}, {3, 1}, {1, 0}, {5, 2}, {9, 0}};
const AttrValue* attr_value;
for (Node* n : graph->op_nodes()) {
if (n->IsArg()) {
TF_ASSERT_OK(n->attrs().Find("index", &attr_value));
n->AddAttr("sub_index",
sub_indices.at(static_cast<int>(attr_value->i())));
}
}
std::vector<FunctionArgIndex> arg_indices;
std::vector<int> ret_indices; std::vector<int> ret_indices;
std::vector<AllocatorAttributes> arg_alloc_attrs; std::vector<AllocatorAttributes> arg_alloc_attrs;
std::vector<AllocatorAttributes> ret_alloc_attrs; std::vector<AllocatorAttributes> ret_alloc_attrs;
@ -225,8 +245,8 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets_Order) {
&ret_alloc_attrs); &ret_alloc_attrs);
ASSERT_TRUE(status.ok()) << status.ToString(); ASSERT_TRUE(status.ok()) << status.ToString();
CheckIndices({1, 3, 5, 7, 9}, arg_indices); CheckArgIndices({{1, 0}, {3, 1}, {5, 2}, {7, 2}, {9, 0}}, arg_indices);
CheckIndices({2, 4, 6, 8, 10}, ret_indices); CheckRetIndices({2, 4, 6, 8, 10}, ret_indices);
CheckAlloc({false, false, false, false, false}, arg_alloc_attrs); CheckAlloc({false, false, false, false, false}, arg_alloc_attrs);
CheckAlloc({false, false, false, false, false}, ret_alloc_attrs); CheckAlloc({false, false, false, false, false}, ret_alloc_attrs);
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <iterator> #include <iterator>
#include <utility> #include <utility>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/device_set.h"
@ -29,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/rendezvous_util.h" #include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
@ -301,7 +303,7 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
// Replace the given handle with the handle for the single component // Replace the given handle with the handle for the single component
// function. // function.
handle = component_data.handle_; handle = component_data.handle;
} }
auto iter = function_data_.find(handle); auto iter = function_data_.find(handle);
@ -777,6 +779,14 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options)); OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
// Expand the nodes assigned to a CompositeDevice before graph partition to
// avoid generating a subgraph on a virtual device for execution.
// This transformation should happen as late as possible, in order to run as
// more graph optimization passes (e.g. PRE_PLACEMENT, PLACER,
// POST_PLACEMENT, POST_REWRITE_FOR_EXEC) on a smaller graph as possible.
TF_RETURN_IF_ERROR(ReplicatePerReplicaNodesInFunctionGraph(
options.composite_devices, graph.get()));
if (options.graph_collector != nullptr) { if (options.graph_collector != nullptr) {
GraphDef def; GraphDef def;
graph->ToGraphDef(&def); graph->ToGraphDef(&def);
@ -869,9 +879,9 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
Graph* subgraph = pair.second.get(); Graph* subgraph = pair.second.get();
status->Update(UpdateArgAndRetvalMetadata( status->Update(UpdateArgAndRetvalMetadata(
subgraph, device_type, &comp_data->arg_indices_, subgraph, device_type, &comp_data->arg_indices,
&comp_data->ret_indices_, &comp_data->arg_alloc_attrs_, &comp_data->ret_indices, &comp_data->arg_alloc_attrs,
&comp_data->ret_alloc_attrs_)); &comp_data->ret_alloc_attrs));
if (!status->ok()) { if (!status->ok()) {
counter.DecrementCount(); counter.DecrementCount();
return; return;
@ -913,7 +923,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
data->is_cross_process_ = true; data->is_cross_process_ = true;
} }
} }
comp_data->handle_ = *component_handle; comp_data->handle = *component_handle;
} }
delete component_handle; delete component_handle;
counter.DecrementCount(); counter.DecrementCount();
@ -955,16 +965,16 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
for (const auto& pair : data->glue_) { for (const auto& pair : data->glue_) {
const ComponentFunctionData& comp_data = pair.second; const ComponentFunctionData& comp_data = pair.second;
DCHECK(comp_data.ret_alloc_attrs_.size() == comp_data.ret_indices_.size()); DCHECK(comp_data.ret_alloc_attrs.size() == comp_data.ret_indices.size());
const string& target = pair.first; const string& target = pair.first;
FunctionLibraryRuntime* target_flr = GetFLR(target); FunctionLibraryRuntime* target_flr = GetFLR(target);
if (target_flr == nullptr) { if (target_flr == nullptr) {
if (!comp_data.ret_indices_.empty()) { if (!comp_data.ret_indices.empty()) {
return errors::Unimplemented( return errors::Unimplemented(
"Currently, outputting tensors on remote devices is not supported. " "Currently, outputting tensors on remote devices is not supported. "
"The ", "The ",
comp_data.ret_indices_[0], comp_data.ret_indices[0],
"-th return value of the function outputs to target_device: ", "-th return value of the function outputs to target_device: ",
target, target,
" Please copy the tensor to local device explicitly using " " Please copy the tensor to local device explicitly using "
@ -973,17 +983,17 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
continue; continue;
} }
Device* target_device = target_flr->device(); Device* target_device = target_flr->device();
const FunctionBody* fbody = target_flr->GetFunctionBody(comp_data.handle_); const FunctionBody* fbody = target_flr->GetFunctionBody(comp_data.handle);
DCHECK(fbody != nullptr); DCHECK(fbody != nullptr);
output_devices->resize(data->num_outputs_); output_devices->resize(data->num_outputs_);
for (int j = 0; j < comp_data.ret_indices_.size(); ++j) { for (int j = 0; j < comp_data.ret_indices.size(); ++j) {
int ret_index = comp_data.ret_indices_[j]; int ret_index = comp_data.ret_indices[j];
if (fbody->ret_types[j] == DT_RESOURCE) { if (fbody->ret_types[j] == DT_RESOURCE) {
(*output_devices)[ret_index] = target_device; (*output_devices)[ret_index] = target_device;
} else { } else {
(*output_devices)[ret_index] = (*output_devices)[ret_index] =
comp_data.ret_alloc_attrs_[j].on_host() ? nullptr : target_device; comp_data.ret_alloc_attrs[j].on_host() ? nullptr : target_device;
} }
} }
} }
@ -1013,9 +1023,8 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
const MultiDeviceFunctionData* data = IsMultiDevice(handle); const MultiDeviceFunctionData* data = IsMultiDevice(handle);
if (data == nullptr) { if (data == nullptr) {
done( done(errors::NotFound("Multi-device function handle ", handle,
errors::InvalidArgument("Failed for find multi-device function handle ", "not found. Was the function instantiated?"));
handle, ". Was the function instantiated?"));
return; return;
} }
@ -1046,10 +1055,10 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
for (const auto& pair : data->glue_) { for (const auto& pair : data->glue_) {
const string& target = pair.first; const string& target = pair.first;
const ComponentFunctionData& comp_data = pair.second; const ComponentFunctionData& comp_data = pair.second;
FunctionLibraryRuntime::Handle handle = pair.second.handle_; FunctionLibraryRuntime::Handle handle = pair.second.handle;
opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs_; opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs;
opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs_; opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs;
opts_copy.remote_execution = false; opts_copy.remote_execution = false;
InternalArgs comp_args; InternalArgs comp_args;
@ -1086,7 +1095,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
Status(status.code(), function_and_msg)); Status(status.code(), function_and_msg));
} else { } else {
for (int i = 0; i < comp_rets->size(); ++i) { for (int i = 0; i < comp_rets->size(); ++i) {
(*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i]; (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
} }
} }
delete comp_rets; delete comp_rets;
@ -1108,7 +1117,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
refcounted_done->UpdateStatus(status); refcounted_done->UpdateStatus(status);
} else { } else {
for (int i = 0; i < comp_rets->size(); ++i) { for (int i = 0; i < comp_rets->size(); ++i) {
(*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i]; (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
} }
} }
delete comp_rets; delete comp_rets;
@ -1225,7 +1234,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
Status overall_status; Status overall_status;
for (const auto& it : mdata->glue_) { for (const auto& it : mdata->glue_) {
const string& device = it.first; const string& device = it.first;
FunctionLibraryRuntime::Handle flr_handle = it.second.handle_; FunctionLibraryRuntime::Handle flr_handle = it.second.handle;
FunctionLibraryRuntime* flr = GetFLR(device); FunctionLibraryRuntime* flr = GetFLR(device);
if (flr == nullptr) { if (flr == nullptr) {
// TODO(nareshmodi): Implement DeregisterGraph call to remote device if // TODO(nareshmodi): Implement DeregisterGraph call to remote device if
@ -1297,6 +1306,19 @@ ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback(
}; };
} }
Status ProcessFunctionLibraryRuntime::CreateRendezvous(
const FunctionLibraryRuntime::Options& opts,
Rendezvous** created_rendezvous) const {
if (rendezvous_factory_) {
return rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous);
} else {
return errors::FailedPrecondition(
"The caller does not provide a rendezvous and "
"ProcessFunctionLibraryRuntime was created without a rendezvous "
"factory.");
}
}
void ProcessFunctionLibraryRuntime::Run( void ProcessFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts, const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args, FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
@ -1305,21 +1327,12 @@ void ProcessFunctionLibraryRuntime::Run(
FunctionLibraryRuntime::Options new_opts = opts; FunctionLibraryRuntime::Options new_opts = opts;
Rendezvous* created_rendezvous = nullptr; Rendezvous* created_rendezvous = nullptr;
if (!opts.rendezvous) { if (!opts.rendezvous) {
if (rendezvous_factory_) { Status s = CreateRendezvous(opts, &created_rendezvous);
Status s = if (!s.ok()) {
rendezvous_factory_(opts.step_id, device_mgr_, &created_rendezvous); done(s);
if (!s.ok()) {
done(s);
return;
}
new_opts.rendezvous = created_rendezvous;
} else {
done(
errors::FailedPrecondition("The caller does not provide a rendezvous "
"and ProcessFunctionLibraryRuntime was "
"created without a rendezvous factory."));
return; return;
} }
new_opts.rendezvous = created_rendezvous;
new_opts.create_rendezvous = false; new_opts.create_rendezvous = false;
} }
@ -1334,9 +1347,14 @@ void ProcessFunctionLibraryRuntime::Run(
if (multi_device) { if (multi_device) {
auto get_component_args = [&args](const ComponentFunctionData& comp_data, auto get_component_args = [&args](const ComponentFunctionData& comp_data,
InternalArgs* comp_args) -> Status { InternalArgs* comp_args) -> Status {
for (const auto& tensor : // "Index"s of _Arg nodes are unique when all arguments are local Tensors.
GetArgsForIndices(comp_data.arg_indices_, args)) { for (const auto& it : comp_data.arg_indices) {
comp_args->args.push_back(tensor); if (it.sub_index >= 0) {
return errors::InvalidArgument("Got unexpected sub_index ",
it.sub_index, " for argument ",
it.index);
}
comp_args->args.push_back(args[it.index]);
} }
return Status::OK(); return Status::OK();
}; };
@ -1520,11 +1538,23 @@ void ProcessFunctionLibraryRuntime::Run(
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args, FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
std::vector<Tensor>* rets, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const { FunctionLibraryRuntime::DoneCallback done) const {
if (!args.HasRemoteInputs()) { if (!args.HasRemoteOrPackedInputs()) {
const std::vector<Tensor> local_inputs = args.GetLocalTensors(); const std::vector<Tensor> local_inputs = args.GetLocalTensors();
return Run(opts, handle, local_inputs, rets, std::move(done)); return Run(opts, handle, local_inputs, rets, std::move(done));
} }
FunctionLibraryRuntime::Options new_opts = opts;
Rendezvous* created_rendezvous = nullptr;
if (!opts.rendezvous) {
Status s = CreateRendezvous(opts, &created_rendezvous);
if (!s.ok()) {
done(s);
return;
}
new_opts.rendezvous = created_rendezvous;
new_opts.create_rendezvous = false;
}
#if defined(IS_MOBILE_PLATFORM) #if defined(IS_MOBILE_PLATFORM)
done(errors::Unimplemented( done(errors::Unimplemented(
"Remote inputs are not available on mobile devices.")); "Remote inputs are not available on mobile devices."));
@ -1532,12 +1562,12 @@ void ProcessFunctionLibraryRuntime::Run(
#else // !IS_MOBILE_PLATFORM #else // !IS_MOBILE_PLATFORM
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>; auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id, done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
/*rendezvous=*/nullptr); created_rendezvous);
auto get_component_args = [&args](const ComponentFunctionData& comp_data, auto get_component_args = [&args](const ComponentFunctionData& comp_data,
InternalArgs* comp_args) -> Status { InternalArgs* comp_args) -> Status {
for (int i = 0; i < comp_data.arg_indices_.size(); ++i) { for (int i = 0; i < comp_data.arg_indices.size(); ++i) {
const int index = comp_data.arg_indices_.at(i); const FunctionArgIndex index = comp_data.arg_indices.at(i);
Tensor tensor; Tensor tensor;
if (args.GetLocalArg(index, &tensor).ok()) { if (args.GetLocalArg(index, &tensor).ok()) {
comp_args->args.push_back(std::move(tensor)); comp_args->args.push_back(std::move(tensor));
@ -1552,7 +1582,7 @@ void ProcessFunctionLibraryRuntime::Run(
} }
return Status::OK(); return Status::OK();
}; };
return RunMultiDevice(opts, handle, rets, cleanup_items, std::move(done), return RunMultiDevice(new_opts, handle, rets, cleanup_items, std::move(done),
std::move(get_component_args)); std::move(get_component_args));
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
} }

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "absl/types/variant.h" #include "absl/types/variant.h"
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
@ -40,16 +41,15 @@ class FunctionArgsInterface {
public: public:
virtual ~FunctionArgsInterface() {} virtual ~FunctionArgsInterface() {}
virtual bool HasRemoteInputs() const = 0; virtual bool HasRemoteOrPackedInputs() const = 0;
virtual Status GetLocalArg(const int index, Tensor* val) const = 0; virtual Status GetLocalArg(const FunctionArgIndex& index,
Tensor* val) const = 0;
virtual std::vector<Tensor> GetLocalTensors() const = 0; virtual std::vector<Tensor> GetLocalTensors() const = 0;
virtual const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const = 0;
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
virtual Status GetRemoteArg(const int index, virtual Status GetRemoteArg(const FunctionArgIndex& index,
eager::RemoteTensorHandle* val) const { eager::RemoteTensorHandle* val) const {
return errors::Unimplemented( return errors::Unimplemented(
"Serializing a remote argument is not implemented."); "Serializing a remote argument is not implemented.");
@ -217,6 +217,12 @@ class ProcessFunctionLibraryRuntime {
return lib_def_; return lib_def_;
} }
// Add a CompositeDevice to `device_set_`
void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
device_set_->AddDevice(d);
}
protected: protected:
friend class FunctionLibraryRuntimeImpl; friend class FunctionLibraryRuntimeImpl;
@ -232,21 +238,21 @@ class ProcessFunctionLibraryRuntime {
// piece of a multi-device function) fits into the multi-device function. // piece of a multi-device function) fits into the multi-device function.
struct ComponentFunctionData { struct ComponentFunctionData {
// The handle for the instantiated component function. // The handle for the instantiated component function.
FunctionLibraryRuntime::Handle handle_; FunctionLibraryRuntime::Handle handle;
// arg_indices_.size() is the number of arguments to the component function. // arg_indices.size() is the number of arguments to the component function.
// The i-th argument of the component function comes from the // The i-th argument of the component function comes from the
// `arg_indices_[i]`-th argument of the multi-device function. // `arg_indices[i]`-th argument of the multi-device function.
std::vector<int> arg_indices_; std::vector<FunctionArgIndex> arg_indices;
// ret_indices_.size() is the number of return values of the component // ret_indices.size() is the number of return values of the component
// function. The i-th return value of the component function goes to the // function. The i-th return value of the component function goes to the
// `ret_indices_[i]`-th return value of the multi-device function. // `ret_indices[i]`-th return value of the multi-device function.
std::vector<int> ret_indices_; std::vector<int> ret_indices;
// arg_alloc_attrs_[i] are the allocator attributes of the i-th argument to // arg_alloc_attrs[i] are the allocator attributes of the i-th argument to
// the component function. // the component function.
std::vector<AllocatorAttributes> arg_alloc_attrs_; std::vector<AllocatorAttributes> arg_alloc_attrs;
// ret_alloc_attrs_[i] are the allocator attributes of the i-th return value // ret_alloc_attrs[i] are the allocator attributes of the i-th return value
// of the component function. // of the component function.
std::vector<AllocatorAttributes> ret_alloc_attrs_; std::vector<AllocatorAttributes> ret_alloc_attrs;
}; };
// Data structure holding information for a single instantiated multi-device // Data structure holding information for a single instantiated multi-device
@ -304,6 +310,9 @@ class ProcessFunctionLibraryRuntime {
InternalArgs* args)> InternalArgs* args)>
get_component_args) const; get_component_args) const;
Status CreateRendezvous(const FunctionLibraryRuntime::Options& opts,
Rendezvous** created_rendezvous) const;
FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback( FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback(
std::vector<std::unique_ptr<CleanUpItem>>* items, std::vector<std::unique_ptr<CleanUpItem>>* items,
FunctionLibraryRuntime::DoneCallback done, const int64 step_id, FunctionLibraryRuntime::DoneCallback done, const int64 step_id,

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function_testlib.h" #include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
@ -143,6 +144,10 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
}})); }}));
} }
void AddCompositeDevice(CompositeDevice* d) {
proc_flr_->AddCompositeDevice(d);
}
Status Instantiate( Status Instantiate(
const string& name, test::function::Attrs attrs, const string& name, test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
@ -187,11 +192,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} }
template <typename T>
Status RunWithRuntime( Status RunWithRuntime(
const string& name, FunctionLibraryRuntime::Options opts, const string& name, FunctionLibraryRuntime::Options opts,
test::function::Attrs attrs, test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
const std::vector<Tensor>& args, std::vector<Tensor*> rets, const T& args, std::vector<Tensor*> rets,
ProcessFunctionLibraryRuntime* pflr) { ProcessFunctionLibraryRuntime* pflr) {
FunctionLibraryRuntime::Handle handle; FunctionLibraryRuntime::Handle handle;
Status status = pflr->Instantiate(name, attrs, instantiate_opts, &handle); Status status = pflr->Instantiate(name, attrs, instantiate_opts, &handle);
@ -248,9 +254,20 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
Status Run(const string& name, FunctionLibraryRuntime::Options opts, Status Run(const string& name, FunctionLibraryRuntime::Options opts,
test::function::Attrs attrs, test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
const std::vector<Tensor>& args, std::vector<Tensor*> rets) { const std::vector<Tensor>& args, std::vector<Tensor*> rets,
return RunWithRuntime(name, opts, attrs, instantiate_opts, args, rets, ProcessFunctionLibraryRuntime* pflr = nullptr) {
proc_flr_.get()); return RunWithRuntime<std::vector<Tensor>>(
name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get());
}
Status RunWithPackedArgs(
const string& name, FunctionLibraryRuntime::Options opts,
test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
const FunctionArgsInterface& args, std::vector<Tensor*> rets,
ProcessFunctionLibraryRuntime* pflr = nullptr) {
return RunWithRuntime<FunctionArgsInterface>(
name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get());
} }
Status RunInstantiated(FunctionLibraryRuntime::Handle handle, Status RunInstantiated(FunctionLibraryRuntime::Handle handle,
@ -719,6 +736,112 @@ Tensor GetResourceHandle(const string& var_name, const string& container,
return tensor; return tensor;
} }
// Returns a function which adds two variables on different devices.
FunctionDef AddVarAcrossDevices() {
return FunctionDefHelper::Create(
// Name
"AddVarAcrossDevices",
// Args
{"x: resource"},
// Return values
{"y: float"},
// Attr def
{},
// Nodes
{
{{"read0"},
"ReadVariableOp",
{"x"},
{{"dtype", DT_FLOAT}},
{},
"/device:CPU:0"},
{{"read1"},
"ReadVariableOp",
{"x"},
{{"dtype", DT_FLOAT}},
{},
"/device:CPU:1"},
{{"add"},
"Add",
{"read0:value:0", "read1:value:0"},
{{"T", DT_FLOAT}},
{},
"/device:CPU:0"},
},
{{"y", "add:z:0"}});
}
// An implementation of FunctionArgsInterface for packed inputs.
class TestFunctionPackedArgs : public FunctionArgsInterface {
public:
TestFunctionPackedArgs(const int index,
gtl::InlinedVector<TensorValue, 4>&& tensor_args) {
packed_args_.emplace(index, std::move(tensor_args));
}
~TestFunctionPackedArgs() override{};
bool HasRemoteOrPackedInputs() const override { return true; };
Status GetLocalArg(const FunctionArgIndex& index,
Tensor* val) const override {
*val = *packed_args_.at(index.index).at(index.sub_index).tensor;
return Status::OK();
};
std::vector<Tensor> GetLocalTensors() const override { return {}; }
private:
absl::flat_hash_map<int, gtl::InlinedVector<TensorValue, 4>> packed_args_;
};
TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) {
Init({AddVarAcrossDevices()});
// Create two variables on two devices.
const Tensor initial_resource_value0 = test::AsTensor<float>({10, 20});
Var* resource0 = new Var(DT_FLOAT);
*resource0->tensor() = initial_resource_value0;
resource0->is_initialized = true;
const Tensor initial_resource_value1 = test::AsTensor<float>({30, 40});
Var* resource1 = new Var(DT_FLOAT);
*resource1->tensor() = initial_resource_value1;
resource1->is_initialized = true;
ResourceMgr* mgr0 = device0_->resource_manager();
ResourceMgr* mgr1 = device1_->resource_manager();
TF_ASSERT_OK(mgr0->Create(mgr0->default_container(), "var", resource0));
TF_ASSERT_OK(mgr1->Create(mgr1->default_container(), "var", resource1));
Tensor resource_handle0 =
GetResourceHandle("var", mgr0->default_container(), device0_->name());
Tensor resource_handle1 =
GetResourceHandle("var", mgr1->default_container(), device1_->name());
// Create a CompositeDevice
Status s;
std::unique_ptr<CompositeDevice> composite_device =
CompositeDevice::MakeDevice({device0_->name(), device1_->name()},
/*unique_device_id=*/0, &s);
TF_ASSERT_OK(s);
AddCompositeDevice(composite_device.get());
FunctionLibraryRuntime::Options opts;
FunctionLibraryRuntime::InstantiateOptions inst_opts =
MakeOptions("CPU:0", {"COMPOSITE:0"}, {"CPU:0"});
inst_opts.composite_devices[composite_device->name()] =
composite_device->underlying_devices();
inst_opts.input_resource_dtypes_and_shapes[0] = {
initial_resource_value0.dtype(), initial_resource_value0.shape()};
gtl::InlinedVector<TensorValue, 4> handles;
handles.push_back(TensorValue(&resource_handle0));
handles.push_back(TensorValue(&resource_handle1));
TestFunctionPackedArgs args(0, std::move(handles));
Tensor ret;
TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}},
inst_opts, args, {&ret}));
test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
}
TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) { TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
if (gpu_device_ == nullptr) { if (gpu_device_ == nullptr) {
GTEST_SKIP() << "No GPUs available"; GTEST_SKIP() << "No GPUs available";
@ -1025,9 +1148,9 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0"; instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
const auto x = test::AsTensor<int64>({17}); const auto x = test::AsTensor<int64>({17});
Tensor y; Tensor y;
TF_CHECK_OK(RunWithRuntime("SessionMetadataReaderFn", opts, {}, TF_CHECK_OK(RunWithRuntime<std::vector<Tensor>>(
instantiate_opts, {x}, {&y}, "SessionMetadataReaderFn", opts, {}, instantiate_opts, {x}, {&y},
cloned_proc_flr.get())); cloned_proc_flr.get()));
SessionMetadata read_metadata; SessionMetadata read_metadata;
ASSERT_TRUE(protobuf::TextFormat::ParseFromString(y.scalar<tstring>()(), ASSERT_TRUE(protobuf::TextFormat::ParseFromString(y.scalar<tstring>()(),
&read_metadata)); &read_metadata));

View File

@ -42,6 +42,7 @@ class ReplicateHelper {
Node* replicated_node = graph->AddNode(node_def, &status); Node* replicated_node = graph->AddNode(node_def, &status);
TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(status);
replicated_node->set_assigned_device_name(device); replicated_node->set_assigned_device_name(device);
replicated_node->AddAttr("sub_index", i);
replicated_nodes[i] = replicated_node; replicated_nodes[i] = replicated_node;
} }
replicated_nodes_map_.emplace(node, std::move(replicated_nodes)); replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
@ -180,7 +181,8 @@ Status ReplicateEdges(const ReplicateHelper& helper,
} // namespace } // namespace
Status ReplicatePerReplicaNodesInFunctionGraph( Status ReplicatePerReplicaNodesInFunctionGraph(
const absl::flat_hash_map<string, std::vector<string>>& composite_devices, const absl::flat_hash_map<string, const std::vector<string>*>&
composite_devices,
Graph* graph) { Graph* graph) {
std::set<string> composite_device_names; std::set<string> composite_device_names;
for (const auto& it : composite_devices) { for (const auto& it : composite_devices) {
@ -198,7 +200,8 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
} }
for (const auto& it : composite_device_to_cluster_nodes) { for (const auto& it : composite_device_to_cluster_nodes) {
const std::vector<string>& allowed_devices = composite_devices.at(it.first); const std::vector<string>& allowed_devices =
*composite_devices.at(it.first);
if (allowed_devices.empty()) { if (allowed_devices.empty()) {
return errors::InvalidArgument("No allowed device of composite device: ", return errors::InvalidArgument("No allowed device of composite device: ",
it.first); it.first);
@ -208,6 +211,7 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
// Reuse the original nodes if there is only one allowed device. // Reuse the original nodes if there is only one allowed device.
for (Node* n : cluster_nodes) { for (Node* n : cluster_nodes) {
n->set_assigned_device_name(allowed_devices.at(0)); n->set_assigned_device_name(allowed_devices.at(0));
n->AddAttr("sub_index", 0);
} }
continue; continue;
} }

View File

@ -35,7 +35,8 @@ namespace tensorflow {
// dependency. // dependency.
// TODO(b/145922293): Register it as a POST_REWRITE_FOR_EXEC pass. // TODO(b/145922293): Register it as a POST_REWRITE_FOR_EXEC pass.
Status ReplicatePerReplicaNodesInFunctionGraph( Status ReplicatePerReplicaNodesInFunctionGraph(
const absl::flat_hash_map<string, std::vector<string>>& composite_devices, const absl::flat_hash_map<string, const std::vector<string>*>&
composite_devices,
Graph* graph); Graph* graph);
} // namespace tensorflow } // namespace tensorflow

View File

@ -75,8 +75,9 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDevice) {
auto ret = ops::_Retval( auto ret = ops::_Retval(
scope.WithOpName("ret").WithControlDependencies({write}), read, 0); scope.WithOpName("ret").WithControlDependencies({write}), read, 0);
const absl::flat_hash_map<string, std::vector<string>> composite_devices = { const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
{"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}}}; const absl::flat_hash_map<string, const std::vector<string>*>
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
Graph graph(OpRegistry::Global()); Graph graph(OpRegistry::Global());
TF_ASSERT_OK(scope.ToGraph(&graph)); TF_ASSERT_OK(scope.ToGraph(&graph));
@ -118,8 +119,9 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDeviceToSingleDevice) {
auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32); auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0); auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
const absl::flat_hash_map<string, std::vector<string>> composite_devices = { const std::vector<string> underlying_devices = {"TPU:0"};
{"TPU_COMPOSITE:0", {"TPU:0"}}}; const absl::flat_hash_map<string, const std::vector<string>*>
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
Graph graph(OpRegistry::Global()); Graph graph(OpRegistry::Global());
TF_ASSERT_OK(scope.ToGraph(&graph)); TF_ASSERT_OK(scope.ToGraph(&graph));
@ -156,9 +158,11 @@ TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) {
auto add = ops::Add(scope.WithOpName("add"), identity0, identity1); auto add = ops::Add(scope.WithOpName("add"), identity0, identity1);
auto ret = ops::_Retval(scope.WithOpName("ret"), add, 0); auto ret = ops::_Retval(scope.WithOpName("ret"), add, 0);
const absl::flat_hash_map<string, std::vector<string>> composite_devices = { const std::vector<string> underlying_devices_0 = {"TPU:0", "TPU:1"};
{"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}}, const std::vector<string> underlying_devices_1 = {"TPU:2", "TPU:3"};
{"TPU_COMPOSITE:1", {"TPU:2", "TPU:3"}}}; const absl::flat_hash_map<string, const std::vector<string>*>
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices_0},
{"TPU_COMPOSITE:1", &underlying_devices_1}};
Graph graph(OpRegistry::Global()); Graph graph(OpRegistry::Global());
TF_ASSERT_OK(scope.ToGraph(&graph)); TF_ASSERT_OK(scope.ToGraph(&graph));
@ -204,8 +208,9 @@ TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) {
} }
TEST(ReplicatePerReplicaNodesTest, NestedFunctions) { TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
const absl::flat_hash_map<string, std::vector<string>> composite_devices = { const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
{"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}}}; const absl::flat_hash_map<string, const std::vector<string>*>
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
FunctionDefLibrary fdef_lib; FunctionDefLibrary fdef_lib;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);

View File

@ -500,11 +500,11 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
: EagerKernelArgs(std::move(tensor_args)), : EagerKernelArgs(std::move(tensor_args)),
serialize_remote_handle_(std::move(serialize_remote_handle)) {} serialize_remote_handle_(std::move(serialize_remote_handle)) {}
bool HasRemoteInputs() const override { return true; } bool HasRemoteOrPackedInputs() const override { return true; }
Status GetRemoteArg(const int index, Status GetRemoteArg(const FunctionArgIndex& index,
eager::RemoteTensorHandle* val) const override { eager::RemoteTensorHandle* val) const override {
return serialize_remote_handle_(index, val); return serialize_remote_handle_(index.index, val);
} }
private: private:
@ -562,7 +562,14 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
eager_pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( eager_pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
remote_device_mgr_.get(), Env::Default(), /*config=*/ remote_device_mgr_.get(), Env::Default(), /*config=*/
nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(), nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
/*thread_pool=*/nullptr, eager_cluster_flr_.get()); /*thread_pool=*/nullptr, eager_cluster_flr_.get(),
/*custom_kernel_creator=*/nullptr, /*session_metadata=*/nullptr,
Rendezvous::Factory{[this](const int64 step_id,
const DeviceMgr* device_mgr,
Rendezvous** r) {
*r = worker_env_.rendezvous_mgr->Find(step_id);
return Status::OK();
}});
} }
void CheckOutputTensorAndClose(const Tensor& tensor) { void CheckOutputTensorAndClose(const Tensor& tensor) {

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/platform.h"
// clang-format on // clang-format on
#include "absl/container/flat_hash_map.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "absl/types/variant.h" #include "absl/types/variant.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
@ -525,6 +526,20 @@ class Device;
// Forward declare. Defined in common_runtime/device_mgr.h // Forward declare. Defined in common_runtime/device_mgr.h
class DeviceMgr; class DeviceMgr;
// Index of an _Arg node.
struct FunctionArgIndex {
explicit FunctionArgIndex(const int index) : index(index) {}
FunctionArgIndex(const int index, const int sub_index)
: index(index), sub_index(sub_index) {}
// The value of the attribute "Index" of the _Arg node.
int index;
// Set only when the _Arg node represents multiple arguments (e.g. an _Arg
// node is replicated to multiple devices/subgraphs). Use sub-index to
// distinguish arguments with the same index.
int sub_index = -1;
};
class FunctionLibraryRuntime { class FunctionLibraryRuntime {
public: public:
virtual ~FunctionLibraryRuntime() {} virtual ~FunctionLibraryRuntime() {}
@ -576,6 +591,10 @@ class FunctionLibraryRuntime {
// infer correct device. // infer correct device.
std::vector<string> output_devices; std::vector<string> output_devices;
// Maps from a CompositeDevice name to a list of underlying physical
// devices.
absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
// This interface is EXPERIMENTAL and subject to change. // This interface is EXPERIMENTAL and subject to change.
// //
// For multi-device functions, a mapping from _Arg node index to type and // For multi-device functions, a mapping from _Arg node index to type and