diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 9b5ff4d8d9b..78908e9ab4a 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -2222,37 +2222,59 @@ def assert_scalar(tensor, name=None, message=None): def ensure_shape(x, shape, name=None): """Updates the shape of a tensor and checks at runtime that the shape holds. - For example: + When executed, this operation asserts that the input tensor `x`'s shape + is compatible with the `shape` argument. + See `tf.TensorShape.is_compatible_with` for details. - >>> @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) - ... def f(tensor): - ... return tf.ensure_shape(tensor, [3, 3]) - >>> - >>> f(tf.zeros([3, 3])) # Passes - - >>> f([1, 2, 3]) # fails - Traceback (most recent call last): - ... - InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3]. + >>> x = tf.constant([[1, 2, 3], + ... [4, 5, 6]]) + >>> x = tf.ensure_shape(x, [2, 3]) - The above example raises `tf.errors.InvalidArgumentError`, - because the shape (3,) is not compatible with the shape (None, 3, 3) + Use `None` for unknown dimensions: - With eager execution this is a shape assertion, that returns the input: + >>> x = tf.ensure_shape(x, [None, 3]) + >>> x = tf.ensure_shape(x, [2, None]) + + If the tensor's shape is not compatible with the `shape` argument, an error + is raised: - >>> x = tf.constant([1,2,3]) - >>> print(x.shape) - (3,) - >>> x = tf.ensure_shape(x, [3]) >>> x = tf.ensure_shape(x, [5]) Traceback (most recent call last): ... tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not compatible with expected shape [5]. [Op:EnsureShape] + During graph construction (typically tracing a `tf.function`), + `tf.ensure_shape` updates the static-shape of the **result** tensor by + merging the two shapes. See `tf.TensorShape.merge_with` for details. + + This is most useful when **you** know a shape that can't be determined + statically by TensorFlow. + + The following trivial `tf.function` prints the input tensor's + static-shape before and after `ensure_shape` is applied. + + >>> @tf.function + ... def f(tensor): + ... print("Static-shape before:", tensor.shape) + ... tensor = tf.ensure_shape(tensor, [None, 3]) + ... print("Static-shape after:", tensor.shape) + ... return tensor + + This lets you see the effect of `tf.ensure_shape` when the function is traced: + >>> cf = f.get_concrete_function(tf.TensorSpec([None, None])) + Static-shape before: (None, None) + Static-shape after: (None, 3) + + >>> cf(tf.zeros([3, 3])) # Passes + >>> cf(tf.constant([1, 2, 3])) # fails + Traceback (most recent call last): + ... + InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3]. + + The above example raises `tf.errors.InvalidArgumentError`, because `x`'s + shape, `(3,)`, is not compatible with the `shape` argument, `(None, 3)` + Inside a `tf.function` or `v1.Graph` context it checks both the buildtime and runtime shapes. This is stricter than `tf.Tensor.set_shape` which only checks the buildtime shape.