Support packed variable for tf data captured function.
If a packed variable is passed to a DatasetOp through a captured function, 1) function instantiation (graph expansion) would insert a PackOp between per-replica arg nodes and the DatasetOp. 2) function execution would feed unpacked ResourceHandles to corresponding sub-functions. PiperOrigin-RevId: 317752101 Change-Id: I53f7f1d6d075fd40a76a32f329f2dcbbf1db6494
This commit is contained in:
		
							parent
							
								
									afeac170f0
								
							
						
					
					
						commit
						26421826a0
					
				@ -1356,12 +1356,23 @@ void ProcessFunctionLibraryRuntime::Run(
 | 
			
		||||
      // "Index"s of _Arg nodes are unique when all arguments are local Tensors.
 | 
			
		||||
      for (const auto& it : comp_data.arg_indices) {
 | 
			
		||||
        if (it.sub_index >= 0) {
 | 
			
		||||
          const Tensor& t = args[it.index];
 | 
			
		||||
          if (t.dtype() != DT_RESOURCE) {
 | 
			
		||||
            return errors::InvalidArgument("Got unexpected sub_index ",
 | 
			
		||||
                                           it.sub_index, " for argument ",
 | 
			
		||||
                                           it.index);
 | 
			
		||||
          }
 | 
			
		||||
          const auto& handles = t.flat<ResourceHandle>();
 | 
			
		||||
          if (it.sub_index >= handles.size()) {
 | 
			
		||||
            return errors::InvalidArgument(
 | 
			
		||||
                "Sub_index ", it.sub_index, "is out of range [0,",
 | 
			
		||||
                handles.size(), ") for argument ", it.index);
 | 
			
		||||
          }
 | 
			
		||||
          comp_args->args.push_back(Tensor(handles(it.sub_index)));
 | 
			
		||||
        } else {
 | 
			
		||||
          comp_args->args.push_back(args[it.index]);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      return Status::OK();
 | 
			
		||||
    };
 | 
			
		||||
    return RunMultiDevice(new_opts, handle, rets, cleanup_items,
 | 
			
		||||
 | 
			
		||||
@ -28,6 +28,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/framework/resource_var.h"
 | 
			
		||||
#include "tensorflow/core/framework/tensor_testutil.h"
 | 
			
		||||
#include "tensorflow/core/framework/type_index.h"
 | 
			
		||||
#include "tensorflow/core/framework/types.pb.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/errors.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/status_test_util.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/threadpool.h"
 | 
			
		||||
@ -867,14 +868,29 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) {
 | 
			
		||||
  inst_opts.input_resource_dtypes_and_shapes[0] = {
 | 
			
		||||
      initial_resource_value0.dtype(), initial_resource_value0.shape()};
 | 
			
		||||
 | 
			
		||||
  // Packed TensorHandle
 | 
			
		||||
  {
 | 
			
		||||
    gtl::InlinedVector<TensorValue, 4> handles;
 | 
			
		||||
    handles.push_back(TensorValue(&resource_handle0));
 | 
			
		||||
    handles.push_back(TensorValue(&resource_handle1));
 | 
			
		||||
    TestFunctionPackedArgs args(0, std::move(handles));
 | 
			
		||||
    Tensor ret;
 | 
			
		||||
  TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}},
 | 
			
		||||
                                inst_opts, args, {&ret}));
 | 
			
		||||
    TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts,
 | 
			
		||||
                                  {{"T", DT_FLOAT}}, inst_opts, args, {&ret}));
 | 
			
		||||
    test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Packed Tensor
 | 
			
		||||
  {
 | 
			
		||||
    Tensor arg(DT_RESOURCE, TensorShape({2}));
 | 
			
		||||
    arg.flat<ResourceHandle>()(0) = resource_handle0.scalar<ResourceHandle>()();
 | 
			
		||||
    arg.flat<ResourceHandle>()(1) = resource_handle1.scalar<ResourceHandle>()();
 | 
			
		||||
 | 
			
		||||
    Tensor ret;
 | 
			
		||||
    TF_CHECK_OK(Run("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}}, inst_opts,
 | 
			
		||||
                    {arg}, {&ret}));
 | 
			
		||||
    test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
 | 
			
		||||
 | 
			
		||||
@ -571,6 +571,9 @@ Status CapturedFunction::Instantiate(
 | 
			
		||||
  DCHECK(lib->device() != nullptr);
 | 
			
		||||
  inst_opts.target = lib->device()->name();
 | 
			
		||||
 | 
			
		||||
  // Maps from a CompositeDevice name to underlying physical device names.
 | 
			
		||||
  absl::flat_hash_map<string, std::vector<string>> composite_devices;
 | 
			
		||||
 | 
			
		||||
  if (inst_opts.is_multi_device_function) {
 | 
			
		||||
    // Compute devices of non-captured inputs.
 | 
			
		||||
    //
 | 
			
		||||
@ -596,9 +599,29 @@ Status CapturedFunction::Instantiate(
 | 
			
		||||
      const auto& input = captured_inputs_[i];
 | 
			
		||||
      DataType dtype = input.dtype();
 | 
			
		||||
      if (dtype == DT_RESOURCE) {
 | 
			
		||||
        const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
 | 
			
		||||
        inst_opts.input_devices.push_back(handle.device());
 | 
			
		||||
        const auto& dtypes_and_shapes = handle.dtypes_and_shapes();
 | 
			
		||||
        const auto& handles = input.flat<ResourceHandle>();
 | 
			
		||||
        const ResourceHandle& handle0 = handles(0);
 | 
			
		||||
        string composite_device;
 | 
			
		||||
        auto iter = fdef->arg_attr().find(num_non_captured_inputs + i);
 | 
			
		||||
        if (iter != fdef->arg_attr().end()) {
 | 
			
		||||
          auto arg_attr = iter->second.attr().find("_composite_device");
 | 
			
		||||
          if (arg_attr != iter->second.attr().end()) {
 | 
			
		||||
            composite_device = arg_attr->second.s();
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        if (!composite_device.empty()) {
 | 
			
		||||
          if (composite_devices.find(composite_device) ==
 | 
			
		||||
              composite_devices.end()) {
 | 
			
		||||
            for (int i = 0; i < handles.size(); ++i) {
 | 
			
		||||
              composite_devices[composite_device].push_back(
 | 
			
		||||
                  handles(i).device());
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
          inst_opts.input_devices.push_back(composite_device);
 | 
			
		||||
        } else {
 | 
			
		||||
          inst_opts.input_devices.push_back(handle0.device());
 | 
			
		||||
        }
 | 
			
		||||
        const auto& dtypes_and_shapes = handle0.dtypes_and_shapes();
 | 
			
		||||
        // Set dtypes and shapes for resource variable inputs.
 | 
			
		||||
        if (!dtypes_and_shapes.empty()) {
 | 
			
		||||
          input_resource_variable_dtypes_and_shapes[num_non_captured_inputs +
 | 
			
		||||
@ -613,6 +636,10 @@ Status CapturedFunction::Instantiate(
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (const auto& it : composite_devices) {
 | 
			
		||||
      inst_opts.composite_devices[it.first] = &it.second;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (size_t i = 0; i < fdef->signature().output_arg_size(); ++i) {
 | 
			
		||||
      inst_opts.output_devices.push_back(inst_opts.target);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -796,10 +796,6 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def testMultiDeviceDataCapturedFunction(self, distribution):
 | 
			
		||||
    if getattr(distribution, "_enable_packed_variable_in_eager_mode", False):
 | 
			
		||||
      self.skipTest(
 | 
			
		||||
          "Dataset captured function doesn't support packed tensors yet "
 | 
			
		||||
          "(b/145922293).")
 | 
			
		||||
    inputs = constant_op.constant([2., 3.])
 | 
			
		||||
    dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5)
 | 
			
		||||
    input_iterator = iter(
 | 
			
		||||
 | 
			
		||||
@ -224,6 +224,13 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
      return read0, read1, read2, read3
 | 
			
		||||
 | 
			
		||||
    arg_attrs = read_var.get_concrete_function().function_def.arg_attr
 | 
			
		||||
    self.assertLen(arg_attrs, 2)
 | 
			
		||||
    self.assertEqual(arg_attrs[0].attr['_composite_device'].s,
 | 
			
		||||
                     compat.as_bytes(packed_var_0.device))
 | 
			
		||||
    self.assertEqual(arg_attrs[1].attr['_composite_device'].s,
 | 
			
		||||
                     compat.as_bytes(packed_var_1.device))
 | 
			
		||||
 | 
			
		||||
    self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6))
 | 
			
		||||
 | 
			
		||||
  def testImplementsAttributeBasic(self):
 | 
			
		||||
 | 
			
		||||
@ -646,6 +646,13 @@ class FuncGraph(ops.Graph):
 | 
			
		||||
    if capture is None:
 | 
			
		||||
      placeholder = _create_substitute_placeholder(
 | 
			
		||||
          tensor, name=name, dtype=tensor.dtype, shape=shape)
 | 
			
		||||
      # Record the composite device as an attribute to the placeholder.
 | 
			
		||||
      # This attribute would be propogated into the arg_attr of the FunctionDef.
 | 
			
		||||
      # Currently, a packed eager tensor is always placed on a CompositeDevice.
 | 
			
		||||
      if isinstance(tensor, ops.EagerTensor) and tensor.is_packed:
 | 
			
		||||
        placeholder.op._set_attr(  # pylint: disable=protected-access
 | 
			
		||||
            "_composite_device",
 | 
			
		||||
            attr_value_pb2.AttrValue(s=compat.as_bytes(tensor.device)))
 | 
			
		||||
      self.add_capture(tensor, placeholder)
 | 
			
		||||
    else:
 | 
			
		||||
      placeholder = capture[1]
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user