From f3f5afbe7bf3499e8735df0655344e7dc7ae554e Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 26 Feb 2020 10:38:36 +0000 Subject: [PATCH 1/2] docs: add tip to prefer tf.shape(x) over x.shape when writing custom layers/models See #36991 for details. --- tensorflow/python/ops/array_ops.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 50afcfbc6e0..4f03b985b69 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -557,6 +557,14 @@ def shape_v2(input, out_type=dtypes.int32, name=None): >>> a.shape TensorShape([None, None, 10]) + + However, when defining custom layers and models that will be run in graph mode + at some point, prefer `tf.shape(x)` over `x.shape`. `x.shape` is the static shape + of `x` and usually evaluates to `None` in the first dimension during graph + construction (to represent the as yet unknown batch size). This can cause problems in + function calls like `tf.zeros(x.shape[0])` which don't support `None` values. + `tf.shape(x)` on the other hand gives the dynamic shape of `x` which isn't + evaluated until training/predicting begins where the full shape of `x` is known. `tf.shape` and `Tensor.shape` should be identical in eager mode. Within `tf.function` or within a `compat.v1` context, not all dimensions may be From 749bd23af669d0ae90e59ed655ab2818ec10a2ec Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 20 May 2020 06:56:24 +0200 Subject: [PATCH 2/2] shorten tf.shape docstring clarify when it's different from `x.shape` --- tensorflow/python/ops/array_ops.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 4f03b985b69..8c84fe1d450 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -535,19 +535,16 @@ def shape_v2(input, out_type=dtypes.int32, name=None): # pylint: disable=redefined-builtin """Returns the shape of a tensor. - This operation returns a 1-D integer tensor representing the shape of `input`. - This represents the minimal set of known information at definition time. + `tf.shape` returns a 1-D integer tensor representing the shape of `input`. For example: >>> t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]) >>> tf.shape(t) - >>> tf.shape(t).numpy() - array([2, 2, 3], dtype=int32) - Note: When using symbolic tensors, such as when using the Keras functional - API, tf.shape() will return the shape of the symbolic tensor. + Note: When using symbolic tensors, such as when using the Keras API, + tf.shape() will return the shape of the symbolic tensor. >>> a = tf.keras.layers.Input((None, 10)) >>> tf.shape(a) @@ -558,17 +555,12 @@ def shape_v2(input, out_type=dtypes.int32, name=None): >>> a.shape TensorShape([None, None, 10]) - However, when defining custom layers and models that will be run in graph mode - at some point, prefer `tf.shape(x)` over `x.shape`. `x.shape` is the static shape - of `x` and usually evaluates to `None` in the first dimension during graph - construction (to represent the as yet unknown batch size). This can cause problems in - function calls like `tf.zeros(x.shape[0])` which don't support `None` values. - `tf.shape(x)` on the other hand gives the dynamic shape of `x` which isn't - evaluated until training/predicting begins where the full shape of `x` is known. + (The first `None` represents the as yet unknown batch size.) `tf.shape` and `Tensor.shape` should be identical in eager mode. Within `tf.function` or within a `compat.v1` context, not all dimensions may be - known until execution time. + known until execution time. Hence when defining custom layers and models + for graph mode, prefer the dynamic `tf.shape(x)` over the static `x.shape`. Args: input: A `Tensor` or `SparseTensor`.