Fix tf.function's shape relaxation check.
TensorShape.is_compatible_with is symmetric; we need an asymmetric test here. PiperOrigin-RevId: 236390944
This commit is contained in:
parent
05ccaff25c
commit
4d100c1d6b
@ -84,47 +84,80 @@ def _flat_shape_list(*params):
|
|||||||
for x in nest.flatten(params)]
|
for x in nest.flatten(params)]
|
||||||
|
|
||||||
|
|
||||||
def _compatible_shapes(flat_x, flat_y):
|
def _shape_less_specific_than(relaxed, to_check):
|
||||||
"""Check if lists of TensorShapes contain compatible shapes.
|
"""Checks if `relaxed` is less specific than `to_check`.
|
||||||
|
|
||||||
|
This is an asymmetric check, unlike `TensorShape.is_compatible_with`. If
|
||||||
|
`to_check` has a dimension with an undefined shape, `relaxed` must also have
|
||||||
|
an undefined shape for that dimension.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
flat_x: List of TensorShape or None.
|
relaxed: A `TensorShape` to check against.
|
||||||
flat_y: List of TensorShape or None.
|
to_check: A second `TensorShape`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if `to_check` represents a set of shapes which is a subset of
|
||||||
|
`relaxed`'s shapes and False otherwise.
|
||||||
|
"""
|
||||||
|
if to_check.dims is not None and relaxed.dims is not None:
|
||||||
|
if to_check.rank != relaxed.rank:
|
||||||
|
return False
|
||||||
|
for check_dim, relaxed_dim in zip(to_check.dims, relaxed.dims):
|
||||||
|
if check_dim.value is None and relaxed_dim.value is not None:
|
||||||
|
return False
|
||||||
|
if not relaxed_dim.is_compatible_with(check_dim):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _compatible_shapes(flat_relaxed, flat_to_check):
|
||||||
|
"""Check if lists of TensorShapes contain compatible shapes.
|
||||||
|
|
||||||
|
Checks that each `flat_relaxed` shape covers a superset of the shapes of the
|
||||||
|
corresponding `flat_to_check` shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
flat_relaxed: List of TensorShape or None.
|
||||||
|
flat_to_check: List of TensorShape or None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A python bool.
|
A python bool.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: if `len(flat_x) != len(flat_y)`.
|
RuntimeError:
|
||||||
RuntimeError: if `flat_x[i] is None != flat_y[i] is None` for any `i`.
|
if `len(flat_relaxed) != len(flat_to_check)`.
|
||||||
|
RuntimeError:
|
||||||
|
if `flat_relaxed[i] is None != flat_to_check[i] is None` for any `i`.
|
||||||
"""
|
"""
|
||||||
if len(flat_x) != len(flat_y):
|
|
||||||
|
if len(flat_relaxed) != len(flat_to_check):
|
||||||
raise RuntimeError("Expected shape lists of identical lengths, but saw: "
|
raise RuntimeError("Expected shape lists of identical lengths, but saw: "
|
||||||
"%s and %s" % (flat_x, flat_y))
|
"%s and %s" % (flat_relaxed, flat_to_check))
|
||||||
def is_compatible(x, y):
|
def is_compatible(relaxed, to_check):
|
||||||
"""Internal help function.
|
"""Internal help function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: TensorShape or None.
|
relaxed: TensorShape or None.
|
||||||
y: TensorShape or None.
|
to_check: TensorShape or None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Python bool.
|
Python bool.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If `x is None != y is None`.
|
RuntimeError: If `relaxed is None != to_check is None`.
|
||||||
"""
|
"""
|
||||||
# If both x and y are None, there is no shape to compare. Otherwise check
|
# If both x and y are None, there is no shape to compare. Otherwise check
|
||||||
# if they are compatible with each other. Either way, both input signatures
|
# if they are compatible with each other. Either way, both input signatures
|
||||||
# must have have Tensors in the same entries. If not, raise an assertion
|
# must have have Tensors in the same entries. If not, raise an assertion
|
||||||
# error.
|
# error.
|
||||||
if x is None != y is None:
|
if relaxed is None != to_check is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Expected signature type matches between flattened input shapes "
|
"Expected signature type matches between flattened input shapes "
|
||||||
"%s and %s; but saw that (%s is None) != (%s is None)"
|
"%s and %s; but saw that (%s is None) != (%s is None)"
|
||||||
% (flat_x, flat_y, x, y))
|
% (flat_relaxed, flat_to_check, relaxed, to_check))
|
||||||
return x is None or x.is_compatible_with(y)
|
return relaxed is None or _shape_less_specific_than(relaxed, to_check)
|
||||||
return all(is_compatible(x, y) for x, y in zip(flat_x, flat_y))
|
return all(is_compatible(relaxed, to_check)
|
||||||
|
for relaxed, to_check in zip(flat_relaxed, flat_to_check))
|
||||||
|
|
||||||
|
|
||||||
def _common_shape(x, y):
|
def _common_shape(x, y):
|
||||||
@ -1558,7 +1591,8 @@ class Function(object):
|
|||||||
rank_only_cache_key, None)
|
rank_only_cache_key, None)
|
||||||
|
|
||||||
if (relaxed_arg_function is not None
|
if (relaxed_arg_function is not None
|
||||||
and _compatible_shapes(relaxed_arg_shapes, arg_shapes)):
|
and _compatible_shapes(flat_relaxed=relaxed_arg_shapes,
|
||||||
|
flat_to_check=arg_shapes)):
|
||||||
return relaxed_arg_function, args, kwargs
|
return relaxed_arg_function, args, kwargs
|
||||||
|
|
||||||
if relaxed_arg_shapes is None:
|
if relaxed_arg_shapes is None:
|
||||||
|
@ -2535,6 +2535,24 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
x = func()
|
x = func()
|
||||||
self.assertRegexpMatches(x.device, 'GPU')
|
self.assertRegexpMatches(x.device, 'GPU')
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testShapeCaching(self):
|
||||||
|
|
||||||
|
@function.defun
|
||||||
|
def func(x):
|
||||||
|
return array_ops.shape(x)
|
||||||
|
|
||||||
|
@function.defun(
|
||||||
|
input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)])
|
||||||
|
def calls_func(x):
|
||||||
|
return func(x)
|
||||||
|
|
||||||
|
self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1]))))
|
||||||
|
self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2]))))
|
||||||
|
self.assertAllEqual(
|
||||||
|
[3, 3],
|
||||||
|
self.evaluate(calls_func(array_ops.zeros([3, 3]))))
|
||||||
|
|
||||||
|
|
||||||
class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user