Also, start with the eager example before tf.function
.
Mention that `tf.ensure_shape` works with `None`. PiperOrigin-RevId: 355269224 Change-Id: Ic723e64b6d67e9aaabab56791b006c2e018a545b
This commit is contained in:
parent
4d73498e4b
commit
adfa69b27b
@ -2222,37 +2222,59 @@ def assert_scalar(tensor, name=None, message=None):
|
||||
def ensure_shape(x, shape, name=None):
|
||||
"""Updates the shape of a tensor and checks at runtime that the shape holds.
|
||||
|
||||
For example:
|
||||
When executed, this operation asserts that the input tensor `x`'s shape
|
||||
is compatible with the `shape` argument.
|
||||
See `tf.TensorShape.is_compatible_with` for details.
|
||||
|
||||
>>> @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
|
||||
... def f(tensor):
|
||||
... return tf.ensure_shape(tensor, [3, 3])
|
||||
>>>
|
||||
>>> f(tf.zeros([3, 3])) # Passes
|
||||
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
|
||||
array([[0., 0., 0.],
|
||||
[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)>
|
||||
>>> f([1, 2, 3]) # fails
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3].
|
||||
>>> x = tf.constant([[1, 2, 3],
|
||||
... [4, 5, 6]])
|
||||
>>> x = tf.ensure_shape(x, [2, 3])
|
||||
|
||||
The above example raises `tf.errors.InvalidArgumentError`,
|
||||
because the shape (3,) is not compatible with the shape (None, 3, 3)
|
||||
Use `None` for unknown dimensions:
|
||||
|
||||
With eager execution this is a shape assertion, that returns the input:
|
||||
>>> x = tf.ensure_shape(x, [None, 3])
|
||||
>>> x = tf.ensure_shape(x, [2, None])
|
||||
|
||||
If the tensor's shape is not compatible with the `shape` argument, an error
|
||||
is raised:
|
||||
|
||||
>>> x = tf.constant([1,2,3])
|
||||
>>> print(x.shape)
|
||||
(3,)
|
||||
>>> x = tf.ensure_shape(x, [3])
|
||||
>>> x = tf.ensure_shape(x, [5])
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not
|
||||
compatible with expected shape [5]. [Op:EnsureShape]
|
||||
|
||||
During graph construction (typically tracing a `tf.function`),
|
||||
`tf.ensure_shape` updates the static-shape of the **result** tensor by
|
||||
merging the two shapes. See `tf.TensorShape.merge_with` for details.
|
||||
|
||||
This is most useful when **you** know a shape that can't be determined
|
||||
statically by TensorFlow.
|
||||
|
||||
The following trivial `tf.function` prints the input tensor's
|
||||
static-shape before and after `ensure_shape` is applied.
|
||||
|
||||
>>> @tf.function
|
||||
... def f(tensor):
|
||||
... print("Static-shape before:", tensor.shape)
|
||||
... tensor = tf.ensure_shape(tensor, [None, 3])
|
||||
... print("Static-shape after:", tensor.shape)
|
||||
... return tensor
|
||||
|
||||
This lets you see the effect of `tf.ensure_shape` when the function is traced:
|
||||
>>> cf = f.get_concrete_function(tf.TensorSpec([None, None]))
|
||||
Static-shape before: (None, None)
|
||||
Static-shape after: (None, 3)
|
||||
|
||||
>>> cf(tf.zeros([3, 3])) # Passes
|
||||
>>> cf(tf.constant([1, 2, 3])) # fails
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3].
|
||||
|
||||
The above example raises `tf.errors.InvalidArgumentError`, because `x`'s
|
||||
shape, `(3,)`, is not compatible with the `shape` argument, `(None, 3)`
|
||||
|
||||
Inside a `tf.function` or `v1.Graph` context it checks both the buildtime and
|
||||
runtime shapes. This is stricter than `tf.Tensor.set_shape` which only
|
||||
checks the buildtime shape.
|
||||
|
Loading…
x
Reference in New Issue
Block a user