Support packed variable for tf data captured function.

If a packed variable is passed to a DatasetOp through a captured function,
1) function instantiation (graph expansion) would insert a PackOp between per-replica arg nodes and the DatasetOp.
2) function execution would feed unpacked ResourceHandles to corresponding sub-functions.

PiperOrigin-RevId: 317752101
Change-Id: I53f7f1d6d075fd40a76a32f329f2dcbbf1db6494
This commit is contained in:
Yujing Zhang 2020-06-22 15:52:37 -07:00 committed by TensorFlower Gardener
parent afeac170f0
commit 26421826a0
6 changed files with 83 additions and 19 deletions

View File

@ -1356,11 +1356,22 @@ void ProcessFunctionLibraryRuntime::Run(
// "Index"s of _Arg nodes are unique when all arguments are local Tensors.
for (const auto& it : comp_data.arg_indices) {
if (it.sub_index >= 0) {
return errors::InvalidArgument("Got unexpected sub_index ",
it.sub_index, " for argument ",
it.index);
const Tensor& t = args[it.index];
if (t.dtype() != DT_RESOURCE) {
return errors::InvalidArgument("Got unexpected sub_index ",
it.sub_index, " for argument ",
it.index);
}
const auto& handles = t.flat<ResourceHandle>();
if (it.sub_index >= handles.size()) {
return errors::InvalidArgument(
"Sub_index ", it.sub_index, "is out of range [0,",
handles.size(), ") for argument ", it.index);
}
comp_args->args.push_back(Tensor(handles(it.sub_index)));
} else {
comp_args->args.push_back(args[it.index]);
}
comp_args->args.push_back(args[it.index]);
}
return Status::OK();
};

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
@ -867,14 +868,29 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) {
inst_opts.input_resource_dtypes_and_shapes[0] = {
initial_resource_value0.dtype(), initial_resource_value0.shape()};
gtl::InlinedVector<TensorValue, 4> handles;
handles.push_back(TensorValue(&resource_handle0));
handles.push_back(TensorValue(&resource_handle1));
TestFunctionPackedArgs args(0, std::move(handles));
Tensor ret;
TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}},
inst_opts, args, {&ret}));
test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
// Packed TensorHandle
{
gtl::InlinedVector<TensorValue, 4> handles;
handles.push_back(TensorValue(&resource_handle0));
handles.push_back(TensorValue(&resource_handle1));
TestFunctionPackedArgs args(0, std::move(handles));
Tensor ret;
TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts,
{{"T", DT_FLOAT}}, inst_opts, args, {&ret}));
test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
}
// Packed Tensor
{
Tensor arg(DT_RESOURCE, TensorShape({2}));
arg.flat<ResourceHandle>()(0) = resource_handle0.scalar<ResourceHandle>()();
arg.flat<ResourceHandle>()(1) = resource_handle1.scalar<ResourceHandle>()();
Tensor ret;
TF_CHECK_OK(Run("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}}, inst_opts,
{arg}, {&ret}));
test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
}
}
TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {

View File

@ -571,6 +571,9 @@ Status CapturedFunction::Instantiate(
DCHECK(lib->device() != nullptr);
inst_opts.target = lib->device()->name();
// Maps from a CompositeDevice name to underlying physical device names.
absl::flat_hash_map<string, std::vector<string>> composite_devices;
if (inst_opts.is_multi_device_function) {
// Compute devices of non-captured inputs.
//
@ -596,9 +599,29 @@ Status CapturedFunction::Instantiate(
const auto& input = captured_inputs_[i];
DataType dtype = input.dtype();
if (dtype == DT_RESOURCE) {
const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
inst_opts.input_devices.push_back(handle.device());
const auto& dtypes_and_shapes = handle.dtypes_and_shapes();
const auto& handles = input.flat<ResourceHandle>();
const ResourceHandle& handle0 = handles(0);
string composite_device;
auto iter = fdef->arg_attr().find(num_non_captured_inputs + i);
if (iter != fdef->arg_attr().end()) {
auto arg_attr = iter->second.attr().find("_composite_device");
if (arg_attr != iter->second.attr().end()) {
composite_device = arg_attr->second.s();
}
}
if (!composite_device.empty()) {
if (composite_devices.find(composite_device) ==
composite_devices.end()) {
for (int i = 0; i < handles.size(); ++i) {
composite_devices[composite_device].push_back(
handles(i).device());
}
}
inst_opts.input_devices.push_back(composite_device);
} else {
inst_opts.input_devices.push_back(handle0.device());
}
const auto& dtypes_and_shapes = handle0.dtypes_and_shapes();
// Set dtypes and shapes for resource variable inputs.
if (!dtypes_and_shapes.empty()) {
input_resource_variable_dtypes_and_shapes[num_non_captured_inputs +
@ -613,6 +636,10 @@ Status CapturedFunction::Instantiate(
}
}
for (const auto& it : composite_devices) {
inst_opts.composite_devices[it.first] = &it.second;
}
for (size_t i = 0; i < fdef->signature().output_arg_size(); ++i) {
inst_opts.output_devices.push_back(inst_opts.target);
}

View File

@ -796,10 +796,6 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
mode=["eager"]
))
def testMultiDeviceDataCapturedFunction(self, distribution):
if getattr(distribution, "_enable_packed_variable_in_eager_mode", False):
self.skipTest(
"Dataset captured function doesn't support packed tensors yet "
"(b/145922293).")
inputs = constant_op.constant([2., 3.])
dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5)
input_iterator = iter(

View File

@ -224,6 +224,13 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
return read0, read1, read2, read3
arg_attrs = read_var.get_concrete_function().function_def.arg_attr
self.assertLen(arg_attrs, 2)
self.assertEqual(arg_attrs[0].attr['_composite_device'].s,
compat.as_bytes(packed_var_0.device))
self.assertEqual(arg_attrs[1].attr['_composite_device'].s,
compat.as_bytes(packed_var_1.device))
self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6))
def testImplementsAttributeBasic(self):

View File

@ -646,6 +646,13 @@ class FuncGraph(ops.Graph):
if capture is None:
placeholder = _create_substitute_placeholder(
tensor, name=name, dtype=tensor.dtype, shape=shape)
# Record the composite device as an attribute to the placeholder.
# This attribute would be propogated into the arg_attr of the FunctionDef.
# Currently, a packed eager tensor is always placed on a CompositeDevice.
if isinstance(tensor, ops.EagerTensor) and tensor.is_packed:
placeholder.op._set_attr( # pylint: disable=protected-access
"_composite_device",
attr_value_pb2.AttrValue(s=compat.as_bytes(tensor.device)))
self.add_capture(tensor, placeholder)
else:
placeholder = capture[1]