Add examples + explanation to tf.gather
Add example code for image, move image to the top. PiperOrigin-RevId: 341512111 Change-Id: Ieeee2acdf9ebc878187f6ca16c7c102544d1c0a6
This commit is contained in:
parent
af3f3f9111
commit
99118cceb8
@ -4734,6 +4734,11 @@ def reverse_sequence_v2(input,
|
|||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["gather"])
|
@tf_export(v1=["gather"])
|
||||||
|
@deprecation.deprecated_args(None,
|
||||||
|
("The `validate_indices` argument has no effect. "
|
||||||
|
"Indices are always validated on CPU and never "
|
||||||
|
"validated on GPU."),
|
||||||
|
"validate_indices")
|
||||||
@dispatch.add_dispatch_support
|
@dispatch.add_dispatch_support
|
||||||
def gather(params,
|
def gather(params,
|
||||||
indices,
|
indices,
|
||||||
@ -4743,62 +4748,176 @@ def gather(params,
|
|||||||
batch_dims=0): # pylint: disable=g-doc-args
|
batch_dims=0): # pylint: disable=g-doc-args
|
||||||
r"""Gather slices from params axis `axis` according to indices.
|
r"""Gather slices from params axis `axis` according to indices.
|
||||||
|
|
||||||
Gather slices from params axis `axis` according to `indices`. `indices` must
|
Gather slices from `params` axis `axis` according to `indices`. `indices`
|
||||||
be an integer tensor of any dimension (usually 0-D or 1-D).
|
must be an integer tensor of any dimension (often 1-D).
|
||||||
|
|
||||||
For 0-D (scalar) `indices`:
|
`Tensor.__getitem__` works for scalars, `tf.newaxis`, and
|
||||||
|
[python slices](https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing)
|
||||||
|
|
||||||
$$\begin{align*}
|
`tf.gather` extends indexing to handle tensors of indices.
|
||||||
output[p_0, ..., p_{axis-1}, && &&& p_{axis + 1}, ..., p_{N-1}] = \\
|
|
||||||
params[p_0, ..., p_{axis-1}, && indices, &&& p_{axis + 1}, ..., p_{N-1}]
|
|
||||||
\end{align*}$$
|
|
||||||
|
|
||||||
Where *N* = `ndims(params)`.
|
In the simplest case it's identical to scalar indexing:
|
||||||
|
|
||||||
For 1-D (vector) `indices` with `batch_dims=0`:
|
>>> params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
|
||||||
|
>>> params[3].numpy()
|
||||||
|
b'p3'
|
||||||
|
>>> tf.gather(params, 3).numpy()
|
||||||
|
b'p3'
|
||||||
|
|
||||||
$$\begin{align*}
|
The most common case is to pass a single axis tensor of indices (this
|
||||||
output[p_0, ..., p_{axis-1}, && &i, &&p_{axis + 1}, ..., p_{N-1}] =\\
|
can't be expressed as a python slice because the indices are not sequential):
|
||||||
params[p_0, ..., p_{axis-1}, && indices[&i], &&p_{axis + 1}, ..., p_{N-1}]
|
|
||||||
\end{align*}$$
|
|
||||||
|
|
||||||
In the general case, produces an output tensor where:
|
>>> indices = [2, 0, 2, 5]
|
||||||
|
>>> tf.gather(params, indices).numpy()
|
||||||
$$\begin{align*}
|
array([b'p2', b'p0', b'p2', b'p5'], dtype=object)
|
||||||
output[p_0, &..., p_{axis-1}, &
|
|
||||||
&i_{B}, ..., i_{M-1}, &
|
|
||||||
p_{axis + 1}, &..., p_{N-1}] = \\
|
|
||||||
params[p_0, &..., p_{axis-1}, &
|
|
||||||
indices[p_0, ..., p_{B-1}, &i_{B}, ..., i_{M-1}], &
|
|
||||||
p_{axis + 1}, &..., p_{N-1}]
|
|
||||||
\end{align*}$$
|
|
||||||
|
|
||||||
Where *N* = `ndims(params)`, *M* = `ndims(indices)`, and *B* = `batch_dims`.
|
|
||||||
Note that `params.shape[:batch_dims]` must be identical to
|
|
||||||
`indices.shape[:batch_dims]`.
|
|
||||||
|
|
||||||
The shape of the output tensor is:
|
|
||||||
|
|
||||||
> `output.shape = params.shape[:axis] + indices.shape[batch_dims:] +
|
|
||||||
> params.shape[axis + 1:]`.
|
|
||||||
|
|
||||||
Note that on CPU, if an out of bound index is found, an error is returned.
|
|
||||||
On GPU, if an out of bound index is found, a 0 is stored in the corresponding
|
|
||||||
output value.
|
|
||||||
|
|
||||||
See also `tf.gather_nd`.
|
|
||||||
|
|
||||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||||
<img style="width:100%" src="https://www.tensorflow.org/images/Gather.png"
|
<img style="width:100%" src="https://www.tensorflow.org/images/Gather.png"
|
||||||
alt>
|
alt>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
The indices can have any shape. When the `params` has 1 axis, the
|
||||||
|
output shape is equal to the input shape:
|
||||||
|
|
||||||
|
>>> tf.gather(params, [[2, 0], [2, 5]]).numpy()
|
||||||
|
array([[b'p2', b'p0'],
|
||||||
|
[b'p2', b'p5']], dtype=object)
|
||||||
|
|
||||||
|
The `params` may also have any shape. `gather` can select slices
|
||||||
|
across any axis depending on the `axis` argument (which defaults to 0).
|
||||||
|
Below it is used to gather first rows, then columns from a matrix:
|
||||||
|
|
||||||
|
>>> params = tf.constant([[0, 1.0, 2.0],
|
||||||
|
... [10.0, 11.0, 12.0],
|
||||||
|
... [20.0, 21.0, 22.0],
|
||||||
|
... [30.0, 31.0, 32.0]])
|
||||||
|
>>> tf.gather(params, indices=[3,1]).numpy()
|
||||||
|
array([[30., 31., 32.],
|
||||||
|
[10., 11., 12.]], dtype=float32)
|
||||||
|
>>> tf.gather(params, indices=[2,1], axis=1).numpy()
|
||||||
|
array([[ 2., 1.],
|
||||||
|
[12., 11.],
|
||||||
|
[22., 21.],
|
||||||
|
[32., 31.]], dtype=float32)
|
||||||
|
|
||||||
|
More generally: The output shape has the same shape as the input, with the
|
||||||
|
indexed-axis replaced by the shape of the indices.
|
||||||
|
|
||||||
|
>>> def result_shape(p_shape, i_shape, axis=0):
|
||||||
|
... return p_shape[:axis] + i_shape + p_shape[axis+1:]
|
||||||
|
>>>
|
||||||
|
>>> result_shape([1, 2, 3], [], axis=1)
|
||||||
|
[1, 3]
|
||||||
|
>>> result_shape([1, 2, 3], [7], axis=1)
|
||||||
|
[1, 7, 3]
|
||||||
|
>>> result_shape([1, 2, 3], [7, 5], axis=1)
|
||||||
|
[1, 7, 5, 3]
|
||||||
|
|
||||||
|
Here are some examples:
|
||||||
|
|
||||||
|
>>> params.shape.as_list()
|
||||||
|
[4, 3]
|
||||||
|
>>> indices = tf.constant([[0, 2]])
|
||||||
|
>>> tf.gather(params, indices=indices, axis=0).shape.as_list()
|
||||||
|
[1, 2, 3]
|
||||||
|
>>> tf.gather(params, indices=indices, axis=1).shape.as_list()
|
||||||
|
[4, 1, 2]
|
||||||
|
|
||||||
|
>>> params = tf.random.normal(shape=(5, 6, 7, 8))
|
||||||
|
>>> indices = tf.random.uniform(shape=(10, 11), maxval=7, dtype=tf.int32)
|
||||||
|
>>> result = tf.gather(params, indices, axis=2)
|
||||||
|
>>> result.shape.as_list()
|
||||||
|
[5, 6, 10, 11, 8]
|
||||||
|
|
||||||
|
This is because each index takes a slice from `params`, and
|
||||||
|
places it at the corresponding location in the output. For the above example
|
||||||
|
|
||||||
|
>>> # For any location in indices
|
||||||
|
>>> a, b = 0, 1
|
||||||
|
>>> tf.reduce_all(
|
||||||
|
... # the corresponding slice of the result
|
||||||
|
... result[:, :, a, b, :] ==
|
||||||
|
... # is equal to the slice of `params` along `axis` at the index.
|
||||||
|
... params[:, :, indices[a, b], :]
|
||||||
|
... ).numpy()
|
||||||
|
True
|
||||||
|
|
||||||
|
### Batching:
|
||||||
|
|
||||||
|
The `batch_dims` argument lets you gather different items from each element
|
||||||
|
of a batch.
|
||||||
|
|
||||||
|
Using `batch_dims=1` is equivalent to having an outer loop over the first
|
||||||
|
axis of `params` and `indices`:
|
||||||
|
|
||||||
|
>>> params = tf.constant([
|
||||||
|
... [0, 0, 1, 0, 2],
|
||||||
|
... [3, 0, 0, 0, 4],
|
||||||
|
... [0, 5, 0, 6, 0]])
|
||||||
|
>>> indices = tf.constant([
|
||||||
|
... [2, 4],
|
||||||
|
... [0, 4],
|
||||||
|
... [1, 3]])
|
||||||
|
|
||||||
|
>>> tf.gather(params, indices, axis=1, batch_dims=1).numpy()
|
||||||
|
array([[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
[5, 6]], dtype=int32)
|
||||||
|
|
||||||
|
This is is equivalent to:
|
||||||
|
|
||||||
|
>>> def manually_batched_gather(params, indices, axis):
|
||||||
|
... batch_dims=1
|
||||||
|
... result = []
|
||||||
|
... for p,i in zip(params, indices):
|
||||||
|
... r = tf.gather(p, i, axis=axis-batch_dims)
|
||||||
|
... result.append(r)
|
||||||
|
... return tf.stack(result)
|
||||||
|
>>> manually_batched_gather(params, indices, axis=1).numpy()
|
||||||
|
array([[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
[5, 6]], dtype=int32)
|
||||||
|
|
||||||
|
Higher values of `batch_dims` are equivalent to multiple nested loops over
|
||||||
|
the outer axes of `params` and `indices`. So the overall shape function is
|
||||||
|
|
||||||
|
>>> def batched_result_shape(p_shape, i_shape, axis=0, batch_dims=0):
|
||||||
|
... return p_shape[:axis] + i_shape[batch_dims:] + p_shape[axis+1:]
|
||||||
|
>>>
|
||||||
|
>>> batched_result_shape(
|
||||||
|
... p_shape=params.shape.as_list(),
|
||||||
|
... i_shape=indices.shape.as_list(),
|
||||||
|
... axis=1,
|
||||||
|
... batch_dims=1)
|
||||||
|
[3, 2]
|
||||||
|
|
||||||
|
>>> tf.gather(params, indices, axis=1, batch_dims=1).shape.as_list()
|
||||||
|
[3, 2]
|
||||||
|
|
||||||
|
See also:
|
||||||
|
|
||||||
|
* `tf.Tensor.__getitem__`: The direct tensor index operation (`t[]`), handles
|
||||||
|
scalars and python-slices `tensor[..., 7, 1:-1]`
|
||||||
|
* `tf.scatter`: A collection of operations similar to `__setitem__`
|
||||||
|
(`t[i] = x`)
|
||||||
|
* `tf.gather_nd`: An operation similar to `tf.gather` but gathers across
|
||||||
|
multiple axis at once (it can gather elements of a matrix instead of rows
|
||||||
|
or columns)
|
||||||
|
* `tf.boolean_mask`, `tf.where`: Binary indexing.
|
||||||
|
* `tf.slice` and `tf.strided_slice`: For lower level access to the
|
||||||
|
implementation of `__getitem__`'s python-slice handling (`t[1:-1:2])
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: The `Tensor` from which to gather values. Must be at least rank
|
params: The `Tensor` from which to gather values. Must be at least rank
|
||||||
`axis + 1`.
|
`axis + 1`.
|
||||||
indices: The index `Tensor`. Must be one of the following types: `int32`,
|
indices: The index `Tensor`. Must be one of the following types: `int32`,
|
||||||
`int64`. Must be in range `[0, params.shape[axis])`.
|
`int64`. The values must be in range `[0, params.shape[axis])`.
|
||||||
validate_indices: Deprecated, does nothing.
|
validate_indices: Deprecated, does nothing. Indices are always validated on
|
||||||
|
CPU, never validated on GPU.
|
||||||
|
|
||||||
|
Caution: On CPU, if an out of bound index is found, an error is raised.
|
||||||
|
On GPU, if an out of bound index is found, a 0 is stored in the
|
||||||
|
corresponding output value.
|
||||||
axis: A `Tensor`. Must be one of the following types: `int32`, `int64`. The
|
axis: A `Tensor`. Must be one of the following types: `int32`, `int64`. The
|
||||||
`axis` in `params` to gather `indices` from. Must be greater than or equal
|
`axis` in `params` to gather `indices` from. Must be greater than or equal
|
||||||
to `batch_dims`. Defaults to the first non-batch dimension. Supports
|
to `batch_dims`. Defaults to the first non-batch dimension. Supports
|
||||||
|
Loading…
Reference in New Issue
Block a user