Fix bug where shape information that was supposed to be masked was visible inside the body of a tf.function for composite tensors such as RaggedTensor and SparseTensor.
PiperOrigin-RevId: 273735181
This commit is contained in:
parent
79d23cb293
commit
8ef501423b
@ -764,7 +764,7 @@ class IteratorSpec(type_spec.TypeSpec):
|
||||
def _component_specs(self):
|
||||
return (
|
||||
tensor_spec.TensorSpec([], dtypes.resource),
|
||||
tensor_spec.TensorSpec([], dtypes.scalar),
|
||||
tensor_spec.TensorSpec([], dtypes.variant),
|
||||
)
|
||||
|
||||
def _to_components(self, value):
|
||||
|
@ -470,7 +470,8 @@ class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
value_specs = [type_spec.type_spec_from_value(v) for v in self._values]
|
||||
value_specs = nest.map_structure(type_spec.type_spec_from_value,
|
||||
self._values)
|
||||
return PerReplicaSpec(value_specs, self._device_map, self._logical_device)
|
||||
|
||||
|
||||
|
@ -1153,6 +1153,23 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
for (input_component, output_component) in zip(input_flat, output_flat):
|
||||
self.assertAllEqual(input_component, output_component)
|
||||
|
||||
def testTracedCompositeDiscardsShapeInfo(self):
|
||||
# SparseTensorSpec intentionally excludes info about the number of elements
|
||||
# that are in a sparse tensor (which is recorded as st.indices.shape[0] and
|
||||
# st.values.shape[0]). Similarly, RaggedTensorSpec intentionally excludes
|
||||
# info about the total number of values in a RaggedTensor (stored as
|
||||
# rt.values.shape[0]). This test checks that the placeholders created by
|
||||
# tf.function() properly mask this shape info.
|
||||
@def_function.function
|
||||
def f(rt, st):
|
||||
self.assertEqual(st.indices.shape.as_list()[:1], [None])
|
||||
self.assertEqual(st.values.shape.as_list(), [None])
|
||||
return (rt, st)
|
||||
|
||||
rt = ragged_factory_ops.constant([[1, 2], [3]])
|
||||
st = sparse_tensor.SparseTensor([[0]], [0], [10])
|
||||
f(rt, st)
|
||||
|
||||
@test_util.run_gpu_only
|
||||
def testFunctionOnDevice(self):
|
||||
x = constant_op.constant([1.]).gpu()
|
||||
|
@ -1083,6 +1083,12 @@ def _get_defun_inputs_from_args(args, names, flat_shapes=None):
|
||||
args, names, structure=args, flat_shapes=flat_shapes)
|
||||
|
||||
|
||||
def _get_composite_tensor_spec(x):
|
||||
"""Returns the TypeSpec for x if it's a composite tensor, or x otherwise."""
|
||||
return (x._type_spec # pylint: disable=protected-access
|
||||
if isinstance(x, composite_tensor.CompositeTensor) else x)
|
||||
|
||||
|
||||
def _get_defun_inputs(args, names, structure, flat_shapes=None):
|
||||
"""Maps python function args to graph-construction inputs.
|
||||
|
||||
@ -1126,6 +1132,12 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None):
|
||||
flat_shapes))
|
||||
shapes_iter = iter(flat_shapes)
|
||||
for arg_value, name in zip(args, names):
|
||||
|
||||
# Replace any composite tensors with their TypeSpecs. This is important
|
||||
# for ensuring that shape information that's not preserved by the TypeSpec
|
||||
# (such as the number of values in a SparseTensor) gets properly masked.
|
||||
arg_value = nest.map_structure(_get_composite_tensor_spec, arg_value)
|
||||
|
||||
flattened = nest.flatten(arg_value, expand_composites=True)
|
||||
tensor_specs = [
|
||||
arg for arg in flattened if isinstance(arg, tensor_spec.TensorSpec)
|
||||
|
@ -324,6 +324,8 @@ class TypeSpec(object):
|
||||
])
|
||||
if isinstance(value, tuple):
|
||||
return tuple([self.__make_cmp_key(v) for v in value])
|
||||
if isinstance(value, list):
|
||||
return (list, tuple([self.__make_cmp_key(v) for v in value]))
|
||||
if isinstance(value, tensor_shape.TensorShape):
|
||||
if value.ndims is None:
|
||||
# Note: we include a type object in the tuple, to ensure we can't get
|
||||
@ -349,7 +351,7 @@ class TypeSpec(object):
|
||||
"""Returns true if the given type serializations compatible."""
|
||||
if type(a) is not type(b):
|
||||
return False
|
||||
if isinstance(a, tuple):
|
||||
if isinstance(a, (list, tuple)):
|
||||
return (len(a) == len(b) and
|
||||
all(TypeSpec.__is_compatible(x, y) for (x, y) in zip(a, b)))
|
||||
if isinstance(a, dict):
|
||||
@ -390,7 +392,7 @@ class TypeSpec(object):
|
||||
"""
|
||||
if type(a) is not type(b):
|
||||
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
|
||||
if isinstance(a, tuple):
|
||||
if isinstance(a, (list, tuple)):
|
||||
if len(a) != len(b):
|
||||
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
|
||||
return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y)
|
||||
|
Loading…
Reference in New Issue
Block a user