[FunctionLibraryRuntime] Optimize single-component "multi-device" function dispatch.
This change enables the (single-device) `FunctionLibraryRuntimeImpl` to dispatch a multi-device function directly, if (i) it has a single component, and (ii) that component is local to the `FunctionLibraryRuntimeImpl` instance. This avoids the (microsecond-scale) overhead of preparing and remapping the inputs to and arguments from a multi-device function, which is important for clients (like tf.data) that invoke many fine-grained functions. PiperOrigin-RevId: 307090208 Change-Id: I21820aa5b84360c2595b49c2e22d7a2b037819c4
This commit is contained in:
parent
c301f1e9b3
commit
e5c6881c77
@ -1203,7 +1203,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
};
|
||||
}
|
||||
|
||||
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
|
||||
LocalHandle local_handle = parent_->GetHandleOnDevice(
|
||||
device_name_, handle, /*include_multi_device=*/true);
|
||||
if (local_handle == kInvalidLocalHandle) {
|
||||
parent_->Run(run_opts, handle, frame, done);
|
||||
return;
|
||||
|
@ -394,6 +394,19 @@ TEST_F(FunctionLibraryRuntimeTest, XTimesTwo) {
|
||||
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, XTimesTwo_MultiDeviceBacked) {
|
||||
Init({test::function::XTimesTwo()});
|
||||
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
||||
Tensor y;
|
||||
|
||||
FunctionLibraryRuntime::InstantiateOptions options;
|
||||
options.is_multi_device_function = true;
|
||||
|
||||
TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
|
||||
{x}, {&y}));
|
||||
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, XTimesN) {
|
||||
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
|
||||
test::function::XTimes16()});
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/common_runtime/partitioning_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
@ -82,20 +84,32 @@ Status UpdateArgAndRetvalMetadata(
|
||||
// Find the Arg and Retval nodes, along with their corresponding indices
|
||||
// in the original function.
|
||||
for (Node* node : subgraph->op_nodes()) {
|
||||
string node_type = node->type_string();
|
||||
if (node->IsArg()) {
|
||||
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
|
||||
int index = static_cast<int>(attr_value->i());
|
||||
arg_indices->push_back(index);
|
||||
arg_nodes.push_back(std::make_pair(node, index));
|
||||
arg_nodes.emplace_back(node, index);
|
||||
} else if (node->IsRetval()) {
|
||||
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
|
||||
int index = static_cast<int>(attr_value->i());
|
||||
ret_indices->push_back(index);
|
||||
ret_nodes.push_back(std::make_pair(node, index));
|
||||
ret_nodes.emplace_back(node, index);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the nodes by index so that the order is stable.
|
||||
//
|
||||
// In particular, this enables calling a single-partition function with
|
||||
// the same signature as the original unpartitioned function.
|
||||
auto comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
|
||||
return a.second < b.second;
|
||||
};
|
||||
std::sort(arg_nodes.begin(), arg_nodes.end(), comparator);
|
||||
std::sort(ret_nodes.begin(), ret_nodes.end(), comparator);
|
||||
|
||||
arg_indices->reserve(arg_nodes.size());
|
||||
for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
|
||||
ret_indices->reserve(ret_nodes.size());
|
||||
for (const auto& pair : ret_nodes) ret_indices->push_back(pair.second);
|
||||
|
||||
for (int i = 0; i < arg_nodes.size(); ++i) {
|
||||
Node* arg = arg_nodes[i].first;
|
||||
arg->AddAttr("index", i);
|
||||
|
@ -41,17 +41,18 @@ Status PartitionFunctionGraph(
|
||||
//
|
||||
// More specifically, this function
|
||||
// (1) rewrites the indices of the `Arg` and `Retval` nodes placed
|
||||
// on a particular device. When a function is partitioned each
|
||||
// partition, `subgraph`, get a subset of the arguments and
|
||||
// on a particular device. When a function is partitioned, each
|
||||
// partition `subgraph` gets a subset of the arguments and
|
||||
// return values. The `index` attributes of these _Arg and _Retval
|
||||
// nodes reflect the indices of these parameters in the original
|
||||
// function. To convert `subgraph` to a function, we need to replace
|
||||
// there original indices with 0, 1, 2, ... .
|
||||
//
|
||||
// The argument and return value order in the partitioned function is
|
||||
// determined by the node iteration order in `subgraph`. This order
|
||||
// is also used in UpdateArgAndRetvalMetadata. This is fine because the
|
||||
// node iteration order is deterministic - it follows the node ids.
|
||||
// determined by the argument and return value order in the original
|
||||
// function. This stability is important because it enables us to treat
|
||||
// a single-partition function as having the same signature as the
|
||||
// subgraph.
|
||||
// (2) records the subsets of `Arg` and `Retval` nodes assigned to the
|
||||
// device in `*_indices`, and
|
||||
// (3) records which `Arg` and `Retval` nodes live in host memory in
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
@ -91,12 +92,18 @@ class PartitioningUtilsTest : public ::testing::Test {
|
||||
// Fills subgraph with an identify function arg->identity->ret
|
||||
// where each node has type `dtype` and arg/ret nodes have
|
||||
// indices `arg_index` and `ret_index`.
|
||||
void SubGraph(Graph* subgraph, DataType dtype, int arg_index, int ret_index) {
|
||||
void SubGraph(Graph* subgraph, DataType dtype,
|
||||
gtl::ArraySlice<int> arg_indices,
|
||||
gtl::ArraySlice<int> ret_indices) {
|
||||
Scope s = Scope::NewRootScope();
|
||||
Scope s1 = s.WithDevice("/job:a/replica:0/task:0/device:CPU:0");
|
||||
auto x = ops::_Arg(s1.WithOpName("x"), dtype, arg_index);
|
||||
auto id_x = ops::Identity(s1.WithOpName("id_x"), x);
|
||||
auto dx_retval = ops::_Retval(s1.WithOpName("retval1"), id_x, ret_index);
|
||||
CHECK_EQ(arg_indices.size(), ret_indices.size());
|
||||
for (size_t i = 0; i < arg_indices.size(); ++i) {
|
||||
auto x = ops::_Arg(s1.WithOpName("x"), dtype, arg_indices[i]);
|
||||
auto id_x = ops::Identity(s1.WithOpName("id_x"), x);
|
||||
auto dx_retval =
|
||||
ops::_Retval(s1.WithOpName("retval1"), id_x, ret_indices[i]);
|
||||
}
|
||||
TF_ASSERT_OK(s.ToGraph(subgraph));
|
||||
Placer placer(subgraph, "", &device_set_, device0_);
|
||||
TF_ASSERT_OK(placer.Run());
|
||||
@ -175,8 +182,8 @@ void CheckIndex(const Node& node, int expected_index) {
|
||||
}
|
||||
|
||||
TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
|
||||
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
SubGraph(graph.get(), DT_FLOAT, 3, 5);
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
SubGraph(graph.get(), DT_FLOAT, {3}, {5});
|
||||
|
||||
std::vector<int> arg_indices;
|
||||
std::vector<int> ret_indices;
|
||||
@ -202,5 +209,27 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
|
||||
CheckIndex(*nodes["retval1"], 0);
|
||||
}
|
||||
|
||||
TEST_F(PartitioningUtilsTest, UpdateArgsAndRets_Order) {
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
SubGraph(graph.get(), DT_FLOAT, {9, 7, 5, 3, 1}, {2, 4, 6, 8, 10});
|
||||
|
||||
std::vector<int> arg_indices;
|
||||
std::vector<int> ret_indices;
|
||||
std::vector<AllocatorAttributes> arg_alloc_attrs;
|
||||
std::vector<AllocatorAttributes> ret_alloc_attrs;
|
||||
|
||||
string device_type = "CPU";
|
||||
|
||||
Status status = UpdateArgAndRetvalMetadata(
|
||||
graph.get(), device_type, &arg_indices, &ret_indices, &arg_alloc_attrs,
|
||||
&ret_alloc_attrs);
|
||||
ASSERT_TRUE(status.ok()) << status.ToString();
|
||||
|
||||
CheckIndices({1, 3, 5, 7, 9}, arg_indices);
|
||||
CheckIndices({2, 4, 6, 8, 10}, ret_indices);
|
||||
CheckAlloc({false, false, false, false, false}, arg_alloc_attrs);
|
||||
CheckAlloc({false, false, false, false, false}, ret_alloc_attrs);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -282,12 +282,25 @@ bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
|
||||
|
||||
FunctionLibraryRuntime::LocalHandle
|
||||
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
|
||||
const string& device_name, FunctionLibraryRuntime::Handle handle) const {
|
||||
const string& device_name, FunctionLibraryRuntime::Handle handle,
|
||||
bool include_multi_device) const {
|
||||
tf_shared_lock l(mu_);
|
||||
|
||||
auto miter = mdevice_data_.find(handle);
|
||||
if (miter != mdevice_data_.end()) {
|
||||
return kInvalidLocalHandle;
|
||||
if (!include_multi_device) return kInvalidLocalHandle;
|
||||
|
||||
const MultiDeviceFunctionData& data = *miter->second;
|
||||
if (data.glue_.size() != 1) return kInvalidLocalHandle;
|
||||
|
||||
const auto& pair = *data.glue_.begin();
|
||||
const string& func_device_name = pair.first;
|
||||
const ComponentFunctionData& component_data = pair.second;
|
||||
if (func_device_name != device_name) return kInvalidLocalHandle;
|
||||
|
||||
// Replace the given handle with the handle for the single component
|
||||
// function.
|
||||
handle = component_data.handle_;
|
||||
}
|
||||
|
||||
auto iter = function_data_.find(handle);
|
||||
|
@ -137,8 +137,13 @@ class ProcessFunctionLibraryRuntime {
|
||||
// index of instantiation of that function. If the function was not
|
||||
// instantiated on `device_name` or the function is multi-device,
|
||||
// returns kInvalidLocalHandle.
|
||||
//
|
||||
// If `include_multi_device` is true and `handle` is a multi-device function
|
||||
// with a single component that is placed on `device_name`, then this method
|
||||
// will return the local handle for that component.
|
||||
FunctionLibraryRuntime::LocalHandle GetHandleOnDevice(
|
||||
const string& device_name, FunctionLibraryRuntime::Handle handle) const;
|
||||
const string& device_name, FunctionLibraryRuntime::Handle handle,
|
||||
bool include_multi_device = false) const;
|
||||
|
||||
// Fills `output_devices` with the devices on which the results will
|
||||
// be produced. If some output is produced on CPU, the corresponding Device*
|
||||
|
Loading…
Reference in New Issue
Block a user