Add pfor converters and basic type information to optionals

Like TensorLists, pfor vectorizes the content rather than the variant-dtype tensor itself.

Adding type information to optionals requires a slight tangent to have cond combine it correctly across branches (since cond can pack any two tensors together in one optional return).

Fixes #44502.

PiperOrigin-RevId: 343304828
Change-Id: I2817dee286e452c5b83f232cade67260e69d23b6
This commit is contained in:
Allen Lavoie 2020-11-19 09:21:40 -08:00 committed by TensorFlower Gardener
parent 05dc9e104b
commit 4f2a50496d
8 changed files with 160 additions and 35 deletions

View File

@ -84,4 +84,6 @@ enum SpecializedType {
ST_INVALID = 0;
// "tensorflow::TensorList" in the variant type registry.
ST_TENSOR_LIST = 1;
// "tensorflow::data::Optional" in the variant type registry.
ST_OPTIONAL = 2;
}

View File

@ -853,7 +853,18 @@ REGISTER_OP("OptionalFromValue")
.Input("components: Toutput_types")
.Output("optional: variant")
.Attr("Toutput_types: list(type) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
.SetShapeFn([](shape_inference::InferenceContext* c) {
std::vector<DataType> dtypes;
TF_RETURN_IF_ERROR(c->GetAttr("Toutput_types", &dtypes));
c->set_output(0, c->Scalar());
std::vector<shape_inference::ShapeAndType> shapes_and_types;
shapes_and_types.reserve(c->num_inputs());
for (int i = 0; i < c->num_inputs(); ++i) {
shapes_and_types.emplace_back(c->input(i), dtypes[i], ST_OPTIONAL);
}
c->set_output_handle_shapes_and_types(0, shapes_and_types);
return Status::OK();
});
REGISTER_OP("OptionalNone")
.Output("optional: variant")

View File

@ -1844,6 +1844,24 @@ class JacobianTest(test.TestCase):
self.assertAllClose(compute_jacobian(use_pfor=True),
compute_jacobian(use_pfor=False))
def test_cond_func_grad_jacobian(self):
@def_function.function
def f(x):
y = control_flow_ops.cond(x > 0., lambda: x**3., lambda: x**2.)
return y
with backprop.GradientTape(persistent=True) as tape:
x = constant_op.constant(1.)
tape.watch(x)
y = f(x)
grad = tape.gradient(y, x)
self.assertAllClose(3., grad)
jacobian = tape.jacobian(grad, x, experimental_use_pfor=False)
self.assertAllClose(6., jacobian)
jacobian_pfor = tape.jacobian(grad, x, experimental_use_pfor=True)
self.assertAllClose(6., jacobian_pfor)
@test_util.run_all_in_graph_and_eager_modes
class BatchJacobianTest(test.TestCase, parameterized.TestCase):

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
@ -38,8 +39,12 @@ def _DTypeFromTensor(tensor):
and handle_data.is_set
and handle_data.shape_and_type):
first_type = handle_data.shape_and_type[0].dtype
if all(shape_and_type.dtype == first_type
for shape_and_type in handle_data.shape_and_type):
# Some variants have statically unknown dtypes; we can't make inferences
# about trainability, so we conservatively assume they're trainable
# (which may waste memory passing zeros around, but will be correct).
if (first_type != types_pb2.DT_INVALID
and all(shape_and_type.dtype == first_type
for shape_and_type in handle_data.shape_and_type)):
return first_type
return dtype

View File

@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
@ -1296,6 +1297,24 @@ class CondV2Test(test.TestCase):
i = f(constant_op.constant(False))
self.assertEqual(self.evaluate(i), 2.0)
def testGradientOfMixedOptionals(self):
@def_function.function
def f(c):
x = constant_op.constant(1., name="x")
def then_branch():
return x ** 2., gen_dataset_ops.optional_from_value(
[constant_op.constant(1)])
def else_branch():
return x ** 3., gen_dataset_ops.optional_from_value(
[constant_op.constant(1.)])
y, _ = cond_v2.cond_v2(c, then_branch, else_branch)
return gradients_impl.gradients(y, x)
self.assertAllClose([2.], f(constant_op.constant(True)))
class CondV2CollectionTest(test.TestCase):

View File

@ -25,6 +25,7 @@ from __future__ import print_function
import collections
from tensorflow.core.framework import types_pb2
from tensorflow.python.eager import backprop_util
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import auto_control_deps_utils as acd
@ -829,14 +830,22 @@ def _copy_handle_data(external_tensors, *branch_graph_outputs):
internal_handle_data.append(handle_data)
else: # There is handle data, so we need to combine it.
combined_shape = tensor_shape.TensorShape(None)
combined_dtype = None
for handle_data in internal_handle_data:
handle_shape = tensor_shape.TensorShape(
handle_data.shape_and_type[0].shape)
combined_shape = combined_shape.most_specific_compatible_shape(
handle_shape)
if combined_dtype is None:
combined_dtype = handle_data.shape_and_type[0].dtype
elif handle_data.shape_and_type[0].dtype != combined_dtype:
# Variants from different branches have different dtypes. The
# combined variant has no static dtype.
combined_dtype = types_pb2.DT_INVALID
combined_handle_data = internal_handle_data[0]
combined_handle_data.shape_and_type[0].shape.CopyFrom(
combined_shape.as_proto())
combined_handle_data.shape_and_type[0].dtype = combined_dtype
handle_data_util.set_handle_data(external, combined_handle_data)

View File

@ -45,6 +45,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_list_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker_v2
@ -1159,6 +1160,21 @@ class TensorListTest(PForTestCase):
self._test_loop_fn(loop_fn, 2)
class OptionalTest(PForTestCase):
def test_optional_from_value(self):
def loop_fn(i):
o = gen_dataset_ops.optional_from_value(
[i, i + 1, constant_op.constant(3)])
gen_dataset_ops.optional_none()
return gen_dataset_ops.optional_get_value(
o, [dtypes.int32, dtypes.int32, dtypes.int32],
[[], [], []])
self._test_loop_fn(loop_fn, 2)
class StackTest(PForTestCase):
@test_util.run_v1_only("b/122612051")

View File

@ -46,6 +46,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import gen_list_ops
@ -83,22 +84,45 @@ def _variant_handle_data(t):
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
if not handle_data.is_set:
return None
if len(handle_data.shape_and_type) != 1:
raise ValueError("Expected handle data of length 1, got {!r} of length {}"
.format(handle_data, len(handle_data.shape_and_type)))
return handle_data.shape_and_type[0]
return handle_data.shape_and_type
def _is_tensor_list(t):
"""True if `t` is a TensorList, False if it isn't, None if unknown."""
def _is_variant_with_internal_stacking(t):
"""Identifies variant tensors which pfor always maintains as scalars.
For these, the pfor tensor is recorded as "stacked" if the content of the
variant tensor (e.g. the elements of a TensorList) are all stacked.
Args:
t: A tensor to identify.
Returns:
True if `t` is a TensorList/Optional, False not, None if unknown.
"""
if t.dtype != dtypes.variant:
return False
shape_and_type = _variant_handle_data(t)
if shape_and_type is None:
# TODO(b/169968286): Identify all variant tensors (e.g. optionals) and we
# can make this an error instead of assuming TensorLists have handle data.
return None # Presumed not a TensorList
return shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST
shapes_and_types = _variant_handle_data(t)
if shapes_and_types is None or not shapes_and_types:
# TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can
# make this an error instead of assuming TensorLists have handle data.
return None # Presumed not a TensorList/Optional
return (shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST or
shapes_and_types[0].specialized_type == types_pb2.ST_OPTIONAL)
def _parse_variant_shapes_and_types(t):
"""Extracts shape and dtype information from a variant tensor `t`."""
shapes_and_types = _variant_handle_data(t)
if shapes_and_types is None or not shapes_and_types:
raise ValueError("Required handle data not set for {!r}".format(t))
if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST:
return shapes_and_types
else:
if shapes_and_types[0].specialized_type != types_pb2.ST_INVALID:
return shapes_and_types
else:
raise ValueError(
"Attempted to stack a variant-dtype tensor with no type set ({!r})"
.format(t))
def _stack(t, length):
@ -109,23 +133,19 @@ def _stack(t, length):
# suitable since operations on stacked handles may expect a vectorized version
# of the variant.
if t.dtype == dtypes.variant:
shape_and_type = _variant_handle_data(t)
if shape_and_type is None:
raise ValueError("Required handle data not set for {!r}".format(t))
if shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST:
shapes_and_types = _parse_variant_shapes_and_types(t)
if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST:
if len(shapes_and_types) != 1:
raise ValueError(
"Expected handle data of length 1, got {!r} of length {}"
.format(shapes_and_types, len(shapes_and_types)))
return wrap(
_stack_tensor_list(t, shape_and_type.dtype, length),
_stack_tensor_list(t, shapes_and_types[0].dtype, length),
True)
else:
if shape_and_type.specialized_type != types_pb2.ST_INVALID:
raise ValueError(
("Attempted to stack an unhandled variant-dtype tensor of "
"type {!r} ({!r})").format(
shape_and_type.specialized_type, t))
else:
raise ValueError(
"Attempted to stack a variant-dtype tensor with no type set ({!r})"
.format(t))
raise ValueError(
("Attempted to stack an unhandled variant-dtype tensor of "
"type {!r} ({!r})").format(shapes_and_types[0].specialized_type, t))
ones = array_ops.ones_like(array_ops.shape(t))
ones = array_ops.reshape(ones, [-1])
length = array_ops.reshape(length, [-1])
@ -1629,7 +1649,7 @@ class PFor(object):
else:
batch_dim = tensor_shape.TensorShape(loop_len)
output_shape = batch_dim.concatenate(output_shape)
if _is_tensor_list(new_output.t):
if _is_variant_with_internal_stacking(new_output.t):
new_output.t.set_shape([])
else:
new_output.t.set_shape(output_shape)
@ -3602,7 +3622,7 @@ def _stack_tensor_list_shape(shape, first_dim):
def _tile_variant_with_length(t, length):
"""stacks `t` `length` times."""
if _is_tensor_list(t):
if _is_variant_with_internal_stacking(t):
# The content of TensorLists is vectorized, not the variant itself.
return t
original_tensor = t
@ -3622,16 +3642,41 @@ def _tile_variant(t, pfor_input):
def _untile_variant(t):
if _is_tensor_list(t):
if _is_variant_with_internal_stacking(t):
# The content of TensorLists is vectorized, not the variant itself.
if not t.shape.is_compatible_with([]):
raise AssertionError(
"Unexpectedly saw a TensorList with non-scalar shape: {!r}"
.format(t))
("Unexpectedly saw a vectorized variant (e.g. TensorList) with "
"non-scalar shape: {!r}").format(t))
return t
return array_ops.gather(t, 0)
@RegisterPFor("OptionalFromValue")
def _convert_optional_from_value(pfor_input):
pfor_input.stack_inputs()
return wrap(
gen_dataset_ops.optional_from_value([x.t for x in pfor_input.inputs]),
True)
@RegisterPFor("OptionalGetValue")
def _convert_optional_get_value(pfor_input):
handle = pfor_input.stacked_input(0)
output_types = pfor_input.get_attr("output_types")
original_output_shapes = pfor_input.get_attr("output_shapes")
output_shapes = []
for shape in original_output_shapes:
shape = tensor_shape.TensorShape(shape)
loop_len_shape = tensor_shape.TensorShape(
[tensor_util.constant_value(pfor_input.pfor.loop_len_vector)])
shape = loop_len_shape.concatenate(shape)
output_shapes.append(shape.as_proto())
results = gen_dataset_ops.optional_get_value(handle, output_types,
output_shapes)
return [wrap(t, True) for t in results]
@RegisterPFor("TensorListReserve")
def _convert_tensor_list_reserve(pfor_input):
element_shape = pfor_input.unstacked_input(0)
@ -4275,7 +4320,7 @@ class WhileV2(object):
shape = shape.merge_with(output_shapes[i])
pfor_input = self._pfor_input.input(i)
if pfor_input.is_stacked:
if _is_tensor_list(pfor_input.t):
if _is_variant_with_internal_stacking(pfor_input.t):
shape = tensor_shape.TensorShape([]).concatenate(shape)
else:
shape = tensor_shape.TensorShape([None]).concatenate(shape)