Improve the documentation of tf.debugging.assert_shapes()

- Specify the first argument as a list of (tensor, shape) tuples.
  We no longer want to publicly support hashable Tensors, as in TF2
  Tensors are not hashable.
- Better explain the shape iterables.

PiperOrigin-RevId: 317581098
Change-Id: I6bb2a114c884ee5b282f085865a5e6e30eb6ae15
This commit is contained in:
Shanqing Cai 2020-06-21 19:49:53 -07:00 committed by TensorFlower Gardener
parent ebf57bdfc7
commit 8e9a05117b

View File

@ -146,7 +146,7 @@ def _unary_assert_doc(sym, sym_name):
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x {sym}` is False. The check can be performed immediately during
`x {sym}` is False. The check can be performed immediately during
eager execution or if `x` is statically known.
""".format(
sym=sym, sym_name=cap_sym_name, opname=opname)
@ -209,7 +209,7 @@ def _binary_assert_doc(sym):
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x {sym} y` is False. The check can be performed immediately during
`x {sym} y` is False. The check can be performed immediately during
eager execution or if `x` and `y` are statically known.
""".format(
sym=sym, opname=opname)
@ -1634,7 +1634,7 @@ def assert_shapes_v2(shapes, data=None, summarize=None, message=None,
>>> n = 10
>>> q = 3
>>> d = 7
>>> x = tf.zeros([n,q])
>>> x = tf.zeros([n,q])
>>> y = tf.ones([n,d])
>>> param = tf.Variable([1.0, 2.0, 3.0])
>>> scalar = 1.0
@ -1644,9 +1644,9 @@ def assert_shapes_v2(shapes, data=None, summarize=None, message=None,
... (param, ('Q',)),
... (scalar, ()),
... ])
>>> tf.debugging.assert_shapes([
... (x, ('N', 'D')),
... (x, ('N', 'D')),
... (y, ('N', 'D'))
... ])
Traceback (most recent call last):
@ -1745,8 +1745,23 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
prefix) are both treated as having a single dimension of size one.
Args:
shapes: dictionary with (`Tensor` to shape) items, or a list of
(`Tensor`, shape) tuples. A shape must be an iterable.
shapes: A list of (`Tensor`, `shape`) tuples, wherein `shape` is the
expected shape of `Tensor`. See the example code above. The `shape` must
be an iterable. Each element of the iterable can be either a concrete
integer value or a string that abstractly represents the dimension.
For example,
- `('N', 'Q')` specifies a 2D shape wherein the first and second
dimensions of shape may or may not be equal.
- `('N', 'N', 'Q')` specifies a 3D shape wherein the first and second
dimensions are equal.
- `(1, 'N')` specifies a 2D shape wherein the first dimension is
exactly 1 and the second dimension can be any value.
Note that the abstract dimension letters take effect across different
tuple elements of the list. For example,
`tf.debugging.assert_shapes([(x, ('N', 'A')), (y, ('N', 'B'))]` asserts
that both `x` and `y` are rank-2 tensors and their first dimensions are
equal (`N`).
`shape` can also be a `tf.TensorShape`.
data: The tensors to print out if the condition is False. Defaults to error
message and first few entries of the violating tensor.
summarize: Print this many entries of the tensor.