Support packed tensor inputs in ProcessFunctionLibraryRuntime.
- Expand the packed _Arg nodes when the graph is ready for graph partition. - Introduce an optional sub-index to function Arg nodes, in order to distinguish between two arguments with the same "index". It happens after replacing a packed _Arg node which is assigned to a CompositeDevice with multiple replica nodes (one per device). The "index" of an _Arg node is unique before expanding it. It's also unique within each subgraph after graph partition. PiperOrigin-RevId: 309781835 Change-Id: Ic6e351f45b7523288b5dae30997ddf0dae86660b
This commit is contained in:
parent
4158c029ef
commit
8fcb130e92
@ -631,6 +631,7 @@ cc_library(
|
||||
],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":composite_device",
|
||||
":device",
|
||||
":device_mgr",
|
||||
":device_set",
|
||||
@ -651,6 +652,7 @@ cc_library(
|
||||
":process_util",
|
||||
":rendezvous_mgr",
|
||||
":rendezvous_util",
|
||||
":replicate_per_replica_nodes",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
@ -658,6 +660,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -52,9 +52,14 @@ Status ExecuteNodeArgs::Init(
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
if (has_remote_inputs_) {
|
||||
serialize_remote_handle_ =
|
||||
[ctx, &op_inputs](const int i,
|
||||
[ctx, &op_inputs](const FunctionArgIndex& index,
|
||||
eager::RemoteTensorHandle* handle) -> Status {
|
||||
VariantDevice variant_device = op_inputs[i]->device();
|
||||
if (index.sub_index >= 0) {
|
||||
return errors::InvalidArgument("Got unexpected sub_index ",
|
||||
index.sub_index, " for argument ",
|
||||
index.index);
|
||||
}
|
||||
VariantDevice variant_device = op_inputs[index.index]->device();
|
||||
if (VariantDeviceIsCustom(variant_device)) {
|
||||
return errors::Internal(
|
||||
"Custom devices and remote execution are currently not supported "
|
||||
@ -62,7 +67,7 @@ Status ExecuteNodeArgs::Init(
|
||||
}
|
||||
Device* device = absl::get<Device*>(variant_device);
|
||||
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
op_inputs[i], handle, device, device->name());
|
||||
op_inputs[index.index], handle, device, device->name());
|
||||
};
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
@ -54,10 +54,12 @@ class ExecuteNodeArgs : public EagerKernelArgs {
|
||||
const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
||||
const core::RefCountPtr<KernelAndDevice>& kernel);
|
||||
|
||||
bool HasRemoteInputs() const override { return has_remote_inputs_; };
|
||||
bool HasRemoteOrPackedInputs() const override {
|
||||
return has_remote_inputs_ || has_packed_inputs_;
|
||||
};
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
Status GetRemoteArg(const int index,
|
||||
Status GetRemoteArg(const FunctionArgIndex& index,
|
||||
eager::RemoteTensorHandle* val) const override {
|
||||
return serialize_remote_handle_(index, val);
|
||||
}
|
||||
@ -65,8 +67,9 @@ class ExecuteNodeArgs : public EagerKernelArgs {
|
||||
|
||||
private:
|
||||
bool has_remote_inputs_ = false;
|
||||
bool has_packed_inputs_ = false;
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
std::function<Status(const int, eager::RemoteTensorHandle*)>
|
||||
std::function<Status(const FunctionArgIndex&, eager::RemoteTensorHandle*)>
|
||||
serialize_remote_handle_;
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
};
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/profiler/lib/annotated_traceme.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
@ -49,13 +50,18 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status EagerKernelArgs::GetLocalArg(const int index, Tensor* val) const {
|
||||
Tensor* arg = tensor_args_.at(index).tensor;
|
||||
Status EagerKernelArgs::GetLocalArg(const FunctionArgIndex& index,
|
||||
Tensor* val) const {
|
||||
if (index.sub_index >= 0) {
|
||||
return errors::InvalidArgument("Got unexpected sub_index ", index.sub_index,
|
||||
" for argument ", index.index);
|
||||
}
|
||||
Tensor* arg = tensor_args_.at(index.index).tensor;
|
||||
if (arg) {
|
||||
*val = *arg;
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::NotFound("Argument ", index, " has no local tensor.");
|
||||
return errors::NotFound("Argument ", index.index, " has no local tensor.");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -66,16 +66,16 @@ class EagerKernelArgs : public FunctionArgsInterface {
|
||||
|
||||
~EagerKernelArgs() override{};
|
||||
|
||||
bool HasRemoteInputs() const override { return false; };
|
||||
bool HasRemoteOrPackedInputs() const override { return false; };
|
||||
TensorValue* MutableInput(int i) { return &tensor_args_[i]; }
|
||||
|
||||
Status GetLocalArg(const int index, Tensor* val) const override;
|
||||
Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override;
|
||||
|
||||
std::vector<Tensor> GetLocalTensors() const override;
|
||||
|
||||
const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const override {
|
||||
const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const {
|
||||
return &tensor_args_;
|
||||
};
|
||||
}
|
||||
|
||||
protected:
|
||||
gtl::InlinedVector<TensorValue, 4> tensor_args_;
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/partitioning_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -73,11 +74,11 @@ Status PartitionFunctionGraph(
|
||||
}
|
||||
|
||||
Status UpdateArgAndRetvalMetadata(
|
||||
Graph* subgraph, const string& device_type, std::vector<int>* arg_indices,
|
||||
std::vector<int>* ret_indices,
|
||||
Graph* subgraph, const string& device_type,
|
||||
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
||||
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
||||
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
|
||||
std::vector<std::pair<Node*, int>> arg_nodes;
|
||||
std::vector<std::pair<Node*, FunctionArgIndex>> arg_nodes;
|
||||
std::vector<std::pair<Node*, int>> ret_nodes;
|
||||
const AttrValue* attr_value;
|
||||
|
||||
@ -87,7 +88,11 @@ Status UpdateArgAndRetvalMetadata(
|
||||
if (node->IsArg()) {
|
||||
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
|
||||
int index = static_cast<int>(attr_value->i());
|
||||
arg_nodes.emplace_back(node, index);
|
||||
int sub_index = -1;
|
||||
if (node->attrs().Find("sub_index", &attr_value).ok()) {
|
||||
sub_index = static_cast<int>(attr_value->i());
|
||||
}
|
||||
arg_nodes.emplace_back(node, FunctionArgIndex(index, sub_index));
|
||||
} else if (node->IsRetval()) {
|
||||
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
|
||||
int index = static_cast<int>(attr_value->i());
|
||||
@ -99,11 +104,16 @@ Status UpdateArgAndRetvalMetadata(
|
||||
//
|
||||
// In particular, this enables calling a single-partition function with
|
||||
// the same signature as the original unpartitioned function.
|
||||
auto comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
|
||||
auto arg_comparator = [](std::pair<Node*, FunctionArgIndex> a,
|
||||
std::pair<Node*, FunctionArgIndex> b) {
|
||||
return std::tie(a.second.index, a.second.sub_index) <
|
||||
std::tie(b.second.index, b.second.sub_index);
|
||||
};
|
||||
std::sort(arg_nodes.begin(), arg_nodes.end(), arg_comparator);
|
||||
auto ret_comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
|
||||
return a.second < b.second;
|
||||
};
|
||||
std::sort(arg_nodes.begin(), arg_nodes.end(), comparator);
|
||||
std::sort(ret_nodes.begin(), ret_nodes.end(), comparator);
|
||||
std::sort(ret_nodes.begin(), ret_nodes.end(), ret_comparator);
|
||||
|
||||
arg_indices->reserve(arg_nodes.size());
|
||||
for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
|
||||
@ -144,16 +154,6 @@ Status UpdateArgAndRetvalMetadata(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
|
||||
gtl::ArraySlice<Tensor> arguments) {
|
||||
std::vector<Tensor> args;
|
||||
args.reserve(indices.size());
|
||||
for (int i : indices) {
|
||||
args.push_back(arguments[i]);
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
string FunctionNameGenerator::GetName() {
|
||||
while (true) {
|
||||
const string candidate = strings::StrCat(name_, "_", counter_++);
|
||||
|
@ -58,15 +58,11 @@ Status PartitionFunctionGraph(
|
||||
// (3) records which `Arg` and `Retval` nodes live in host memory in
|
||||
// `*_alloc_attrs`.
|
||||
Status UpdateArgAndRetvalMetadata(
|
||||
Graph* subgraph, const string& device_type, std::vector<int>* arg_indices,
|
||||
std::vector<int>* ret_indices,
|
||||
Graph* subgraph, const string& device_type,
|
||||
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
||||
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
||||
std::vector<AllocatorAttributes>* ret_alloc_attrs);
|
||||
|
||||
// Extracts tensors at `indices` from `arguments`.
|
||||
std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
|
||||
gtl::ArraySlice<Tensor> arguments);
|
||||
|
||||
// Utility for generating function names not present in `flib_def`, using
|
||||
// given `name` as the base for the name.
|
||||
class FunctionNameGenerator {
|
||||
|
@ -158,14 +158,23 @@ TEST_F(PartitioningUtilsTest, TwoDevices) {
|
||||
ASSERT_EQ(3, part2->num_op_nodes());
|
||||
}
|
||||
|
||||
void CheckIndices(const std::vector<int>& expected,
|
||||
const std::vector<int>& actual) {
|
||||
void CheckRetIndices(const std::vector<int>& expected,
|
||||
const std::vector<int>& actual) {
|
||||
ASSERT_EQ(expected.size(), actual.size());
|
||||
for (int i = 0; i < expected.size(); ++i) {
|
||||
ASSERT_EQ(expected[i], actual[i]) << " at index " << i;
|
||||
}
|
||||
}
|
||||
|
||||
void CheckArgIndices(const std::vector<FunctionArgIndex>& expected,
|
||||
const std::vector<FunctionArgIndex>& actual) {
|
||||
ASSERT_EQ(expected.size(), actual.size());
|
||||
for (int i = 0; i < expected.size(); ++i) {
|
||||
ASSERT_EQ(expected[i].index, actual[i].index) << " at index " << i;
|
||||
ASSERT_EQ(expected[i].sub_index, actual[i].sub_index) << " at index " << i;
|
||||
}
|
||||
}
|
||||
|
||||
void CheckAlloc(const std::vector<bool>& expected,
|
||||
const std::vector<AllocatorAttributes>& actual) {
|
||||
ASSERT_EQ(expected.size(), actual.size());
|
||||
@ -185,7 +194,7 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
SubGraph(graph.get(), DT_FLOAT, {3}, {5});
|
||||
|
||||
std::vector<int> arg_indices;
|
||||
std::vector<FunctionArgIndex> arg_indices;
|
||||
std::vector<int> ret_indices;
|
||||
std::vector<AllocatorAttributes> arg_alloc_attrs;
|
||||
std::vector<AllocatorAttributes> ret_alloc_attrs;
|
||||
@ -197,8 +206,8 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
|
||||
&ret_alloc_attrs);
|
||||
ASSERT_TRUE(status.ok()) << status.ToString();
|
||||
|
||||
CheckIndices({3}, arg_indices);
|
||||
CheckIndices({5}, ret_indices);
|
||||
CheckArgIndices({{3, -1}}, arg_indices);
|
||||
CheckRetIndices({5}, ret_indices);
|
||||
CheckAlloc({false}, arg_alloc_attrs);
|
||||
CheckAlloc({false}, ret_alloc_attrs);
|
||||
|
||||
@ -213,7 +222,18 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets_Order) {
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
SubGraph(graph.get(), DT_FLOAT, {9, 7, 5, 3, 1}, {2, 4, 6, 8, 10});
|
||||
|
||||
std::vector<int> arg_indices;
|
||||
const std::map<int, int> sub_indices = {
|
||||
{7, 2}, {3, 1}, {1, 0}, {5, 2}, {9, 0}};
|
||||
const AttrValue* attr_value;
|
||||
for (Node* n : graph->op_nodes()) {
|
||||
if (n->IsArg()) {
|
||||
TF_ASSERT_OK(n->attrs().Find("index", &attr_value));
|
||||
n->AddAttr("sub_index",
|
||||
sub_indices.at(static_cast<int>(attr_value->i())));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<FunctionArgIndex> arg_indices;
|
||||
std::vector<int> ret_indices;
|
||||
std::vector<AllocatorAttributes> arg_alloc_attrs;
|
||||
std::vector<AllocatorAttributes> ret_alloc_attrs;
|
||||
@ -225,8 +245,8 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets_Order) {
|
||||
&ret_alloc_attrs);
|
||||
ASSERT_TRUE(status.ok()) << status.ToString();
|
||||
|
||||
CheckIndices({1, 3, 5, 7, 9}, arg_indices);
|
||||
CheckIndices({2, 4, 6, 8, 10}, ret_indices);
|
||||
CheckArgIndices({{1, 0}, {3, 1}, {5, 2}, {7, 2}, {9, 0}}, arg_indices);
|
||||
CheckRetIndices({2, 4, 6, 8, 10}, ret_indices);
|
||||
CheckAlloc({false, false, false, false, false}, arg_alloc_attrs);
|
||||
CheckAlloc({false, false, false, false, false}, ret_alloc_attrs);
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
@ -29,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||
#include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -301,7 +303,7 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
|
||||
|
||||
// Replace the given handle with the handle for the single component
|
||||
// function.
|
||||
handle = component_data.handle_;
|
||||
handle = component_data.handle;
|
||||
}
|
||||
|
||||
auto iter = function_data_.find(handle);
|
||||
@ -777,6 +779,14 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
|
||||
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
|
||||
|
||||
// Expand the nodes assigned to a CompositeDevice before graph partition to
|
||||
// avoid generating a subgraph on a virtual device for execution.
|
||||
// This transformation should happen as late as possible, in order to run as
|
||||
// more graph optimization passes (e.g. PRE_PLACEMENT, PLACER,
|
||||
// POST_PLACEMENT, POST_REWRITE_FOR_EXEC) on a smaller graph as possible.
|
||||
TF_RETURN_IF_ERROR(ReplicatePerReplicaNodesInFunctionGraph(
|
||||
options.composite_devices, graph.get()));
|
||||
|
||||
if (options.graph_collector != nullptr) {
|
||||
GraphDef def;
|
||||
graph->ToGraphDef(&def);
|
||||
@ -869,9 +879,9 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
Graph* subgraph = pair.second.get();
|
||||
|
||||
status->Update(UpdateArgAndRetvalMetadata(
|
||||
subgraph, device_type, &comp_data->arg_indices_,
|
||||
&comp_data->ret_indices_, &comp_data->arg_alloc_attrs_,
|
||||
&comp_data->ret_alloc_attrs_));
|
||||
subgraph, device_type, &comp_data->arg_indices,
|
||||
&comp_data->ret_indices, &comp_data->arg_alloc_attrs,
|
||||
&comp_data->ret_alloc_attrs));
|
||||
if (!status->ok()) {
|
||||
counter.DecrementCount();
|
||||
return;
|
||||
@ -913,7 +923,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
data->is_cross_process_ = true;
|
||||
}
|
||||
}
|
||||
comp_data->handle_ = *component_handle;
|
||||
comp_data->handle = *component_handle;
|
||||
}
|
||||
delete component_handle;
|
||||
counter.DecrementCount();
|
||||
@ -955,16 +965,16 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
||||
|
||||
for (const auto& pair : data->glue_) {
|
||||
const ComponentFunctionData& comp_data = pair.second;
|
||||
DCHECK(comp_data.ret_alloc_attrs_.size() == comp_data.ret_indices_.size());
|
||||
DCHECK(comp_data.ret_alloc_attrs.size() == comp_data.ret_indices.size());
|
||||
|
||||
const string& target = pair.first;
|
||||
FunctionLibraryRuntime* target_flr = GetFLR(target);
|
||||
if (target_flr == nullptr) {
|
||||
if (!comp_data.ret_indices_.empty()) {
|
||||
if (!comp_data.ret_indices.empty()) {
|
||||
return errors::Unimplemented(
|
||||
"Currently, outputting tensors on remote devices is not supported. "
|
||||
"The ",
|
||||
comp_data.ret_indices_[0],
|
||||
comp_data.ret_indices[0],
|
||||
"-th return value of the function outputs to target_device: ",
|
||||
target,
|
||||
" Please copy the tensor to local device explicitly using "
|
||||
@ -973,17 +983,17 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
||||
continue;
|
||||
}
|
||||
Device* target_device = target_flr->device();
|
||||
const FunctionBody* fbody = target_flr->GetFunctionBody(comp_data.handle_);
|
||||
const FunctionBody* fbody = target_flr->GetFunctionBody(comp_data.handle);
|
||||
DCHECK(fbody != nullptr);
|
||||
|
||||
output_devices->resize(data->num_outputs_);
|
||||
for (int j = 0; j < comp_data.ret_indices_.size(); ++j) {
|
||||
int ret_index = comp_data.ret_indices_[j];
|
||||
for (int j = 0; j < comp_data.ret_indices.size(); ++j) {
|
||||
int ret_index = comp_data.ret_indices[j];
|
||||
if (fbody->ret_types[j] == DT_RESOURCE) {
|
||||
(*output_devices)[ret_index] = target_device;
|
||||
} else {
|
||||
(*output_devices)[ret_index] =
|
||||
comp_data.ret_alloc_attrs_[j].on_host() ? nullptr : target_device;
|
||||
comp_data.ret_alloc_attrs[j].on_host() ? nullptr : target_device;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1013,9 +1023,8 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
|
||||
|
||||
const MultiDeviceFunctionData* data = IsMultiDevice(handle);
|
||||
if (data == nullptr) {
|
||||
done(
|
||||
errors::InvalidArgument("Failed for find multi-device function handle ",
|
||||
handle, ". Was the function instantiated?"));
|
||||
done(errors::NotFound("Multi-device function handle ", handle,
|
||||
"not found. Was the function instantiated?"));
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1046,10 +1055,10 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
|
||||
for (const auto& pair : data->glue_) {
|
||||
const string& target = pair.first;
|
||||
const ComponentFunctionData& comp_data = pair.second;
|
||||
FunctionLibraryRuntime::Handle handle = pair.second.handle_;
|
||||
FunctionLibraryRuntime::Handle handle = pair.second.handle;
|
||||
|
||||
opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs_;
|
||||
opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs_;
|
||||
opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs;
|
||||
opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs;
|
||||
opts_copy.remote_execution = false;
|
||||
|
||||
InternalArgs comp_args;
|
||||
@ -1086,7 +1095,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
|
||||
Status(status.code(), function_and_msg));
|
||||
} else {
|
||||
for (int i = 0; i < comp_rets->size(); ++i) {
|
||||
(*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i];
|
||||
(*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
|
||||
}
|
||||
}
|
||||
delete comp_rets;
|
||||
@ -1108,7 +1117,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
|
||||
refcounted_done->UpdateStatus(status);
|
||||
} else {
|
||||
for (int i = 0; i < comp_rets->size(); ++i) {
|
||||
(*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i];
|
||||
(*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
|
||||
}
|
||||
}
|
||||
delete comp_rets;
|
||||
@ -1225,7 +1234,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
|
||||
Status overall_status;
|
||||
for (const auto& it : mdata->glue_) {
|
||||
const string& device = it.first;
|
||||
FunctionLibraryRuntime::Handle flr_handle = it.second.handle_;
|
||||
FunctionLibraryRuntime::Handle flr_handle = it.second.handle;
|
||||
FunctionLibraryRuntime* flr = GetFLR(device);
|
||||
if (flr == nullptr) {
|
||||
// TODO(nareshmodi): Implement DeregisterGraph call to remote device if
|
||||
@ -1297,6 +1306,19 @@ ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback(
|
||||
};
|
||||
}
|
||||
|
||||
Status ProcessFunctionLibraryRuntime::CreateRendezvous(
|
||||
const FunctionLibraryRuntime::Options& opts,
|
||||
Rendezvous** created_rendezvous) const {
|
||||
if (rendezvous_factory_) {
|
||||
return rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous);
|
||||
} else {
|
||||
return errors::FailedPrecondition(
|
||||
"The caller does not provide a rendezvous and "
|
||||
"ProcessFunctionLibraryRuntime was created without a rendezvous "
|
||||
"factory.");
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessFunctionLibraryRuntime::Run(
|
||||
const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
|
||||
@ -1305,21 +1327,12 @@ void ProcessFunctionLibraryRuntime::Run(
|
||||
FunctionLibraryRuntime::Options new_opts = opts;
|
||||
Rendezvous* created_rendezvous = nullptr;
|
||||
if (!opts.rendezvous) {
|
||||
if (rendezvous_factory_) {
|
||||
Status s =
|
||||
rendezvous_factory_(opts.step_id, device_mgr_, &created_rendezvous);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
new_opts.rendezvous = created_rendezvous;
|
||||
} else {
|
||||
done(
|
||||
errors::FailedPrecondition("The caller does not provide a rendezvous "
|
||||
"and ProcessFunctionLibraryRuntime was "
|
||||
"created without a rendezvous factory."));
|
||||
Status s = CreateRendezvous(opts, &created_rendezvous);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
new_opts.rendezvous = created_rendezvous;
|
||||
new_opts.create_rendezvous = false;
|
||||
}
|
||||
|
||||
@ -1334,9 +1347,14 @@ void ProcessFunctionLibraryRuntime::Run(
|
||||
if (multi_device) {
|
||||
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
|
||||
InternalArgs* comp_args) -> Status {
|
||||
for (const auto& tensor :
|
||||
GetArgsForIndices(comp_data.arg_indices_, args)) {
|
||||
comp_args->args.push_back(tensor);
|
||||
// "Index"s of _Arg nodes are unique when all arguments are local Tensors.
|
||||
for (const auto& it : comp_data.arg_indices) {
|
||||
if (it.sub_index >= 0) {
|
||||
return errors::InvalidArgument("Got unexpected sub_index ",
|
||||
it.sub_index, " for argument ",
|
||||
it.index);
|
||||
}
|
||||
comp_args->args.push_back(args[it.index]);
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
@ -1520,11 +1538,23 @@ void ProcessFunctionLibraryRuntime::Run(
|
||||
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
|
||||
std::vector<Tensor>* rets,
|
||||
FunctionLibraryRuntime::DoneCallback done) const {
|
||||
if (!args.HasRemoteInputs()) {
|
||||
if (!args.HasRemoteOrPackedInputs()) {
|
||||
const std::vector<Tensor> local_inputs = args.GetLocalTensors();
|
||||
return Run(opts, handle, local_inputs, rets, std::move(done));
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Options new_opts = opts;
|
||||
Rendezvous* created_rendezvous = nullptr;
|
||||
if (!opts.rendezvous) {
|
||||
Status s = CreateRendezvous(opts, &created_rendezvous);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
new_opts.rendezvous = created_rendezvous;
|
||||
new_opts.create_rendezvous = false;
|
||||
}
|
||||
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
done(errors::Unimplemented(
|
||||
"Remote inputs are not available on mobile devices."));
|
||||
@ -1532,12 +1562,12 @@ void ProcessFunctionLibraryRuntime::Run(
|
||||
#else // !IS_MOBILE_PLATFORM
|
||||
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
|
||||
done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
|
||||
/*rendezvous=*/nullptr);
|
||||
created_rendezvous);
|
||||
|
||||
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
|
||||
InternalArgs* comp_args) -> Status {
|
||||
for (int i = 0; i < comp_data.arg_indices_.size(); ++i) {
|
||||
const int index = comp_data.arg_indices_.at(i);
|
||||
for (int i = 0; i < comp_data.arg_indices.size(); ++i) {
|
||||
const FunctionArgIndex index = comp_data.arg_indices.at(i);
|
||||
Tensor tensor;
|
||||
if (args.GetLocalArg(index, &tensor).ok()) {
|
||||
comp_args->args.push_back(std::move(tensor));
|
||||
@ -1552,7 +1582,7 @@ void ProcessFunctionLibraryRuntime::Run(
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
return RunMultiDevice(opts, handle, rets, cleanup_items, std::move(done),
|
||||
return RunMultiDevice(new_opts, handle, rets, cleanup_items, std::move(done),
|
||||
std::move(get_component_args));
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -40,16 +41,15 @@ class FunctionArgsInterface {
|
||||
public:
|
||||
virtual ~FunctionArgsInterface() {}
|
||||
|
||||
virtual bool HasRemoteInputs() const = 0;
|
||||
virtual bool HasRemoteOrPackedInputs() const = 0;
|
||||
|
||||
virtual Status GetLocalArg(const int index, Tensor* val) const = 0;
|
||||
virtual Status GetLocalArg(const FunctionArgIndex& index,
|
||||
Tensor* val) const = 0;
|
||||
|
||||
virtual std::vector<Tensor> GetLocalTensors() const = 0;
|
||||
|
||||
virtual const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const = 0;
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
virtual Status GetRemoteArg(const int index,
|
||||
virtual Status GetRemoteArg(const FunctionArgIndex& index,
|
||||
eager::RemoteTensorHandle* val) const {
|
||||
return errors::Unimplemented(
|
||||
"Serializing a remote argument is not implemented.");
|
||||
@ -217,6 +217,12 @@ class ProcessFunctionLibraryRuntime {
|
||||
return lib_def_;
|
||||
}
|
||||
|
||||
// Add a CompositeDevice to `device_set_`
|
||||
void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
device_set_->AddDevice(d);
|
||||
}
|
||||
|
||||
protected:
|
||||
friend class FunctionLibraryRuntimeImpl;
|
||||
|
||||
@ -232,21 +238,21 @@ class ProcessFunctionLibraryRuntime {
|
||||
// piece of a multi-device function) fits into the multi-device function.
|
||||
struct ComponentFunctionData {
|
||||
// The handle for the instantiated component function.
|
||||
FunctionLibraryRuntime::Handle handle_;
|
||||
// arg_indices_.size() is the number of arguments to the component function.
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
// arg_indices.size() is the number of arguments to the component function.
|
||||
// The i-th argument of the component function comes from the
|
||||
// `arg_indices_[i]`-th argument of the multi-device function.
|
||||
std::vector<int> arg_indices_;
|
||||
// ret_indices_.size() is the number of return values of the component
|
||||
// `arg_indices[i]`-th argument of the multi-device function.
|
||||
std::vector<FunctionArgIndex> arg_indices;
|
||||
// ret_indices.size() is the number of return values of the component
|
||||
// function. The i-th return value of the component function goes to the
|
||||
// `ret_indices_[i]`-th return value of the multi-device function.
|
||||
std::vector<int> ret_indices_;
|
||||
// arg_alloc_attrs_[i] are the allocator attributes of the i-th argument to
|
||||
// `ret_indices[i]`-th return value of the multi-device function.
|
||||
std::vector<int> ret_indices;
|
||||
// arg_alloc_attrs[i] are the allocator attributes of the i-th argument to
|
||||
// the component function.
|
||||
std::vector<AllocatorAttributes> arg_alloc_attrs_;
|
||||
// ret_alloc_attrs_[i] are the allocator attributes of the i-th return value
|
||||
std::vector<AllocatorAttributes> arg_alloc_attrs;
|
||||
// ret_alloc_attrs[i] are the allocator attributes of the i-th return value
|
||||
// of the component function.
|
||||
std::vector<AllocatorAttributes> ret_alloc_attrs_;
|
||||
std::vector<AllocatorAttributes> ret_alloc_attrs;
|
||||
};
|
||||
|
||||
// Data structure holding information for a single instantiated multi-device
|
||||
@ -304,6 +310,9 @@ class ProcessFunctionLibraryRuntime {
|
||||
InternalArgs* args)>
|
||||
get_component_args) const;
|
||||
|
||||
Status CreateRendezvous(const FunctionLibraryRuntime::Options& opts,
|
||||
Rendezvous** created_rendezvous) const;
|
||||
|
||||
FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback(
|
||||
std::vector<std::unique_ptr<CleanUpItem>>* items,
|
||||
FunctionLibraryRuntime::DoneCallback done, const int64 step_id,
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/function_testlib.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
@ -143,6 +144,10 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
}}));
|
||||
}
|
||||
|
||||
void AddCompositeDevice(CompositeDevice* d) {
|
||||
proc_flr_->AddCompositeDevice(d);
|
||||
}
|
||||
|
||||
Status Instantiate(
|
||||
const string& name, test::function::Attrs attrs,
|
||||
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
|
||||
@ -187,11 +192,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status RunWithRuntime(
|
||||
const string& name, FunctionLibraryRuntime::Options opts,
|
||||
test::function::Attrs attrs,
|
||||
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
|
||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
|
||||
const T& args, std::vector<Tensor*> rets,
|
||||
ProcessFunctionLibraryRuntime* pflr) {
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status status = pflr->Instantiate(name, attrs, instantiate_opts, &handle);
|
||||
@ -248,9 +254,20 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
Status Run(const string& name, FunctionLibraryRuntime::Options opts,
|
||||
test::function::Attrs attrs,
|
||||
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
|
||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
|
||||
return RunWithRuntime(name, opts, attrs, instantiate_opts, args, rets,
|
||||
proc_flr_.get());
|
||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
|
||||
ProcessFunctionLibraryRuntime* pflr = nullptr) {
|
||||
return RunWithRuntime<std::vector<Tensor>>(
|
||||
name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get());
|
||||
}
|
||||
|
||||
Status RunWithPackedArgs(
|
||||
const string& name, FunctionLibraryRuntime::Options opts,
|
||||
test::function::Attrs attrs,
|
||||
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
|
||||
const FunctionArgsInterface& args, std::vector<Tensor*> rets,
|
||||
ProcessFunctionLibraryRuntime* pflr = nullptr) {
|
||||
return RunWithRuntime<FunctionArgsInterface>(
|
||||
name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get());
|
||||
}
|
||||
|
||||
Status RunInstantiated(FunctionLibraryRuntime::Handle handle,
|
||||
@ -719,6 +736,112 @@ Tensor GetResourceHandle(const string& var_name, const string& container,
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Returns a function which adds two variables on different devices.
|
||||
FunctionDef AddVarAcrossDevices() {
|
||||
return FunctionDefHelper::Create(
|
||||
// Name
|
||||
"AddVarAcrossDevices",
|
||||
// Args
|
||||
{"x: resource"},
|
||||
// Return values
|
||||
{"y: float"},
|
||||
// Attr def
|
||||
{},
|
||||
// Nodes
|
||||
{
|
||||
{{"read0"},
|
||||
"ReadVariableOp",
|
||||
{"x"},
|
||||
{{"dtype", DT_FLOAT}},
|
||||
{},
|
||||
"/device:CPU:0"},
|
||||
{{"read1"},
|
||||
"ReadVariableOp",
|
||||
{"x"},
|
||||
{{"dtype", DT_FLOAT}},
|
||||
{},
|
||||
"/device:CPU:1"},
|
||||
{{"add"},
|
||||
"Add",
|
||||
{"read0:value:0", "read1:value:0"},
|
||||
{{"T", DT_FLOAT}},
|
||||
{},
|
||||
"/device:CPU:0"},
|
||||
},
|
||||
{{"y", "add:z:0"}});
|
||||
}
|
||||
|
||||
// An implementation of FunctionArgsInterface for packed inputs.
|
||||
class TestFunctionPackedArgs : public FunctionArgsInterface {
|
||||
public:
|
||||
TestFunctionPackedArgs(const int index,
|
||||
gtl::InlinedVector<TensorValue, 4>&& tensor_args) {
|
||||
packed_args_.emplace(index, std::move(tensor_args));
|
||||
}
|
||||
|
||||
~TestFunctionPackedArgs() override{};
|
||||
|
||||
bool HasRemoteOrPackedInputs() const override { return true; };
|
||||
|
||||
Status GetLocalArg(const FunctionArgIndex& index,
|
||||
Tensor* val) const override {
|
||||
*val = *packed_args_.at(index.index).at(index.sub_index).tensor;
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
std::vector<Tensor> GetLocalTensors() const override { return {}; }
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<int, gtl::InlinedVector<TensorValue, 4>> packed_args_;
|
||||
};
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) {
|
||||
Init({AddVarAcrossDevices()});
|
||||
// Create two variables on two devices.
|
||||
const Tensor initial_resource_value0 = test::AsTensor<float>({10, 20});
|
||||
Var* resource0 = new Var(DT_FLOAT);
|
||||
*resource0->tensor() = initial_resource_value0;
|
||||
resource0->is_initialized = true;
|
||||
const Tensor initial_resource_value1 = test::AsTensor<float>({30, 40});
|
||||
Var* resource1 = new Var(DT_FLOAT);
|
||||
*resource1->tensor() = initial_resource_value1;
|
||||
resource1->is_initialized = true;
|
||||
ResourceMgr* mgr0 = device0_->resource_manager();
|
||||
ResourceMgr* mgr1 = device1_->resource_manager();
|
||||
TF_ASSERT_OK(mgr0->Create(mgr0->default_container(), "var", resource0));
|
||||
TF_ASSERT_OK(mgr1->Create(mgr1->default_container(), "var", resource1));
|
||||
|
||||
Tensor resource_handle0 =
|
||||
GetResourceHandle("var", mgr0->default_container(), device0_->name());
|
||||
Tensor resource_handle1 =
|
||||
GetResourceHandle("var", mgr1->default_container(), device1_->name());
|
||||
|
||||
// Create a CompositeDevice
|
||||
Status s;
|
||||
std::unique_ptr<CompositeDevice> composite_device =
|
||||
CompositeDevice::MakeDevice({device0_->name(), device1_->name()},
|
||||
/*unique_device_id=*/0, &s);
|
||||
TF_ASSERT_OK(s);
|
||||
AddCompositeDevice(composite_device.get());
|
||||
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
FunctionLibraryRuntime::InstantiateOptions inst_opts =
|
||||
MakeOptions("CPU:0", {"COMPOSITE:0"}, {"CPU:0"});
|
||||
inst_opts.composite_devices[composite_device->name()] =
|
||||
composite_device->underlying_devices();
|
||||
inst_opts.input_resource_dtypes_and_shapes[0] = {
|
||||
initial_resource_value0.dtype(), initial_resource_value0.shape()};
|
||||
|
||||
gtl::InlinedVector<TensorValue, 4> handles;
|
||||
handles.push_back(TensorValue(&resource_handle0));
|
||||
handles.push_back(TensorValue(&resource_handle1));
|
||||
TestFunctionPackedArgs args(0, std::move(handles));
|
||||
Tensor ret;
|
||||
TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}},
|
||||
inst_opts, args, {&ret}));
|
||||
test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
|
||||
}
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
|
||||
if (gpu_device_ == nullptr) {
|
||||
GTEST_SKIP() << "No GPUs available";
|
||||
@ -1025,9 +1148,9 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
||||
const auto x = test::AsTensor<int64>({17});
|
||||
Tensor y;
|
||||
TF_CHECK_OK(RunWithRuntime("SessionMetadataReaderFn", opts, {},
|
||||
instantiate_opts, {x}, {&y},
|
||||
cloned_proc_flr.get()));
|
||||
TF_CHECK_OK(RunWithRuntime<std::vector<Tensor>>(
|
||||
"SessionMetadataReaderFn", opts, {}, instantiate_opts, {x}, {&y},
|
||||
cloned_proc_flr.get()));
|
||||
SessionMetadata read_metadata;
|
||||
ASSERT_TRUE(protobuf::TextFormat::ParseFromString(y.scalar<tstring>()(),
|
||||
&read_metadata));
|
||||
|
@ -42,6 +42,7 @@ class ReplicateHelper {
|
||||
Node* replicated_node = graph->AddNode(node_def, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
replicated_node->set_assigned_device_name(device);
|
||||
replicated_node->AddAttr("sub_index", i);
|
||||
replicated_nodes[i] = replicated_node;
|
||||
}
|
||||
replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
|
||||
@ -180,7 +181,8 @@ Status ReplicateEdges(const ReplicateHelper& helper,
|
||||
} // namespace
|
||||
|
||||
Status ReplicatePerReplicaNodesInFunctionGraph(
|
||||
const absl::flat_hash_map<string, std::vector<string>>& composite_devices,
|
||||
const absl::flat_hash_map<string, const std::vector<string>*>&
|
||||
composite_devices,
|
||||
Graph* graph) {
|
||||
std::set<string> composite_device_names;
|
||||
for (const auto& it : composite_devices) {
|
||||
@ -198,7 +200,8 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
|
||||
}
|
||||
|
||||
for (const auto& it : composite_device_to_cluster_nodes) {
|
||||
const std::vector<string>& allowed_devices = composite_devices.at(it.first);
|
||||
const std::vector<string>& allowed_devices =
|
||||
*composite_devices.at(it.first);
|
||||
if (allowed_devices.empty()) {
|
||||
return errors::InvalidArgument("No allowed device of composite device: ",
|
||||
it.first);
|
||||
@ -208,6 +211,7 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
|
||||
// Reuse the original nodes if there is only one allowed device.
|
||||
for (Node* n : cluster_nodes) {
|
||||
n->set_assigned_device_name(allowed_devices.at(0));
|
||||
n->AddAttr("sub_index", 0);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
@ -35,7 +35,8 @@ namespace tensorflow {
|
||||
// dependency.
|
||||
// TODO(b/145922293): Register it as a POST_REWRITE_FOR_EXEC pass.
|
||||
Status ReplicatePerReplicaNodesInFunctionGraph(
|
||||
const absl::flat_hash_map<string, std::vector<string>>& composite_devices,
|
||||
const absl::flat_hash_map<string, const std::vector<string>*>&
|
||||
composite_devices,
|
||||
Graph* graph);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -75,8 +75,9 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDevice) {
|
||||
auto ret = ops::_Retval(
|
||||
scope.WithOpName("ret").WithControlDependencies({write}), read, 0);
|
||||
|
||||
const absl::flat_hash_map<string, std::vector<string>> composite_devices = {
|
||||
{"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}}};
|
||||
const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
|
||||
const absl::flat_hash_map<string, const std::vector<string>*>
|
||||
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
|
||||
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||
@ -118,8 +119,9 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDeviceToSingleDevice) {
|
||||
auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
|
||||
auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
|
||||
|
||||
const absl::flat_hash_map<string, std::vector<string>> composite_devices = {
|
||||
{"TPU_COMPOSITE:0", {"TPU:0"}}};
|
||||
const std::vector<string> underlying_devices = {"TPU:0"};
|
||||
const absl::flat_hash_map<string, const std::vector<string>*>
|
||||
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
|
||||
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||
@ -156,9 +158,11 @@ TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) {
|
||||
auto add = ops::Add(scope.WithOpName("add"), identity0, identity1);
|
||||
auto ret = ops::_Retval(scope.WithOpName("ret"), add, 0);
|
||||
|
||||
const absl::flat_hash_map<string, std::vector<string>> composite_devices = {
|
||||
{"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}},
|
||||
{"TPU_COMPOSITE:1", {"TPU:2", "TPU:3"}}};
|
||||
const std::vector<string> underlying_devices_0 = {"TPU:0", "TPU:1"};
|
||||
const std::vector<string> underlying_devices_1 = {"TPU:2", "TPU:3"};
|
||||
const absl::flat_hash_map<string, const std::vector<string>*>
|
||||
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices_0},
|
||||
{"TPU_COMPOSITE:1", &underlying_devices_1}};
|
||||
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||
@ -204,8 +208,9 @@ TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) {
|
||||
}
|
||||
|
||||
TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
|
||||
const absl::flat_hash_map<string, std::vector<string>> composite_devices = {
|
||||
{"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}}};
|
||||
const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
|
||||
const absl::flat_hash_map<string, const std::vector<string>*>
|
||||
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
|
||||
|
||||
FunctionDefLibrary fdef_lib;
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
||||
|
@ -500,11 +500,11 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
|
||||
: EagerKernelArgs(std::move(tensor_args)),
|
||||
serialize_remote_handle_(std::move(serialize_remote_handle)) {}
|
||||
|
||||
bool HasRemoteInputs() const override { return true; }
|
||||
bool HasRemoteOrPackedInputs() const override { return true; }
|
||||
|
||||
Status GetRemoteArg(const int index,
|
||||
Status GetRemoteArg(const FunctionArgIndex& index,
|
||||
eager::RemoteTensorHandle* val) const override {
|
||||
return serialize_remote_handle_(index, val);
|
||||
return serialize_remote_handle_(index.index, val);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -562,7 +562,14 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
|
||||
eager_pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
remote_device_mgr_.get(), Env::Default(), /*config=*/
|
||||
nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
|
||||
/*thread_pool=*/nullptr, eager_cluster_flr_.get());
|
||||
/*thread_pool=*/nullptr, eager_cluster_flr_.get(),
|
||||
/*custom_kernel_creator=*/nullptr, /*session_metadata=*/nullptr,
|
||||
Rendezvous::Factory{[this](const int64 step_id,
|
||||
const DeviceMgr* device_mgr,
|
||||
Rendezvous** r) {
|
||||
*r = worker_env_.rendezvous_mgr->Find(step_id);
|
||||
return Status::OK();
|
||||
}});
|
||||
}
|
||||
|
||||
void CheckOutputTensorAndClose(const Tensor& tensor) {
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
@ -525,6 +526,20 @@ class Device;
|
||||
// Forward declare. Defined in common_runtime/device_mgr.h
|
||||
class DeviceMgr;
|
||||
|
||||
// Index of an _Arg node.
|
||||
struct FunctionArgIndex {
|
||||
explicit FunctionArgIndex(const int index) : index(index) {}
|
||||
FunctionArgIndex(const int index, const int sub_index)
|
||||
: index(index), sub_index(sub_index) {}
|
||||
|
||||
// The value of the attribute "Index" of the _Arg node.
|
||||
int index;
|
||||
// Set only when the _Arg node represents multiple arguments (e.g. an _Arg
|
||||
// node is replicated to multiple devices/subgraphs). Use sub-index to
|
||||
// distinguish arguments with the same index.
|
||||
int sub_index = -1;
|
||||
};
|
||||
|
||||
class FunctionLibraryRuntime {
|
||||
public:
|
||||
virtual ~FunctionLibraryRuntime() {}
|
||||
@ -576,6 +591,10 @@ class FunctionLibraryRuntime {
|
||||
// infer correct device.
|
||||
std::vector<string> output_devices;
|
||||
|
||||
// Maps from a CompositeDevice name to a list of underlying physical
|
||||
// devices.
|
||||
absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
|
||||
|
||||
// This interface is EXPERIMENTAL and subject to change.
|
||||
//
|
||||
// For multi-device functions, a mapping from _Arg node index to type and
|
||||
|
Loading…
Reference in New Issue
Block a user