Support packed variable for tf data captured function.
If a packed variable is passed to a DatasetOp through a captured function, 1) function instantiation (graph expansion) would insert a PackOp between per-replica arg nodes and the DatasetOp. 2) function execution would feed unpacked ResourceHandles to corresponding sub-functions. PiperOrigin-RevId: 317752101 Change-Id: I53f7f1d6d075fd40a76a32f329f2dcbbf1db6494
This commit is contained in:
parent
afeac170f0
commit
26421826a0
|
@ -1356,11 +1356,22 @@ void ProcessFunctionLibraryRuntime::Run(
|
|||
// "Index"s of _Arg nodes are unique when all arguments are local Tensors.
|
||||
for (const auto& it : comp_data.arg_indices) {
|
||||
if (it.sub_index >= 0) {
|
||||
return errors::InvalidArgument("Got unexpected sub_index ",
|
||||
it.sub_index, " for argument ",
|
||||
it.index);
|
||||
const Tensor& t = args[it.index];
|
||||
if (t.dtype() != DT_RESOURCE) {
|
||||
return errors::InvalidArgument("Got unexpected sub_index ",
|
||||
it.sub_index, " for argument ",
|
||||
it.index);
|
||||
}
|
||||
const auto& handles = t.flat<ResourceHandle>();
|
||||
if (it.sub_index >= handles.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Sub_index ", it.sub_index, "is out of range [0,",
|
||||
handles.size(), ") for argument ", it.index);
|
||||
}
|
||||
comp_args->args.push_back(Tensor(handles(it.sub_index)));
|
||||
} else {
|
||||
comp_args->args.push_back(args[it.index]);
|
||||
}
|
||||
comp_args->args.push_back(args[it.index]);
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
|
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/resource_var.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/type_index.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
|
@ -867,14 +868,29 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) {
|
|||
inst_opts.input_resource_dtypes_and_shapes[0] = {
|
||||
initial_resource_value0.dtype(), initial_resource_value0.shape()};
|
||||
|
||||
gtl::InlinedVector<TensorValue, 4> handles;
|
||||
handles.push_back(TensorValue(&resource_handle0));
|
||||
handles.push_back(TensorValue(&resource_handle1));
|
||||
TestFunctionPackedArgs args(0, std::move(handles));
|
||||
Tensor ret;
|
||||
TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}},
|
||||
inst_opts, args, {&ret}));
|
||||
test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
|
||||
// Packed TensorHandle
|
||||
{
|
||||
gtl::InlinedVector<TensorValue, 4> handles;
|
||||
handles.push_back(TensorValue(&resource_handle0));
|
||||
handles.push_back(TensorValue(&resource_handle1));
|
||||
TestFunctionPackedArgs args(0, std::move(handles));
|
||||
Tensor ret;
|
||||
TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts,
|
||||
{{"T", DT_FLOAT}}, inst_opts, args, {&ret}));
|
||||
test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
|
||||
}
|
||||
|
||||
// Packed Tensor
|
||||
{
|
||||
Tensor arg(DT_RESOURCE, TensorShape({2}));
|
||||
arg.flat<ResourceHandle>()(0) = resource_handle0.scalar<ResourceHandle>()();
|
||||
arg.flat<ResourceHandle>()(1) = resource_handle1.scalar<ResourceHandle>()();
|
||||
|
||||
Tensor ret;
|
||||
TF_CHECK_OK(Run("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}}, inst_opts,
|
||||
{arg}, {&ret}));
|
||||
test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
|
||||
|
|
|
@ -571,6 +571,9 @@ Status CapturedFunction::Instantiate(
|
|||
DCHECK(lib->device() != nullptr);
|
||||
inst_opts.target = lib->device()->name();
|
||||
|
||||
// Maps from a CompositeDevice name to underlying physical device names.
|
||||
absl::flat_hash_map<string, std::vector<string>> composite_devices;
|
||||
|
||||
if (inst_opts.is_multi_device_function) {
|
||||
// Compute devices of non-captured inputs.
|
||||
//
|
||||
|
@ -596,9 +599,29 @@ Status CapturedFunction::Instantiate(
|
|||
const auto& input = captured_inputs_[i];
|
||||
DataType dtype = input.dtype();
|
||||
if (dtype == DT_RESOURCE) {
|
||||
const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
|
||||
inst_opts.input_devices.push_back(handle.device());
|
||||
const auto& dtypes_and_shapes = handle.dtypes_and_shapes();
|
||||
const auto& handles = input.flat<ResourceHandle>();
|
||||
const ResourceHandle& handle0 = handles(0);
|
||||
string composite_device;
|
||||
auto iter = fdef->arg_attr().find(num_non_captured_inputs + i);
|
||||
if (iter != fdef->arg_attr().end()) {
|
||||
auto arg_attr = iter->second.attr().find("_composite_device");
|
||||
if (arg_attr != iter->second.attr().end()) {
|
||||
composite_device = arg_attr->second.s();
|
||||
}
|
||||
}
|
||||
if (!composite_device.empty()) {
|
||||
if (composite_devices.find(composite_device) ==
|
||||
composite_devices.end()) {
|
||||
for (int i = 0; i < handles.size(); ++i) {
|
||||
composite_devices[composite_device].push_back(
|
||||
handles(i).device());
|
||||
}
|
||||
}
|
||||
inst_opts.input_devices.push_back(composite_device);
|
||||
} else {
|
||||
inst_opts.input_devices.push_back(handle0.device());
|
||||
}
|
||||
const auto& dtypes_and_shapes = handle0.dtypes_and_shapes();
|
||||
// Set dtypes and shapes for resource variable inputs.
|
||||
if (!dtypes_and_shapes.empty()) {
|
||||
input_resource_variable_dtypes_and_shapes[num_non_captured_inputs +
|
||||
|
@ -613,6 +636,10 @@ Status CapturedFunction::Instantiate(
|
|||
}
|
||||
}
|
||||
|
||||
for (const auto& it : composite_devices) {
|
||||
inst_opts.composite_devices[it.first] = &it.second;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < fdef->signature().output_arg_size(); ++i) {
|
||||
inst_opts.output_devices.push_back(inst_opts.target);
|
||||
}
|
||||
|
|
|
@ -796,10 +796,6 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
|||
mode=["eager"]
|
||||
))
|
||||
def testMultiDeviceDataCapturedFunction(self, distribution):
|
||||
if getattr(distribution, "_enable_packed_variable_in_eager_mode", False):
|
||||
self.skipTest(
|
||||
"Dataset captured function doesn't support packed tensors yet "
|
||||
"(b/145922293).")
|
||||
inputs = constant_op.constant([2., 3.])
|
||||
dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5)
|
||||
input_iterator = iter(
|
||||
|
|
|
@ -224,6 +224,13 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||
|
||||
return read0, read1, read2, read3
|
||||
|
||||
arg_attrs = read_var.get_concrete_function().function_def.arg_attr
|
||||
self.assertLen(arg_attrs, 2)
|
||||
self.assertEqual(arg_attrs[0].attr['_composite_device'].s,
|
||||
compat.as_bytes(packed_var_0.device))
|
||||
self.assertEqual(arg_attrs[1].attr['_composite_device'].s,
|
||||
compat.as_bytes(packed_var_1.device))
|
||||
|
||||
self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6))
|
||||
|
||||
def testImplementsAttributeBasic(self):
|
||||
|
|
|
@ -646,6 +646,13 @@ class FuncGraph(ops.Graph):
|
|||
if capture is None:
|
||||
placeholder = _create_substitute_placeholder(
|
||||
tensor, name=name, dtype=tensor.dtype, shape=shape)
|
||||
# Record the composite device as an attribute to the placeholder.
|
||||
# This attribute would be propogated into the arg_attr of the FunctionDef.
|
||||
# Currently, a packed eager tensor is always placed on a CompositeDevice.
|
||||
if isinstance(tensor, ops.EagerTensor) and tensor.is_packed:
|
||||
placeholder.op._set_attr( # pylint: disable=protected-access
|
||||
"_composite_device",
|
||||
attr_value_pb2.AttrValue(s=compat.as_bytes(tensor.device)))
|
||||
self.add_capture(tensor, placeholder)
|
||||
else:
|
||||
placeholder = capture[1]
|
||||
|
|
Loading…
Reference in New Issue