Add pfor converters and basic type information to optionals
Like TensorLists, pfor vectorizes the content rather than the variant-dtype tensor itself. Adding type information to optionals requires a slight tangent to have cond combine it correctly across branches (since cond can pack any two tensors together in one optional return). Fixes #44502. PiperOrigin-RevId: 343304828 Change-Id: I2817dee286e452c5b83f232cade67260e69d23b6
This commit is contained in:
parent
05dc9e104b
commit
4f2a50496d
@ -84,4 +84,6 @@ enum SpecializedType {
|
||||
ST_INVALID = 0;
|
||||
// "tensorflow::TensorList" in the variant type registry.
|
||||
ST_TENSOR_LIST = 1;
|
||||
// "tensorflow::data::Optional" in the variant type registry.
|
||||
ST_OPTIONAL = 2;
|
||||
}
|
||||
|
@ -853,7 +853,18 @@ REGISTER_OP("OptionalFromValue")
|
||||
.Input("components: Toutput_types")
|
||||
.Output("optional: variant")
|
||||
.Attr("Toutput_types: list(type) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
std::vector<DataType> dtypes;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("Toutput_types", &dtypes));
|
||||
c->set_output(0, c->Scalar());
|
||||
std::vector<shape_inference::ShapeAndType> shapes_and_types;
|
||||
shapes_and_types.reserve(c->num_inputs());
|
||||
for (int i = 0; i < c->num_inputs(); ++i) {
|
||||
shapes_and_types.emplace_back(c->input(i), dtypes[i], ST_OPTIONAL);
|
||||
}
|
||||
c->set_output_handle_shapes_and_types(0, shapes_and_types);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("OptionalNone")
|
||||
.Output("optional: variant")
|
||||
|
@ -1844,6 +1844,24 @@ class JacobianTest(test.TestCase):
|
||||
self.assertAllClose(compute_jacobian(use_pfor=True),
|
||||
compute_jacobian(use_pfor=False))
|
||||
|
||||
def test_cond_func_grad_jacobian(self):
|
||||
|
||||
@def_function.function
|
||||
def f(x):
|
||||
y = control_flow_ops.cond(x > 0., lambda: x**3., lambda: x**2.)
|
||||
return y
|
||||
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
x = constant_op.constant(1.)
|
||||
tape.watch(x)
|
||||
y = f(x)
|
||||
grad = tape.gradient(y, x)
|
||||
self.assertAllClose(3., grad)
|
||||
jacobian = tape.jacobian(grad, x, experimental_use_pfor=False)
|
||||
self.assertAllClose(6., jacobian)
|
||||
jacobian_pfor = tape.jacobian(grad, x, experimental_use_pfor=True)
|
||||
self.assertAllClose(6., jacobian_pfor)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class BatchJacobianTest(test.TestCase, parameterized.TestCase):
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
@ -38,8 +39,12 @@ def _DTypeFromTensor(tensor):
|
||||
and handle_data.is_set
|
||||
and handle_data.shape_and_type):
|
||||
first_type = handle_data.shape_and_type[0].dtype
|
||||
if all(shape_and_type.dtype == first_type
|
||||
for shape_and_type in handle_data.shape_and_type):
|
||||
# Some variants have statically unknown dtypes; we can't make inferences
|
||||
# about trainability, so we conservatively assume they're trainable
|
||||
# (which may waste memory passing zeros around, but will be correct).
|
||||
if (first_type != types_pb2.DT_INVALID
|
||||
and all(shape_and_type.dtype == first_type
|
||||
for shape_and_type in handle_data.shape_and_type)):
|
||||
return first_type
|
||||
return dtype
|
||||
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import cond_v2
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -1296,6 +1297,24 @@ class CondV2Test(test.TestCase):
|
||||
i = f(constant_op.constant(False))
|
||||
self.assertEqual(self.evaluate(i), 2.0)
|
||||
|
||||
def testGradientOfMixedOptionals(self):
|
||||
|
||||
@def_function.function
|
||||
def f(c):
|
||||
x = constant_op.constant(1., name="x")
|
||||
|
||||
def then_branch():
|
||||
return x ** 2., gen_dataset_ops.optional_from_value(
|
||||
[constant_op.constant(1)])
|
||||
|
||||
def else_branch():
|
||||
return x ** 3., gen_dataset_ops.optional_from_value(
|
||||
[constant_op.constant(1.)])
|
||||
|
||||
y, _ = cond_v2.cond_v2(c, then_branch, else_branch)
|
||||
return gradients_impl.gradients(y, x)
|
||||
self.assertAllClose([2.], f(constant_op.constant(True)))
|
||||
|
||||
|
||||
class CondV2CollectionTest(test.TestCase):
|
||||
|
||||
|
@ -25,6 +25,7 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.framework import auto_control_deps
|
||||
from tensorflow.python.framework import auto_control_deps_utils as acd
|
||||
@ -829,14 +830,22 @@ def _copy_handle_data(external_tensors, *branch_graph_outputs):
|
||||
internal_handle_data.append(handle_data)
|
||||
else: # There is handle data, so we need to combine it.
|
||||
combined_shape = tensor_shape.TensorShape(None)
|
||||
combined_dtype = None
|
||||
for handle_data in internal_handle_data:
|
||||
handle_shape = tensor_shape.TensorShape(
|
||||
handle_data.shape_and_type[0].shape)
|
||||
combined_shape = combined_shape.most_specific_compatible_shape(
|
||||
handle_shape)
|
||||
if combined_dtype is None:
|
||||
combined_dtype = handle_data.shape_and_type[0].dtype
|
||||
elif handle_data.shape_and_type[0].dtype != combined_dtype:
|
||||
# Variants from different branches have different dtypes. The
|
||||
# combined variant has no static dtype.
|
||||
combined_dtype = types_pb2.DT_INVALID
|
||||
combined_handle_data = internal_handle_data[0]
|
||||
combined_handle_data.shape_and_type[0].shape.CopyFrom(
|
||||
combined_shape.as_proto())
|
||||
combined_handle_data.shape_and_type[0].dtype = combined_dtype
|
||||
handle_data_util.set_handle_data(external, combined_handle_data)
|
||||
|
||||
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import gen_list_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
@ -1159,6 +1160,21 @@ class TensorListTest(PForTestCase):
|
||||
self._test_loop_fn(loop_fn, 2)
|
||||
|
||||
|
||||
class OptionalTest(PForTestCase):
|
||||
|
||||
def test_optional_from_value(self):
|
||||
|
||||
def loop_fn(i):
|
||||
o = gen_dataset_ops.optional_from_value(
|
||||
[i, i + 1, constant_op.constant(3)])
|
||||
gen_dataset_ops.optional_none()
|
||||
return gen_dataset_ops.optional_get_value(
|
||||
o, [dtypes.int32, dtypes.int32, dtypes.int32],
|
||||
[[], [], []])
|
||||
|
||||
self._test_loop_fn(loop_fn, 2)
|
||||
|
||||
|
||||
class StackTest(PForTestCase):
|
||||
|
||||
@test_util.run_v1_only("b/122612051")
|
||||
|
@ -46,6 +46,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import gen_image_ops
|
||||
from tensorflow.python.ops import gen_linalg_ops
|
||||
from tensorflow.python.ops import gen_list_ops
|
||||
@ -83,22 +84,45 @@ def _variant_handle_data(t):
|
||||
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
|
||||
if not handle_data.is_set:
|
||||
return None
|
||||
if len(handle_data.shape_and_type) != 1:
|
||||
raise ValueError("Expected handle data of length 1, got {!r} of length {}"
|
||||
.format(handle_data, len(handle_data.shape_and_type)))
|
||||
return handle_data.shape_and_type[0]
|
||||
return handle_data.shape_and_type
|
||||
|
||||
|
||||
def _is_tensor_list(t):
|
||||
"""True if `t` is a TensorList, False if it isn't, None if unknown."""
|
||||
def _is_variant_with_internal_stacking(t):
|
||||
"""Identifies variant tensors which pfor always maintains as scalars.
|
||||
|
||||
For these, the pfor tensor is recorded as "stacked" if the content of the
|
||||
variant tensor (e.g. the elements of a TensorList) are all stacked.
|
||||
|
||||
Args:
|
||||
t: A tensor to identify.
|
||||
Returns:
|
||||
True if `t` is a TensorList/Optional, False not, None if unknown.
|
||||
"""
|
||||
if t.dtype != dtypes.variant:
|
||||
return False
|
||||
shape_and_type = _variant_handle_data(t)
|
||||
if shape_and_type is None:
|
||||
# TODO(b/169968286): Identify all variant tensors (e.g. optionals) and we
|
||||
# can make this an error instead of assuming TensorLists have handle data.
|
||||
return None # Presumed not a TensorList
|
||||
return shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST
|
||||
shapes_and_types = _variant_handle_data(t)
|
||||
if shapes_and_types is None or not shapes_and_types:
|
||||
# TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can
|
||||
# make this an error instead of assuming TensorLists have handle data.
|
||||
return None # Presumed not a TensorList/Optional
|
||||
return (shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST or
|
||||
shapes_and_types[0].specialized_type == types_pb2.ST_OPTIONAL)
|
||||
|
||||
|
||||
def _parse_variant_shapes_and_types(t):
|
||||
"""Extracts shape and dtype information from a variant tensor `t`."""
|
||||
shapes_and_types = _variant_handle_data(t)
|
||||
if shapes_and_types is None or not shapes_and_types:
|
||||
raise ValueError("Required handle data not set for {!r}".format(t))
|
||||
if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST:
|
||||
return shapes_and_types
|
||||
else:
|
||||
if shapes_and_types[0].specialized_type != types_pb2.ST_INVALID:
|
||||
return shapes_and_types
|
||||
else:
|
||||
raise ValueError(
|
||||
"Attempted to stack a variant-dtype tensor with no type set ({!r})"
|
||||
.format(t))
|
||||
|
||||
|
||||
def _stack(t, length):
|
||||
@ -109,23 +133,19 @@ def _stack(t, length):
|
||||
# suitable since operations on stacked handles may expect a vectorized version
|
||||
# of the variant.
|
||||
if t.dtype == dtypes.variant:
|
||||
shape_and_type = _variant_handle_data(t)
|
||||
if shape_and_type is None:
|
||||
raise ValueError("Required handle data not set for {!r}".format(t))
|
||||
if shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST:
|
||||
shapes_and_types = _parse_variant_shapes_and_types(t)
|
||||
if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST:
|
||||
if len(shapes_and_types) != 1:
|
||||
raise ValueError(
|
||||
"Expected handle data of length 1, got {!r} of length {}"
|
||||
.format(shapes_and_types, len(shapes_and_types)))
|
||||
return wrap(
|
||||
_stack_tensor_list(t, shape_and_type.dtype, length),
|
||||
_stack_tensor_list(t, shapes_and_types[0].dtype, length),
|
||||
True)
|
||||
else:
|
||||
if shape_and_type.specialized_type != types_pb2.ST_INVALID:
|
||||
raise ValueError(
|
||||
("Attempted to stack an unhandled variant-dtype tensor of "
|
||||
"type {!r} ({!r})").format(
|
||||
shape_and_type.specialized_type, t))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Attempted to stack a variant-dtype tensor with no type set ({!r})"
|
||||
.format(t))
|
||||
raise ValueError(
|
||||
("Attempted to stack an unhandled variant-dtype tensor of "
|
||||
"type {!r} ({!r})").format(shapes_and_types[0].specialized_type, t))
|
||||
ones = array_ops.ones_like(array_ops.shape(t))
|
||||
ones = array_ops.reshape(ones, [-1])
|
||||
length = array_ops.reshape(length, [-1])
|
||||
@ -1629,7 +1649,7 @@ class PFor(object):
|
||||
else:
|
||||
batch_dim = tensor_shape.TensorShape(loop_len)
|
||||
output_shape = batch_dim.concatenate(output_shape)
|
||||
if _is_tensor_list(new_output.t):
|
||||
if _is_variant_with_internal_stacking(new_output.t):
|
||||
new_output.t.set_shape([])
|
||||
else:
|
||||
new_output.t.set_shape(output_shape)
|
||||
@ -3602,7 +3622,7 @@ def _stack_tensor_list_shape(shape, first_dim):
|
||||
|
||||
def _tile_variant_with_length(t, length):
|
||||
"""stacks `t` `length` times."""
|
||||
if _is_tensor_list(t):
|
||||
if _is_variant_with_internal_stacking(t):
|
||||
# The content of TensorLists is vectorized, not the variant itself.
|
||||
return t
|
||||
original_tensor = t
|
||||
@ -3622,16 +3642,41 @@ def _tile_variant(t, pfor_input):
|
||||
|
||||
|
||||
def _untile_variant(t):
|
||||
if _is_tensor_list(t):
|
||||
if _is_variant_with_internal_stacking(t):
|
||||
# The content of TensorLists is vectorized, not the variant itself.
|
||||
if not t.shape.is_compatible_with([]):
|
||||
raise AssertionError(
|
||||
"Unexpectedly saw a TensorList with non-scalar shape: {!r}"
|
||||
.format(t))
|
||||
("Unexpectedly saw a vectorized variant (e.g. TensorList) with "
|
||||
"non-scalar shape: {!r}").format(t))
|
||||
return t
|
||||
return array_ops.gather(t, 0)
|
||||
|
||||
|
||||
@RegisterPFor("OptionalFromValue")
|
||||
def _convert_optional_from_value(pfor_input):
|
||||
pfor_input.stack_inputs()
|
||||
return wrap(
|
||||
gen_dataset_ops.optional_from_value([x.t for x in pfor_input.inputs]),
|
||||
True)
|
||||
|
||||
|
||||
@RegisterPFor("OptionalGetValue")
|
||||
def _convert_optional_get_value(pfor_input):
|
||||
handle = pfor_input.stacked_input(0)
|
||||
output_types = pfor_input.get_attr("output_types")
|
||||
original_output_shapes = pfor_input.get_attr("output_shapes")
|
||||
output_shapes = []
|
||||
for shape in original_output_shapes:
|
||||
shape = tensor_shape.TensorShape(shape)
|
||||
loop_len_shape = tensor_shape.TensorShape(
|
||||
[tensor_util.constant_value(pfor_input.pfor.loop_len_vector)])
|
||||
shape = loop_len_shape.concatenate(shape)
|
||||
output_shapes.append(shape.as_proto())
|
||||
results = gen_dataset_ops.optional_get_value(handle, output_types,
|
||||
output_shapes)
|
||||
return [wrap(t, True) for t in results]
|
||||
|
||||
|
||||
@RegisterPFor("TensorListReserve")
|
||||
def _convert_tensor_list_reserve(pfor_input):
|
||||
element_shape = pfor_input.unstacked_input(0)
|
||||
@ -4275,7 +4320,7 @@ class WhileV2(object):
|
||||
shape = shape.merge_with(output_shapes[i])
|
||||
pfor_input = self._pfor_input.input(i)
|
||||
if pfor_input.is_stacked:
|
||||
if _is_tensor_list(pfor_input.t):
|
||||
if _is_variant_with_internal_stacking(pfor_input.t):
|
||||
shape = tensor_shape.TensorShape([]).concatenate(shape)
|
||||
else:
|
||||
shape = tensor_shape.TensorShape([None]).concatenate(shape)
|
||||
|
Loading…
Reference in New Issue
Block a user