diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index aa059c9e3b0..53dd065f135 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -1463,10 +1463,10 @@ class AssertShapesTest(test.TestCase): def test_raise_static_shape_mismatch(self): x = array_ops.ones([3, 2], name="x") y = array_ops.ones([2, 3], name="y") - shapes = { - x: ("N", "Q"), - y: ("N", "D"), - } + shapes = [ + (x, ("N", "Q")), + (y, ("N", "D")), + ] regex = (r"Specified by tensor .* dimension 0. " r"Tensor .* dimension 0 must have size 3. " r"Received size 2") @@ -1476,10 +1476,10 @@ class AssertShapesTest(test.TestCase): with ops.Graph().as_default(): x = array_ops.placeholder(dtypes.float32, [None, 2], name="x") y = array_ops.placeholder(dtypes.float32, [None, 3], name="y") - shapes = { - x: ("N", "Q"), - y: ("N", "D"), - } + shapes = [ + (x, ("N", "Q")), + (y, ("N", "D")), + ] regex = (r"\[Specified by tensor x.* dimension 0\] " r"\[Tensor y.* dimension\] \[0\] \[must have size\] \[3\]") feed_dict = {x: np.ones([3, 2]), y: np.ones([2, 3])} @@ -1489,10 +1489,10 @@ class AssertShapesTest(test.TestCase): def test_raise_static_shape_explicit_mismatch(self): x = array_ops.ones([3, 2], name="x") y = array_ops.ones([2, 3], name="y") - shapes = { - x: (3, "Q"), - y: (3, "D"), - } + shapes = [ + (x, (3, "Q")), + (y, (3, "D")), + ] regex = (r"Specified explicitly. " r"Tensor .* dimension 0 must have size 3. " r"Received size 2") @@ -1502,21 +1502,21 @@ class AssertShapesTest(test.TestCase): def test_rank_zero_rank_one_size_one_equivalence(self): rank_one_size_one = array_ops.ones([1], name="rank_one_size_one") rank_zero = array_ops.constant(5, name="rank_zero") - check_ops.assert_shapes({ - rank_one_size_one: (), - rank_zero: (), - }) - check_ops.assert_shapes({ - rank_one_size_one: (1,), - rank_zero: (1,), - }) + check_ops.assert_shapes([ + (rank_one_size_one, ()), + (rank_zero, ()), + ]) + check_ops.assert_shapes([ + (rank_one_size_one, (1,)), + (rank_zero, (1,)), + ]) @test_util.run_in_graph_and_eager_modes def test_raise_static_rank_1_size_not_1_mismatch_scalar(self): x = array_ops.constant([2, 2], name="x") - shapes = { - x: (), - } + shapes = [ + (x, ()), + ] regex = (r"Specified explicitly. " r"Tensor .* dimension 0 must have size 1. " r"Received size 2") @@ -1525,9 +1525,9 @@ class AssertShapesTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def test_raise_static_scalar_mismatch_rank_1_size_not_1(self): x = array_ops.constant(2, name="x") - shapes = { - x: (2,), - } + shapes = [ + (x, (2,)), + ] regex = (r"Specified explicitly. " r"Tensor .* dimension 0 must have size 2. " r"Received size 1") @@ -1537,7 +1537,7 @@ class AssertShapesTest(test.TestCase): def test_scalar_implies_size_one(self): scalar = array_ops.constant(5, name="rank_zero") x = array_ops.ones([2, 2], name="x") - shapes = {scalar: ("a",), x: ("a", 2)} + shapes = [(scalar, ("a",)), (x, ("a", 2))] regex = (r"Specified by tensor .* dimension 0. " r"Tensor .* dimension 0 must have size 1. " r"Received size 2") @@ -1546,7 +1546,7 @@ class AssertShapesTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def test_raise_not_iterable(self): x = array_ops.constant([1, 2], name="x") - shapes = {x: 2} + shapes = [(x, 2)] regex = (r"Tensor .*. " r"Specified shape must be an iterable. " r"An iterable has the attribute `__iter__` or `__getitem__`. " @@ -1557,10 +1557,10 @@ class AssertShapesTest(test.TestCase): with ops.Graph().as_default(): x = array_ops.placeholder(dtypes.float32, [None, 2], name="xa") y = array_ops.placeholder(dtypes.float32, [None, 3], name="y") - shapes = { - x: (3, "Q"), - y: (3, "D"), - } + shapes = [ + (x, (3, "Q")), + (y, (3, "D")), + ] regex = (r"\[Specified explicitly\] " r"\[Tensor y.* dimension\] \[0\] \[must have size\] \[3\]") feed_dict = {x: np.ones([3, 2]), y: np.ones([2, 3])} @@ -1569,7 +1569,7 @@ class AssertShapesTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def test_no_op_when_specified_as_unknown(self): x = array_ops.constant([1, 1], name="x") - assertion = check_ops.assert_shapes({x: None}) + assertion = check_ops.assert_shapes([(x, None)]) with ops.control_dependencies([assertion]): out = array_ops.identity(x) self.evaluate(out) @@ -1616,7 +1616,7 @@ class AssertShapesTest(test.TestCase): for shape in rank_two_shapes: regex = r"Tensor .* must have rank\] \[2\]" self.raises_dynamic_error( - shapes={x: shape}, regex=regex, feed_dict={x: x_value}) + shapes=[(x, shape)], regex=regex, feed_dict={x: x_value}) @test_util.run_in_graph_and_eager_modes def test_correctly_matching(self): @@ -1626,25 +1626,25 @@ class AssertShapesTest(test.TestCase): x = array_ops.ones([1, 2, 3], name="x") y = array_ops.ones([3, 1, 2], name="y") z = array_ops.ones([2, 3, 1], name="z") - assertion = check_ops.assert_shapes({ - x: ("a", "b", "c"), - y: ("c", "a", "b"), - z: ("b", "c", "a"), - v: ("a", "b"), - w: ("c",), - u: "a" - }) + assertion = check_ops.assert_shapes([ + (x, ("a", "b", "c")), + (y, ("c", "a", "b")), + (z, ("b", "c", "a")), + (v, ("a", "b")), + (w, ("c",)), + (u, "a") + ]) with ops.control_dependencies([assertion]): out = array_ops.identity(x) self.evaluate(out) - assertion = check_ops.assert_shapes({ - x: (1, "b", "c"), - y: ("c", "a", 2), - z: ("b", 3, "a"), - v: ("a", 2), - w: (3,), - u: () - }) + assertion = check_ops.assert_shapes([ + (x, (1, "b", "c")), + (y, ("c", "a", 2)), + (z, ("b", 3, "a")), + (v, ("a", 2)), + (w, (3,)), + (u, ()) + ]) with ops.control_dependencies([assertion]): out = array_ops.identity(x) self.evaluate(out) @@ -1653,10 +1653,10 @@ class AssertShapesTest(test.TestCase): def test_variable_length_symbols(self): x = array_ops.ones([4, 1], name="x") y = array_ops.ones([4, 2], name="y") - assertion = check_ops.assert_shapes({ - x: ("num_observations", "input_dim"), - y: ("num_observations", "output_dim"), - }) + assertion = check_ops.assert_shapes([ + (x, ("num_observations", "input_dim")), + (y, ("num_observations", "output_dim")), + ]) with ops.control_dependencies([assertion]): out = array_ops.identity(x) self.evaluate(out) @@ -1665,22 +1665,22 @@ class AssertShapesTest(test.TestCase): def test_raise_implicit_mismatch_using_iterable_alternatives(self): x = array_ops.ones([2, 2], name="x") y = array_ops.ones([1, 3], name="y") - styles = [{ - x: ("A", "B"), - y: ("A", "C"), - }, { - x: "AB", - y: "AC" - }, { - x: ["A", "B"], - y: ["A", "C"], - }, { - x: np.array(["A", "B"]), - y: np.array(["A", "C"]) - }, { - x: ("A", "B"), - y: "AC" - }] + styles = [[ + (x, ("A", "B")), + (y, ("A", "C")), + ], [ + (x, "AB"), + (y, "AC") + ], [ + (x, ["A", "B"]), + (y, ["A", "C"]), + ], [ + (x, np.array(["A", "B"])), + (y, np.array(["A", "C"])) + ], [ + (x, ("A", "B")), + (y, "AC") + ]] for shapes in styles: self.raises_static_error( shapes=shapes, @@ -1692,22 +1692,22 @@ class AssertShapesTest(test.TestCase): def test_raise_explicit_mismatch_using_iterable_alternatives(self): x = array_ops.ones([2, 2], name="x") y = array_ops.ones([1, 3], name="y") - styles = [{ - x: (2, 2), - y: (2, 3), - }, { - x: "22", - y: "23" - }, { - x: [2, 2], - y: [2, 3], - }, { - x: np.array([2, 2]), - y: np.array([2, 3]) - }, { - x: (2, 2), - y: "23" - }] + styles = [[ + (x, (2, 2)), + (y, (2, 3)), + ], [ + (x, "22"), + (y, "23") + ], [ + (x, [2, 2]), + (y, [2, 3]), + ], [ + (x, np.array([2, 2])), + (y, np.array([2, 3])) + ], [ + (x, (2, 2)), + (y, "23") + ]] for shapes in styles: self.raises_static_error( shapes=shapes, @@ -1720,16 +1720,16 @@ class AssertShapesTest(test.TestCase): x = array_ops.ones([1, 2, 3], name="x") y = array_ops.ones([2, 1], name="y") a1 = check_ops.assert_shapes({ - x: (None, 2, None), - y: (None, 1), + (x, (None, 2, None)), + (y, (None, 1)), }) a2 = check_ops.assert_shapes({ - x: (".", 2, "."), - y: (".", 1), + (x, (".", 2, ".")), + (y, (".", 1)), }) a3 = check_ops.assert_shapes({ - x: ".2.", - y: ".1", + (x, ".2."), + (y, ".1"), }) with ops.control_dependencies([a1, a2, a3]): out = array_ops.identity(x) @@ -1739,14 +1739,14 @@ class AssertShapesTest(test.TestCase): def test_raise_static_shape_explicit_mismatch_innermost_dims(self): x = array_ops.ones([3, 2], name="x") y = array_ops.ones([2, 3], name="y") - s1 = { - x: (3, "Q"), - y: (Ellipsis, 3, "D"), - } - s2 = { - x: "3Q", - y: "*3D", - } + s1 = [ + (x, (3, "Q")), + (y, (Ellipsis, 3, "D")), + ] + s2 = [ + (x, "3Q"), + (y, "*3D"), + ] regex = (r"Specified explicitly. " r"Tensor .* dimension -2 must have size 3. " r"Received size 2") @@ -1757,14 +1757,14 @@ class AssertShapesTest(test.TestCase): def test_correctly_matching_innermost_dims(self): x = array_ops.ones([1, 2, 3, 2], name="x") y = array_ops.ones([2, 3, 3], name="y") - a1 = check_ops.assert_shapes({ - x: (Ellipsis, "N", "Q"), - y: (Ellipsis, "N", "D"), - }) - a2 = check_ops.assert_shapes({ - x: "*NQ", - y: "*ND", - }) + a1 = check_ops.assert_shapes([ + (x, (Ellipsis, "N", "Q")), + (y, (Ellipsis, "N", "D")), + ]) + a2 = check_ops.assert_shapes([ + (x, "*NQ"), + (y, "*ND"), + ]) with ops.control_dependencies([a1, a2]): out = array_ops.identity(x) self.evaluate(out) @@ -1772,12 +1772,12 @@ class AssertShapesTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def test_raise_variable_num_outer_dims_prefix_misuse(self): x = array_ops.ones([1, 2], name="x") - s1 = { - x: ("N", Ellipsis, "Q"), - } - s2 = { - x: "N*Q", - } + s1 = [ + (x, ("N", Ellipsis, "Q")), + ] + s2 = [ + (x, "N*Q"), + ] regex = (r"Tensor .* specified shape index .*. " r"Symbol `...` or `\*` for a variable number of " r"unspecified dimensions is only allowed as the first entry") @@ -1786,7 +1786,7 @@ class AssertShapesTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def test_empty_shapes_dict_no_op(self): - assertion = check_ops.assert_shapes({}) + assertion = check_ops.assert_shapes([]) with ops.control_dependencies([assertion]): out = array_ops.identity(0) self.evaluate(out) diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 2bb1c4b6b61..3997c401dc3 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import numpy as np from tensorflow.python.eager import context @@ -1663,8 +1664,10 @@ def _dimension_sizes(x): def _symbolic_dimension_sizes(symbolic_shape): - if len(symbolic_shape) == 0: + # If len(symbolic_shape) == 0 construct a tuple + if not symbolic_shape: return tuple([1]) + return symbolic_shape @@ -1682,8 +1685,9 @@ def _is_symbol_for_any_size(symbol): return symbol in [None, '.'] -def _is_symbol_for_unspecified_dims(symbol): - return symbol in [Ellipsis, '*'] +_TensorDimSizes = collections.namedtuple( + '_TensorDimSizes', + ['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes']) @tf_export('debugging.assert_shapes', v1=[]) @@ -1697,12 +1701,12 @@ def assert_shapes_v2(shapes, data=None, summarize=None, message=None, Example: ```python - tf.assert_shapes({ - x: ('N', 'Q'), - y: ('N', 'D'), - param: ('Q',), - scalar: () - }) + tf.assert_shapes([ + (x: ('N', 'Q')), + (y: ('N', 'D')), + (param: ('Q',)), + (scalar: ()), + ]) ``` If `x`, `y`, `param` or `scalar` does not have a shape that satisfies @@ -1750,10 +1754,10 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): ```python tf.assert_shapes({ - x: ('N', 'Q'), - y: ('N', 'D'), - param: ('Q',), - scalar: () + (x, ('N', 'Q')), + (y, ('N', 'D')), + (param, ('Q',)), + (scalar, ()) }) ``` @@ -1800,13 +1804,17 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): Raises: ValueError: If static checks determine any shape constraint is violated. """ + # If the user manages to assemble a dict containing tensors (possible in + # Graph mode only), make sure we still accept that. + if isinstance(shapes, dict): + shapes = shapes.items() + message = message or '' with ops.name_scope(name, 'assert_shapes', [shapes, data]): - # Shape specified as None implies no constraint - shapes = {x: shapes[x] for x in shapes if shapes[x] is not None} - - shapes = {ops.convert_to_tensor(x): shapes[x] for x in shapes} + shape_constraints = [ + (ops.convert_to_tensor(x), s) for x, s in shapes if s is not None + ] executing_eagerly = context.executing_eagerly() @@ -1815,8 +1823,8 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): return _shape_and_dtype_str(x) return x.name - for x in shapes: - symbolic_shape = shapes[x] + tensor_dim_sizes = [] + for tensor, symbolic_shape in shape_constraints: is_iterable = ( hasattr(symbolic_shape, '__iter__') or hasattr(symbolic_shape, '__getitem__') # For Python 2 compat. @@ -1827,44 +1835,46 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): 'Tensor %s. Specified shape must be an iterable. ' 'An iterable has the attribute `__iter__` or `__getitem__`. ' 'Received specified shape: %s' % - (message, tensor_name(x), symbolic_shape)) - shapes[x] = tuple(shapes[x]) + (message, tensor_name(tensor), symbolic_shape)) - tensors_specified_innermost = set() - for x in shapes: - symbolic_shape = shapes[x] - for i, symbol in enumerate(symbolic_shape): - if not _is_symbol_for_unspecified_dims(symbol): + # We convert this into a tuple to handle strings, lists and numpy arrays + symbolic_shape_tuple = tuple(symbolic_shape) + + tensors_specified_innermost = False + for i, symbol in enumerate(symbolic_shape_tuple): + if symbol not in [Ellipsis, '*']: continue + if i != 0: raise ValueError( '%s. ' 'Tensor %s specified shape index %d. ' 'Symbol `...` or `*` for a variable number of ' 'unspecified dimensions is only allowed as the first entry' % - (message, tensor_name(x), i)) - tensors_specified_innermost.add(x) + (message, tensor_name(tensor), i)) - actual_sizes_by_tensor = {x: _dimension_sizes(x) for x in shapes} - specified_sizes_by_tensor = { - x: _symbolic_dimension_sizes( - # Ignoring innermost prefix - shapes[x][1:] if x in tensors_specified_innermost else shapes[x]) - for x in shapes - } + tensors_specified_innermost = True + + # Only include the size of the specified dimensions since the 0th symbol + # is either ellipsis or * + tensor_dim_sizes.append( + _TensorDimSizes( + tensor, tensors_specified_innermost, _dimension_sizes(tensor), + _symbolic_dimension_sizes( + symbolic_shape_tuple[1:] + if tensors_specified_innermost else symbolic_shape_tuple))) rank_assertions = [] - for x in shapes.keys(): - symbolic_sizes = specified_sizes_by_tensor[x] - rank = len(symbolic_sizes) + for sizes in tensor_dim_sizes: + rank = len(sizes.symbolic_sizes) rank_zero_or_one = rank in [0, 1] - if x in tensors_specified_innermost: + if sizes.unspecified_dim: if rank_zero_or_one: - # No assertion of rank needed as `x` only need to have rank at least 0. - # See elif rank_zero_or_one case comment. + # No assertion of rank needed as `x` only need to have rank at least + # 0. See elif rank_zero_or_one case comment. continue assertion = assert_rank_at_least( - x=x, + x=sizes.x, rank=rank, data=data, summarize=summarize, @@ -1875,7 +1885,7 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): # no distinction between the two in terms of rank. # See _dimension_sizes. assertion = assert_rank_in( - x=x, + x=sizes.x, ranks=[0, 1], data=data, summarize=summarize, @@ -1883,7 +1893,7 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): name=name) else: assertion = assert_rank( - x=x, + x=sizes.x, rank=rank, data=data, summarize=summarize, @@ -1893,19 +1903,15 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): size_assertions = [] size_specifications = {} - for x in shapes.keys(): - actual_sizes = actual_sizes_by_tensor[x] - symbolic_sizes = specified_sizes_by_tensor[x] - innermost_dims = x in tensors_specified_innermost - - for i, size_symbol in enumerate(symbolic_sizes): + for sizes in tensor_dim_sizes: + for i, size_symbol in enumerate(sizes.symbolic_sizes): if _is_symbol_for_any_size(size_symbol): # Size specified as any implies no constraint continue - if innermost_dims: - tensor_dim = i - len(symbolic_sizes) + if sizes.unspecified_dim: + tensor_dim = i - len(sizes.symbolic_sizes) else: tensor_dim = i @@ -1920,14 +1926,15 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): 'Specified by tensor %s dimension %d' % (tensor_name(specified_by_y), specified_at_dim)) - actual_size = actual_sizes[tensor_dim] + actual_size = sizes.actual_sizes[tensor_dim] if _has_known_value(actual_size) and _has_known_value(specified_size): if int(actual_size) != int(specified_size): raise ValueError( '%s. %s. Tensor %s dimension %s must have size %d. ' 'Received size %d, shape %s' % - (message, size_check_message, tensor_name(x), tensor_dim, - specified_size, actual_size, x.get_shape())) + (message, size_check_message, tensor_name(sizes.x), + tensor_dim, specified_size, actual_size, + sizes.x.get_shape())) # No dynamic assertion needed continue @@ -1938,15 +1945,15 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): if data is None: data_ = [ message, size_check_message, - 'Tensor %s dimension' % tensor_name(x), tensor_dim, + 'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim, 'must have size', specified_size, 'Received shape: ', - array_ops.shape(x) + array_ops.shape(sizes.x) ] size_assertions.append( control_flow_ops.Assert(condition, data_, summarize=summarize)) else: - size = actual_sizes[tensor_dim] - size_specifications[size_symbol] = (size, x, tensor_dim) + size = sizes.actual_sizes[tensor_dim] + size_specifications[size_symbol] = (size, sizes.x, tensor_dim) with ops.control_dependencies(rank_assertions): shapes_assertion = control_flow_ops.group(size_assertions)