[FunctionLibraryRuntime] Optimize single-component "multi-device" function dispatch.

This change enables the (single-device) `FunctionLibraryRuntimeImpl` to dispatch a multi-device function directly, if (i) it has a single component, and (ii) that component is local to the `FunctionLibraryRuntimeImpl` instance. This avoids the (microsecond-scale) overhead of preparing and remapping the inputs to and arguments from a multi-device function, which is important for clients (like tf.data) that invoke many fine-grained functions.

PiperOrigin-RevId: 307090208
Change-Id: I21820aa5b84360c2595b49c2e22d7a2b037819c4
This commit is contained in:
Derek Murray 2020-04-17 12:14:49 -07:00 committed by TensorFlower Gardener
parent c301f1e9b3
commit e5c6881c77
7 changed files with 96 additions and 20 deletions

View File

@ -1203,7 +1203,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
};
}
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
LocalHandle local_handle = parent_->GetHandleOnDevice(
device_name_, handle, /*include_multi_device=*/true);
if (local_handle == kInvalidLocalHandle) {
parent_->Run(run_opts, handle, frame, done);
return;

View File

@ -394,6 +394,19 @@ TEST_F(FunctionLibraryRuntimeTest, XTimesTwo) {
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
}
TEST_F(FunctionLibraryRuntimeTest, XTimesTwo_MultiDeviceBacked) {
Init({test::function::XTimesTwo()});
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
FunctionLibraryRuntime::InstantiateOptions options;
options.is_multi_device_function = true;
TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
{x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
}
TEST_F(FunctionLibraryRuntimeTest, XTimesN) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/partitioning_utils.h"
#include <algorithm>
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
@ -82,20 +84,32 @@ Status UpdateArgAndRetvalMetadata(
// Find the Arg and Retval nodes, along with their corresponding indices
// in the original function.
for (Node* node : subgraph->op_nodes()) {
string node_type = node->type_string();
if (node->IsArg()) {
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
int index = static_cast<int>(attr_value->i());
arg_indices->push_back(index);
arg_nodes.push_back(std::make_pair(node, index));
arg_nodes.emplace_back(node, index);
} else if (node->IsRetval()) {
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
int index = static_cast<int>(attr_value->i());
ret_indices->push_back(index);
ret_nodes.push_back(std::make_pair(node, index));
ret_nodes.emplace_back(node, index);
}
}
// Sort the nodes by index so that the order is stable.
//
// In particular, this enables calling a single-partition function with
// the same signature as the original unpartitioned function.
auto comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
return a.second < b.second;
};
std::sort(arg_nodes.begin(), arg_nodes.end(), comparator);
std::sort(ret_nodes.begin(), ret_nodes.end(), comparator);
arg_indices->reserve(arg_nodes.size());
for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
ret_indices->reserve(ret_nodes.size());
for (const auto& pair : ret_nodes) ret_indices->push_back(pair.second);
for (int i = 0; i < arg_nodes.size(); ++i) {
Node* arg = arg_nodes[i].first;
arg->AddAttr("index", i);

View File

@ -41,17 +41,18 @@ Status PartitionFunctionGraph(
//
// More specifically, this function
// (1) rewrites the indices of the `Arg` and `Retval` nodes placed
// on a particular device. When a function is partitioned each
// partition, `subgraph`, get a subset of the arguments and
// on a particular device. When a function is partitioned, each
// partition `subgraph` gets a subset of the arguments and
// return values. The `index` attributes of these _Arg and _Retval
// nodes reflect the indices of these parameters in the original
// function. To convert `subgraph` to a function, we need to replace
// there original indices with 0, 1, 2, ... .
//
// The argument and return value order in the partitioned function is
// determined by the node iteration order in `subgraph`. This order
// is also used in UpdateArgAndRetvalMetadata. This is fine because the
// node iteration order is deterministic - it follows the node ids.
// determined by the argument and return value order in the original
// function. This stability is important because it enables us to treat
// a single-partition function as having the same signature as the
// subgraph.
// (2) records the subsets of `Arg` and `Retval` nodes assigned to the
// device in `*_indices`, and
// (3) records which `Arg` and `Retval` nodes live in host memory in

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
@ -91,12 +92,18 @@ class PartitioningUtilsTest : public ::testing::Test {
// Fills subgraph with an identify function arg->identity->ret
// where each node has type `dtype` and arg/ret nodes have
// indices `arg_index` and `ret_index`.
void SubGraph(Graph* subgraph, DataType dtype, int arg_index, int ret_index) {
void SubGraph(Graph* subgraph, DataType dtype,
gtl::ArraySlice<int> arg_indices,
gtl::ArraySlice<int> ret_indices) {
Scope s = Scope::NewRootScope();
Scope s1 = s.WithDevice("/job:a/replica:0/task:0/device:CPU:0");
auto x = ops::_Arg(s1.WithOpName("x"), dtype, arg_index);
auto id_x = ops::Identity(s1.WithOpName("id_x"), x);
auto dx_retval = ops::_Retval(s1.WithOpName("retval1"), id_x, ret_index);
CHECK_EQ(arg_indices.size(), ret_indices.size());
for (size_t i = 0; i < arg_indices.size(); ++i) {
auto x = ops::_Arg(s1.WithOpName("x"), dtype, arg_indices[i]);
auto id_x = ops::Identity(s1.WithOpName("id_x"), x);
auto dx_retval =
ops::_Retval(s1.WithOpName("retval1"), id_x, ret_indices[i]);
}
TF_ASSERT_OK(s.ToGraph(subgraph));
Placer placer(subgraph, "", &device_set_, device0_);
TF_ASSERT_OK(placer.Run());
@ -175,8 +182,8 @@ void CheckIndex(const Node& node, int expected_index) {
}
TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
SubGraph(graph.get(), DT_FLOAT, 3, 5);
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
SubGraph(graph.get(), DT_FLOAT, {3}, {5});
std::vector<int> arg_indices;
std::vector<int> ret_indices;
@ -202,5 +209,27 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
CheckIndex(*nodes["retval1"], 0);
}
TEST_F(PartitioningUtilsTest, UpdateArgsAndRets_Order) {
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
SubGraph(graph.get(), DT_FLOAT, {9, 7, 5, 3, 1}, {2, 4, 6, 8, 10});
std::vector<int> arg_indices;
std::vector<int> ret_indices;
std::vector<AllocatorAttributes> arg_alloc_attrs;
std::vector<AllocatorAttributes> ret_alloc_attrs;
string device_type = "CPU";
Status status = UpdateArgAndRetvalMetadata(
graph.get(), device_type, &arg_indices, &ret_indices, &arg_alloc_attrs,
&ret_alloc_attrs);
ASSERT_TRUE(status.ok()) << status.ToString();
CheckIndices({1, 3, 5, 7, 9}, arg_indices);
CheckIndices({2, 4, 6, 8, 10}, ret_indices);
CheckAlloc({false, false, false, false, false}, arg_alloc_attrs);
CheckAlloc({false, false, false, false, false}, ret_alloc_attrs);
}
} // anonymous namespace
} // namespace tensorflow

View File

@ -282,12 +282,25 @@ bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
FunctionLibraryRuntime::LocalHandle
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) const {
const string& device_name, FunctionLibraryRuntime::Handle handle,
bool include_multi_device) const {
tf_shared_lock l(mu_);
auto miter = mdevice_data_.find(handle);
if (miter != mdevice_data_.end()) {
return kInvalidLocalHandle;
if (!include_multi_device) return kInvalidLocalHandle;
const MultiDeviceFunctionData& data = *miter->second;
if (data.glue_.size() != 1) return kInvalidLocalHandle;
const auto& pair = *data.glue_.begin();
const string& func_device_name = pair.first;
const ComponentFunctionData& component_data = pair.second;
if (func_device_name != device_name) return kInvalidLocalHandle;
// Replace the given handle with the handle for the single component
// function.
handle = component_data.handle_;
}
auto iter = function_data_.find(handle);

View File

@ -137,8 +137,13 @@ class ProcessFunctionLibraryRuntime {
// index of instantiation of that function. If the function was not
// instantiated on `device_name` or the function is multi-device,
// returns kInvalidLocalHandle.
//
// If `include_multi_device` is true and `handle` is a multi-device function
// with a single component that is placed on `device_name`, then this method
// will return the local handle for that component.
FunctionLibraryRuntime::LocalHandle GetHandleOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) const;
const string& device_name, FunctionLibraryRuntime::Handle handle,
bool include_multi_device = false) const;
// Fills `output_devices` with the devices on which the results will
// be produced. If some output is produced on CPU, the corresponding Device*