Fix tf.function's shape relaxation check.

TensorShape.is_compatible_with is symmetric; we need an asymmetric test here.

PiperOrigin-RevId: 236390944
This commit is contained in:
Allen Lavoie 2019-03-01 15:37:09 -08:00 committed by TensorFlower Gardener
parent 05ccaff25c
commit 4d100c1d6b
2 changed files with 69 additions and 17 deletions

View File

@ -84,47 +84,80 @@ def _flat_shape_list(*params):
for x in nest.flatten(params)] for x in nest.flatten(params)]
def _compatible_shapes(flat_x, flat_y): def _shape_less_specific_than(relaxed, to_check):
"""Check if lists of TensorShapes contain compatible shapes. """Checks if `relaxed` is less specific than `to_check`.
This is an asymmetric check, unlike `TensorShape.is_compatible_with`. If
`to_check` has a dimension with an undefined shape, `relaxed` must also have
an undefined shape for that dimension.
Args: Args:
flat_x: List of TensorShape or None. relaxed: A `TensorShape` to check against.
flat_y: List of TensorShape or None. to_check: A second `TensorShape`.
Returns:
True if `to_check` represents a set of shapes which is a subset of
`relaxed`'s shapes and False otherwise.
"""
if to_check.dims is not None and relaxed.dims is not None:
if to_check.rank != relaxed.rank:
return False
for check_dim, relaxed_dim in zip(to_check.dims, relaxed.dims):
if check_dim.value is None and relaxed_dim.value is not None:
return False
if not relaxed_dim.is_compatible_with(check_dim):
return False
return True
def _compatible_shapes(flat_relaxed, flat_to_check):
"""Check if lists of TensorShapes contain compatible shapes.
Checks that each `flat_relaxed` shape covers a superset of the shapes of the
corresponding `flat_to_check` shape.
Args:
flat_relaxed: List of TensorShape or None.
flat_to_check: List of TensorShape or None.
Returns: Returns:
A python bool. A python bool.
Raises: Raises:
RuntimeError: if `len(flat_x) != len(flat_y)`. RuntimeError:
RuntimeError: if `flat_x[i] is None != flat_y[i] is None` for any `i`. if `len(flat_relaxed) != len(flat_to_check)`.
RuntimeError:
if `flat_relaxed[i] is None != flat_to_check[i] is None` for any `i`.
""" """
if len(flat_x) != len(flat_y):
if len(flat_relaxed) != len(flat_to_check):
raise RuntimeError("Expected shape lists of identical lengths, but saw: " raise RuntimeError("Expected shape lists of identical lengths, but saw: "
"%s and %s" % (flat_x, flat_y)) "%s and %s" % (flat_relaxed, flat_to_check))
def is_compatible(x, y): def is_compatible(relaxed, to_check):
"""Internal help function. """Internal help function.
Args: Args:
x: TensorShape or None. relaxed: TensorShape or None.
y: TensorShape or None. to_check: TensorShape or None.
Returns: Returns:
Python bool. Python bool.
Raises: Raises:
RuntimeError: If `x is None != y is None`. RuntimeError: If `relaxed is None != to_check is None`.
""" """
# If both x and y are None, there is no shape to compare. Otherwise check # If both x and y are None, there is no shape to compare. Otherwise check
# if they are compatible with each other. Either way, both input signatures # if they are compatible with each other. Either way, both input signatures
# must have have Tensors in the same entries. If not, raise an assertion # must have have Tensors in the same entries. If not, raise an assertion
# error. # error.
if x is None != y is None: if relaxed is None != to_check is None:
raise RuntimeError( raise RuntimeError(
"Expected signature type matches between flattened input shapes " "Expected signature type matches between flattened input shapes "
"%s and %s; but saw that (%s is None) != (%s is None)" "%s and %s; but saw that (%s is None) != (%s is None)"
% (flat_x, flat_y, x, y)) % (flat_relaxed, flat_to_check, relaxed, to_check))
return x is None or x.is_compatible_with(y) return relaxed is None or _shape_less_specific_than(relaxed, to_check)
return all(is_compatible(x, y) for x, y in zip(flat_x, flat_y)) return all(is_compatible(relaxed, to_check)
for relaxed, to_check in zip(flat_relaxed, flat_to_check))
def _common_shape(x, y): def _common_shape(x, y):
@ -1558,7 +1591,8 @@ class Function(object):
rank_only_cache_key, None) rank_only_cache_key, None)
if (relaxed_arg_function is not None if (relaxed_arg_function is not None
and _compatible_shapes(relaxed_arg_shapes, arg_shapes)): and _compatible_shapes(flat_relaxed=relaxed_arg_shapes,
flat_to_check=arg_shapes)):
return relaxed_arg_function, args, kwargs return relaxed_arg_function, args, kwargs
if relaxed_arg_shapes is None: if relaxed_arg_shapes is None:

View File

@ -2535,6 +2535,24 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
x = func() x = func()
self.assertRegexpMatches(x.device, 'GPU') self.assertRegexpMatches(x.device, 'GPU')
@test_util.run_in_graph_and_eager_modes
def testShapeCaching(self):
@function.defun
def func(x):
return array_ops.shape(x)
@function.defun(
input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)])
def calls_func(x):
return func(x)
self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1]))))
self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2]))))
self.assertAllEqual(
[3, 3],
self.evaluate(calls_func(array_ops.zeros([3, 3]))))
class MultiDeviceTest(test.TestCase, parameterized.TestCase): class MultiDeviceTest(test.TestCase, parameterized.TestCase):