From 99118cceb86180920e939aca9ba4a960e4872d03 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 9 Nov 2020 16:45:08 -0800 Subject: [PATCH] Add examples + explanation to `tf.gather` Add example code for image, move image to the top. PiperOrigin-RevId: 341512111 Change-Id: Ieeee2acdf9ebc878187f6ca16c7c102544d1c0a6 --- tensorflow/python/ops/array_ops.py | 199 +++++++++++++++++++++++------ 1 file changed, 159 insertions(+), 40 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 843021ae046..284bc5147a3 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -4734,6 +4734,11 @@ def reverse_sequence_v2(input, @tf_export(v1=["gather"]) +@deprecation.deprecated_args(None, + ("The `validate_indices` argument has no effect. " + "Indices are always validated on CPU and never " + "validated on GPU."), + "validate_indices") @dispatch.add_dispatch_support def gather(params, indices, @@ -4743,62 +4748,176 @@ def gather(params, batch_dims=0): # pylint: disable=g-doc-args r"""Gather slices from params axis `axis` according to indices. - Gather slices from params axis `axis` according to `indices`. `indices` must - be an integer tensor of any dimension (usually 0-D or 1-D). + Gather slices from `params` axis `axis` according to `indices`. `indices` + must be an integer tensor of any dimension (often 1-D). - For 0-D (scalar) `indices`: + `Tensor.__getitem__` works for scalars, `tf.newaxis`, and + [python slices](https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing) - $$\begin{align*} - output[p_0, ..., p_{axis-1}, && &&& p_{axis + 1}, ..., p_{N-1}] = \\ - params[p_0, ..., p_{axis-1}, && indices, &&& p_{axis + 1}, ..., p_{N-1}] - \end{align*}$$ + `tf.gather` extends indexing to handle tensors of indices. - Where *N* = `ndims(params)`. + In the simplest case it's identical to scalar indexing: - For 1-D (vector) `indices` with `batch_dims=0`: + >>> params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5']) + >>> params[3].numpy() + b'p3' + >>> tf.gather(params, 3).numpy() + b'p3' - $$\begin{align*} - output[p_0, ..., p_{axis-1}, && &i, &&p_{axis + 1}, ..., p_{N-1}] =\\ - params[p_0, ..., p_{axis-1}, && indices[&i], &&p_{axis + 1}, ..., p_{N-1}] - \end{align*}$$ + The most common case is to pass a single axis tensor of indices (this + can't be expressed as a python slice because the indices are not sequential): - In the general case, produces an output tensor where: - - $$\begin{align*} - output[p_0, &..., p_{axis-1}, & - &i_{B}, ..., i_{M-1}, & - p_{axis + 1}, &..., p_{N-1}] = \\ - params[p_0, &..., p_{axis-1}, & - indices[p_0, ..., p_{B-1}, &i_{B}, ..., i_{M-1}], & - p_{axis + 1}, &..., p_{N-1}] - \end{align*}$$ - - Where *N* = `ndims(params)`, *M* = `ndims(indices)`, and *B* = `batch_dims`. - Note that `params.shape[:batch_dims]` must be identical to - `indices.shape[:batch_dims]`. - - The shape of the output tensor is: - - > `output.shape = params.shape[:axis] + indices.shape[batch_dims:] + - > params.shape[axis + 1:]`. - - Note that on CPU, if an out of bound index is found, an error is returned. - On GPU, if an out of bound index is found, a 0 is stored in the corresponding - output value. - - See also `tf.gather_nd`. + >>> indices = [2, 0, 2, 5] + >>> tf.gather(params, indices).numpy() + array([b'p2', b'p0', b'p2', b'p5'], dtype=object)
+ The indices can have any shape. When the `params` has 1 axis, the + output shape is equal to the input shape: + + >>> tf.gather(params, [[2, 0], [2, 5]]).numpy() + array([[b'p2', b'p0'], + [b'p2', b'p5']], dtype=object) + + The `params` may also have any shape. `gather` can select slices + across any axis depending on the `axis` argument (which defaults to 0). + Below it is used to gather first rows, then columns from a matrix: + + >>> params = tf.constant([[0, 1.0, 2.0], + ... [10.0, 11.0, 12.0], + ... [20.0, 21.0, 22.0], + ... [30.0, 31.0, 32.0]]) + >>> tf.gather(params, indices=[3,1]).numpy() + array([[30., 31., 32.], + [10., 11., 12.]], dtype=float32) + >>> tf.gather(params, indices=[2,1], axis=1).numpy() + array([[ 2., 1.], + [12., 11.], + [22., 21.], + [32., 31.]], dtype=float32) + + More generally: The output shape has the same shape as the input, with the + indexed-axis replaced by the shape of the indices. + + >>> def result_shape(p_shape, i_shape, axis=0): + ... return p_shape[:axis] + i_shape + p_shape[axis+1:] + >>> + >>> result_shape([1, 2, 3], [], axis=1) + [1, 3] + >>> result_shape([1, 2, 3], [7], axis=1) + [1, 7, 3] + >>> result_shape([1, 2, 3], [7, 5], axis=1) + [1, 7, 5, 3] + + Here are some examples: + + >>> params.shape.as_list() + [4, 3] + >>> indices = tf.constant([[0, 2]]) + >>> tf.gather(params, indices=indices, axis=0).shape.as_list() + [1, 2, 3] + >>> tf.gather(params, indices=indices, axis=1).shape.as_list() + [4, 1, 2] + + >>> params = tf.random.normal(shape=(5, 6, 7, 8)) + >>> indices = tf.random.uniform(shape=(10, 11), maxval=7, dtype=tf.int32) + >>> result = tf.gather(params, indices, axis=2) + >>> result.shape.as_list() + [5, 6, 10, 11, 8] + + This is because each index takes a slice from `params`, and + places it at the corresponding location in the output. For the above example + + >>> # For any location in indices + >>> a, b = 0, 1 + >>> tf.reduce_all( + ... # the corresponding slice of the result + ... result[:, :, a, b, :] == + ... # is equal to the slice of `params` along `axis` at the index. + ... params[:, :, indices[a, b], :] + ... ).numpy() + True + + ### Batching: + + The `batch_dims` argument lets you gather different items from each element + of a batch. + + Using `batch_dims=1` is equivalent to having an outer loop over the first + axis of `params` and `indices`: + + >>> params = tf.constant([ + ... [0, 0, 1, 0, 2], + ... [3, 0, 0, 0, 4], + ... [0, 5, 0, 6, 0]]) + >>> indices = tf.constant([ + ... [2, 4], + ... [0, 4], + ... [1, 3]]) + + >>> tf.gather(params, indices, axis=1, batch_dims=1).numpy() + array([[1, 2], + [3, 4], + [5, 6]], dtype=int32) + + This is is equivalent to: + + >>> def manually_batched_gather(params, indices, axis): + ... batch_dims=1 + ... result = [] + ... for p,i in zip(params, indices): + ... r = tf.gather(p, i, axis=axis-batch_dims) + ... result.append(r) + ... return tf.stack(result) + >>> manually_batched_gather(params, indices, axis=1).numpy() + array([[1, 2], + [3, 4], + [5, 6]], dtype=int32) + + Higher values of `batch_dims` are equivalent to multiple nested loops over + the outer axes of `params` and `indices`. So the overall shape function is + + >>> def batched_result_shape(p_shape, i_shape, axis=0, batch_dims=0): + ... return p_shape[:axis] + i_shape[batch_dims:] + p_shape[axis+1:] + >>> + >>> batched_result_shape( + ... p_shape=params.shape.as_list(), + ... i_shape=indices.shape.as_list(), + ... axis=1, + ... batch_dims=1) + [3, 2] + + >>> tf.gather(params, indices, axis=1, batch_dims=1).shape.as_list() + [3, 2] + + See also: + + * `tf.Tensor.__getitem__`: The direct tensor index operation (`t[]`), handles + scalars and python-slices `tensor[..., 7, 1:-1]` + * `tf.scatter`: A collection of operations similar to `__setitem__` + (`t[i] = x`) + * `tf.gather_nd`: An operation similar to `tf.gather` but gathers across + multiple axis at once (it can gather elements of a matrix instead of rows + or columns) + * `tf.boolean_mask`, `tf.where`: Binary indexing. + * `tf.slice` and `tf.strided_slice`: For lower level access to the + implementation of `__getitem__`'s python-slice handling (`t[1:-1:2]) + Args: params: The `Tensor` from which to gather values. Must be at least rank `axis + 1`. indices: The index `Tensor`. Must be one of the following types: `int32`, - `int64`. Must be in range `[0, params.shape[axis])`. - validate_indices: Deprecated, does nothing. + `int64`. The values must be in range `[0, params.shape[axis])`. + validate_indices: Deprecated, does nothing. Indices are always validated on + CPU, never validated on GPU. + + Caution: On CPU, if an out of bound index is found, an error is raised. + On GPU, if an out of bound index is found, a 0 is stored in the + corresponding output value. axis: A `Tensor`. Must be one of the following types: `int32`, `int64`. The `axis` in `params` to gather `indices` from. Must be greater than or equal to `batch_dims`. Defaults to the first non-batch dimension. Supports