Fix bug where shape information that was supposed to be masked was visible inside the body of a tf.function for composite tensors such as RaggedTensor and SparseTensor.

PiperOrigin-RevId: 273735181
This commit is contained in:
Edward Loper 2019-10-09 06:38:19 -07:00 committed by TensorFlower Gardener
parent 79d23cb293
commit 8ef501423b
5 changed files with 36 additions and 4 deletions

View File

@ -764,7 +764,7 @@ class IteratorSpec(type_spec.TypeSpec):
def _component_specs(self):
return (
tensor_spec.TensorSpec([], dtypes.resource),
tensor_spec.TensorSpec([], dtypes.scalar),
tensor_spec.TensorSpec([], dtypes.variant),
)
def _to_components(self, value):

View File

@ -470,7 +470,8 @@ class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
@property
def _type_spec(self):
value_specs = [type_spec.type_spec_from_value(v) for v in self._values]
value_specs = nest.map_structure(type_spec.type_spec_from_value,
self._values)
return PerReplicaSpec(value_specs, self._device_map, self._logical_device)

View File

@ -1153,6 +1153,23 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
for (input_component, output_component) in zip(input_flat, output_flat):
self.assertAllEqual(input_component, output_component)
def testTracedCompositeDiscardsShapeInfo(self):
# SparseTensorSpec intentionally excludes info about the number of elements
# that are in a sparse tensor (which is recorded as st.indices.shape[0] and
# st.values.shape[0]). Similarly, RaggedTensorSpec intentionally excludes
# info about the total number of values in a RaggedTensor (stored as
# rt.values.shape[0]). This test checks that the placeholders created by
# tf.function() properly mask this shape info.
@def_function.function
def f(rt, st):
self.assertEqual(st.indices.shape.as_list()[:1], [None])
self.assertEqual(st.values.shape.as_list(), [None])
return (rt, st)
rt = ragged_factory_ops.constant([[1, 2], [3]])
st = sparse_tensor.SparseTensor([[0]], [0], [10])
f(rt, st)
@test_util.run_gpu_only
def testFunctionOnDevice(self):
x = constant_op.constant([1.]).gpu()

View File

@ -1083,6 +1083,12 @@ def _get_defun_inputs_from_args(args, names, flat_shapes=None):
args, names, structure=args, flat_shapes=flat_shapes)
def _get_composite_tensor_spec(x):
"""Returns the TypeSpec for x if it's a composite tensor, or x otherwise."""
return (x._type_spec # pylint: disable=protected-access
if isinstance(x, composite_tensor.CompositeTensor) else x)
def _get_defun_inputs(args, names, structure, flat_shapes=None):
"""Maps python function args to graph-construction inputs.
@ -1126,6 +1132,12 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None):
flat_shapes))
shapes_iter = iter(flat_shapes)
for arg_value, name in zip(args, names):
# Replace any composite tensors with their TypeSpecs. This is important
# for ensuring that shape information that's not preserved by the TypeSpec
# (such as the number of values in a SparseTensor) gets properly masked.
arg_value = nest.map_structure(_get_composite_tensor_spec, arg_value)
flattened = nest.flatten(arg_value, expand_composites=True)
tensor_specs = [
arg for arg in flattened if isinstance(arg, tensor_spec.TensorSpec)

View File

@ -324,6 +324,8 @@ class TypeSpec(object):
])
if isinstance(value, tuple):
return tuple([self.__make_cmp_key(v) for v in value])
if isinstance(value, list):
return (list, tuple([self.__make_cmp_key(v) for v in value]))
if isinstance(value, tensor_shape.TensorShape):
if value.ndims is None:
# Note: we include a type object in the tuple, to ensure we can't get
@ -349,7 +351,7 @@ class TypeSpec(object):
"""Returns true if the given type serializations compatible."""
if type(a) is not type(b):
return False
if isinstance(a, tuple):
if isinstance(a, (list, tuple)):
return (len(a) == len(b) and
all(TypeSpec.__is_compatible(x, y) for (x, y) in zip(a, b)))
if isinstance(a, dict):
@ -390,7 +392,7 @@ class TypeSpec(object):
"""
if type(a) is not type(b):
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
if isinstance(a, tuple):
if isinstance(a, (list, tuple)):
if len(a) != len(b):
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y)