Support packed tensor inputs in ProcessFunctionLibraryRuntime.

- Expand the packed _Arg nodes when the graph is ready for graph partition.
- Introduce an optional sub-index to function Arg nodes, in order to distinguish between two arguments with the same "index". It happens after replacing a packed _Arg node which is assigned to a CompositeDevice with multiple replica nodes (one per device).

The "index" of an _Arg node is unique before expanding it. It's also unique within each subgraph after graph partition.

PiperOrigin-RevId: 309781835
Change-Id: Ic6e351f45b7523288b5dae30997ddf0dae86660b
This commit is contained in:
Yujing Zhang 2020-05-04 11:15:05 -07:00 committed by TensorFlower Gardener
parent 4158c029ef
commit 8fcb130e92
16 changed files with 356 additions and 125 deletions

View File

@ -631,6 +631,7 @@ cc_library(
],
copts = tf_copts(),
deps = [
":composite_device",
":device",
":device_mgr",
":device_set",
@ -651,6 +652,7 @@ cc_library(
":process_util",
":rendezvous_mgr",
":rendezvous_util",
":replicate_per_replica_nodes",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
@ -658,6 +660,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],

View File

@ -52,9 +52,14 @@ Status ExecuteNodeArgs::Init(
#if !defined(IS_MOBILE_PLATFORM)
if (has_remote_inputs_) {
serialize_remote_handle_ =
[ctx, &op_inputs](const int i,
[ctx, &op_inputs](const FunctionArgIndex& index,
eager::RemoteTensorHandle* handle) -> Status {
VariantDevice variant_device = op_inputs[i]->device();
if (index.sub_index >= 0) {
return errors::InvalidArgument("Got unexpected sub_index ",
index.sub_index, " for argument ",
index.index);
}
VariantDevice variant_device = op_inputs[index.index]->device();
if (VariantDeviceIsCustom(variant_device)) {
return errors::Internal(
"Custom devices and remote execution are currently not supported "
@ -62,7 +67,7 @@ Status ExecuteNodeArgs::Init(
}
Device* device = absl::get<Device*>(variant_device);
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
op_inputs[i], handle, device, device->name());
op_inputs[index.index], handle, device, device->name());
};
}
#endif // !IS_MOBILE_PLATFORM

View File

@ -54,10 +54,12 @@ class ExecuteNodeArgs : public EagerKernelArgs {
const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
const core::RefCountPtr<KernelAndDevice>& kernel);
bool HasRemoteInputs() const override { return has_remote_inputs_; };
bool HasRemoteOrPackedInputs() const override {
return has_remote_inputs_ || has_packed_inputs_;
};
#if !defined(IS_MOBILE_PLATFORM)
Status GetRemoteArg(const int index,
Status GetRemoteArg(const FunctionArgIndex& index,
eager::RemoteTensorHandle* val) const override {
return serialize_remote_handle_(index, val);
}
@ -65,8 +67,9 @@ class ExecuteNodeArgs : public EagerKernelArgs {
private:
bool has_remote_inputs_ = false;
bool has_packed_inputs_ = false;
#if !defined(IS_MOBILE_PLATFORM)
std::function<Status(const int, eager::RemoteTensorHandle*)>
std::function<Status(const FunctionArgIndex&, eager::RemoteTensorHandle*)>
serialize_remote_handle_;
#endif // IS_MOBILE_PLATFORM
};

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/profiler/lib/annotated_traceme.h"
#include "tensorflow/core/profiler/lib/traceme.h"
@ -49,13 +50,18 @@ limitations under the License.
namespace tensorflow {
Status EagerKernelArgs::GetLocalArg(const int index, Tensor* val) const {
Tensor* arg = tensor_args_.at(index).tensor;
Status EagerKernelArgs::GetLocalArg(const FunctionArgIndex& index,
Tensor* val) const {
if (index.sub_index >= 0) {
return errors::InvalidArgument("Got unexpected sub_index ", index.sub_index,
" for argument ", index.index);
}
Tensor* arg = tensor_args_.at(index.index).tensor;
if (arg) {
*val = *arg;
return Status::OK();
} else {
return errors::NotFound("Argument ", index, " has no local tensor.");
return errors::NotFound("Argument ", index.index, " has no local tensor.");
}
}

View File

@ -66,16 +66,16 @@ class EagerKernelArgs : public FunctionArgsInterface {
~EagerKernelArgs() override{};
bool HasRemoteInputs() const override { return false; };
bool HasRemoteOrPackedInputs() const override { return false; };
TensorValue* MutableInput(int i) { return &tensor_args_[i]; }
Status GetLocalArg(const int index, Tensor* val) const override;
Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override;
std::vector<Tensor> GetLocalTensors() const override;
const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const override {
const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const {
return &tensor_args_;
};
}
protected:
gtl::InlinedVector<TensorValue, 4> tensor_args_;

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/partitioning_utils.h"
#include <algorithm>
#include <utility>
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/function.h"
@ -73,11 +74,11 @@ Status PartitionFunctionGraph(
}
Status UpdateArgAndRetvalMetadata(
Graph* subgraph, const string& device_type, std::vector<int>* arg_indices,
std::vector<int>* ret_indices,
Graph* subgraph, const string& device_type,
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
std::vector<AllocatorAttributes>* arg_alloc_attrs,
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
std::vector<std::pair<Node*, int>> arg_nodes;
std::vector<std::pair<Node*, FunctionArgIndex>> arg_nodes;
std::vector<std::pair<Node*, int>> ret_nodes;
const AttrValue* attr_value;
@ -87,7 +88,11 @@ Status UpdateArgAndRetvalMetadata(
if (node->IsArg()) {
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
int index = static_cast<int>(attr_value->i());
arg_nodes.emplace_back(node, index);
int sub_index = -1;
if (node->attrs().Find("sub_index", &attr_value).ok()) {
sub_index = static_cast<int>(attr_value->i());
}
arg_nodes.emplace_back(node, FunctionArgIndex(index, sub_index));
} else if (node->IsRetval()) {
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
int index = static_cast<int>(attr_value->i());
@ -99,11 +104,16 @@ Status UpdateArgAndRetvalMetadata(
//
// In particular, this enables calling a single-partition function with
// the same signature as the original unpartitioned function.
auto comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
auto arg_comparator = [](std::pair<Node*, FunctionArgIndex> a,
std::pair<Node*, FunctionArgIndex> b) {
return std::tie(a.second.index, a.second.sub_index) <
std::tie(b.second.index, b.second.sub_index);
};
std::sort(arg_nodes.begin(), arg_nodes.end(), arg_comparator);
auto ret_comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
return a.second < b.second;
};
std::sort(arg_nodes.begin(), arg_nodes.end(), comparator);
std::sort(ret_nodes.begin(), ret_nodes.end(), comparator);
std::sort(ret_nodes.begin(), ret_nodes.end(), ret_comparator);
arg_indices->reserve(arg_nodes.size());
for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
@ -144,16 +154,6 @@ Status UpdateArgAndRetvalMetadata(
return Status::OK();
}
std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
gtl::ArraySlice<Tensor> arguments) {
std::vector<Tensor> args;
args.reserve(indices.size());
for (int i : indices) {
args.push_back(arguments[i]);
}
return args;
}
string FunctionNameGenerator::GetName() {
while (true) {
const string candidate = strings::StrCat(name_, "_", counter_++);

View File

@ -58,15 +58,11 @@ Status PartitionFunctionGraph(
// (3) records which `Arg` and `Retval` nodes live in host memory in
// `*_alloc_attrs`.
Status UpdateArgAndRetvalMetadata(
Graph* subgraph, const string& device_type, std::vector<int>* arg_indices,
std::vector<int>* ret_indices,
Graph* subgraph, const string& device_type,
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
std::vector<AllocatorAttributes>* arg_alloc_attrs,
std::vector<AllocatorAttributes>* ret_alloc_attrs);
// Extracts tensors at `indices` from `arguments`.
std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
gtl::ArraySlice<Tensor> arguments);
// Utility for generating function names not present in `flib_def`, using
// given `name` as the base for the name.
class FunctionNameGenerator {

View File

@ -158,14 +158,23 @@ TEST_F(PartitioningUtilsTest, TwoDevices) {
ASSERT_EQ(3, part2->num_op_nodes());
}
void CheckIndices(const std::vector<int>& expected,
const std::vector<int>& actual) {
void CheckRetIndices(const std::vector<int>& expected,
const std::vector<int>& actual) {
ASSERT_EQ(expected.size(), actual.size());
for (int i = 0; i < expected.size(); ++i) {
ASSERT_EQ(expected[i], actual[i]) << " at index " << i;
}
}
void CheckArgIndices(const std::vector<FunctionArgIndex>& expected,
const std::vector<FunctionArgIndex>& actual) {
ASSERT_EQ(expected.size(), actual.size());
for (int i = 0; i < expected.size(); ++i) {
ASSERT_EQ(expected[i].index, actual[i].index) << " at index " << i;
ASSERT_EQ(expected[i].sub_index, actual[i].sub_index) << " at index " << i;
}
}
void CheckAlloc(const std::vector<bool>& expected,
const std::vector<AllocatorAttributes>& actual) {
ASSERT_EQ(expected.size(), actual.size());
@ -185,7 +194,7 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
SubGraph(graph.get(), DT_FLOAT, {3}, {5});
std::vector<int> arg_indices;
std::vector<FunctionArgIndex> arg_indices;
std::vector<int> ret_indices;
std::vector<AllocatorAttributes> arg_alloc_attrs;
std::vector<AllocatorAttributes> ret_alloc_attrs;
@ -197,8 +206,8 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
&ret_alloc_attrs);
ASSERT_TRUE(status.ok()) << status.ToString();
CheckIndices({3}, arg_indices);
CheckIndices({5}, ret_indices);
CheckArgIndices({{3, -1}}, arg_indices);
CheckRetIndices({5}, ret_indices);
CheckAlloc({false}, arg_alloc_attrs);
CheckAlloc({false}, ret_alloc_attrs);
@ -213,7 +222,18 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets_Order) {
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
SubGraph(graph.get(), DT_FLOAT, {9, 7, 5, 3, 1}, {2, 4, 6, 8, 10});
std::vector<int> arg_indices;
const std::map<int, int> sub_indices = {
{7, 2}, {3, 1}, {1, 0}, {5, 2}, {9, 0}};
const AttrValue* attr_value;
for (Node* n : graph->op_nodes()) {
if (n->IsArg()) {
TF_ASSERT_OK(n->attrs().Find("index", &attr_value));
n->AddAttr("sub_index",
sub_indices.at(static_cast<int>(attr_value->i())));
}
}
std::vector<FunctionArgIndex> arg_indices;
std::vector<int> ret_indices;
std::vector<AllocatorAttributes> arg_alloc_attrs;
std::vector<AllocatorAttributes> ret_alloc_attrs;
@ -225,8 +245,8 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets_Order) {
&ret_alloc_attrs);
ASSERT_TRUE(status.ok()) << status.ToString();
CheckIndices({1, 3, 5, 7, 9}, arg_indices);
CheckIndices({2, 4, 6, 8, 10}, ret_indices);
CheckArgIndices({{1, 0}, {3, 1}, {5, 2}, {7, 2}, {9, 0}}, arg_indices);
CheckRetIndices({2, 4, 6, 8, 10}, ret_indices);
CheckAlloc({false, false, false, false, false}, arg_alloc_attrs);
CheckAlloc({false, false, false, false, false}, ret_alloc_attrs);
}

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <iterator>
#include <utility>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/common_runtime/device_set.h"
@ -29,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -301,7 +303,7 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
// Replace the given handle with the handle for the single component
// function.
handle = component_data.handle_;
handle = component_data.handle;
}
auto iter = function_data_.find(handle);
@ -777,6 +779,14 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
// Expand the nodes assigned to a CompositeDevice before graph partition to
// avoid generating a subgraph on a virtual device for execution.
// This transformation should happen as late as possible, in order to run as
// more graph optimization passes (e.g. PRE_PLACEMENT, PLACER,
// POST_PLACEMENT, POST_REWRITE_FOR_EXEC) on a smaller graph as possible.
TF_RETURN_IF_ERROR(ReplicatePerReplicaNodesInFunctionGraph(
options.composite_devices, graph.get()));
if (options.graph_collector != nullptr) {
GraphDef def;
graph->ToGraphDef(&def);
@ -869,9 +879,9 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
Graph* subgraph = pair.second.get();
status->Update(UpdateArgAndRetvalMetadata(
subgraph, device_type, &comp_data->arg_indices_,
&comp_data->ret_indices_, &comp_data->arg_alloc_attrs_,
&comp_data->ret_alloc_attrs_));
subgraph, device_type, &comp_data->arg_indices,
&comp_data->ret_indices, &comp_data->arg_alloc_attrs,
&comp_data->ret_alloc_attrs));
if (!status->ok()) {
counter.DecrementCount();
return;
@ -913,7 +923,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
data->is_cross_process_ = true;
}
}
comp_data->handle_ = *component_handle;
comp_data->handle = *component_handle;
}
delete component_handle;
counter.DecrementCount();
@ -955,16 +965,16 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
for (const auto& pair : data->glue_) {
const ComponentFunctionData& comp_data = pair.second;
DCHECK(comp_data.ret_alloc_attrs_.size() == comp_data.ret_indices_.size());
DCHECK(comp_data.ret_alloc_attrs.size() == comp_data.ret_indices.size());
const string& target = pair.first;
FunctionLibraryRuntime* target_flr = GetFLR(target);
if (target_flr == nullptr) {
if (!comp_data.ret_indices_.empty()) {
if (!comp_data.ret_indices.empty()) {
return errors::Unimplemented(
"Currently, outputting tensors on remote devices is not supported. "
"The ",
comp_data.ret_indices_[0],
comp_data.ret_indices[0],
"-th return value of the function outputs to target_device: ",
target,
" Please copy the tensor to local device explicitly using "
@ -973,17 +983,17 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
continue;
}
Device* target_device = target_flr->device();
const FunctionBody* fbody = target_flr->GetFunctionBody(comp_data.handle_);
const FunctionBody* fbody = target_flr->GetFunctionBody(comp_data.handle);
DCHECK(fbody != nullptr);
output_devices->resize(data->num_outputs_);
for (int j = 0; j < comp_data.ret_indices_.size(); ++j) {
int ret_index = comp_data.ret_indices_[j];
for (int j = 0; j < comp_data.ret_indices.size(); ++j) {
int ret_index = comp_data.ret_indices[j];
if (fbody->ret_types[j] == DT_RESOURCE) {
(*output_devices)[ret_index] = target_device;
} else {
(*output_devices)[ret_index] =
comp_data.ret_alloc_attrs_[j].on_host() ? nullptr : target_device;
comp_data.ret_alloc_attrs[j].on_host() ? nullptr : target_device;
}
}
}
@ -1013,9 +1023,8 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
const MultiDeviceFunctionData* data = IsMultiDevice(handle);
if (data == nullptr) {
done(
errors::InvalidArgument("Failed for find multi-device function handle ",
handle, ". Was the function instantiated?"));
done(errors::NotFound("Multi-device function handle ", handle,
"not found. Was the function instantiated?"));
return;
}
@ -1046,10 +1055,10 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
for (const auto& pair : data->glue_) {
const string& target = pair.first;
const ComponentFunctionData& comp_data = pair.second;
FunctionLibraryRuntime::Handle handle = pair.second.handle_;
FunctionLibraryRuntime::Handle handle = pair.second.handle;
opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs_;
opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs_;
opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs;
opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs;
opts_copy.remote_execution = false;
InternalArgs comp_args;
@ -1086,7 +1095,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
Status(status.code(), function_and_msg));
} else {
for (int i = 0; i < comp_rets->size(); ++i) {
(*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i];
(*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
}
}
delete comp_rets;
@ -1108,7 +1117,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
refcounted_done->UpdateStatus(status);
} else {
for (int i = 0; i < comp_rets->size(); ++i) {
(*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i];
(*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
}
}
delete comp_rets;
@ -1225,7 +1234,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
Status overall_status;
for (const auto& it : mdata->glue_) {
const string& device = it.first;
FunctionLibraryRuntime::Handle flr_handle = it.second.handle_;
FunctionLibraryRuntime::Handle flr_handle = it.second.handle;
FunctionLibraryRuntime* flr = GetFLR(device);
if (flr == nullptr) {
// TODO(nareshmodi): Implement DeregisterGraph call to remote device if
@ -1297,6 +1306,19 @@ ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback(
};
}
Status ProcessFunctionLibraryRuntime::CreateRendezvous(
const FunctionLibraryRuntime::Options& opts,
Rendezvous** created_rendezvous) const {
if (rendezvous_factory_) {
return rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous);
} else {
return errors::FailedPrecondition(
"The caller does not provide a rendezvous and "
"ProcessFunctionLibraryRuntime was created without a rendezvous "
"factory.");
}
}
void ProcessFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
@ -1305,21 +1327,12 @@ void ProcessFunctionLibraryRuntime::Run(
FunctionLibraryRuntime::Options new_opts = opts;
Rendezvous* created_rendezvous = nullptr;
if (!opts.rendezvous) {
if (rendezvous_factory_) {
Status s =
rendezvous_factory_(opts.step_id, device_mgr_, &created_rendezvous);
if (!s.ok()) {
done(s);
return;
}
new_opts.rendezvous = created_rendezvous;
} else {
done(
errors::FailedPrecondition("The caller does not provide a rendezvous "
"and ProcessFunctionLibraryRuntime was "
"created without a rendezvous factory."));
Status s = CreateRendezvous(opts, &created_rendezvous);
if (!s.ok()) {
done(s);
return;
}
new_opts.rendezvous = created_rendezvous;
new_opts.create_rendezvous = false;
}
@ -1334,9 +1347,14 @@ void ProcessFunctionLibraryRuntime::Run(
if (multi_device) {
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
InternalArgs* comp_args) -> Status {
for (const auto& tensor :
GetArgsForIndices(comp_data.arg_indices_, args)) {
comp_args->args.push_back(tensor);
// "Index"s of _Arg nodes are unique when all arguments are local Tensors.
for (const auto& it : comp_data.arg_indices) {
if (it.sub_index >= 0) {
return errors::InvalidArgument("Got unexpected sub_index ",
it.sub_index, " for argument ",
it.index);
}
comp_args->args.push_back(args[it.index]);
}
return Status::OK();
};
@ -1520,11 +1538,23 @@ void ProcessFunctionLibraryRuntime::Run(
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
if (!args.HasRemoteInputs()) {
if (!args.HasRemoteOrPackedInputs()) {
const std::vector<Tensor> local_inputs = args.GetLocalTensors();
return Run(opts, handle, local_inputs, rets, std::move(done));
}
FunctionLibraryRuntime::Options new_opts = opts;
Rendezvous* created_rendezvous = nullptr;
if (!opts.rendezvous) {
Status s = CreateRendezvous(opts, &created_rendezvous);
if (!s.ok()) {
done(s);
return;
}
new_opts.rendezvous = created_rendezvous;
new_opts.create_rendezvous = false;
}
#if defined(IS_MOBILE_PLATFORM)
done(errors::Unimplemented(
"Remote inputs are not available on mobile devices."));
@ -1532,12 +1562,12 @@ void ProcessFunctionLibraryRuntime::Run(
#else // !IS_MOBILE_PLATFORM
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
/*rendezvous=*/nullptr);
created_rendezvous);
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
InternalArgs* comp_args) -> Status {
for (int i = 0; i < comp_data.arg_indices_.size(); ++i) {
const int index = comp_data.arg_indices_.at(i);
for (int i = 0; i < comp_data.arg_indices.size(); ++i) {
const FunctionArgIndex index = comp_data.arg_indices.at(i);
Tensor tensor;
if (args.GetLocalArg(index, &tensor).ok()) {
comp_args->args.push_back(std::move(tensor));
@ -1552,7 +1582,7 @@ void ProcessFunctionLibraryRuntime::Run(
}
return Status::OK();
};
return RunMultiDevice(opts, handle, rets, cleanup_items, std::move(done),
return RunMultiDevice(new_opts, handle, rets, cleanup_items, std::move(done),
std::move(get_component_args));
#endif // !IS_MOBILE_PLATFORM
}

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/function.h"
@ -40,16 +41,15 @@ class FunctionArgsInterface {
public:
virtual ~FunctionArgsInterface() {}
virtual bool HasRemoteInputs() const = 0;
virtual bool HasRemoteOrPackedInputs() const = 0;
virtual Status GetLocalArg(const int index, Tensor* val) const = 0;
virtual Status GetLocalArg(const FunctionArgIndex& index,
Tensor* val) const = 0;
virtual std::vector<Tensor> GetLocalTensors() const = 0;
virtual const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const = 0;
#if !defined(IS_MOBILE_PLATFORM)
virtual Status GetRemoteArg(const int index,
virtual Status GetRemoteArg(const FunctionArgIndex& index,
eager::RemoteTensorHandle* val) const {
return errors::Unimplemented(
"Serializing a remote argument is not implemented.");
@ -217,6 +217,12 @@ class ProcessFunctionLibraryRuntime {
return lib_def_;
}
// Add a CompositeDevice to `device_set_`
void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
device_set_->AddDevice(d);
}
protected:
friend class FunctionLibraryRuntimeImpl;
@ -232,21 +238,21 @@ class ProcessFunctionLibraryRuntime {
// piece of a multi-device function) fits into the multi-device function.
struct ComponentFunctionData {
// The handle for the instantiated component function.
FunctionLibraryRuntime::Handle handle_;
// arg_indices_.size() is the number of arguments to the component function.
FunctionLibraryRuntime::Handle handle;
// arg_indices.size() is the number of arguments to the component function.
// The i-th argument of the component function comes from the
// `arg_indices_[i]`-th argument of the multi-device function.
std::vector<int> arg_indices_;
// ret_indices_.size() is the number of return values of the component
// `arg_indices[i]`-th argument of the multi-device function.
std::vector<FunctionArgIndex> arg_indices;
// ret_indices.size() is the number of return values of the component
// function. The i-th return value of the component function goes to the
// `ret_indices_[i]`-th return value of the multi-device function.
std::vector<int> ret_indices_;
// arg_alloc_attrs_[i] are the allocator attributes of the i-th argument to
// `ret_indices[i]`-th return value of the multi-device function.
std::vector<int> ret_indices;
// arg_alloc_attrs[i] are the allocator attributes of the i-th argument to
// the component function.
std::vector<AllocatorAttributes> arg_alloc_attrs_;
// ret_alloc_attrs_[i] are the allocator attributes of the i-th return value
std::vector<AllocatorAttributes> arg_alloc_attrs;
// ret_alloc_attrs[i] are the allocator attributes of the i-th return value
// of the component function.
std::vector<AllocatorAttributes> ret_alloc_attrs_;
std::vector<AllocatorAttributes> ret_alloc_attrs;
};
// Data structure holding information for a single instantiated multi-device
@ -304,6 +310,9 @@ class ProcessFunctionLibraryRuntime {
InternalArgs* args)>
get_component_args) const;
Status CreateRendezvous(const FunctionLibraryRuntime::Options& opts,
Rendezvous** created_rendezvous) const;
FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback(
std::vector<std::unique_ptr<CleanUpItem>>* items,
FunctionLibraryRuntime::DoneCallback done, const int64 step_id,

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
@ -143,6 +144,10 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
}}));
}
void AddCompositeDevice(CompositeDevice* d) {
proc_flr_->AddCompositeDevice(d);
}
Status Instantiate(
const string& name, test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
@ -187,11 +192,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
}
template <typename T>
Status RunWithRuntime(
const string& name, FunctionLibraryRuntime::Options opts,
test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
const T& args, std::vector<Tensor*> rets,
ProcessFunctionLibraryRuntime* pflr) {
FunctionLibraryRuntime::Handle handle;
Status status = pflr->Instantiate(name, attrs, instantiate_opts, &handle);
@ -248,9 +254,20 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
Status Run(const string& name, FunctionLibraryRuntime::Options opts,
test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
return RunWithRuntime(name, opts, attrs, instantiate_opts, args, rets,
proc_flr_.get());
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
ProcessFunctionLibraryRuntime* pflr = nullptr) {
return RunWithRuntime<std::vector<Tensor>>(
name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get());
}
Status RunWithPackedArgs(
const string& name, FunctionLibraryRuntime::Options opts,
test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
const FunctionArgsInterface& args, std::vector<Tensor*> rets,
ProcessFunctionLibraryRuntime* pflr = nullptr) {
return RunWithRuntime<FunctionArgsInterface>(
name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get());
}
Status RunInstantiated(FunctionLibraryRuntime::Handle handle,
@ -719,6 +736,112 @@ Tensor GetResourceHandle(const string& var_name, const string& container,
return tensor;
}
// Returns a function which adds two variables on different devices.
FunctionDef AddVarAcrossDevices() {
return FunctionDefHelper::Create(
// Name
"AddVarAcrossDevices",
// Args
{"x: resource"},
// Return values
{"y: float"},
// Attr def
{},
// Nodes
{
{{"read0"},
"ReadVariableOp",
{"x"},
{{"dtype", DT_FLOAT}},
{},
"/device:CPU:0"},
{{"read1"},
"ReadVariableOp",
{"x"},
{{"dtype", DT_FLOAT}},
{},
"/device:CPU:1"},
{{"add"},
"Add",
{"read0:value:0", "read1:value:0"},
{{"T", DT_FLOAT}},
{},
"/device:CPU:0"},
},
{{"y", "add:z:0"}});
}
// An implementation of FunctionArgsInterface for packed inputs.
class TestFunctionPackedArgs : public FunctionArgsInterface {
public:
TestFunctionPackedArgs(const int index,
gtl::InlinedVector<TensorValue, 4>&& tensor_args) {
packed_args_.emplace(index, std::move(tensor_args));
}
~TestFunctionPackedArgs() override{};
bool HasRemoteOrPackedInputs() const override { return true; };
Status GetLocalArg(const FunctionArgIndex& index,
Tensor* val) const override {
*val = *packed_args_.at(index.index).at(index.sub_index).tensor;
return Status::OK();
};
std::vector<Tensor> GetLocalTensors() const override { return {}; }
private:
absl::flat_hash_map<int, gtl::InlinedVector<TensorValue, 4>> packed_args_;
};
TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) {
Init({AddVarAcrossDevices()});
// Create two variables on two devices.
const Tensor initial_resource_value0 = test::AsTensor<float>({10, 20});
Var* resource0 = new Var(DT_FLOAT);
*resource0->tensor() = initial_resource_value0;
resource0->is_initialized = true;
const Tensor initial_resource_value1 = test::AsTensor<float>({30, 40});
Var* resource1 = new Var(DT_FLOAT);
*resource1->tensor() = initial_resource_value1;
resource1->is_initialized = true;
ResourceMgr* mgr0 = device0_->resource_manager();
ResourceMgr* mgr1 = device1_->resource_manager();
TF_ASSERT_OK(mgr0->Create(mgr0->default_container(), "var", resource0));
TF_ASSERT_OK(mgr1->Create(mgr1->default_container(), "var", resource1));
Tensor resource_handle0 =
GetResourceHandle("var", mgr0->default_container(), device0_->name());
Tensor resource_handle1 =
GetResourceHandle("var", mgr1->default_container(), device1_->name());
// Create a CompositeDevice
Status s;
std::unique_ptr<CompositeDevice> composite_device =
CompositeDevice::MakeDevice({device0_->name(), device1_->name()},
/*unique_device_id=*/0, &s);
TF_ASSERT_OK(s);
AddCompositeDevice(composite_device.get());
FunctionLibraryRuntime::Options opts;
FunctionLibraryRuntime::InstantiateOptions inst_opts =
MakeOptions("CPU:0", {"COMPOSITE:0"}, {"CPU:0"});
inst_opts.composite_devices[composite_device->name()] =
composite_device->underlying_devices();
inst_opts.input_resource_dtypes_and_shapes[0] = {
initial_resource_value0.dtype(), initial_resource_value0.shape()};
gtl::InlinedVector<TensorValue, 4> handles;
handles.push_back(TensorValue(&resource_handle0));
handles.push_back(TensorValue(&resource_handle1));
TestFunctionPackedArgs args(0, std::move(handles));
Tensor ret;
TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}},
inst_opts, args, {&ret}));
test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
}
TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
if (gpu_device_ == nullptr) {
GTEST_SKIP() << "No GPUs available";
@ -1025,9 +1148,9 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
const auto x = test::AsTensor<int64>({17});
Tensor y;
TF_CHECK_OK(RunWithRuntime("SessionMetadataReaderFn", opts, {},
instantiate_opts, {x}, {&y},
cloned_proc_flr.get()));
TF_CHECK_OK(RunWithRuntime<std::vector<Tensor>>(
"SessionMetadataReaderFn", opts, {}, instantiate_opts, {x}, {&y},
cloned_proc_flr.get()));
SessionMetadata read_metadata;
ASSERT_TRUE(protobuf::TextFormat::ParseFromString(y.scalar<tstring>()(),
&read_metadata));

View File

@ -42,6 +42,7 @@ class ReplicateHelper {
Node* replicated_node = graph->AddNode(node_def, &status);
TF_RETURN_IF_ERROR(status);
replicated_node->set_assigned_device_name(device);
replicated_node->AddAttr("sub_index", i);
replicated_nodes[i] = replicated_node;
}
replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
@ -180,7 +181,8 @@ Status ReplicateEdges(const ReplicateHelper& helper,
} // namespace
Status ReplicatePerReplicaNodesInFunctionGraph(
const absl::flat_hash_map<string, std::vector<string>>& composite_devices,
const absl::flat_hash_map<string, const std::vector<string>*>&
composite_devices,
Graph* graph) {
std::set<string> composite_device_names;
for (const auto& it : composite_devices) {
@ -198,7 +200,8 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
}
for (const auto& it : composite_device_to_cluster_nodes) {
const std::vector<string>& allowed_devices = composite_devices.at(it.first);
const std::vector<string>& allowed_devices =
*composite_devices.at(it.first);
if (allowed_devices.empty()) {
return errors::InvalidArgument("No allowed device of composite device: ",
it.first);
@ -208,6 +211,7 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
// Reuse the original nodes if there is only one allowed device.
for (Node* n : cluster_nodes) {
n->set_assigned_device_name(allowed_devices.at(0));
n->AddAttr("sub_index", 0);
}
continue;
}

View File

@ -35,7 +35,8 @@ namespace tensorflow {
// dependency.
// TODO(b/145922293): Register it as a POST_REWRITE_FOR_EXEC pass.
Status ReplicatePerReplicaNodesInFunctionGraph(
const absl::flat_hash_map<string, std::vector<string>>& composite_devices,
const absl::flat_hash_map<string, const std::vector<string>*>&
composite_devices,
Graph* graph);
} // namespace tensorflow

View File

@ -75,8 +75,9 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDevice) {
auto ret = ops::_Retval(
scope.WithOpName("ret").WithControlDependencies({write}), read, 0);
const absl::flat_hash_map<string, std::vector<string>> composite_devices = {
{"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}}};
const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
const absl::flat_hash_map<string, const std::vector<string>*>
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
Graph graph(OpRegistry::Global());
TF_ASSERT_OK(scope.ToGraph(&graph));
@ -118,8 +119,9 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDeviceToSingleDevice) {
auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
const absl::flat_hash_map<string, std::vector<string>> composite_devices = {
{"TPU_COMPOSITE:0", {"TPU:0"}}};
const std::vector<string> underlying_devices = {"TPU:0"};
const absl::flat_hash_map<string, const std::vector<string>*>
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
Graph graph(OpRegistry::Global());
TF_ASSERT_OK(scope.ToGraph(&graph));
@ -156,9 +158,11 @@ TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) {
auto add = ops::Add(scope.WithOpName("add"), identity0, identity1);
auto ret = ops::_Retval(scope.WithOpName("ret"), add, 0);
const absl::flat_hash_map<string, std::vector<string>> composite_devices = {
{"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}},
{"TPU_COMPOSITE:1", {"TPU:2", "TPU:3"}}};
const std::vector<string> underlying_devices_0 = {"TPU:0", "TPU:1"};
const std::vector<string> underlying_devices_1 = {"TPU:2", "TPU:3"};
const absl::flat_hash_map<string, const std::vector<string>*>
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices_0},
{"TPU_COMPOSITE:1", &underlying_devices_1}};
Graph graph(OpRegistry::Global());
TF_ASSERT_OK(scope.ToGraph(&graph));
@ -204,8 +208,9 @@ TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) {
}
TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
const absl::flat_hash_map<string, std::vector<string>> composite_devices = {
{"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}}};
const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
const absl::flat_hash_map<string, const std::vector<string>*>
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
FunctionDefLibrary fdef_lib;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);

View File

@ -500,11 +500,11 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
: EagerKernelArgs(std::move(tensor_args)),
serialize_remote_handle_(std::move(serialize_remote_handle)) {}
bool HasRemoteInputs() const override { return true; }
bool HasRemoteOrPackedInputs() const override { return true; }
Status GetRemoteArg(const int index,
Status GetRemoteArg(const FunctionArgIndex& index,
eager::RemoteTensorHandle* val) const override {
return serialize_remote_handle_(index, val);
return serialize_remote_handle_(index.index, val);
}
private:
@ -562,7 +562,14 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
eager_pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
remote_device_mgr_.get(), Env::Default(), /*config=*/
nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
/*thread_pool=*/nullptr, eager_cluster_flr_.get());
/*thread_pool=*/nullptr, eager_cluster_flr_.get(),
/*custom_kernel_creator=*/nullptr, /*session_metadata=*/nullptr,
Rendezvous::Factory{[this](const int64 step_id,
const DeviceMgr* device_mgr,
Rendezvous** r) {
*r = worker_env_.rendezvous_mgr->Find(step_id);
return Status::OK();
}});
}
void CheckOutputTensorAndClose(const Tensor& tensor) {

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/container/flat_hash_map.h"
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "tensorflow/core/framework/attr_value.pb.h"
@ -525,6 +526,20 @@ class Device;
// Forward declare. Defined in common_runtime/device_mgr.h
class DeviceMgr;
// Index of an _Arg node.
struct FunctionArgIndex {
explicit FunctionArgIndex(const int index) : index(index) {}
FunctionArgIndex(const int index, const int sub_index)
: index(index), sub_index(sub_index) {}
// The value of the attribute "Index" of the _Arg node.
int index;
// Set only when the _Arg node represents multiple arguments (e.g. an _Arg
// node is replicated to multiple devices/subgraphs). Use sub-index to
// distinguish arguments with the same index.
int sub_index = -1;
};
class FunctionLibraryRuntime {
public:
virtual ~FunctionLibraryRuntime() {}
@ -576,6 +591,10 @@ class FunctionLibraryRuntime {
// infer correct device.
std::vector<string> output_devices;
// Maps from a CompositeDevice name to a list of underlying physical
// devices.
absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
// This interface is EXPERIMENTAL and subject to change.
//
// For multi-device functions, a mapping from _Arg node index to type and