Support packed variable for tf data captured function.
If a packed variable is passed to a DatasetOp through a captured function, 1) function instantiation (graph expansion) would insert a PackOp between per-replica arg nodes and the DatasetOp. 2) function execution would feed unpacked ResourceHandles to corresponding sub-functions. PiperOrigin-RevId: 317752101 Change-Id: I53f7f1d6d075fd40a76a32f329f2dcbbf1db6494
This commit is contained in:
parent
afeac170f0
commit
26421826a0
@ -1356,11 +1356,22 @@ void ProcessFunctionLibraryRuntime::Run(
|
|||||||
// "Index"s of _Arg nodes are unique when all arguments are local Tensors.
|
// "Index"s of _Arg nodes are unique when all arguments are local Tensors.
|
||||||
for (const auto& it : comp_data.arg_indices) {
|
for (const auto& it : comp_data.arg_indices) {
|
||||||
if (it.sub_index >= 0) {
|
if (it.sub_index >= 0) {
|
||||||
return errors::InvalidArgument("Got unexpected sub_index ",
|
const Tensor& t = args[it.index];
|
||||||
it.sub_index, " for argument ",
|
if (t.dtype() != DT_RESOURCE) {
|
||||||
it.index);
|
return errors::InvalidArgument("Got unexpected sub_index ",
|
||||||
|
it.sub_index, " for argument ",
|
||||||
|
it.index);
|
||||||
|
}
|
||||||
|
const auto& handles = t.flat<ResourceHandle>();
|
||||||
|
if (it.sub_index >= handles.size()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Sub_index ", it.sub_index, "is out of range [0,",
|
||||||
|
handles.size(), ") for argument ", it.index);
|
||||||
|
}
|
||||||
|
comp_args->args.push_back(Tensor(handles(it.sub_index)));
|
||||||
|
} else {
|
||||||
|
comp_args->args.push_back(args[it.index]);
|
||||||
}
|
}
|
||||||
comp_args->args.push_back(args[it.index]);
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
};
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/resource_var.h"
|
#include "tensorflow/core/framework/resource_var.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/framework/type_index.h"
|
#include "tensorflow/core/framework/type_index.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
@ -867,14 +868,29 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) {
|
|||||||
inst_opts.input_resource_dtypes_and_shapes[0] = {
|
inst_opts.input_resource_dtypes_and_shapes[0] = {
|
||||||
initial_resource_value0.dtype(), initial_resource_value0.shape()};
|
initial_resource_value0.dtype(), initial_resource_value0.shape()};
|
||||||
|
|
||||||
gtl::InlinedVector<TensorValue, 4> handles;
|
// Packed TensorHandle
|
||||||
handles.push_back(TensorValue(&resource_handle0));
|
{
|
||||||
handles.push_back(TensorValue(&resource_handle1));
|
gtl::InlinedVector<TensorValue, 4> handles;
|
||||||
TestFunctionPackedArgs args(0, std::move(handles));
|
handles.push_back(TensorValue(&resource_handle0));
|
||||||
Tensor ret;
|
handles.push_back(TensorValue(&resource_handle1));
|
||||||
TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}},
|
TestFunctionPackedArgs args(0, std::move(handles));
|
||||||
inst_opts, args, {&ret}));
|
Tensor ret;
|
||||||
test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
|
TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts,
|
||||||
|
{{"T", DT_FLOAT}}, inst_opts, args, {&ret}));
|
||||||
|
test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Packed Tensor
|
||||||
|
{
|
||||||
|
Tensor arg(DT_RESOURCE, TensorShape({2}));
|
||||||
|
arg.flat<ResourceHandle>()(0) = resource_handle0.scalar<ResourceHandle>()();
|
||||||
|
arg.flat<ResourceHandle>()(1) = resource_handle1.scalar<ResourceHandle>()();
|
||||||
|
|
||||||
|
Tensor ret;
|
||||||
|
TF_CHECK_OK(Run("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}}, inst_opts,
|
||||||
|
{arg}, {&ret}));
|
||||||
|
test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
|
||||||
|
@ -571,6 +571,9 @@ Status CapturedFunction::Instantiate(
|
|||||||
DCHECK(lib->device() != nullptr);
|
DCHECK(lib->device() != nullptr);
|
||||||
inst_opts.target = lib->device()->name();
|
inst_opts.target = lib->device()->name();
|
||||||
|
|
||||||
|
// Maps from a CompositeDevice name to underlying physical device names.
|
||||||
|
absl::flat_hash_map<string, std::vector<string>> composite_devices;
|
||||||
|
|
||||||
if (inst_opts.is_multi_device_function) {
|
if (inst_opts.is_multi_device_function) {
|
||||||
// Compute devices of non-captured inputs.
|
// Compute devices of non-captured inputs.
|
||||||
//
|
//
|
||||||
@ -596,9 +599,29 @@ Status CapturedFunction::Instantiate(
|
|||||||
const auto& input = captured_inputs_[i];
|
const auto& input = captured_inputs_[i];
|
||||||
DataType dtype = input.dtype();
|
DataType dtype = input.dtype();
|
||||||
if (dtype == DT_RESOURCE) {
|
if (dtype == DT_RESOURCE) {
|
||||||
const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
|
const auto& handles = input.flat<ResourceHandle>();
|
||||||
inst_opts.input_devices.push_back(handle.device());
|
const ResourceHandle& handle0 = handles(0);
|
||||||
const auto& dtypes_and_shapes = handle.dtypes_and_shapes();
|
string composite_device;
|
||||||
|
auto iter = fdef->arg_attr().find(num_non_captured_inputs + i);
|
||||||
|
if (iter != fdef->arg_attr().end()) {
|
||||||
|
auto arg_attr = iter->second.attr().find("_composite_device");
|
||||||
|
if (arg_attr != iter->second.attr().end()) {
|
||||||
|
composite_device = arg_attr->second.s();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!composite_device.empty()) {
|
||||||
|
if (composite_devices.find(composite_device) ==
|
||||||
|
composite_devices.end()) {
|
||||||
|
for (int i = 0; i < handles.size(); ++i) {
|
||||||
|
composite_devices[composite_device].push_back(
|
||||||
|
handles(i).device());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inst_opts.input_devices.push_back(composite_device);
|
||||||
|
} else {
|
||||||
|
inst_opts.input_devices.push_back(handle0.device());
|
||||||
|
}
|
||||||
|
const auto& dtypes_and_shapes = handle0.dtypes_and_shapes();
|
||||||
// Set dtypes and shapes for resource variable inputs.
|
// Set dtypes and shapes for resource variable inputs.
|
||||||
if (!dtypes_and_shapes.empty()) {
|
if (!dtypes_and_shapes.empty()) {
|
||||||
input_resource_variable_dtypes_and_shapes[num_non_captured_inputs +
|
input_resource_variable_dtypes_and_shapes[num_non_captured_inputs +
|
||||||
@ -613,6 +636,10 @@ Status CapturedFunction::Instantiate(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (const auto& it : composite_devices) {
|
||||||
|
inst_opts.composite_devices[it.first] = &it.second;
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < fdef->signature().output_arg_size(); ++i) {
|
for (size_t i = 0; i < fdef->signature().output_arg_size(); ++i) {
|
||||||
inst_opts.output_devices.push_back(inst_opts.target);
|
inst_opts.output_devices.push_back(inst_opts.target);
|
||||||
}
|
}
|
||||||
|
@ -796,10 +796,6 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
|||||||
mode=["eager"]
|
mode=["eager"]
|
||||||
))
|
))
|
||||||
def testMultiDeviceDataCapturedFunction(self, distribution):
|
def testMultiDeviceDataCapturedFunction(self, distribution):
|
||||||
if getattr(distribution, "_enable_packed_variable_in_eager_mode", False):
|
|
||||||
self.skipTest(
|
|
||||||
"Dataset captured function doesn't support packed tensors yet "
|
|
||||||
"(b/145922293).")
|
|
||||||
inputs = constant_op.constant([2., 3.])
|
inputs = constant_op.constant([2., 3.])
|
||||||
dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5)
|
dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5)
|
||||||
input_iterator = iter(
|
input_iterator = iter(
|
||||||
|
@ -224,6 +224,13 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
return read0, read1, read2, read3
|
return read0, read1, read2, read3
|
||||||
|
|
||||||
|
arg_attrs = read_var.get_concrete_function().function_def.arg_attr
|
||||||
|
self.assertLen(arg_attrs, 2)
|
||||||
|
self.assertEqual(arg_attrs[0].attr['_composite_device'].s,
|
||||||
|
compat.as_bytes(packed_var_0.device))
|
||||||
|
self.assertEqual(arg_attrs[1].attr['_composite_device'].s,
|
||||||
|
compat.as_bytes(packed_var_1.device))
|
||||||
|
|
||||||
self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6))
|
self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6))
|
||||||
|
|
||||||
def testImplementsAttributeBasic(self):
|
def testImplementsAttributeBasic(self):
|
||||||
|
@ -646,6 +646,13 @@ class FuncGraph(ops.Graph):
|
|||||||
if capture is None:
|
if capture is None:
|
||||||
placeholder = _create_substitute_placeholder(
|
placeholder = _create_substitute_placeholder(
|
||||||
tensor, name=name, dtype=tensor.dtype, shape=shape)
|
tensor, name=name, dtype=tensor.dtype, shape=shape)
|
||||||
|
# Record the composite device as an attribute to the placeholder.
|
||||||
|
# This attribute would be propogated into the arg_attr of the FunctionDef.
|
||||||
|
# Currently, a packed eager tensor is always placed on a CompositeDevice.
|
||||||
|
if isinstance(tensor, ops.EagerTensor) and tensor.is_packed:
|
||||||
|
placeholder.op._set_attr( # pylint: disable=protected-access
|
||||||
|
"_composite_device",
|
||||||
|
attr_value_pb2.AttrValue(s=compat.as_bytes(tensor.device)))
|
||||||
self.add_capture(tensor, placeholder)
|
self.add_capture(tensor, placeholder)
|
||||||
else:
|
else:
|
||||||
placeholder = capture[1]
|
placeholder = capture[1]
|
||||||
|
Loading…
Reference in New Issue
Block a user