diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index d1127fd0f45..199671ac6a0 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -84,47 +84,80 @@ def _flat_shape_list(*params): for x in nest.flatten(params)] -def _compatible_shapes(flat_x, flat_y): - """Check if lists of TensorShapes contain compatible shapes. +def _shape_less_specific_than(relaxed, to_check): + """Checks if `relaxed` is less specific than `to_check`. + + This is an asymmetric check, unlike `TensorShape.is_compatible_with`. If + `to_check` has a dimension with an undefined shape, `relaxed` must also have + an undefined shape for that dimension. Args: - flat_x: List of TensorShape or None. - flat_y: List of TensorShape or None. + relaxed: A `TensorShape` to check against. + to_check: A second `TensorShape`. + + Returns: + True if `to_check` represents a set of shapes which is a subset of + `relaxed`'s shapes and False otherwise. + """ + if to_check.dims is not None and relaxed.dims is not None: + if to_check.rank != relaxed.rank: + return False + for check_dim, relaxed_dim in zip(to_check.dims, relaxed.dims): + if check_dim.value is None and relaxed_dim.value is not None: + return False + if not relaxed_dim.is_compatible_with(check_dim): + return False + return True + + +def _compatible_shapes(flat_relaxed, flat_to_check): + """Check if lists of TensorShapes contain compatible shapes. + + Checks that each `flat_relaxed` shape covers a superset of the shapes of the + corresponding `flat_to_check` shape. + + Args: + flat_relaxed: List of TensorShape or None. + flat_to_check: List of TensorShape or None. Returns: A python bool. Raises: - RuntimeError: if `len(flat_x) != len(flat_y)`. - RuntimeError: if `flat_x[i] is None != flat_y[i] is None` for any `i`. + RuntimeError: + if `len(flat_relaxed) != len(flat_to_check)`. + RuntimeError: + if `flat_relaxed[i] is None != flat_to_check[i] is None` for any `i`. """ - if len(flat_x) != len(flat_y): + + if len(flat_relaxed) != len(flat_to_check): raise RuntimeError("Expected shape lists of identical lengths, but saw: " - "%s and %s" % (flat_x, flat_y)) - def is_compatible(x, y): + "%s and %s" % (flat_relaxed, flat_to_check)) + def is_compatible(relaxed, to_check): """Internal help function. Args: - x: TensorShape or None. - y: TensorShape or None. + relaxed: TensorShape or None. + to_check: TensorShape or None. Returns: Python bool. Raises: - RuntimeError: If `x is None != y is None`. + RuntimeError: If `relaxed is None != to_check is None`. """ # If both x and y are None, there is no shape to compare. Otherwise check # if they are compatible with each other. Either way, both input signatures # must have have Tensors in the same entries. If not, raise an assertion # error. - if x is None != y is None: + if relaxed is None != to_check is None: raise RuntimeError( "Expected signature type matches between flattened input shapes " "%s and %s; but saw that (%s is None) != (%s is None)" - % (flat_x, flat_y, x, y)) - return x is None or x.is_compatible_with(y) - return all(is_compatible(x, y) for x, y in zip(flat_x, flat_y)) + % (flat_relaxed, flat_to_check, relaxed, to_check)) + return relaxed is None or _shape_less_specific_than(relaxed, to_check) + return all(is_compatible(relaxed, to_check) + for relaxed, to_check in zip(flat_relaxed, flat_to_check)) def _common_shape(x, y): @@ -1558,7 +1591,8 @@ class Function(object): rank_only_cache_key, None) if (relaxed_arg_function is not None - and _compatible_shapes(relaxed_arg_shapes, arg_shapes)): + and _compatible_shapes(flat_relaxed=relaxed_arg_shapes, + flat_to_check=arg_shapes)): return relaxed_arg_function, args, kwargs if relaxed_arg_shapes is None: diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index d86c1001e6c..3b80064db20 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -2535,6 +2535,24 @@ class FunctionTest(test.TestCase, parameterized.TestCase): x = func() self.assertRegexpMatches(x.device, 'GPU') + @test_util.run_in_graph_and_eager_modes + def testShapeCaching(self): + + @function.defun + def func(x): + return array_ops.shape(x) + + @function.defun( + input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)]) + def calls_func(x): + return func(x) + + self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1])))) + self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2])))) + self.assertAllEqual( + [3, 3], + self.evaluate(calls_func(array_ops.zeros([3, 3])))) + class MultiDeviceTest(test.TestCase, parameterized.TestCase):