Support packed tensor inputs in ProcessFunctionLibraryRuntime.
- Expand the packed _Arg nodes when the graph is ready for graph partition. - Introduce an optional sub-index to function Arg nodes, in order to distinguish between two arguments with the same "index". It happens after replacing a packed _Arg node which is assigned to a CompositeDevice with multiple replica nodes (one per device). The "index" of an _Arg node is unique before expanding it. It's also unique within each subgraph after graph partition. PiperOrigin-RevId: 309781835 Change-Id: Ic6e351f45b7523288b5dae30997ddf0dae86660b
This commit is contained in:
parent
4158c029ef
commit
8fcb130e92
@ -631,6 +631,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
|
":composite_device",
|
||||||
":device",
|
":device",
|
||||||
":device_mgr",
|
":device_mgr",
|
||||||
":device_set",
|
":device_set",
|
||||||
@ -651,6 +652,7 @@ cc_library(
|
|||||||
":process_util",
|
":process_util",
|
||||||
":rendezvous_mgr",
|
":rendezvous_mgr",
|
||||||
":rendezvous_util",
|
":rendezvous_util",
|
||||||
|
":replicate_per_replica_nodes",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:graph",
|
"//tensorflow/core:graph",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -658,6 +660,7 @@ cc_library(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
|
@ -52,9 +52,14 @@ Status ExecuteNodeArgs::Init(
|
|||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
if (has_remote_inputs_) {
|
if (has_remote_inputs_) {
|
||||||
serialize_remote_handle_ =
|
serialize_remote_handle_ =
|
||||||
[ctx, &op_inputs](const int i,
|
[ctx, &op_inputs](const FunctionArgIndex& index,
|
||||||
eager::RemoteTensorHandle* handle) -> Status {
|
eager::RemoteTensorHandle* handle) -> Status {
|
||||||
VariantDevice variant_device = op_inputs[i]->device();
|
if (index.sub_index >= 0) {
|
||||||
|
return errors::InvalidArgument("Got unexpected sub_index ",
|
||||||
|
index.sub_index, " for argument ",
|
||||||
|
index.index);
|
||||||
|
}
|
||||||
|
VariantDevice variant_device = op_inputs[index.index]->device();
|
||||||
if (VariantDeviceIsCustom(variant_device)) {
|
if (VariantDeviceIsCustom(variant_device)) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"Custom devices and remote execution are currently not supported "
|
"Custom devices and remote execution are currently not supported "
|
||||||
@ -62,7 +67,7 @@ Status ExecuteNodeArgs::Init(
|
|||||||
}
|
}
|
||||||
Device* device = absl::get<Device*>(variant_device);
|
Device* device = absl::get<Device*>(variant_device);
|
||||||
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
|
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||||
op_inputs[i], handle, device, device->name());
|
op_inputs[index.index], handle, device, device->name());
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
|
@ -54,10 +54,12 @@ class ExecuteNodeArgs : public EagerKernelArgs {
|
|||||||
const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
||||||
const core::RefCountPtr<KernelAndDevice>& kernel);
|
const core::RefCountPtr<KernelAndDevice>& kernel);
|
||||||
|
|
||||||
bool HasRemoteInputs() const override { return has_remote_inputs_; };
|
bool HasRemoteOrPackedInputs() const override {
|
||||||
|
return has_remote_inputs_ || has_packed_inputs_;
|
||||||
|
};
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
Status GetRemoteArg(const int index,
|
Status GetRemoteArg(const FunctionArgIndex& index,
|
||||||
eager::RemoteTensorHandle* val) const override {
|
eager::RemoteTensorHandle* val) const override {
|
||||||
return serialize_remote_handle_(index, val);
|
return serialize_remote_handle_(index, val);
|
||||||
}
|
}
|
||||||
@ -65,8 +67,9 @@ class ExecuteNodeArgs : public EagerKernelArgs {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
bool has_remote_inputs_ = false;
|
bool has_remote_inputs_ = false;
|
||||||
|
bool has_packed_inputs_ = false;
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
std::function<Status(const int, eager::RemoteTensorHandle*)>
|
std::function<Status(const FunctionArgIndex&, eager::RemoteTensorHandle*)>
|
||||||
serialize_remote_handle_;
|
serialize_remote_handle_;
|
||||||
#endif // IS_MOBILE_PLATFORM
|
#endif // IS_MOBILE_PLATFORM
|
||||||
};
|
};
|
||||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/fingerprint.h"
|
#include "tensorflow/core/platform/fingerprint.h"
|
||||||
#include "tensorflow/core/profiler/lib/annotated_traceme.h"
|
#include "tensorflow/core/profiler/lib/annotated_traceme.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
@ -49,13 +50,18 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
Status EagerKernelArgs::GetLocalArg(const int index, Tensor* val) const {
|
Status EagerKernelArgs::GetLocalArg(const FunctionArgIndex& index,
|
||||||
Tensor* arg = tensor_args_.at(index).tensor;
|
Tensor* val) const {
|
||||||
|
if (index.sub_index >= 0) {
|
||||||
|
return errors::InvalidArgument("Got unexpected sub_index ", index.sub_index,
|
||||||
|
" for argument ", index.index);
|
||||||
|
}
|
||||||
|
Tensor* arg = tensor_args_.at(index.index).tensor;
|
||||||
if (arg) {
|
if (arg) {
|
||||||
*val = *arg;
|
*val = *arg;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
return errors::NotFound("Argument ", index, " has no local tensor.");
|
return errors::NotFound("Argument ", index.index, " has no local tensor.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,16 +66,16 @@ class EagerKernelArgs : public FunctionArgsInterface {
|
|||||||
|
|
||||||
~EagerKernelArgs() override{};
|
~EagerKernelArgs() override{};
|
||||||
|
|
||||||
bool HasRemoteInputs() const override { return false; };
|
bool HasRemoteOrPackedInputs() const override { return false; };
|
||||||
TensorValue* MutableInput(int i) { return &tensor_args_[i]; }
|
TensorValue* MutableInput(int i) { return &tensor_args_[i]; }
|
||||||
|
|
||||||
Status GetLocalArg(const int index, Tensor* val) const override;
|
Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override;
|
||||||
|
|
||||||
std::vector<Tensor> GetLocalTensors() const override;
|
std::vector<Tensor> GetLocalTensors() const override;
|
||||||
|
|
||||||
const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const override {
|
const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const {
|
||||||
return &tensor_args_;
|
return &tensor_args_;
|
||||||
};
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
gtl::InlinedVector<TensorValue, 4> tensor_args_;
|
gtl::InlinedVector<TensorValue, 4> tensor_args_;
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/partitioning_utils.h"
|
#include "tensorflow/core/common_runtime/partitioning_utils.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
@ -73,11 +74,11 @@ Status PartitionFunctionGraph(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status UpdateArgAndRetvalMetadata(
|
Status UpdateArgAndRetvalMetadata(
|
||||||
Graph* subgraph, const string& device_type, std::vector<int>* arg_indices,
|
Graph* subgraph, const string& device_type,
|
||||||
std::vector<int>* ret_indices,
|
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
||||||
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
||||||
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
|
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
|
||||||
std::vector<std::pair<Node*, int>> arg_nodes;
|
std::vector<std::pair<Node*, FunctionArgIndex>> arg_nodes;
|
||||||
std::vector<std::pair<Node*, int>> ret_nodes;
|
std::vector<std::pair<Node*, int>> ret_nodes;
|
||||||
const AttrValue* attr_value;
|
const AttrValue* attr_value;
|
||||||
|
|
||||||
@ -87,7 +88,11 @@ Status UpdateArgAndRetvalMetadata(
|
|||||||
if (node->IsArg()) {
|
if (node->IsArg()) {
|
||||||
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
|
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
|
||||||
int index = static_cast<int>(attr_value->i());
|
int index = static_cast<int>(attr_value->i());
|
||||||
arg_nodes.emplace_back(node, index);
|
int sub_index = -1;
|
||||||
|
if (node->attrs().Find("sub_index", &attr_value).ok()) {
|
||||||
|
sub_index = static_cast<int>(attr_value->i());
|
||||||
|
}
|
||||||
|
arg_nodes.emplace_back(node, FunctionArgIndex(index, sub_index));
|
||||||
} else if (node->IsRetval()) {
|
} else if (node->IsRetval()) {
|
||||||
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
|
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
|
||||||
int index = static_cast<int>(attr_value->i());
|
int index = static_cast<int>(attr_value->i());
|
||||||
@ -99,11 +104,16 @@ Status UpdateArgAndRetvalMetadata(
|
|||||||
//
|
//
|
||||||
// In particular, this enables calling a single-partition function with
|
// In particular, this enables calling a single-partition function with
|
||||||
// the same signature as the original unpartitioned function.
|
// the same signature as the original unpartitioned function.
|
||||||
auto comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
|
auto arg_comparator = [](std::pair<Node*, FunctionArgIndex> a,
|
||||||
|
std::pair<Node*, FunctionArgIndex> b) {
|
||||||
|
return std::tie(a.second.index, a.second.sub_index) <
|
||||||
|
std::tie(b.second.index, b.second.sub_index);
|
||||||
|
};
|
||||||
|
std::sort(arg_nodes.begin(), arg_nodes.end(), arg_comparator);
|
||||||
|
auto ret_comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
|
||||||
return a.second < b.second;
|
return a.second < b.second;
|
||||||
};
|
};
|
||||||
std::sort(arg_nodes.begin(), arg_nodes.end(), comparator);
|
std::sort(ret_nodes.begin(), ret_nodes.end(), ret_comparator);
|
||||||
std::sort(ret_nodes.begin(), ret_nodes.end(), comparator);
|
|
||||||
|
|
||||||
arg_indices->reserve(arg_nodes.size());
|
arg_indices->reserve(arg_nodes.size());
|
||||||
for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
|
for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
|
||||||
@ -144,16 +154,6 @@ Status UpdateArgAndRetvalMetadata(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
|
|
||||||
gtl::ArraySlice<Tensor> arguments) {
|
|
||||||
std::vector<Tensor> args;
|
|
||||||
args.reserve(indices.size());
|
|
||||||
for (int i : indices) {
|
|
||||||
args.push_back(arguments[i]);
|
|
||||||
}
|
|
||||||
return args;
|
|
||||||
}
|
|
||||||
|
|
||||||
string FunctionNameGenerator::GetName() {
|
string FunctionNameGenerator::GetName() {
|
||||||
while (true) {
|
while (true) {
|
||||||
const string candidate = strings::StrCat(name_, "_", counter_++);
|
const string candidate = strings::StrCat(name_, "_", counter_++);
|
||||||
|
@ -58,15 +58,11 @@ Status PartitionFunctionGraph(
|
|||||||
// (3) records which `Arg` and `Retval` nodes live in host memory in
|
// (3) records which `Arg` and `Retval` nodes live in host memory in
|
||||||
// `*_alloc_attrs`.
|
// `*_alloc_attrs`.
|
||||||
Status UpdateArgAndRetvalMetadata(
|
Status UpdateArgAndRetvalMetadata(
|
||||||
Graph* subgraph, const string& device_type, std::vector<int>* arg_indices,
|
Graph* subgraph, const string& device_type,
|
||||||
std::vector<int>* ret_indices,
|
std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
|
||||||
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
||||||
std::vector<AllocatorAttributes>* ret_alloc_attrs);
|
std::vector<AllocatorAttributes>* ret_alloc_attrs);
|
||||||
|
|
||||||
// Extracts tensors at `indices` from `arguments`.
|
|
||||||
std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
|
|
||||||
gtl::ArraySlice<Tensor> arguments);
|
|
||||||
|
|
||||||
// Utility for generating function names not present in `flib_def`, using
|
// Utility for generating function names not present in `flib_def`, using
|
||||||
// given `name` as the base for the name.
|
// given `name` as the base for the name.
|
||||||
class FunctionNameGenerator {
|
class FunctionNameGenerator {
|
||||||
|
@ -158,14 +158,23 @@ TEST_F(PartitioningUtilsTest, TwoDevices) {
|
|||||||
ASSERT_EQ(3, part2->num_op_nodes());
|
ASSERT_EQ(3, part2->num_op_nodes());
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckIndices(const std::vector<int>& expected,
|
void CheckRetIndices(const std::vector<int>& expected,
|
||||||
const std::vector<int>& actual) {
|
const std::vector<int>& actual) {
|
||||||
ASSERT_EQ(expected.size(), actual.size());
|
ASSERT_EQ(expected.size(), actual.size());
|
||||||
for (int i = 0; i < expected.size(); ++i) {
|
for (int i = 0; i < expected.size(); ++i) {
|
||||||
ASSERT_EQ(expected[i], actual[i]) << " at index " << i;
|
ASSERT_EQ(expected[i], actual[i]) << " at index " << i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CheckArgIndices(const std::vector<FunctionArgIndex>& expected,
|
||||||
|
const std::vector<FunctionArgIndex>& actual) {
|
||||||
|
ASSERT_EQ(expected.size(), actual.size());
|
||||||
|
for (int i = 0; i < expected.size(); ++i) {
|
||||||
|
ASSERT_EQ(expected[i].index, actual[i].index) << " at index " << i;
|
||||||
|
ASSERT_EQ(expected[i].sub_index, actual[i].sub_index) << " at index " << i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void CheckAlloc(const std::vector<bool>& expected,
|
void CheckAlloc(const std::vector<bool>& expected,
|
||||||
const std::vector<AllocatorAttributes>& actual) {
|
const std::vector<AllocatorAttributes>& actual) {
|
||||||
ASSERT_EQ(expected.size(), actual.size());
|
ASSERT_EQ(expected.size(), actual.size());
|
||||||
@ -185,7 +194,7 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
|
|||||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||||
SubGraph(graph.get(), DT_FLOAT, {3}, {5});
|
SubGraph(graph.get(), DT_FLOAT, {3}, {5});
|
||||||
|
|
||||||
std::vector<int> arg_indices;
|
std::vector<FunctionArgIndex> arg_indices;
|
||||||
std::vector<int> ret_indices;
|
std::vector<int> ret_indices;
|
||||||
std::vector<AllocatorAttributes> arg_alloc_attrs;
|
std::vector<AllocatorAttributes> arg_alloc_attrs;
|
||||||
std::vector<AllocatorAttributes> ret_alloc_attrs;
|
std::vector<AllocatorAttributes> ret_alloc_attrs;
|
||||||
@ -197,8 +206,8 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
|
|||||||
&ret_alloc_attrs);
|
&ret_alloc_attrs);
|
||||||
ASSERT_TRUE(status.ok()) << status.ToString();
|
ASSERT_TRUE(status.ok()) << status.ToString();
|
||||||
|
|
||||||
CheckIndices({3}, arg_indices);
|
CheckArgIndices({{3, -1}}, arg_indices);
|
||||||
CheckIndices({5}, ret_indices);
|
CheckRetIndices({5}, ret_indices);
|
||||||
CheckAlloc({false}, arg_alloc_attrs);
|
CheckAlloc({false}, arg_alloc_attrs);
|
||||||
CheckAlloc({false}, ret_alloc_attrs);
|
CheckAlloc({false}, ret_alloc_attrs);
|
||||||
|
|
||||||
@ -213,7 +222,18 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets_Order) {
|
|||||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||||
SubGraph(graph.get(), DT_FLOAT, {9, 7, 5, 3, 1}, {2, 4, 6, 8, 10});
|
SubGraph(graph.get(), DT_FLOAT, {9, 7, 5, 3, 1}, {2, 4, 6, 8, 10});
|
||||||
|
|
||||||
std::vector<int> arg_indices;
|
const std::map<int, int> sub_indices = {
|
||||||
|
{7, 2}, {3, 1}, {1, 0}, {5, 2}, {9, 0}};
|
||||||
|
const AttrValue* attr_value;
|
||||||
|
for (Node* n : graph->op_nodes()) {
|
||||||
|
if (n->IsArg()) {
|
||||||
|
TF_ASSERT_OK(n->attrs().Find("index", &attr_value));
|
||||||
|
n->AddAttr("sub_index",
|
||||||
|
sub_indices.at(static_cast<int>(attr_value->i())));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<FunctionArgIndex> arg_indices;
|
||||||
std::vector<int> ret_indices;
|
std::vector<int> ret_indices;
|
||||||
std::vector<AllocatorAttributes> arg_alloc_attrs;
|
std::vector<AllocatorAttributes> arg_alloc_attrs;
|
||||||
std::vector<AllocatorAttributes> ret_alloc_attrs;
|
std::vector<AllocatorAttributes> ret_alloc_attrs;
|
||||||
@ -225,8 +245,8 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets_Order) {
|
|||||||
&ret_alloc_attrs);
|
&ret_alloc_attrs);
|
||||||
ASSERT_TRUE(status.ok()) << status.ToString();
|
ASSERT_TRUE(status.ok()) << status.ToString();
|
||||||
|
|
||||||
CheckIndices({1, 3, 5, 7, 9}, arg_indices);
|
CheckArgIndices({{1, 0}, {3, 1}, {5, 2}, {7, 2}, {9, 0}}, arg_indices);
|
||||||
CheckIndices({2, 4, 6, 8, 10}, ret_indices);
|
CheckRetIndices({2, 4, 6, 8, 10}, ret_indices);
|
||||||
CheckAlloc({false, false, false, false, false}, arg_alloc_attrs);
|
CheckAlloc({false, false, false, false, false}, arg_alloc_attrs);
|
||||||
CheckAlloc({false, false, false, false, false}, ret_alloc_attrs);
|
CheckAlloc({false, false, false, false, false}, ret_alloc_attrs);
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "tensorflow/core/common_runtime/device_set.h"
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
@ -29,6 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/process_util.h"
|
#include "tensorflow/core/common_runtime/process_util.h"
|
||||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||||
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||||
|
#include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
@ -301,7 +303,7 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
|
|||||||
|
|
||||||
// Replace the given handle with the handle for the single component
|
// Replace the given handle with the handle for the single component
|
||||||
// function.
|
// function.
|
||||||
handle = component_data.handle_;
|
handle = component_data.handle;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto iter = function_data_.find(handle);
|
auto iter = function_data_.find(handle);
|
||||||
@ -777,6 +779,14 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
|||||||
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
|
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
|
||||||
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
|
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
|
||||||
|
|
||||||
|
// Expand the nodes assigned to a CompositeDevice before graph partition to
|
||||||
|
// avoid generating a subgraph on a virtual device for execution.
|
||||||
|
// This transformation should happen as late as possible, in order to run as
|
||||||
|
// more graph optimization passes (e.g. PRE_PLACEMENT, PLACER,
|
||||||
|
// POST_PLACEMENT, POST_REWRITE_FOR_EXEC) on a smaller graph as possible.
|
||||||
|
TF_RETURN_IF_ERROR(ReplicatePerReplicaNodesInFunctionGraph(
|
||||||
|
options.composite_devices, graph.get()));
|
||||||
|
|
||||||
if (options.graph_collector != nullptr) {
|
if (options.graph_collector != nullptr) {
|
||||||
GraphDef def;
|
GraphDef def;
|
||||||
graph->ToGraphDef(&def);
|
graph->ToGraphDef(&def);
|
||||||
@ -869,9 +879,9 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
|||||||
Graph* subgraph = pair.second.get();
|
Graph* subgraph = pair.second.get();
|
||||||
|
|
||||||
status->Update(UpdateArgAndRetvalMetadata(
|
status->Update(UpdateArgAndRetvalMetadata(
|
||||||
subgraph, device_type, &comp_data->arg_indices_,
|
subgraph, device_type, &comp_data->arg_indices,
|
||||||
&comp_data->ret_indices_, &comp_data->arg_alloc_attrs_,
|
&comp_data->ret_indices, &comp_data->arg_alloc_attrs,
|
||||||
&comp_data->ret_alloc_attrs_));
|
&comp_data->ret_alloc_attrs));
|
||||||
if (!status->ok()) {
|
if (!status->ok()) {
|
||||||
counter.DecrementCount();
|
counter.DecrementCount();
|
||||||
return;
|
return;
|
||||||
@ -913,7 +923,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
|||||||
data->is_cross_process_ = true;
|
data->is_cross_process_ = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
comp_data->handle_ = *component_handle;
|
comp_data->handle = *component_handle;
|
||||||
}
|
}
|
||||||
delete component_handle;
|
delete component_handle;
|
||||||
counter.DecrementCount();
|
counter.DecrementCount();
|
||||||
@ -955,16 +965,16 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
|||||||
|
|
||||||
for (const auto& pair : data->glue_) {
|
for (const auto& pair : data->glue_) {
|
||||||
const ComponentFunctionData& comp_data = pair.second;
|
const ComponentFunctionData& comp_data = pair.second;
|
||||||
DCHECK(comp_data.ret_alloc_attrs_.size() == comp_data.ret_indices_.size());
|
DCHECK(comp_data.ret_alloc_attrs.size() == comp_data.ret_indices.size());
|
||||||
|
|
||||||
const string& target = pair.first;
|
const string& target = pair.first;
|
||||||
FunctionLibraryRuntime* target_flr = GetFLR(target);
|
FunctionLibraryRuntime* target_flr = GetFLR(target);
|
||||||
if (target_flr == nullptr) {
|
if (target_flr == nullptr) {
|
||||||
if (!comp_data.ret_indices_.empty()) {
|
if (!comp_data.ret_indices.empty()) {
|
||||||
return errors::Unimplemented(
|
return errors::Unimplemented(
|
||||||
"Currently, outputting tensors on remote devices is not supported. "
|
"Currently, outputting tensors on remote devices is not supported. "
|
||||||
"The ",
|
"The ",
|
||||||
comp_data.ret_indices_[0],
|
comp_data.ret_indices[0],
|
||||||
"-th return value of the function outputs to target_device: ",
|
"-th return value of the function outputs to target_device: ",
|
||||||
target,
|
target,
|
||||||
" Please copy the tensor to local device explicitly using "
|
" Please copy the tensor to local device explicitly using "
|
||||||
@ -973,17 +983,17 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Device* target_device = target_flr->device();
|
Device* target_device = target_flr->device();
|
||||||
const FunctionBody* fbody = target_flr->GetFunctionBody(comp_data.handle_);
|
const FunctionBody* fbody = target_flr->GetFunctionBody(comp_data.handle);
|
||||||
DCHECK(fbody != nullptr);
|
DCHECK(fbody != nullptr);
|
||||||
|
|
||||||
output_devices->resize(data->num_outputs_);
|
output_devices->resize(data->num_outputs_);
|
||||||
for (int j = 0; j < comp_data.ret_indices_.size(); ++j) {
|
for (int j = 0; j < comp_data.ret_indices.size(); ++j) {
|
||||||
int ret_index = comp_data.ret_indices_[j];
|
int ret_index = comp_data.ret_indices[j];
|
||||||
if (fbody->ret_types[j] == DT_RESOURCE) {
|
if (fbody->ret_types[j] == DT_RESOURCE) {
|
||||||
(*output_devices)[ret_index] = target_device;
|
(*output_devices)[ret_index] = target_device;
|
||||||
} else {
|
} else {
|
||||||
(*output_devices)[ret_index] =
|
(*output_devices)[ret_index] =
|
||||||
comp_data.ret_alloc_attrs_[j].on_host() ? nullptr : target_device;
|
comp_data.ret_alloc_attrs[j].on_host() ? nullptr : target_device;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1013,9 +1023,8 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
|
|||||||
|
|
||||||
const MultiDeviceFunctionData* data = IsMultiDevice(handle);
|
const MultiDeviceFunctionData* data = IsMultiDevice(handle);
|
||||||
if (data == nullptr) {
|
if (data == nullptr) {
|
||||||
done(
|
done(errors::NotFound("Multi-device function handle ", handle,
|
||||||
errors::InvalidArgument("Failed for find multi-device function handle ",
|
"not found. Was the function instantiated?"));
|
||||||
handle, ". Was the function instantiated?"));
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1046,10 +1055,10 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
|
|||||||
for (const auto& pair : data->glue_) {
|
for (const auto& pair : data->glue_) {
|
||||||
const string& target = pair.first;
|
const string& target = pair.first;
|
||||||
const ComponentFunctionData& comp_data = pair.second;
|
const ComponentFunctionData& comp_data = pair.second;
|
||||||
FunctionLibraryRuntime::Handle handle = pair.second.handle_;
|
FunctionLibraryRuntime::Handle handle = pair.second.handle;
|
||||||
|
|
||||||
opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs_;
|
opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs;
|
||||||
opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs_;
|
opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs;
|
||||||
opts_copy.remote_execution = false;
|
opts_copy.remote_execution = false;
|
||||||
|
|
||||||
InternalArgs comp_args;
|
InternalArgs comp_args;
|
||||||
@ -1086,7 +1095,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
|
|||||||
Status(status.code(), function_and_msg));
|
Status(status.code(), function_and_msg));
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < comp_rets->size(); ++i) {
|
for (int i = 0; i < comp_rets->size(); ++i) {
|
||||||
(*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i];
|
(*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete comp_rets;
|
delete comp_rets;
|
||||||
@ -1108,7 +1117,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
|
|||||||
refcounted_done->UpdateStatus(status);
|
refcounted_done->UpdateStatus(status);
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < comp_rets->size(); ++i) {
|
for (int i = 0; i < comp_rets->size(); ++i) {
|
||||||
(*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i];
|
(*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete comp_rets;
|
delete comp_rets;
|
||||||
@ -1225,7 +1234,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
|
|||||||
Status overall_status;
|
Status overall_status;
|
||||||
for (const auto& it : mdata->glue_) {
|
for (const auto& it : mdata->glue_) {
|
||||||
const string& device = it.first;
|
const string& device = it.first;
|
||||||
FunctionLibraryRuntime::Handle flr_handle = it.second.handle_;
|
FunctionLibraryRuntime::Handle flr_handle = it.second.handle;
|
||||||
FunctionLibraryRuntime* flr = GetFLR(device);
|
FunctionLibraryRuntime* flr = GetFLR(device);
|
||||||
if (flr == nullptr) {
|
if (flr == nullptr) {
|
||||||
// TODO(nareshmodi): Implement DeregisterGraph call to remote device if
|
// TODO(nareshmodi): Implement DeregisterGraph call to remote device if
|
||||||
@ -1297,6 +1306,19 @@ ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status ProcessFunctionLibraryRuntime::CreateRendezvous(
|
||||||
|
const FunctionLibraryRuntime::Options& opts,
|
||||||
|
Rendezvous** created_rendezvous) const {
|
||||||
|
if (rendezvous_factory_) {
|
||||||
|
return rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous);
|
||||||
|
} else {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"The caller does not provide a rendezvous and "
|
||||||
|
"ProcessFunctionLibraryRuntime was created without a rendezvous "
|
||||||
|
"factory.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ProcessFunctionLibraryRuntime::Run(
|
void ProcessFunctionLibraryRuntime::Run(
|
||||||
const FunctionLibraryRuntime::Options& opts,
|
const FunctionLibraryRuntime::Options& opts,
|
||||||
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
|
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
|
||||||
@ -1305,21 +1327,12 @@ void ProcessFunctionLibraryRuntime::Run(
|
|||||||
FunctionLibraryRuntime::Options new_opts = opts;
|
FunctionLibraryRuntime::Options new_opts = opts;
|
||||||
Rendezvous* created_rendezvous = nullptr;
|
Rendezvous* created_rendezvous = nullptr;
|
||||||
if (!opts.rendezvous) {
|
if (!opts.rendezvous) {
|
||||||
if (rendezvous_factory_) {
|
Status s = CreateRendezvous(opts, &created_rendezvous);
|
||||||
Status s =
|
if (!s.ok()) {
|
||||||
rendezvous_factory_(opts.step_id, device_mgr_, &created_rendezvous);
|
done(s);
|
||||||
if (!s.ok()) {
|
|
||||||
done(s);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
new_opts.rendezvous = created_rendezvous;
|
|
||||||
} else {
|
|
||||||
done(
|
|
||||||
errors::FailedPrecondition("The caller does not provide a rendezvous "
|
|
||||||
"and ProcessFunctionLibraryRuntime was "
|
|
||||||
"created without a rendezvous factory."));
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
new_opts.rendezvous = created_rendezvous;
|
||||||
new_opts.create_rendezvous = false;
|
new_opts.create_rendezvous = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1334,9 +1347,14 @@ void ProcessFunctionLibraryRuntime::Run(
|
|||||||
if (multi_device) {
|
if (multi_device) {
|
||||||
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
|
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
|
||||||
InternalArgs* comp_args) -> Status {
|
InternalArgs* comp_args) -> Status {
|
||||||
for (const auto& tensor :
|
// "Index"s of _Arg nodes are unique when all arguments are local Tensors.
|
||||||
GetArgsForIndices(comp_data.arg_indices_, args)) {
|
for (const auto& it : comp_data.arg_indices) {
|
||||||
comp_args->args.push_back(tensor);
|
if (it.sub_index >= 0) {
|
||||||
|
return errors::InvalidArgument("Got unexpected sub_index ",
|
||||||
|
it.sub_index, " for argument ",
|
||||||
|
it.index);
|
||||||
|
}
|
||||||
|
comp_args->args.push_back(args[it.index]);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
};
|
||||||
@ -1520,11 +1538,23 @@ void ProcessFunctionLibraryRuntime::Run(
|
|||||||
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
|
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
|
||||||
std::vector<Tensor>* rets,
|
std::vector<Tensor>* rets,
|
||||||
FunctionLibraryRuntime::DoneCallback done) const {
|
FunctionLibraryRuntime::DoneCallback done) const {
|
||||||
if (!args.HasRemoteInputs()) {
|
if (!args.HasRemoteOrPackedInputs()) {
|
||||||
const std::vector<Tensor> local_inputs = args.GetLocalTensors();
|
const std::vector<Tensor> local_inputs = args.GetLocalTensors();
|
||||||
return Run(opts, handle, local_inputs, rets, std::move(done));
|
return Run(opts, handle, local_inputs, rets, std::move(done));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FunctionLibraryRuntime::Options new_opts = opts;
|
||||||
|
Rendezvous* created_rendezvous = nullptr;
|
||||||
|
if (!opts.rendezvous) {
|
||||||
|
Status s = CreateRendezvous(opts, &created_rendezvous);
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
new_opts.rendezvous = created_rendezvous;
|
||||||
|
new_opts.create_rendezvous = false;
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(IS_MOBILE_PLATFORM)
|
#if defined(IS_MOBILE_PLATFORM)
|
||||||
done(errors::Unimplemented(
|
done(errors::Unimplemented(
|
||||||
"Remote inputs are not available on mobile devices."));
|
"Remote inputs are not available on mobile devices."));
|
||||||
@ -1532,12 +1562,12 @@ void ProcessFunctionLibraryRuntime::Run(
|
|||||||
#else // !IS_MOBILE_PLATFORM
|
#else // !IS_MOBILE_PLATFORM
|
||||||
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
|
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
|
||||||
done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
|
done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
|
||||||
/*rendezvous=*/nullptr);
|
created_rendezvous);
|
||||||
|
|
||||||
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
|
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
|
||||||
InternalArgs* comp_args) -> Status {
|
InternalArgs* comp_args) -> Status {
|
||||||
for (int i = 0; i < comp_data.arg_indices_.size(); ++i) {
|
for (int i = 0; i < comp_data.arg_indices.size(); ++i) {
|
||||||
const int index = comp_data.arg_indices_.at(i);
|
const FunctionArgIndex index = comp_data.arg_indices.at(i);
|
||||||
Tensor tensor;
|
Tensor tensor;
|
||||||
if (args.GetLocalArg(index, &tensor).ok()) {
|
if (args.GetLocalArg(index, &tensor).ok()) {
|
||||||
comp_args->args.push_back(std::move(tensor));
|
comp_args->args.push_back(std::move(tensor));
|
||||||
@ -1552,7 +1582,7 @@ void ProcessFunctionLibraryRuntime::Run(
|
|||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
};
|
||||||
return RunMultiDevice(opts, handle, rets, cleanup_items, std::move(done),
|
return RunMultiDevice(new_opts, handle, rets, cleanup_items, std::move(done),
|
||||||
std::move(get_component_args));
|
std::move(get_component_args));
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "absl/types/variant.h"
|
#include "absl/types/variant.h"
|
||||||
|
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
#include "tensorflow/core/common_runtime/device_set.h"
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
@ -40,16 +41,15 @@ class FunctionArgsInterface {
|
|||||||
public:
|
public:
|
||||||
virtual ~FunctionArgsInterface() {}
|
virtual ~FunctionArgsInterface() {}
|
||||||
|
|
||||||
virtual bool HasRemoteInputs() const = 0;
|
virtual bool HasRemoteOrPackedInputs() const = 0;
|
||||||
|
|
||||||
virtual Status GetLocalArg(const int index, Tensor* val) const = 0;
|
virtual Status GetLocalArg(const FunctionArgIndex& index,
|
||||||
|
Tensor* val) const = 0;
|
||||||
|
|
||||||
virtual std::vector<Tensor> GetLocalTensors() const = 0;
|
virtual std::vector<Tensor> GetLocalTensors() const = 0;
|
||||||
|
|
||||||
virtual const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const = 0;
|
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
virtual Status GetRemoteArg(const int index,
|
virtual Status GetRemoteArg(const FunctionArgIndex& index,
|
||||||
eager::RemoteTensorHandle* val) const {
|
eager::RemoteTensorHandle* val) const {
|
||||||
return errors::Unimplemented(
|
return errors::Unimplemented(
|
||||||
"Serializing a remote argument is not implemented.");
|
"Serializing a remote argument is not implemented.");
|
||||||
@ -217,6 +217,12 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
return lib_def_;
|
return lib_def_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add a CompositeDevice to `device_set_`
|
||||||
|
void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
device_set_->AddDevice(d);
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
friend class FunctionLibraryRuntimeImpl;
|
friend class FunctionLibraryRuntimeImpl;
|
||||||
|
|
||||||
@ -232,21 +238,21 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
// piece of a multi-device function) fits into the multi-device function.
|
// piece of a multi-device function) fits into the multi-device function.
|
||||||
struct ComponentFunctionData {
|
struct ComponentFunctionData {
|
||||||
// The handle for the instantiated component function.
|
// The handle for the instantiated component function.
|
||||||
FunctionLibraryRuntime::Handle handle_;
|
FunctionLibraryRuntime::Handle handle;
|
||||||
// arg_indices_.size() is the number of arguments to the component function.
|
// arg_indices.size() is the number of arguments to the component function.
|
||||||
// The i-th argument of the component function comes from the
|
// The i-th argument of the component function comes from the
|
||||||
// `arg_indices_[i]`-th argument of the multi-device function.
|
// `arg_indices[i]`-th argument of the multi-device function.
|
||||||
std::vector<int> arg_indices_;
|
std::vector<FunctionArgIndex> arg_indices;
|
||||||
// ret_indices_.size() is the number of return values of the component
|
// ret_indices.size() is the number of return values of the component
|
||||||
// function. The i-th return value of the component function goes to the
|
// function. The i-th return value of the component function goes to the
|
||||||
// `ret_indices_[i]`-th return value of the multi-device function.
|
// `ret_indices[i]`-th return value of the multi-device function.
|
||||||
std::vector<int> ret_indices_;
|
std::vector<int> ret_indices;
|
||||||
// arg_alloc_attrs_[i] are the allocator attributes of the i-th argument to
|
// arg_alloc_attrs[i] are the allocator attributes of the i-th argument to
|
||||||
// the component function.
|
// the component function.
|
||||||
std::vector<AllocatorAttributes> arg_alloc_attrs_;
|
std::vector<AllocatorAttributes> arg_alloc_attrs;
|
||||||
// ret_alloc_attrs_[i] are the allocator attributes of the i-th return value
|
// ret_alloc_attrs[i] are the allocator attributes of the i-th return value
|
||||||
// of the component function.
|
// of the component function.
|
||||||
std::vector<AllocatorAttributes> ret_alloc_attrs_;
|
std::vector<AllocatorAttributes> ret_alloc_attrs;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Data structure holding information for a single instantiated multi-device
|
// Data structure holding information for a single instantiated multi-device
|
||||||
@ -304,6 +310,9 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
InternalArgs* args)>
|
InternalArgs* args)>
|
||||||
get_component_args) const;
|
get_component_args) const;
|
||||||
|
|
||||||
|
Status CreateRendezvous(const FunctionLibraryRuntime::Options& opts,
|
||||||
|
Rendezvous** created_rendezvous) const;
|
||||||
|
|
||||||
FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback(
|
FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback(
|
||||||
std::vector<std::unique_ptr<CleanUpItem>>* items,
|
std::vector<std::unique_ptr<CleanUpItem>>* items,
|
||||||
FunctionLibraryRuntime::DoneCallback done, const int64 step_id,
|
FunctionLibraryRuntime::DoneCallback done, const int64 step_id,
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/function_testlib.h"
|
#include "tensorflow/core/common_runtime/function_testlib.h"
|
||||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||||
@ -143,6 +144,10 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
}}));
|
}}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AddCompositeDevice(CompositeDevice* d) {
|
||||||
|
proc_flr_->AddCompositeDevice(d);
|
||||||
|
}
|
||||||
|
|
||||||
Status Instantiate(
|
Status Instantiate(
|
||||||
const string& name, test::function::Attrs attrs,
|
const string& name, test::function::Attrs attrs,
|
||||||
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
|
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
|
||||||
@ -187,11 +192,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
Status RunWithRuntime(
|
Status RunWithRuntime(
|
||||||
const string& name, FunctionLibraryRuntime::Options opts,
|
const string& name, FunctionLibraryRuntime::Options opts,
|
||||||
test::function::Attrs attrs,
|
test::function::Attrs attrs,
|
||||||
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
|
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
|
||||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
|
const T& args, std::vector<Tensor*> rets,
|
||||||
ProcessFunctionLibraryRuntime* pflr) {
|
ProcessFunctionLibraryRuntime* pflr) {
|
||||||
FunctionLibraryRuntime::Handle handle;
|
FunctionLibraryRuntime::Handle handle;
|
||||||
Status status = pflr->Instantiate(name, attrs, instantiate_opts, &handle);
|
Status status = pflr->Instantiate(name, attrs, instantiate_opts, &handle);
|
||||||
@ -248,9 +254,20 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
Status Run(const string& name, FunctionLibraryRuntime::Options opts,
|
Status Run(const string& name, FunctionLibraryRuntime::Options opts,
|
||||||
test::function::Attrs attrs,
|
test::function::Attrs attrs,
|
||||||
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
|
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
|
||||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
|
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
|
||||||
return RunWithRuntime(name, opts, attrs, instantiate_opts, args, rets,
|
ProcessFunctionLibraryRuntime* pflr = nullptr) {
|
||||||
proc_flr_.get());
|
return RunWithRuntime<std::vector<Tensor>>(
|
||||||
|
name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RunWithPackedArgs(
|
||||||
|
const string& name, FunctionLibraryRuntime::Options opts,
|
||||||
|
test::function::Attrs attrs,
|
||||||
|
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
|
||||||
|
const FunctionArgsInterface& args, std::vector<Tensor*> rets,
|
||||||
|
ProcessFunctionLibraryRuntime* pflr = nullptr) {
|
||||||
|
return RunWithRuntime<FunctionArgsInterface>(
|
||||||
|
name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RunInstantiated(FunctionLibraryRuntime::Handle handle,
|
Status RunInstantiated(FunctionLibraryRuntime::Handle handle,
|
||||||
@ -719,6 +736,112 @@ Tensor GetResourceHandle(const string& var_name, const string& container,
|
|||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns a function which adds two variables on different devices.
|
||||||
|
FunctionDef AddVarAcrossDevices() {
|
||||||
|
return FunctionDefHelper::Create(
|
||||||
|
// Name
|
||||||
|
"AddVarAcrossDevices",
|
||||||
|
// Args
|
||||||
|
{"x: resource"},
|
||||||
|
// Return values
|
||||||
|
{"y: float"},
|
||||||
|
// Attr def
|
||||||
|
{},
|
||||||
|
// Nodes
|
||||||
|
{
|
||||||
|
{{"read0"},
|
||||||
|
"ReadVariableOp",
|
||||||
|
{"x"},
|
||||||
|
{{"dtype", DT_FLOAT}},
|
||||||
|
{},
|
||||||
|
"/device:CPU:0"},
|
||||||
|
{{"read1"},
|
||||||
|
"ReadVariableOp",
|
||||||
|
{"x"},
|
||||||
|
{{"dtype", DT_FLOAT}},
|
||||||
|
{},
|
||||||
|
"/device:CPU:1"},
|
||||||
|
{{"add"},
|
||||||
|
"Add",
|
||||||
|
{"read0:value:0", "read1:value:0"},
|
||||||
|
{{"T", DT_FLOAT}},
|
||||||
|
{},
|
||||||
|
"/device:CPU:0"},
|
||||||
|
},
|
||||||
|
{{"y", "add:z:0"}});
|
||||||
|
}
|
||||||
|
|
||||||
|
// An implementation of FunctionArgsInterface for packed inputs.
|
||||||
|
class TestFunctionPackedArgs : public FunctionArgsInterface {
|
||||||
|
public:
|
||||||
|
TestFunctionPackedArgs(const int index,
|
||||||
|
gtl::InlinedVector<TensorValue, 4>&& tensor_args) {
|
||||||
|
packed_args_.emplace(index, std::move(tensor_args));
|
||||||
|
}
|
||||||
|
|
||||||
|
~TestFunctionPackedArgs() override{};
|
||||||
|
|
||||||
|
bool HasRemoteOrPackedInputs() const override { return true; };
|
||||||
|
|
||||||
|
Status GetLocalArg(const FunctionArgIndex& index,
|
||||||
|
Tensor* val) const override {
|
||||||
|
*val = *packed_args_.at(index.index).at(index.sub_index).tensor;
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<Tensor> GetLocalTensors() const override { return {}; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
absl::flat_hash_map<int, gtl::InlinedVector<TensorValue, 4>> packed_args_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) {
|
||||||
|
Init({AddVarAcrossDevices()});
|
||||||
|
// Create two variables on two devices.
|
||||||
|
const Tensor initial_resource_value0 = test::AsTensor<float>({10, 20});
|
||||||
|
Var* resource0 = new Var(DT_FLOAT);
|
||||||
|
*resource0->tensor() = initial_resource_value0;
|
||||||
|
resource0->is_initialized = true;
|
||||||
|
const Tensor initial_resource_value1 = test::AsTensor<float>({30, 40});
|
||||||
|
Var* resource1 = new Var(DT_FLOAT);
|
||||||
|
*resource1->tensor() = initial_resource_value1;
|
||||||
|
resource1->is_initialized = true;
|
||||||
|
ResourceMgr* mgr0 = device0_->resource_manager();
|
||||||
|
ResourceMgr* mgr1 = device1_->resource_manager();
|
||||||
|
TF_ASSERT_OK(mgr0->Create(mgr0->default_container(), "var", resource0));
|
||||||
|
TF_ASSERT_OK(mgr1->Create(mgr1->default_container(), "var", resource1));
|
||||||
|
|
||||||
|
Tensor resource_handle0 =
|
||||||
|
GetResourceHandle("var", mgr0->default_container(), device0_->name());
|
||||||
|
Tensor resource_handle1 =
|
||||||
|
GetResourceHandle("var", mgr1->default_container(), device1_->name());
|
||||||
|
|
||||||
|
// Create a CompositeDevice
|
||||||
|
Status s;
|
||||||
|
std::unique_ptr<CompositeDevice> composite_device =
|
||||||
|
CompositeDevice::MakeDevice({device0_->name(), device1_->name()},
|
||||||
|
/*unique_device_id=*/0, &s);
|
||||||
|
TF_ASSERT_OK(s);
|
||||||
|
AddCompositeDevice(composite_device.get());
|
||||||
|
|
||||||
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
FunctionLibraryRuntime::InstantiateOptions inst_opts =
|
||||||
|
MakeOptions("CPU:0", {"COMPOSITE:0"}, {"CPU:0"});
|
||||||
|
inst_opts.composite_devices[composite_device->name()] =
|
||||||
|
composite_device->underlying_devices();
|
||||||
|
inst_opts.input_resource_dtypes_and_shapes[0] = {
|
||||||
|
initial_resource_value0.dtype(), initial_resource_value0.shape()};
|
||||||
|
|
||||||
|
gtl::InlinedVector<TensorValue, 4> handles;
|
||||||
|
handles.push_back(TensorValue(&resource_handle0));
|
||||||
|
handles.push_back(TensorValue(&resource_handle1));
|
||||||
|
TestFunctionPackedArgs args(0, std::move(handles));
|
||||||
|
Tensor ret;
|
||||||
|
TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}},
|
||||||
|
inst_opts, args, {&ret}));
|
||||||
|
test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
|
||||||
if (gpu_device_ == nullptr) {
|
if (gpu_device_ == nullptr) {
|
||||||
GTEST_SKIP() << "No GPUs available";
|
GTEST_SKIP() << "No GPUs available";
|
||||||
@ -1025,9 +1148,9 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
|
|||||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
||||||
const auto x = test::AsTensor<int64>({17});
|
const auto x = test::AsTensor<int64>({17});
|
||||||
Tensor y;
|
Tensor y;
|
||||||
TF_CHECK_OK(RunWithRuntime("SessionMetadataReaderFn", opts, {},
|
TF_CHECK_OK(RunWithRuntime<std::vector<Tensor>>(
|
||||||
instantiate_opts, {x}, {&y},
|
"SessionMetadataReaderFn", opts, {}, instantiate_opts, {x}, {&y},
|
||||||
cloned_proc_flr.get()));
|
cloned_proc_flr.get()));
|
||||||
SessionMetadata read_metadata;
|
SessionMetadata read_metadata;
|
||||||
ASSERT_TRUE(protobuf::TextFormat::ParseFromString(y.scalar<tstring>()(),
|
ASSERT_TRUE(protobuf::TextFormat::ParseFromString(y.scalar<tstring>()(),
|
||||||
&read_metadata));
|
&read_metadata));
|
||||||
|
@ -42,6 +42,7 @@ class ReplicateHelper {
|
|||||||
Node* replicated_node = graph->AddNode(node_def, &status);
|
Node* replicated_node = graph->AddNode(node_def, &status);
|
||||||
TF_RETURN_IF_ERROR(status);
|
TF_RETURN_IF_ERROR(status);
|
||||||
replicated_node->set_assigned_device_name(device);
|
replicated_node->set_assigned_device_name(device);
|
||||||
|
replicated_node->AddAttr("sub_index", i);
|
||||||
replicated_nodes[i] = replicated_node;
|
replicated_nodes[i] = replicated_node;
|
||||||
}
|
}
|
||||||
replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
|
replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
|
||||||
@ -180,7 +181,8 @@ Status ReplicateEdges(const ReplicateHelper& helper,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status ReplicatePerReplicaNodesInFunctionGraph(
|
Status ReplicatePerReplicaNodesInFunctionGraph(
|
||||||
const absl::flat_hash_map<string, std::vector<string>>& composite_devices,
|
const absl::flat_hash_map<string, const std::vector<string>*>&
|
||||||
|
composite_devices,
|
||||||
Graph* graph) {
|
Graph* graph) {
|
||||||
std::set<string> composite_device_names;
|
std::set<string> composite_device_names;
|
||||||
for (const auto& it : composite_devices) {
|
for (const auto& it : composite_devices) {
|
||||||
@ -198,7 +200,8 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (const auto& it : composite_device_to_cluster_nodes) {
|
for (const auto& it : composite_device_to_cluster_nodes) {
|
||||||
const std::vector<string>& allowed_devices = composite_devices.at(it.first);
|
const std::vector<string>& allowed_devices =
|
||||||
|
*composite_devices.at(it.first);
|
||||||
if (allowed_devices.empty()) {
|
if (allowed_devices.empty()) {
|
||||||
return errors::InvalidArgument("No allowed device of composite device: ",
|
return errors::InvalidArgument("No allowed device of composite device: ",
|
||||||
it.first);
|
it.first);
|
||||||
@ -208,6 +211,7 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
|
|||||||
// Reuse the original nodes if there is only one allowed device.
|
// Reuse the original nodes if there is only one allowed device.
|
||||||
for (Node* n : cluster_nodes) {
|
for (Node* n : cluster_nodes) {
|
||||||
n->set_assigned_device_name(allowed_devices.at(0));
|
n->set_assigned_device_name(allowed_devices.at(0));
|
||||||
|
n->AddAttr("sub_index", 0);
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -35,7 +35,8 @@ namespace tensorflow {
|
|||||||
// dependency.
|
// dependency.
|
||||||
// TODO(b/145922293): Register it as a POST_REWRITE_FOR_EXEC pass.
|
// TODO(b/145922293): Register it as a POST_REWRITE_FOR_EXEC pass.
|
||||||
Status ReplicatePerReplicaNodesInFunctionGraph(
|
Status ReplicatePerReplicaNodesInFunctionGraph(
|
||||||
const absl::flat_hash_map<string, std::vector<string>>& composite_devices,
|
const absl::flat_hash_map<string, const std::vector<string>*>&
|
||||||
|
composite_devices,
|
||||||
Graph* graph);
|
Graph* graph);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -75,8 +75,9 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDevice) {
|
|||||||
auto ret = ops::_Retval(
|
auto ret = ops::_Retval(
|
||||||
scope.WithOpName("ret").WithControlDependencies({write}), read, 0);
|
scope.WithOpName("ret").WithControlDependencies({write}), read, 0);
|
||||||
|
|
||||||
const absl::flat_hash_map<string, std::vector<string>> composite_devices = {
|
const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
|
||||||
{"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}}};
|
const absl::flat_hash_map<string, const std::vector<string>*>
|
||||||
|
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
|
||||||
|
|
||||||
Graph graph(OpRegistry::Global());
|
Graph graph(OpRegistry::Global());
|
||||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||||
@ -118,8 +119,9 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDeviceToSingleDevice) {
|
|||||||
auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
|
auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
|
||||||
auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
|
auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
|
||||||
|
|
||||||
const absl::flat_hash_map<string, std::vector<string>> composite_devices = {
|
const std::vector<string> underlying_devices = {"TPU:0"};
|
||||||
{"TPU_COMPOSITE:0", {"TPU:0"}}};
|
const absl::flat_hash_map<string, const std::vector<string>*>
|
||||||
|
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
|
||||||
|
|
||||||
Graph graph(OpRegistry::Global());
|
Graph graph(OpRegistry::Global());
|
||||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||||
@ -156,9 +158,11 @@ TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) {
|
|||||||
auto add = ops::Add(scope.WithOpName("add"), identity0, identity1);
|
auto add = ops::Add(scope.WithOpName("add"), identity0, identity1);
|
||||||
auto ret = ops::_Retval(scope.WithOpName("ret"), add, 0);
|
auto ret = ops::_Retval(scope.WithOpName("ret"), add, 0);
|
||||||
|
|
||||||
const absl::flat_hash_map<string, std::vector<string>> composite_devices = {
|
const std::vector<string> underlying_devices_0 = {"TPU:0", "TPU:1"};
|
||||||
{"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}},
|
const std::vector<string> underlying_devices_1 = {"TPU:2", "TPU:3"};
|
||||||
{"TPU_COMPOSITE:1", {"TPU:2", "TPU:3"}}};
|
const absl::flat_hash_map<string, const std::vector<string>*>
|
||||||
|
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices_0},
|
||||||
|
{"TPU_COMPOSITE:1", &underlying_devices_1}};
|
||||||
|
|
||||||
Graph graph(OpRegistry::Global());
|
Graph graph(OpRegistry::Global());
|
||||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||||
@ -204,8 +208,9 @@ TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
|
TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
|
||||||
const absl::flat_hash_map<string, std::vector<string>> composite_devices = {
|
const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
|
||||||
{"TPU_COMPOSITE:0", {"TPU:0", "TPU:1"}}};
|
const absl::flat_hash_map<string, const std::vector<string>*>
|
||||||
|
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
|
||||||
|
|
||||||
FunctionDefLibrary fdef_lib;
|
FunctionDefLibrary fdef_lib;
|
||||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
||||||
|
@ -500,11 +500,11 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
|
|||||||
: EagerKernelArgs(std::move(tensor_args)),
|
: EagerKernelArgs(std::move(tensor_args)),
|
||||||
serialize_remote_handle_(std::move(serialize_remote_handle)) {}
|
serialize_remote_handle_(std::move(serialize_remote_handle)) {}
|
||||||
|
|
||||||
bool HasRemoteInputs() const override { return true; }
|
bool HasRemoteOrPackedInputs() const override { return true; }
|
||||||
|
|
||||||
Status GetRemoteArg(const int index,
|
Status GetRemoteArg(const FunctionArgIndex& index,
|
||||||
eager::RemoteTensorHandle* val) const override {
|
eager::RemoteTensorHandle* val) const override {
|
||||||
return serialize_remote_handle_(index, val);
|
return serialize_remote_handle_(index.index, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -562,7 +562,14 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
|
|||||||
eager_pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
eager_pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||||
remote_device_mgr_.get(), Env::Default(), /*config=*/
|
remote_device_mgr_.get(), Env::Default(), /*config=*/
|
||||||
nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
|
nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
|
||||||
/*thread_pool=*/nullptr, eager_cluster_flr_.get());
|
/*thread_pool=*/nullptr, eager_cluster_flr_.get(),
|
||||||
|
/*custom_kernel_creator=*/nullptr, /*session_metadata=*/nullptr,
|
||||||
|
Rendezvous::Factory{[this](const int64 step_id,
|
||||||
|
const DeviceMgr* device_mgr,
|
||||||
|
Rendezvous** r) {
|
||||||
|
*r = worker_env_.rendezvous_mgr->Find(step_id);
|
||||||
|
return Status::OK();
|
||||||
|
}});
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckOutputTensorAndClose(const Tensor& tensor) {
|
void CheckOutputTensorAndClose(const Tensor& tensor) {
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/platform.h"
|
#include "tensorflow/core/platform/platform.h"
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "absl/types/variant.h"
|
#include "absl/types/variant.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
@ -525,6 +526,20 @@ class Device;
|
|||||||
// Forward declare. Defined in common_runtime/device_mgr.h
|
// Forward declare. Defined in common_runtime/device_mgr.h
|
||||||
class DeviceMgr;
|
class DeviceMgr;
|
||||||
|
|
||||||
|
// Index of an _Arg node.
|
||||||
|
struct FunctionArgIndex {
|
||||||
|
explicit FunctionArgIndex(const int index) : index(index) {}
|
||||||
|
FunctionArgIndex(const int index, const int sub_index)
|
||||||
|
: index(index), sub_index(sub_index) {}
|
||||||
|
|
||||||
|
// The value of the attribute "Index" of the _Arg node.
|
||||||
|
int index;
|
||||||
|
// Set only when the _Arg node represents multiple arguments (e.g. an _Arg
|
||||||
|
// node is replicated to multiple devices/subgraphs). Use sub-index to
|
||||||
|
// distinguish arguments with the same index.
|
||||||
|
int sub_index = -1;
|
||||||
|
};
|
||||||
|
|
||||||
class FunctionLibraryRuntime {
|
class FunctionLibraryRuntime {
|
||||||
public:
|
public:
|
||||||
virtual ~FunctionLibraryRuntime() {}
|
virtual ~FunctionLibraryRuntime() {}
|
||||||
@ -576,6 +591,10 @@ class FunctionLibraryRuntime {
|
|||||||
// infer correct device.
|
// infer correct device.
|
||||||
std::vector<string> output_devices;
|
std::vector<string> output_devices;
|
||||||
|
|
||||||
|
// Maps from a CompositeDevice name to a list of underlying physical
|
||||||
|
// devices.
|
||||||
|
absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
|
||||||
|
|
||||||
// This interface is EXPERIMENTAL and subject to change.
|
// This interface is EXPERIMENTAL and subject to change.
|
||||||
//
|
//
|
||||||
// For multi-device functions, a mapping from _Arg node index to type and
|
// For multi-device functions, a mapping from _Arg node index to type and
|
||||||
|
Loading…
Reference in New Issue
Block a user