Modify tf.assert_shapes to take list of pairs
To support the tensor equality changes. We need to assume that tensors are unhashable. Thus we need change tf.assert_shapes to take a list of pairs instead of a dict. PiperOrigin-RevId: 258892939
This commit is contained in:
parent
7b8cd8dce9
commit
47e650a119
@ -1463,10 +1463,10 @@ class AssertShapesTest(test.TestCase):
|
||||
def test_raise_static_shape_mismatch(self):
|
||||
x = array_ops.ones([3, 2], name="x")
|
||||
y = array_ops.ones([2, 3], name="y")
|
||||
shapes = {
|
||||
x: ("N", "Q"),
|
||||
y: ("N", "D"),
|
||||
}
|
||||
shapes = [
|
||||
(x, ("N", "Q")),
|
||||
(y, ("N", "D")),
|
||||
]
|
||||
regex = (r"Specified by tensor .* dimension 0. "
|
||||
r"Tensor .* dimension 0 must have size 3. "
|
||||
r"Received size 2")
|
||||
@ -1476,10 +1476,10 @@ class AssertShapesTest(test.TestCase):
|
||||
with ops.Graph().as_default():
|
||||
x = array_ops.placeholder(dtypes.float32, [None, 2], name="x")
|
||||
y = array_ops.placeholder(dtypes.float32, [None, 3], name="y")
|
||||
shapes = {
|
||||
x: ("N", "Q"),
|
||||
y: ("N", "D"),
|
||||
}
|
||||
shapes = [
|
||||
(x, ("N", "Q")),
|
||||
(y, ("N", "D")),
|
||||
]
|
||||
regex = (r"\[Specified by tensor x.* dimension 0\] "
|
||||
r"\[Tensor y.* dimension\] \[0\] \[must have size\] \[3\]")
|
||||
feed_dict = {x: np.ones([3, 2]), y: np.ones([2, 3])}
|
||||
@ -1489,10 +1489,10 @@ class AssertShapesTest(test.TestCase):
|
||||
def test_raise_static_shape_explicit_mismatch(self):
|
||||
x = array_ops.ones([3, 2], name="x")
|
||||
y = array_ops.ones([2, 3], name="y")
|
||||
shapes = {
|
||||
x: (3, "Q"),
|
||||
y: (3, "D"),
|
||||
}
|
||||
shapes = [
|
||||
(x, (3, "Q")),
|
||||
(y, (3, "D")),
|
||||
]
|
||||
regex = (r"Specified explicitly. "
|
||||
r"Tensor .* dimension 0 must have size 3. "
|
||||
r"Received size 2")
|
||||
@ -1502,21 +1502,21 @@ class AssertShapesTest(test.TestCase):
|
||||
def test_rank_zero_rank_one_size_one_equivalence(self):
|
||||
rank_one_size_one = array_ops.ones([1], name="rank_one_size_one")
|
||||
rank_zero = array_ops.constant(5, name="rank_zero")
|
||||
check_ops.assert_shapes({
|
||||
rank_one_size_one: (),
|
||||
rank_zero: (),
|
||||
})
|
||||
check_ops.assert_shapes({
|
||||
rank_one_size_one: (1,),
|
||||
rank_zero: (1,),
|
||||
})
|
||||
check_ops.assert_shapes([
|
||||
(rank_one_size_one, ()),
|
||||
(rank_zero, ()),
|
||||
])
|
||||
check_ops.assert_shapes([
|
||||
(rank_one_size_one, (1,)),
|
||||
(rank_zero, (1,)),
|
||||
])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_raise_static_rank_1_size_not_1_mismatch_scalar(self):
|
||||
x = array_ops.constant([2, 2], name="x")
|
||||
shapes = {
|
||||
x: (),
|
||||
}
|
||||
shapes = [
|
||||
(x, ()),
|
||||
]
|
||||
regex = (r"Specified explicitly. "
|
||||
r"Tensor .* dimension 0 must have size 1. "
|
||||
r"Received size 2")
|
||||
@ -1525,9 +1525,9 @@ class AssertShapesTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_raise_static_scalar_mismatch_rank_1_size_not_1(self):
|
||||
x = array_ops.constant(2, name="x")
|
||||
shapes = {
|
||||
x: (2,),
|
||||
}
|
||||
shapes = [
|
||||
(x, (2,)),
|
||||
]
|
||||
regex = (r"Specified explicitly. "
|
||||
r"Tensor .* dimension 0 must have size 2. "
|
||||
r"Received size 1")
|
||||
@ -1537,7 +1537,7 @@ class AssertShapesTest(test.TestCase):
|
||||
def test_scalar_implies_size_one(self):
|
||||
scalar = array_ops.constant(5, name="rank_zero")
|
||||
x = array_ops.ones([2, 2], name="x")
|
||||
shapes = {scalar: ("a",), x: ("a", 2)}
|
||||
shapes = [(scalar, ("a",)), (x, ("a", 2))]
|
||||
regex = (r"Specified by tensor .* dimension 0. "
|
||||
r"Tensor .* dimension 0 must have size 1. "
|
||||
r"Received size 2")
|
||||
@ -1546,7 +1546,7 @@ class AssertShapesTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_raise_not_iterable(self):
|
||||
x = array_ops.constant([1, 2], name="x")
|
||||
shapes = {x: 2}
|
||||
shapes = [(x, 2)]
|
||||
regex = (r"Tensor .*. "
|
||||
r"Specified shape must be an iterable. "
|
||||
r"An iterable has the attribute `__iter__` or `__getitem__`. "
|
||||
@ -1557,10 +1557,10 @@ class AssertShapesTest(test.TestCase):
|
||||
with ops.Graph().as_default():
|
||||
x = array_ops.placeholder(dtypes.float32, [None, 2], name="xa")
|
||||
y = array_ops.placeholder(dtypes.float32, [None, 3], name="y")
|
||||
shapes = {
|
||||
x: (3, "Q"),
|
||||
y: (3, "D"),
|
||||
}
|
||||
shapes = [
|
||||
(x, (3, "Q")),
|
||||
(y, (3, "D")),
|
||||
]
|
||||
regex = (r"\[Specified explicitly\] "
|
||||
r"\[Tensor y.* dimension\] \[0\] \[must have size\] \[3\]")
|
||||
feed_dict = {x: np.ones([3, 2]), y: np.ones([2, 3])}
|
||||
@ -1569,7 +1569,7 @@ class AssertShapesTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_no_op_when_specified_as_unknown(self):
|
||||
x = array_ops.constant([1, 1], name="x")
|
||||
assertion = check_ops.assert_shapes({x: None})
|
||||
assertion = check_ops.assert_shapes([(x, None)])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(x)
|
||||
self.evaluate(out)
|
||||
@ -1616,7 +1616,7 @@ class AssertShapesTest(test.TestCase):
|
||||
for shape in rank_two_shapes:
|
||||
regex = r"Tensor .* must have rank\] \[2\]"
|
||||
self.raises_dynamic_error(
|
||||
shapes={x: shape}, regex=regex, feed_dict={x: x_value})
|
||||
shapes=[(x, shape)], regex=regex, feed_dict={x: x_value})
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_correctly_matching(self):
|
||||
@ -1626,25 +1626,25 @@ class AssertShapesTest(test.TestCase):
|
||||
x = array_ops.ones([1, 2, 3], name="x")
|
||||
y = array_ops.ones([3, 1, 2], name="y")
|
||||
z = array_ops.ones([2, 3, 1], name="z")
|
||||
assertion = check_ops.assert_shapes({
|
||||
x: ("a", "b", "c"),
|
||||
y: ("c", "a", "b"),
|
||||
z: ("b", "c", "a"),
|
||||
v: ("a", "b"),
|
||||
w: ("c",),
|
||||
u: "a"
|
||||
})
|
||||
assertion = check_ops.assert_shapes([
|
||||
(x, ("a", "b", "c")),
|
||||
(y, ("c", "a", "b")),
|
||||
(z, ("b", "c", "a")),
|
||||
(v, ("a", "b")),
|
||||
(w, ("c",)),
|
||||
(u, "a")
|
||||
])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(x)
|
||||
self.evaluate(out)
|
||||
assertion = check_ops.assert_shapes({
|
||||
x: (1, "b", "c"),
|
||||
y: ("c", "a", 2),
|
||||
z: ("b", 3, "a"),
|
||||
v: ("a", 2),
|
||||
w: (3,),
|
||||
u: ()
|
||||
})
|
||||
assertion = check_ops.assert_shapes([
|
||||
(x, (1, "b", "c")),
|
||||
(y, ("c", "a", 2)),
|
||||
(z, ("b", 3, "a")),
|
||||
(v, ("a", 2)),
|
||||
(w, (3,)),
|
||||
(u, ())
|
||||
])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(x)
|
||||
self.evaluate(out)
|
||||
@ -1653,10 +1653,10 @@ class AssertShapesTest(test.TestCase):
|
||||
def test_variable_length_symbols(self):
|
||||
x = array_ops.ones([4, 1], name="x")
|
||||
y = array_ops.ones([4, 2], name="y")
|
||||
assertion = check_ops.assert_shapes({
|
||||
x: ("num_observations", "input_dim"),
|
||||
y: ("num_observations", "output_dim"),
|
||||
})
|
||||
assertion = check_ops.assert_shapes([
|
||||
(x, ("num_observations", "input_dim")),
|
||||
(y, ("num_observations", "output_dim")),
|
||||
])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(x)
|
||||
self.evaluate(out)
|
||||
@ -1665,22 +1665,22 @@ class AssertShapesTest(test.TestCase):
|
||||
def test_raise_implicit_mismatch_using_iterable_alternatives(self):
|
||||
x = array_ops.ones([2, 2], name="x")
|
||||
y = array_ops.ones([1, 3], name="y")
|
||||
styles = [{
|
||||
x: ("A", "B"),
|
||||
y: ("A", "C"),
|
||||
}, {
|
||||
x: "AB",
|
||||
y: "AC"
|
||||
}, {
|
||||
x: ["A", "B"],
|
||||
y: ["A", "C"],
|
||||
}, {
|
||||
x: np.array(["A", "B"]),
|
||||
y: np.array(["A", "C"])
|
||||
}, {
|
||||
x: ("A", "B"),
|
||||
y: "AC"
|
||||
}]
|
||||
styles = [[
|
||||
(x, ("A", "B")),
|
||||
(y, ("A", "C")),
|
||||
], [
|
||||
(x, "AB"),
|
||||
(y, "AC")
|
||||
], [
|
||||
(x, ["A", "B"]),
|
||||
(y, ["A", "C"]),
|
||||
], [
|
||||
(x, np.array(["A", "B"])),
|
||||
(y, np.array(["A", "C"]))
|
||||
], [
|
||||
(x, ("A", "B")),
|
||||
(y, "AC")
|
||||
]]
|
||||
for shapes in styles:
|
||||
self.raises_static_error(
|
||||
shapes=shapes,
|
||||
@ -1692,22 +1692,22 @@ class AssertShapesTest(test.TestCase):
|
||||
def test_raise_explicit_mismatch_using_iterable_alternatives(self):
|
||||
x = array_ops.ones([2, 2], name="x")
|
||||
y = array_ops.ones([1, 3], name="y")
|
||||
styles = [{
|
||||
x: (2, 2),
|
||||
y: (2, 3),
|
||||
}, {
|
||||
x: "22",
|
||||
y: "23"
|
||||
}, {
|
||||
x: [2, 2],
|
||||
y: [2, 3],
|
||||
}, {
|
||||
x: np.array([2, 2]),
|
||||
y: np.array([2, 3])
|
||||
}, {
|
||||
x: (2, 2),
|
||||
y: "23"
|
||||
}]
|
||||
styles = [[
|
||||
(x, (2, 2)),
|
||||
(y, (2, 3)),
|
||||
], [
|
||||
(x, "22"),
|
||||
(y, "23")
|
||||
], [
|
||||
(x, [2, 2]),
|
||||
(y, [2, 3]),
|
||||
], [
|
||||
(x, np.array([2, 2])),
|
||||
(y, np.array([2, 3]))
|
||||
], [
|
||||
(x, (2, 2)),
|
||||
(y, "23")
|
||||
]]
|
||||
for shapes in styles:
|
||||
self.raises_static_error(
|
||||
shapes=shapes,
|
||||
@ -1720,16 +1720,16 @@ class AssertShapesTest(test.TestCase):
|
||||
x = array_ops.ones([1, 2, 3], name="x")
|
||||
y = array_ops.ones([2, 1], name="y")
|
||||
a1 = check_ops.assert_shapes({
|
||||
x: (None, 2, None),
|
||||
y: (None, 1),
|
||||
(x, (None, 2, None)),
|
||||
(y, (None, 1)),
|
||||
})
|
||||
a2 = check_ops.assert_shapes({
|
||||
x: (".", 2, "."),
|
||||
y: (".", 1),
|
||||
(x, (".", 2, ".")),
|
||||
(y, (".", 1)),
|
||||
})
|
||||
a3 = check_ops.assert_shapes({
|
||||
x: ".2.",
|
||||
y: ".1",
|
||||
(x, ".2."),
|
||||
(y, ".1"),
|
||||
})
|
||||
with ops.control_dependencies([a1, a2, a3]):
|
||||
out = array_ops.identity(x)
|
||||
@ -1739,14 +1739,14 @@ class AssertShapesTest(test.TestCase):
|
||||
def test_raise_static_shape_explicit_mismatch_innermost_dims(self):
|
||||
x = array_ops.ones([3, 2], name="x")
|
||||
y = array_ops.ones([2, 3], name="y")
|
||||
s1 = {
|
||||
x: (3, "Q"),
|
||||
y: (Ellipsis, 3, "D"),
|
||||
}
|
||||
s2 = {
|
||||
x: "3Q",
|
||||
y: "*3D",
|
||||
}
|
||||
s1 = [
|
||||
(x, (3, "Q")),
|
||||
(y, (Ellipsis, 3, "D")),
|
||||
]
|
||||
s2 = [
|
||||
(x, "3Q"),
|
||||
(y, "*3D"),
|
||||
]
|
||||
regex = (r"Specified explicitly. "
|
||||
r"Tensor .* dimension -2 must have size 3. "
|
||||
r"Received size 2")
|
||||
@ -1757,14 +1757,14 @@ class AssertShapesTest(test.TestCase):
|
||||
def test_correctly_matching_innermost_dims(self):
|
||||
x = array_ops.ones([1, 2, 3, 2], name="x")
|
||||
y = array_ops.ones([2, 3, 3], name="y")
|
||||
a1 = check_ops.assert_shapes({
|
||||
x: (Ellipsis, "N", "Q"),
|
||||
y: (Ellipsis, "N", "D"),
|
||||
})
|
||||
a2 = check_ops.assert_shapes({
|
||||
x: "*NQ",
|
||||
y: "*ND",
|
||||
})
|
||||
a1 = check_ops.assert_shapes([
|
||||
(x, (Ellipsis, "N", "Q")),
|
||||
(y, (Ellipsis, "N", "D")),
|
||||
])
|
||||
a2 = check_ops.assert_shapes([
|
||||
(x, "*NQ"),
|
||||
(y, "*ND"),
|
||||
])
|
||||
with ops.control_dependencies([a1, a2]):
|
||||
out = array_ops.identity(x)
|
||||
self.evaluate(out)
|
||||
@ -1772,12 +1772,12 @@ class AssertShapesTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_raise_variable_num_outer_dims_prefix_misuse(self):
|
||||
x = array_ops.ones([1, 2], name="x")
|
||||
s1 = {
|
||||
x: ("N", Ellipsis, "Q"),
|
||||
}
|
||||
s2 = {
|
||||
x: "N*Q",
|
||||
}
|
||||
s1 = [
|
||||
(x, ("N", Ellipsis, "Q")),
|
||||
]
|
||||
s2 = [
|
||||
(x, "N*Q"),
|
||||
]
|
||||
regex = (r"Tensor .* specified shape index .*. "
|
||||
r"Symbol `...` or `\*` for a variable number of "
|
||||
r"unspecified dimensions is only allowed as the first entry")
|
||||
@ -1786,7 +1786,7 @@ class AssertShapesTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_empty_shapes_dict_no_op(self):
|
||||
assertion = check_ops.assert_shapes({})
|
||||
assertion = check_ops.assert_shapes([])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(0)
|
||||
self.evaluate(out)
|
||||
|
@ -19,6 +19,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
@ -1663,8 +1664,10 @@ def _dimension_sizes(x):
|
||||
|
||||
|
||||
def _symbolic_dimension_sizes(symbolic_shape):
|
||||
if len(symbolic_shape) == 0:
|
||||
# If len(symbolic_shape) == 0 construct a tuple
|
||||
if not symbolic_shape:
|
||||
return tuple([1])
|
||||
|
||||
return symbolic_shape
|
||||
|
||||
|
||||
@ -1682,8 +1685,9 @@ def _is_symbol_for_any_size(symbol):
|
||||
return symbol in [None, '.']
|
||||
|
||||
|
||||
def _is_symbol_for_unspecified_dims(symbol):
|
||||
return symbol in [Ellipsis, '*']
|
||||
_TensorDimSizes = collections.namedtuple(
|
||||
'_TensorDimSizes',
|
||||
['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes'])
|
||||
|
||||
|
||||
@tf_export('debugging.assert_shapes', v1=[])
|
||||
@ -1697,12 +1701,12 @@ def assert_shapes_v2(shapes, data=None, summarize=None, message=None,
|
||||
Example:
|
||||
|
||||
```python
|
||||
tf.assert_shapes({
|
||||
x: ('N', 'Q'),
|
||||
y: ('N', 'D'),
|
||||
param: ('Q',),
|
||||
scalar: ()
|
||||
})
|
||||
tf.assert_shapes([
|
||||
(x: ('N', 'Q')),
|
||||
(y: ('N', 'D')),
|
||||
(param: ('Q',)),
|
||||
(scalar: ()),
|
||||
])
|
||||
```
|
||||
|
||||
If `x`, `y`, `param` or `scalar` does not have a shape that satisfies
|
||||
@ -1750,10 +1754,10 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
|
||||
|
||||
```python
|
||||
tf.assert_shapes({
|
||||
x: ('N', 'Q'),
|
||||
y: ('N', 'D'),
|
||||
param: ('Q',),
|
||||
scalar: ()
|
||||
(x, ('N', 'Q')),
|
||||
(y, ('N', 'D')),
|
||||
(param, ('Q',)),
|
||||
(scalar, ())
|
||||
})
|
||||
```
|
||||
|
||||
@ -1800,13 +1804,17 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
|
||||
Raises:
|
||||
ValueError: If static checks determine any shape constraint is violated.
|
||||
"""
|
||||
# If the user manages to assemble a dict containing tensors (possible in
|
||||
# Graph mode only), make sure we still accept that.
|
||||
if isinstance(shapes, dict):
|
||||
shapes = shapes.items()
|
||||
|
||||
message = message or ''
|
||||
with ops.name_scope(name, 'assert_shapes', [shapes, data]):
|
||||
|
||||
# Shape specified as None implies no constraint
|
||||
shapes = {x: shapes[x] for x in shapes if shapes[x] is not None}
|
||||
|
||||
shapes = {ops.convert_to_tensor(x): shapes[x] for x in shapes}
|
||||
shape_constraints = [
|
||||
(ops.convert_to_tensor(x), s) for x, s in shapes if s is not None
|
||||
]
|
||||
|
||||
executing_eagerly = context.executing_eagerly()
|
||||
|
||||
@ -1815,8 +1823,8 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
|
||||
return _shape_and_dtype_str(x)
|
||||
return x.name
|
||||
|
||||
for x in shapes:
|
||||
symbolic_shape = shapes[x]
|
||||
tensor_dim_sizes = []
|
||||
for tensor, symbolic_shape in shape_constraints:
|
||||
is_iterable = (
|
||||
hasattr(symbolic_shape, '__iter__') or
|
||||
hasattr(symbolic_shape, '__getitem__') # For Python 2 compat.
|
||||
@ -1827,44 +1835,46 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
|
||||
'Tensor %s. Specified shape must be an iterable. '
|
||||
'An iterable has the attribute `__iter__` or `__getitem__`. '
|
||||
'Received specified shape: %s' %
|
||||
(message, tensor_name(x), symbolic_shape))
|
||||
shapes[x] = tuple(shapes[x])
|
||||
(message, tensor_name(tensor), symbolic_shape))
|
||||
|
||||
tensors_specified_innermost = set()
|
||||
for x in shapes:
|
||||
symbolic_shape = shapes[x]
|
||||
for i, symbol in enumerate(symbolic_shape):
|
||||
if not _is_symbol_for_unspecified_dims(symbol):
|
||||
# We convert this into a tuple to handle strings, lists and numpy arrays
|
||||
symbolic_shape_tuple = tuple(symbolic_shape)
|
||||
|
||||
tensors_specified_innermost = False
|
||||
for i, symbol in enumerate(symbolic_shape_tuple):
|
||||
if symbol not in [Ellipsis, '*']:
|
||||
continue
|
||||
|
||||
if i != 0:
|
||||
raise ValueError(
|
||||
'%s. '
|
||||
'Tensor %s specified shape index %d. '
|
||||
'Symbol `...` or `*` for a variable number of '
|
||||
'unspecified dimensions is only allowed as the first entry' %
|
||||
(message, tensor_name(x), i))
|
||||
tensors_specified_innermost.add(x)
|
||||
(message, tensor_name(tensor), i))
|
||||
|
||||
actual_sizes_by_tensor = {x: _dimension_sizes(x) for x in shapes}
|
||||
specified_sizes_by_tensor = {
|
||||
x: _symbolic_dimension_sizes(
|
||||
# Ignoring innermost prefix
|
||||
shapes[x][1:] if x in tensors_specified_innermost else shapes[x])
|
||||
for x in shapes
|
||||
}
|
||||
tensors_specified_innermost = True
|
||||
|
||||
# Only include the size of the specified dimensions since the 0th symbol
|
||||
# is either ellipsis or *
|
||||
tensor_dim_sizes.append(
|
||||
_TensorDimSizes(
|
||||
tensor, tensors_specified_innermost, _dimension_sizes(tensor),
|
||||
_symbolic_dimension_sizes(
|
||||
symbolic_shape_tuple[1:]
|
||||
if tensors_specified_innermost else symbolic_shape_tuple)))
|
||||
|
||||
rank_assertions = []
|
||||
for x in shapes.keys():
|
||||
symbolic_sizes = specified_sizes_by_tensor[x]
|
||||
rank = len(symbolic_sizes)
|
||||
for sizes in tensor_dim_sizes:
|
||||
rank = len(sizes.symbolic_sizes)
|
||||
rank_zero_or_one = rank in [0, 1]
|
||||
if x in tensors_specified_innermost:
|
||||
if sizes.unspecified_dim:
|
||||
if rank_zero_or_one:
|
||||
# No assertion of rank needed as `x` only need to have rank at least 0.
|
||||
# See elif rank_zero_or_one case comment.
|
||||
# No assertion of rank needed as `x` only need to have rank at least
|
||||
# 0. See elif rank_zero_or_one case comment.
|
||||
continue
|
||||
assertion = assert_rank_at_least(
|
||||
x=x,
|
||||
x=sizes.x,
|
||||
rank=rank,
|
||||
data=data,
|
||||
summarize=summarize,
|
||||
@ -1875,7 +1885,7 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
|
||||
# no distinction between the two in terms of rank.
|
||||
# See _dimension_sizes.
|
||||
assertion = assert_rank_in(
|
||||
x=x,
|
||||
x=sizes.x,
|
||||
ranks=[0, 1],
|
||||
data=data,
|
||||
summarize=summarize,
|
||||
@ -1883,7 +1893,7 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
|
||||
name=name)
|
||||
else:
|
||||
assertion = assert_rank(
|
||||
x=x,
|
||||
x=sizes.x,
|
||||
rank=rank,
|
||||
data=data,
|
||||
summarize=summarize,
|
||||
@ -1893,19 +1903,15 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
|
||||
|
||||
size_assertions = []
|
||||
size_specifications = {}
|
||||
for x in shapes.keys():
|
||||
actual_sizes = actual_sizes_by_tensor[x]
|
||||
symbolic_sizes = specified_sizes_by_tensor[x]
|
||||
innermost_dims = x in tensors_specified_innermost
|
||||
|
||||
for i, size_symbol in enumerate(symbolic_sizes):
|
||||
for sizes in tensor_dim_sizes:
|
||||
for i, size_symbol in enumerate(sizes.symbolic_sizes):
|
||||
|
||||
if _is_symbol_for_any_size(size_symbol):
|
||||
# Size specified as any implies no constraint
|
||||
continue
|
||||
|
||||
if innermost_dims:
|
||||
tensor_dim = i - len(symbolic_sizes)
|
||||
if sizes.unspecified_dim:
|
||||
tensor_dim = i - len(sizes.symbolic_sizes)
|
||||
else:
|
||||
tensor_dim = i
|
||||
|
||||
@ -1920,14 +1926,15 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
|
||||
'Specified by tensor %s dimension %d' %
|
||||
(tensor_name(specified_by_y), specified_at_dim))
|
||||
|
||||
actual_size = actual_sizes[tensor_dim]
|
||||
actual_size = sizes.actual_sizes[tensor_dim]
|
||||
if _has_known_value(actual_size) and _has_known_value(specified_size):
|
||||
if int(actual_size) != int(specified_size):
|
||||
raise ValueError(
|
||||
'%s. %s. Tensor %s dimension %s must have size %d. '
|
||||
'Received size %d, shape %s' %
|
||||
(message, size_check_message, tensor_name(x), tensor_dim,
|
||||
specified_size, actual_size, x.get_shape()))
|
||||
(message, size_check_message, tensor_name(sizes.x),
|
||||
tensor_dim, specified_size, actual_size,
|
||||
sizes.x.get_shape()))
|
||||
# No dynamic assertion needed
|
||||
continue
|
||||
|
||||
@ -1938,15 +1945,15 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
|
||||
if data is None:
|
||||
data_ = [
|
||||
message, size_check_message,
|
||||
'Tensor %s dimension' % tensor_name(x), tensor_dim,
|
||||
'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim,
|
||||
'must have size', specified_size, 'Received shape: ',
|
||||
array_ops.shape(x)
|
||||
array_ops.shape(sizes.x)
|
||||
]
|
||||
size_assertions.append(
|
||||
control_flow_ops.Assert(condition, data_, summarize=summarize))
|
||||
else:
|
||||
size = actual_sizes[tensor_dim]
|
||||
size_specifications[size_symbol] = (size, x, tensor_dim)
|
||||
size = sizes.actual_sizes[tensor_dim]
|
||||
size_specifications[size_symbol] = (size, sizes.x, tensor_dim)
|
||||
|
||||
with ops.control_dependencies(rank_assertions):
|
||||
shapes_assertion = control_flow_ops.group(size_assertions)
|
||||
|
Loading…
x
Reference in New Issue
Block a user