Fix tf.function's shape relaxation check.
TensorShape.is_compatible_with is symmetric; we need an asymmetric test here. PiperOrigin-RevId: 236390944
This commit is contained in:
parent
05ccaff25c
commit
4d100c1d6b
@ -84,47 +84,80 @@ def _flat_shape_list(*params):
|
||||
for x in nest.flatten(params)]
|
||||
|
||||
|
||||
def _compatible_shapes(flat_x, flat_y):
|
||||
"""Check if lists of TensorShapes contain compatible shapes.
|
||||
def _shape_less_specific_than(relaxed, to_check):
|
||||
"""Checks if `relaxed` is less specific than `to_check`.
|
||||
|
||||
This is an asymmetric check, unlike `TensorShape.is_compatible_with`. If
|
||||
`to_check` has a dimension with an undefined shape, `relaxed` must also have
|
||||
an undefined shape for that dimension.
|
||||
|
||||
Args:
|
||||
flat_x: List of TensorShape or None.
|
||||
flat_y: List of TensorShape or None.
|
||||
relaxed: A `TensorShape` to check against.
|
||||
to_check: A second `TensorShape`.
|
||||
|
||||
Returns:
|
||||
True if `to_check` represents a set of shapes which is a subset of
|
||||
`relaxed`'s shapes and False otherwise.
|
||||
"""
|
||||
if to_check.dims is not None and relaxed.dims is not None:
|
||||
if to_check.rank != relaxed.rank:
|
||||
return False
|
||||
for check_dim, relaxed_dim in zip(to_check.dims, relaxed.dims):
|
||||
if check_dim.value is None and relaxed_dim.value is not None:
|
||||
return False
|
||||
if not relaxed_dim.is_compatible_with(check_dim):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _compatible_shapes(flat_relaxed, flat_to_check):
|
||||
"""Check if lists of TensorShapes contain compatible shapes.
|
||||
|
||||
Checks that each `flat_relaxed` shape covers a superset of the shapes of the
|
||||
corresponding `flat_to_check` shape.
|
||||
|
||||
Args:
|
||||
flat_relaxed: List of TensorShape or None.
|
||||
flat_to_check: List of TensorShape or None.
|
||||
|
||||
Returns:
|
||||
A python bool.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if `len(flat_x) != len(flat_y)`.
|
||||
RuntimeError: if `flat_x[i] is None != flat_y[i] is None` for any `i`.
|
||||
RuntimeError:
|
||||
if `len(flat_relaxed) != len(flat_to_check)`.
|
||||
RuntimeError:
|
||||
if `flat_relaxed[i] is None != flat_to_check[i] is None` for any `i`.
|
||||
"""
|
||||
if len(flat_x) != len(flat_y):
|
||||
|
||||
if len(flat_relaxed) != len(flat_to_check):
|
||||
raise RuntimeError("Expected shape lists of identical lengths, but saw: "
|
||||
"%s and %s" % (flat_x, flat_y))
|
||||
def is_compatible(x, y):
|
||||
"%s and %s" % (flat_relaxed, flat_to_check))
|
||||
def is_compatible(relaxed, to_check):
|
||||
"""Internal help function.
|
||||
|
||||
Args:
|
||||
x: TensorShape or None.
|
||||
y: TensorShape or None.
|
||||
relaxed: TensorShape or None.
|
||||
to_check: TensorShape or None.
|
||||
|
||||
Returns:
|
||||
Python bool.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `x is None != y is None`.
|
||||
RuntimeError: If `relaxed is None != to_check is None`.
|
||||
"""
|
||||
# If both x and y are None, there is no shape to compare. Otherwise check
|
||||
# if they are compatible with each other. Either way, both input signatures
|
||||
# must have have Tensors in the same entries. If not, raise an assertion
|
||||
# error.
|
||||
if x is None != y is None:
|
||||
if relaxed is None != to_check is None:
|
||||
raise RuntimeError(
|
||||
"Expected signature type matches between flattened input shapes "
|
||||
"%s and %s; but saw that (%s is None) != (%s is None)"
|
||||
% (flat_x, flat_y, x, y))
|
||||
return x is None or x.is_compatible_with(y)
|
||||
return all(is_compatible(x, y) for x, y in zip(flat_x, flat_y))
|
||||
% (flat_relaxed, flat_to_check, relaxed, to_check))
|
||||
return relaxed is None or _shape_less_specific_than(relaxed, to_check)
|
||||
return all(is_compatible(relaxed, to_check)
|
||||
for relaxed, to_check in zip(flat_relaxed, flat_to_check))
|
||||
|
||||
|
||||
def _common_shape(x, y):
|
||||
@ -1558,7 +1591,8 @@ class Function(object):
|
||||
rank_only_cache_key, None)
|
||||
|
||||
if (relaxed_arg_function is not None
|
||||
and _compatible_shapes(relaxed_arg_shapes, arg_shapes)):
|
||||
and _compatible_shapes(flat_relaxed=relaxed_arg_shapes,
|
||||
flat_to_check=arg_shapes)):
|
||||
return relaxed_arg_function, args, kwargs
|
||||
|
||||
if relaxed_arg_shapes is None:
|
||||
|
@ -2535,6 +2535,24 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
x = func()
|
||||
self.assertRegexpMatches(x.device, 'GPU')
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testShapeCaching(self):
|
||||
|
||||
@function.defun
|
||||
def func(x):
|
||||
return array_ops.shape(x)
|
||||
|
||||
@function.defun(
|
||||
input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)])
|
||||
def calls_func(x):
|
||||
return func(x)
|
||||
|
||||
self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1]))))
|
||||
self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2]))))
|
||||
self.assertAllEqual(
|
||||
[3, 3],
|
||||
self.evaluate(calls_func(array_ops.zeros([3, 3]))))
|
||||
|
||||
|
||||
class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user