From 26421826a0b43d6148ece23d2fc6d0a4d0442fa0 Mon Sep 17 00:00:00 2001 From: Yujing Zhang Date: Mon, 22 Jun 2020 15:52:37 -0700 Subject: [PATCH] Support packed variable for tf data captured function. If a packed variable is passed to a DatasetOp through a captured function, 1) function instantiation (graph expansion) would insert a PackOp between per-replica arg nodes and the DatasetOp. 2) function execution would feed unpacked ResourceHandles to corresponding sub-functions. PiperOrigin-RevId: 317752101 Change-Id: I53f7f1d6d075fd40a76a32f329f2dcbbf1db6494 --- .../process_function_library_runtime.cc | 19 ++++++++--- .../process_function_library_runtime_test.cc | 32 +++++++++++++----- .../core/kernels/data/captured_function.cc | 33 +++++++++++++++++-- .../custom_training_loop_input_test.py | 4 --- tensorflow/python/eager/function_test.py | 7 ++++ tensorflow/python/framework/func_graph.py | 7 ++++ 6 files changed, 83 insertions(+), 19 deletions(-) diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 160b9dd88f4..5ee6546f6be 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -1356,11 +1356,22 @@ void ProcessFunctionLibraryRuntime::Run( // "Index"s of _Arg nodes are unique when all arguments are local Tensors. for (const auto& it : comp_data.arg_indices) { if (it.sub_index >= 0) { - return errors::InvalidArgument("Got unexpected sub_index ", - it.sub_index, " for argument ", - it.index); + const Tensor& t = args[it.index]; + if (t.dtype() != DT_RESOURCE) { + return errors::InvalidArgument("Got unexpected sub_index ", + it.sub_index, " for argument ", + it.index); + } + const auto& handles = t.flat(); + if (it.sub_index >= handles.size()) { + return errors::InvalidArgument( + "Sub_index ", it.sub_index, "is out of range [0,", + handles.size(), ") for argument ", it.index); + } + comp_args->args.push_back(Tensor(handles(it.sub_index))); + } else { + comp_args->args.push_back(args[it.index]); } - comp_args->args.push_back(args[it.index]); } return Status::OK(); }; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index a007501fc82..6e17cdf4316 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/type_index.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -867,14 +868,29 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) { inst_opts.input_resource_dtypes_and_shapes[0] = { initial_resource_value0.dtype(), initial_resource_value0.shape()}; - gtl::InlinedVector handles; - handles.push_back(TensorValue(&resource_handle0)); - handles.push_back(TensorValue(&resource_handle1)); - TestFunctionPackedArgs args(0, std::move(handles)); - Tensor ret; - TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}}, - inst_opts, args, {&ret})); - test::ExpectTensorEqual(ret, test::AsTensor({40, 60})); + // Packed TensorHandle + { + gtl::InlinedVector handles; + handles.push_back(TensorValue(&resource_handle0)); + handles.push_back(TensorValue(&resource_handle1)); + TestFunctionPackedArgs args(0, std::move(handles)); + Tensor ret; + TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts, + {{"T", DT_FLOAT}}, inst_opts, args, {&ret})); + test::ExpectTensorEqual(ret, test::AsTensor({40, 60})); + } + + // Packed Tensor + { + Tensor arg(DT_RESOURCE, TensorShape({2})); + arg.flat()(0) = resource_handle0.scalar()(); + arg.flat()(1) = resource_handle1.scalar()(); + + Tensor ret; + TF_CHECK_OK(Run("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}}, inst_opts, + {arg}, {&ret})); + test::ExpectTensorEqual(ret, test::AsTensor({40, 60})); + } } TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) { diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index f740d7ff1ad..07e5a5b1273 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -571,6 +571,9 @@ Status CapturedFunction::Instantiate( DCHECK(lib->device() != nullptr); inst_opts.target = lib->device()->name(); + // Maps from a CompositeDevice name to underlying physical device names. + absl::flat_hash_map> composite_devices; + if (inst_opts.is_multi_device_function) { // Compute devices of non-captured inputs. // @@ -596,9 +599,29 @@ Status CapturedFunction::Instantiate( const auto& input = captured_inputs_[i]; DataType dtype = input.dtype(); if (dtype == DT_RESOURCE) { - const ResourceHandle& handle = input.flat()(0); - inst_opts.input_devices.push_back(handle.device()); - const auto& dtypes_and_shapes = handle.dtypes_and_shapes(); + const auto& handles = input.flat(); + const ResourceHandle& handle0 = handles(0); + string composite_device; + auto iter = fdef->arg_attr().find(num_non_captured_inputs + i); + if (iter != fdef->arg_attr().end()) { + auto arg_attr = iter->second.attr().find("_composite_device"); + if (arg_attr != iter->second.attr().end()) { + composite_device = arg_attr->second.s(); + } + } + if (!composite_device.empty()) { + if (composite_devices.find(composite_device) == + composite_devices.end()) { + for (int i = 0; i < handles.size(); ++i) { + composite_devices[composite_device].push_back( + handles(i).device()); + } + } + inst_opts.input_devices.push_back(composite_device); + } else { + inst_opts.input_devices.push_back(handle0.device()); + } + const auto& dtypes_and_shapes = handle0.dtypes_and_shapes(); // Set dtypes and shapes for resource variable inputs. if (!dtypes_and_shapes.empty()) { input_resource_variable_dtypes_and_shapes[num_non_captured_inputs + @@ -613,6 +636,10 @@ Status CapturedFunction::Instantiate( } } + for (const auto& it : composite_devices) { + inst_opts.composite_devices[it.first] = &it.second; + } + for (size_t i = 0; i < fdef->signature().output_arg_size(); ++i) { inst_opts.output_devices.push_back(inst_opts.target); } diff --git a/tensorflow/python/distribute/custom_training_loop_input_test.py b/tensorflow/python/distribute/custom_training_loop_input_test.py index 748cb7834fc..e2c4076f3f1 100644 --- a/tensorflow/python/distribute/custom_training_loop_input_test.py +++ b/tensorflow/python/distribute/custom_training_loop_input_test.py @@ -796,10 +796,6 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, mode=["eager"] )) def testMultiDeviceDataCapturedFunction(self, distribution): - if getattr(distribution, "_enable_packed_variable_in_eager_mode", False): - self.skipTest( - "Dataset captured function doesn't support packed tensors yet " - "(b/145922293).") inputs = constant_op.constant([2., 3.]) dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5) input_iterator = iter( diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 2c49795ba8a..3c42d95e437 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -224,6 +224,13 @@ class FunctionTest(test.TestCase, parameterized.TestCase): return read0, read1, read2, read3 + arg_attrs = read_var.get_concrete_function().function_def.arg_attr + self.assertLen(arg_attrs, 2) + self.assertEqual(arg_attrs[0].attr['_composite_device'].s, + compat.as_bytes(packed_var_0.device)) + self.assertEqual(arg_attrs[1].attr['_composite_device'].s, + compat.as_bytes(packed_var_1.device)) + self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6)) def testImplementsAttributeBasic(self): diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index b0f8821b17f..e8e8fcbf081 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -646,6 +646,13 @@ class FuncGraph(ops.Graph): if capture is None: placeholder = _create_substitute_placeholder( tensor, name=name, dtype=tensor.dtype, shape=shape) + # Record the composite device as an attribute to the placeholder. + # This attribute would be propogated into the arg_attr of the FunctionDef. + # Currently, a packed eager tensor is always placed on a CompositeDevice. + if isinstance(tensor, ops.EagerTensor) and tensor.is_packed: + placeholder.op._set_attr( # pylint: disable=protected-access + "_composite_device", + attr_value_pb2.AttrValue(s=compat.as_bytes(tensor.device))) self.add_capture(tensor, placeholder) else: placeholder = capture[1]