Modify tf.assert_shapes to take list of pairs

To support the tensor equality changes. We need to assume that tensors
are unhashable. Thus we need change tf.assert_shapes to take a list of
pairs instead of a dict.

PiperOrigin-RevId: 258892939
This commit is contained in:
Gaurav Jain 2019-07-18 19:39:23 -07:00 committed by TensorFlower Gardener
parent 7b8cd8dce9
commit 47e650a119
2 changed files with 181 additions and 174 deletions

View File

@ -1463,10 +1463,10 @@ class AssertShapesTest(test.TestCase):
def test_raise_static_shape_mismatch(self):
x = array_ops.ones([3, 2], name="x")
y = array_ops.ones([2, 3], name="y")
shapes = {
x: ("N", "Q"),
y: ("N", "D"),
}
shapes = [
(x, ("N", "Q")),
(y, ("N", "D")),
]
regex = (r"Specified by tensor .* dimension 0. "
r"Tensor .* dimension 0 must have size 3. "
r"Received size 2")
@ -1476,10 +1476,10 @@ class AssertShapesTest(test.TestCase):
with ops.Graph().as_default():
x = array_ops.placeholder(dtypes.float32, [None, 2], name="x")
y = array_ops.placeholder(dtypes.float32, [None, 3], name="y")
shapes = {
x: ("N", "Q"),
y: ("N", "D"),
}
shapes = [
(x, ("N", "Q")),
(y, ("N", "D")),
]
regex = (r"\[Specified by tensor x.* dimension 0\] "
r"\[Tensor y.* dimension\] \[0\] \[must have size\] \[3\]")
feed_dict = {x: np.ones([3, 2]), y: np.ones([2, 3])}
@ -1489,10 +1489,10 @@ class AssertShapesTest(test.TestCase):
def test_raise_static_shape_explicit_mismatch(self):
x = array_ops.ones([3, 2], name="x")
y = array_ops.ones([2, 3], name="y")
shapes = {
x: (3, "Q"),
y: (3, "D"),
}
shapes = [
(x, (3, "Q")),
(y, (3, "D")),
]
regex = (r"Specified explicitly. "
r"Tensor .* dimension 0 must have size 3. "
r"Received size 2")
@ -1502,21 +1502,21 @@ class AssertShapesTest(test.TestCase):
def test_rank_zero_rank_one_size_one_equivalence(self):
rank_one_size_one = array_ops.ones([1], name="rank_one_size_one")
rank_zero = array_ops.constant(5, name="rank_zero")
check_ops.assert_shapes({
rank_one_size_one: (),
rank_zero: (),
})
check_ops.assert_shapes({
rank_one_size_one: (1,),
rank_zero: (1,),
})
check_ops.assert_shapes([
(rank_one_size_one, ()),
(rank_zero, ()),
])
check_ops.assert_shapes([
(rank_one_size_one, (1,)),
(rank_zero, (1,)),
])
@test_util.run_in_graph_and_eager_modes
def test_raise_static_rank_1_size_not_1_mismatch_scalar(self):
x = array_ops.constant([2, 2], name="x")
shapes = {
x: (),
}
shapes = [
(x, ()),
]
regex = (r"Specified explicitly. "
r"Tensor .* dimension 0 must have size 1. "
r"Received size 2")
@ -1525,9 +1525,9 @@ class AssertShapesTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_raise_static_scalar_mismatch_rank_1_size_not_1(self):
x = array_ops.constant(2, name="x")
shapes = {
x: (2,),
}
shapes = [
(x, (2,)),
]
regex = (r"Specified explicitly. "
r"Tensor .* dimension 0 must have size 2. "
r"Received size 1")
@ -1537,7 +1537,7 @@ class AssertShapesTest(test.TestCase):
def test_scalar_implies_size_one(self):
scalar = array_ops.constant(5, name="rank_zero")
x = array_ops.ones([2, 2], name="x")
shapes = {scalar: ("a",), x: ("a", 2)}
shapes = [(scalar, ("a",)), (x, ("a", 2))]
regex = (r"Specified by tensor .* dimension 0. "
r"Tensor .* dimension 0 must have size 1. "
r"Received size 2")
@ -1546,7 +1546,7 @@ class AssertShapesTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_raise_not_iterable(self):
x = array_ops.constant([1, 2], name="x")
shapes = {x: 2}
shapes = [(x, 2)]
regex = (r"Tensor .*. "
r"Specified shape must be an iterable. "
r"An iterable has the attribute `__iter__` or `__getitem__`. "
@ -1557,10 +1557,10 @@ class AssertShapesTest(test.TestCase):
with ops.Graph().as_default():
x = array_ops.placeholder(dtypes.float32, [None, 2], name="xa")
y = array_ops.placeholder(dtypes.float32, [None, 3], name="y")
shapes = {
x: (3, "Q"),
y: (3, "D"),
}
shapes = [
(x, (3, "Q")),
(y, (3, "D")),
]
regex = (r"\[Specified explicitly\] "
r"\[Tensor y.* dimension\] \[0\] \[must have size\] \[3\]")
feed_dict = {x: np.ones([3, 2]), y: np.ones([2, 3])}
@ -1569,7 +1569,7 @@ class AssertShapesTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_no_op_when_specified_as_unknown(self):
x = array_ops.constant([1, 1], name="x")
assertion = check_ops.assert_shapes({x: None})
assertion = check_ops.assert_shapes([(x, None)])
with ops.control_dependencies([assertion]):
out = array_ops.identity(x)
self.evaluate(out)
@ -1616,7 +1616,7 @@ class AssertShapesTest(test.TestCase):
for shape in rank_two_shapes:
regex = r"Tensor .* must have rank\] \[2\]"
self.raises_dynamic_error(
shapes={x: shape}, regex=regex, feed_dict={x: x_value})
shapes=[(x, shape)], regex=regex, feed_dict={x: x_value})
@test_util.run_in_graph_and_eager_modes
def test_correctly_matching(self):
@ -1626,25 +1626,25 @@ class AssertShapesTest(test.TestCase):
x = array_ops.ones([1, 2, 3], name="x")
y = array_ops.ones([3, 1, 2], name="y")
z = array_ops.ones([2, 3, 1], name="z")
assertion = check_ops.assert_shapes({
x: ("a", "b", "c"),
y: ("c", "a", "b"),
z: ("b", "c", "a"),
v: ("a", "b"),
w: ("c",),
u: "a"
})
assertion = check_ops.assert_shapes([
(x, ("a", "b", "c")),
(y, ("c", "a", "b")),
(z, ("b", "c", "a")),
(v, ("a", "b")),
(w, ("c",)),
(u, "a")
])
with ops.control_dependencies([assertion]):
out = array_ops.identity(x)
self.evaluate(out)
assertion = check_ops.assert_shapes({
x: (1, "b", "c"),
y: ("c", "a", 2),
z: ("b", 3, "a"),
v: ("a", 2),
w: (3,),
u: ()
})
assertion = check_ops.assert_shapes([
(x, (1, "b", "c")),
(y, ("c", "a", 2)),
(z, ("b", 3, "a")),
(v, ("a", 2)),
(w, (3,)),
(u, ())
])
with ops.control_dependencies([assertion]):
out = array_ops.identity(x)
self.evaluate(out)
@ -1653,10 +1653,10 @@ class AssertShapesTest(test.TestCase):
def test_variable_length_symbols(self):
x = array_ops.ones([4, 1], name="x")
y = array_ops.ones([4, 2], name="y")
assertion = check_ops.assert_shapes({
x: ("num_observations", "input_dim"),
y: ("num_observations", "output_dim"),
})
assertion = check_ops.assert_shapes([
(x, ("num_observations", "input_dim")),
(y, ("num_observations", "output_dim")),
])
with ops.control_dependencies([assertion]):
out = array_ops.identity(x)
self.evaluate(out)
@ -1665,22 +1665,22 @@ class AssertShapesTest(test.TestCase):
def test_raise_implicit_mismatch_using_iterable_alternatives(self):
x = array_ops.ones([2, 2], name="x")
y = array_ops.ones([1, 3], name="y")
styles = [{
x: ("A", "B"),
y: ("A", "C"),
}, {
x: "AB",
y: "AC"
}, {
x: ["A", "B"],
y: ["A", "C"],
}, {
x: np.array(["A", "B"]),
y: np.array(["A", "C"])
}, {
x: ("A", "B"),
y: "AC"
}]
styles = [[
(x, ("A", "B")),
(y, ("A", "C")),
], [
(x, "AB"),
(y, "AC")
], [
(x, ["A", "B"]),
(y, ["A", "C"]),
], [
(x, np.array(["A", "B"])),
(y, np.array(["A", "C"]))
], [
(x, ("A", "B")),
(y, "AC")
]]
for shapes in styles:
self.raises_static_error(
shapes=shapes,
@ -1692,22 +1692,22 @@ class AssertShapesTest(test.TestCase):
def test_raise_explicit_mismatch_using_iterable_alternatives(self):
x = array_ops.ones([2, 2], name="x")
y = array_ops.ones([1, 3], name="y")
styles = [{
x: (2, 2),
y: (2, 3),
}, {
x: "22",
y: "23"
}, {
x: [2, 2],
y: [2, 3],
}, {
x: np.array([2, 2]),
y: np.array([2, 3])
}, {
x: (2, 2),
y: "23"
}]
styles = [[
(x, (2, 2)),
(y, (2, 3)),
], [
(x, "22"),
(y, "23")
], [
(x, [2, 2]),
(y, [2, 3]),
], [
(x, np.array([2, 2])),
(y, np.array([2, 3]))
], [
(x, (2, 2)),
(y, "23")
]]
for shapes in styles:
self.raises_static_error(
shapes=shapes,
@ -1720,16 +1720,16 @@ class AssertShapesTest(test.TestCase):
x = array_ops.ones([1, 2, 3], name="x")
y = array_ops.ones([2, 1], name="y")
a1 = check_ops.assert_shapes({
x: (None, 2, None),
y: (None, 1),
(x, (None, 2, None)),
(y, (None, 1)),
})
a2 = check_ops.assert_shapes({
x: (".", 2, "."),
y: (".", 1),
(x, (".", 2, ".")),
(y, (".", 1)),
})
a3 = check_ops.assert_shapes({
x: ".2.",
y: ".1",
(x, ".2."),
(y, ".1"),
})
with ops.control_dependencies([a1, a2, a3]):
out = array_ops.identity(x)
@ -1739,14 +1739,14 @@ class AssertShapesTest(test.TestCase):
def test_raise_static_shape_explicit_mismatch_innermost_dims(self):
x = array_ops.ones([3, 2], name="x")
y = array_ops.ones([2, 3], name="y")
s1 = {
x: (3, "Q"),
y: (Ellipsis, 3, "D"),
}
s2 = {
x: "3Q",
y: "*3D",
}
s1 = [
(x, (3, "Q")),
(y, (Ellipsis, 3, "D")),
]
s2 = [
(x, "3Q"),
(y, "*3D"),
]
regex = (r"Specified explicitly. "
r"Tensor .* dimension -2 must have size 3. "
r"Received size 2")
@ -1757,14 +1757,14 @@ class AssertShapesTest(test.TestCase):
def test_correctly_matching_innermost_dims(self):
x = array_ops.ones([1, 2, 3, 2], name="x")
y = array_ops.ones([2, 3, 3], name="y")
a1 = check_ops.assert_shapes({
x: (Ellipsis, "N", "Q"),
y: (Ellipsis, "N", "D"),
})
a2 = check_ops.assert_shapes({
x: "*NQ",
y: "*ND",
})
a1 = check_ops.assert_shapes([
(x, (Ellipsis, "N", "Q")),
(y, (Ellipsis, "N", "D")),
])
a2 = check_ops.assert_shapes([
(x, "*NQ"),
(y, "*ND"),
])
with ops.control_dependencies([a1, a2]):
out = array_ops.identity(x)
self.evaluate(out)
@ -1772,12 +1772,12 @@ class AssertShapesTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_raise_variable_num_outer_dims_prefix_misuse(self):
x = array_ops.ones([1, 2], name="x")
s1 = {
x: ("N", Ellipsis, "Q"),
}
s2 = {
x: "N*Q",
}
s1 = [
(x, ("N", Ellipsis, "Q")),
]
s2 = [
(x, "N*Q"),
]
regex = (r"Tensor .* specified shape index .*. "
r"Symbol `...` or `\*` for a variable number of "
r"unspecified dimensions is only allowed as the first entry")
@ -1786,7 +1786,7 @@ class AssertShapesTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_empty_shapes_dict_no_op(self):
assertion = check_ops.assert_shapes({})
assertion = check_ops.assert_shapes([])
with ops.control_dependencies([assertion]):
out = array_ops.identity(0)
self.evaluate(out)

View File

@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from tensorflow.python.eager import context
@ -1663,8 +1664,10 @@ def _dimension_sizes(x):
def _symbolic_dimension_sizes(symbolic_shape):
if len(symbolic_shape) == 0:
# If len(symbolic_shape) == 0 construct a tuple
if not symbolic_shape:
return tuple([1])
return symbolic_shape
@ -1682,8 +1685,9 @@ def _is_symbol_for_any_size(symbol):
return symbol in [None, '.']
def _is_symbol_for_unspecified_dims(symbol):
return symbol in [Ellipsis, '*']
_TensorDimSizes = collections.namedtuple(
'_TensorDimSizes',
['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes'])
@tf_export('debugging.assert_shapes', v1=[])
@ -1697,12 +1701,12 @@ def assert_shapes_v2(shapes, data=None, summarize=None, message=None,
Example:
```python
tf.assert_shapes({
x: ('N', 'Q'),
y: ('N', 'D'),
param: ('Q',),
scalar: ()
})
tf.assert_shapes([
(x: ('N', 'Q')),
(y: ('N', 'D')),
(param: ('Q',)),
(scalar: ()),
])
```
If `x`, `y`, `param` or `scalar` does not have a shape that satisfies
@ -1750,10 +1754,10 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
```python
tf.assert_shapes({
x: ('N', 'Q'),
y: ('N', 'D'),
param: ('Q',),
scalar: ()
(x, ('N', 'Q')),
(y, ('N', 'D')),
(param, ('Q',)),
(scalar, ())
})
```
@ -1800,13 +1804,17 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
Raises:
ValueError: If static checks determine any shape constraint is violated.
"""
# If the user manages to assemble a dict containing tensors (possible in
# Graph mode only), make sure we still accept that.
if isinstance(shapes, dict):
shapes = shapes.items()
message = message or ''
with ops.name_scope(name, 'assert_shapes', [shapes, data]):
# Shape specified as None implies no constraint
shapes = {x: shapes[x] for x in shapes if shapes[x] is not None}
shapes = {ops.convert_to_tensor(x): shapes[x] for x in shapes}
shape_constraints = [
(ops.convert_to_tensor(x), s) for x, s in shapes if s is not None
]
executing_eagerly = context.executing_eagerly()
@ -1815,8 +1823,8 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
return _shape_and_dtype_str(x)
return x.name
for x in shapes:
symbolic_shape = shapes[x]
tensor_dim_sizes = []
for tensor, symbolic_shape in shape_constraints:
is_iterable = (
hasattr(symbolic_shape, '__iter__') or
hasattr(symbolic_shape, '__getitem__') # For Python 2 compat.
@ -1827,44 +1835,46 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
'Tensor %s. Specified shape must be an iterable. '
'An iterable has the attribute `__iter__` or `__getitem__`. '
'Received specified shape: %s' %
(message, tensor_name(x), symbolic_shape))
shapes[x] = tuple(shapes[x])
(message, tensor_name(tensor), symbolic_shape))
tensors_specified_innermost = set()
for x in shapes:
symbolic_shape = shapes[x]
for i, symbol in enumerate(symbolic_shape):
if not _is_symbol_for_unspecified_dims(symbol):
# We convert this into a tuple to handle strings, lists and numpy arrays
symbolic_shape_tuple = tuple(symbolic_shape)
tensors_specified_innermost = False
for i, symbol in enumerate(symbolic_shape_tuple):
if symbol not in [Ellipsis, '*']:
continue
if i != 0:
raise ValueError(
'%s. '
'Tensor %s specified shape index %d. '
'Symbol `...` or `*` for a variable number of '
'unspecified dimensions is only allowed as the first entry' %
(message, tensor_name(x), i))
tensors_specified_innermost.add(x)
(message, tensor_name(tensor), i))
actual_sizes_by_tensor = {x: _dimension_sizes(x) for x in shapes}
specified_sizes_by_tensor = {
x: _symbolic_dimension_sizes(
# Ignoring innermost prefix
shapes[x][1:] if x in tensors_specified_innermost else shapes[x])
for x in shapes
}
tensors_specified_innermost = True
# Only include the size of the specified dimensions since the 0th symbol
# is either ellipsis or *
tensor_dim_sizes.append(
_TensorDimSizes(
tensor, tensors_specified_innermost, _dimension_sizes(tensor),
_symbolic_dimension_sizes(
symbolic_shape_tuple[1:]
if tensors_specified_innermost else symbolic_shape_tuple)))
rank_assertions = []
for x in shapes.keys():
symbolic_sizes = specified_sizes_by_tensor[x]
rank = len(symbolic_sizes)
for sizes in tensor_dim_sizes:
rank = len(sizes.symbolic_sizes)
rank_zero_or_one = rank in [0, 1]
if x in tensors_specified_innermost:
if sizes.unspecified_dim:
if rank_zero_or_one:
# No assertion of rank needed as `x` only need to have rank at least 0.
# See elif rank_zero_or_one case comment.
# No assertion of rank needed as `x` only need to have rank at least
# 0. See elif rank_zero_or_one case comment.
continue
assertion = assert_rank_at_least(
x=x,
x=sizes.x,
rank=rank,
data=data,
summarize=summarize,
@ -1875,7 +1885,7 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
# no distinction between the two in terms of rank.
# See _dimension_sizes.
assertion = assert_rank_in(
x=x,
x=sizes.x,
ranks=[0, 1],
data=data,
summarize=summarize,
@ -1883,7 +1893,7 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
name=name)
else:
assertion = assert_rank(
x=x,
x=sizes.x,
rank=rank,
data=data,
summarize=summarize,
@ -1893,19 +1903,15 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
size_assertions = []
size_specifications = {}
for x in shapes.keys():
actual_sizes = actual_sizes_by_tensor[x]
symbolic_sizes = specified_sizes_by_tensor[x]
innermost_dims = x in tensors_specified_innermost
for i, size_symbol in enumerate(symbolic_sizes):
for sizes in tensor_dim_sizes:
for i, size_symbol in enumerate(sizes.symbolic_sizes):
if _is_symbol_for_any_size(size_symbol):
# Size specified as any implies no constraint
continue
if innermost_dims:
tensor_dim = i - len(symbolic_sizes)
if sizes.unspecified_dim:
tensor_dim = i - len(sizes.symbolic_sizes)
else:
tensor_dim = i
@ -1920,14 +1926,15 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
'Specified by tensor %s dimension %d' %
(tensor_name(specified_by_y), specified_at_dim))
actual_size = actual_sizes[tensor_dim]
actual_size = sizes.actual_sizes[tensor_dim]
if _has_known_value(actual_size) and _has_known_value(specified_size):
if int(actual_size) != int(specified_size):
raise ValueError(
'%s. %s. Tensor %s dimension %s must have size %d. '
'Received size %d, shape %s' %
(message, size_check_message, tensor_name(x), tensor_dim,
specified_size, actual_size, x.get_shape()))
(message, size_check_message, tensor_name(sizes.x),
tensor_dim, specified_size, actual_size,
sizes.x.get_shape()))
# No dynamic assertion needed
continue
@ -1938,15 +1945,15 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
if data is None:
data_ = [
message, size_check_message,
'Tensor %s dimension' % tensor_name(x), tensor_dim,
'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim,
'must have size', specified_size, 'Received shape: ',
array_ops.shape(x)
array_ops.shape(sizes.x)
]
size_assertions.append(
control_flow_ops.Assert(condition, data_, summarize=summarize))
else:
size = actual_sizes[tensor_dim]
size_specifications[size_symbol] = (size, x, tensor_dim)
size = sizes.actual_sizes[tensor_dim]
size_specifications[size_symbol] = (size, sizes.x, tensor_dim)
with ops.control_dependencies(rank_assertions):
shapes_assertion = control_flow_ops.group(size_assertions)