Also, start with the eager example before tf.function.

Mention that `tf.ensure_shape` works with `None`.

PiperOrigin-RevId: 355269224
Change-Id: Ic723e64b6d67e9aaabab56791b006c2e018a545b
This commit is contained in:
Mark Daoust 2021-02-02 15:43:56 -08:00 committed by TensorFlower Gardener
parent 4d73498e4b
commit adfa69b27b

View File

@ -2222,37 +2222,59 @@ def assert_scalar(tensor, name=None, message=None):
def ensure_shape(x, shape, name=None):
"""Updates the shape of a tensor and checks at runtime that the shape holds.
For example:
When executed, this operation asserts that the input tensor `x`'s shape
is compatible with the `shape` argument.
See `tf.TensorShape.is_compatible_with` for details.
>>> @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
... def f(tensor):
... return tf.ensure_shape(tensor, [3, 3])
>>>
>>> f(tf.zeros([3, 3])) # Passes
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32)>
>>> f([1, 2, 3]) # fails
Traceback (most recent call last):
...
InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3].
>>> x = tf.constant([[1, 2, 3],
... [4, 5, 6]])
>>> x = tf.ensure_shape(x, [2, 3])
The above example raises `tf.errors.InvalidArgumentError`,
because the shape (3,) is not compatible with the shape (None, 3, 3)
Use `None` for unknown dimensions:
With eager execution this is a shape assertion, that returns the input:
>>> x = tf.ensure_shape(x, [None, 3])
>>> x = tf.ensure_shape(x, [2, None])
If the tensor's shape is not compatible with the `shape` argument, an error
is raised:
>>> x = tf.constant([1,2,3])
>>> print(x.shape)
(3,)
>>> x = tf.ensure_shape(x, [3])
>>> x = tf.ensure_shape(x, [5])
Traceback (most recent call last):
...
tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not
compatible with expected shape [5]. [Op:EnsureShape]
During graph construction (typically tracing a `tf.function`),
`tf.ensure_shape` updates the static-shape of the **result** tensor by
merging the two shapes. See `tf.TensorShape.merge_with` for details.
This is most useful when **you** know a shape that can't be determined
statically by TensorFlow.
The following trivial `tf.function` prints the input tensor's
static-shape before and after `ensure_shape` is applied.
>>> @tf.function
... def f(tensor):
... print("Static-shape before:", tensor.shape)
... tensor = tf.ensure_shape(tensor, [None, 3])
... print("Static-shape after:", tensor.shape)
... return tensor
This lets you see the effect of `tf.ensure_shape` when the function is traced:
>>> cf = f.get_concrete_function(tf.TensorSpec([None, None]))
Static-shape before: (None, None)
Static-shape after: (None, 3)
>>> cf(tf.zeros([3, 3])) # Passes
>>> cf(tf.constant([1, 2, 3])) # fails
Traceback (most recent call last):
...
InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3].
The above example raises `tf.errors.InvalidArgumentError`, because `x`'s
shape, `(3,)`, is not compatible with the `shape` argument, `(None, 3)`
Inside a `tf.function` or `v1.Graph` context it checks both the buildtime and
runtime shapes. This is stricter than `tf.Tensor.set_shape` which only
checks the buildtime shape.