[XLA] Modify operation semantics description for XLA Gather. NFC.
PiperOrigin-RevId: 257298693
This commit is contained in:
parent
7a6aa853e1
commit
e70b1dd882
@ -1367,12 +1367,12 @@ For a more intuitive description, see the "Informal Description" section below.
|
||||
: : : detailed description. :
|
||||
| `offset_dims` | `ArraySlice<int64>` | The set of dimensions in the |
|
||||
: : : output shape that offset into :
|
||||
: : : a array sliced from operand. :
|
||||
: : : an array sliced from operand. :
|
||||
| `slice_sizes` | `ArraySlice<int64>` | `slice_sizes[i]` is the |
|
||||
: : : bounds for the slice on :
|
||||
: : : dimension `i`. :
|
||||
| `collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each |
|
||||
: : : \: slice that are collapsed :
|
||||
: : : slice that are collapsed :
|
||||
: : : away. These dimensions must :
|
||||
: : : have size 1. :
|
||||
| `start_index_map` | `ArraySlice<int64>` | A map that describes how to |
|
||||
@ -1383,8 +1383,11 @@ For a more intuitive description, see the "Informal Description" section below.
|
||||
For convenience, we label dimensions in the output array not in `offset_dims`
|
||||
as `batch_dims`.
|
||||
|
||||
The output is an array of rank `batch_dims.size` + `operand.rank` -
|
||||
`collapsed_slice_dims`.size.
|
||||
The output is an array of rank `batch_dims.size` + `offset_dims.size`.
|
||||
|
||||
The `operand.rank` must equal the sume of `offset_dims.size` and
|
||||
`collapsed_slice_dims`. Also, `slice_sizes.size` has to be equal to
|
||||
`operand.rank`.
|
||||
|
||||
If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider
|
||||
`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of
|
||||
@ -1405,61 +1408,65 @@ accounting for `collapsed_slice_dims` (i.e. we pick
|
||||
`adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes`
|
||||
with the bounds at indices `collapsed_slice_dims` removed).
|
||||
|
||||
Formally, the operand index `In` corresponding to an output index `Out` is
|
||||
computed as follows:
|
||||
Formally, the operand index `In` corresponding to a given output index `Out` is
|
||||
calculated as follows:
|
||||
|
||||
1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out
|
||||
vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where
|
||||
Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
|
||||
this is well defined even if `G` is empty -- if `G` is empty then `S` =
|
||||
`start_indices`.
|
||||
1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out a
|
||||
vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where
|
||||
Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
|
||||
this is well defined even if `G` is empty -- if `G` is empty then `S` =
|
||||
`start_indices`.
|
||||
|
||||
2. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by
|
||||
scattering `S` using `start_index_map`. More precisely:
|
||||
1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
|
||||
`start_index_map.size`.
|
||||
2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
|
||||
2. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by
|
||||
scattering `S` using `start_index_map`. More precisely:
|
||||
|
||||
3. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices
|
||||
at the offset dimensions in `Out` according to the `collapsed_slice_dims`
|
||||
set. More precisely:
|
||||
1. `O`<sub>`in`</sub>[`expand_offset_dims`(`k`)] =
|
||||
`Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
|
||||
(`expand_offset_dims` is defined below).
|
||||
2. `O`<sub>`in`</sub>[`_`] = `0` otherwise.
|
||||
4. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
|
||||
addition.
|
||||
1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
|
||||
`start_index_map.size`.
|
||||
|
||||
`expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`)
|
||||
and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g.,
|
||||
2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
|
||||
|
||||
3. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices
|
||||
at the offset dimensions in `Out` according to the `collapsed_slice_dims`
|
||||
set. More precisely:
|
||||
|
||||
1. `O`<sub>`in`</sub>[`remapped_offset_dims`(`k`)] =
|
||||
`Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
|
||||
(`remapped_offset_dims` is defined below).
|
||||
|
||||
2. `O`<sub>`in`</sub>[`_`] = `0` otherwise.
|
||||
|
||||
4. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
|
||||
addition.
|
||||
|
||||
`remapped_offset_dims` is a monotonic function with domain [`0`, `offset.size`)
|
||||
and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g.,
|
||||
`offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`,
|
||||
`2`} then `expand_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
|
||||
`2`} then `remapped_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
|
||||
|
||||
### Informal Description and Examples
|
||||
|
||||
Informally, every index `Out` in the output array corresponds to an element `E`
|
||||
in the operand array, computed as follows:
|
||||
|
||||
- We use the batch dimensions in `Out` to look up a starting index from
|
||||
`start_indices`.
|
||||
- We use the batch dimensions in `Out` to look up a starting index from
|
||||
`start_indices`.
|
||||
|
||||
- We use `start_index_map` to map the starting index (which may have size less
|
||||
than operand.rank) to a "full" starting index into operand.
|
||||
- We use `start_index_map` to map the starting index (whose size may be less
|
||||
than operand.rank) to a "full" starting index into the `operand`.
|
||||
|
||||
- We dynamic-slice out a slice with size `slice_sizes` using the full starting
|
||||
index.
|
||||
- We dynamic-slice out a slice with size `slice_sizes` using the full starting
|
||||
index.
|
||||
|
||||
- We reshape the slice by collapsing the `collapsed_slice_dims` dimensions.
|
||||
Since all collapsed slice dimensions have to have bound 1 this reshape is
|
||||
always legal.
|
||||
- We reshape the slice by collapsing the `collapsed_slice_dims` dimensions.
|
||||
Since all collapsed slice dimensions must have a bound of 1, this reshape is
|
||||
always legal.
|
||||
|
||||
- We use the offset dimensions in `Out` to index into this slice to get the
|
||||
input element, `E`, corresponding to output index `Out`.
|
||||
- We use the offset dimensions in `Out` to index into this slice to get the
|
||||
input element, `E`, corresponding to output index `Out`.
|
||||
|
||||
`index_vector_dim` is set to `start_indices.rank` - `1` in all of the
|
||||
examples that follow. More interesting values for `index_vector_dim` does not
|
||||
change the operation fundamentally, but makes the visual representation more
|
||||
cumbersome.
|
||||
`index_vector_dim` is set to `start_indices.rank` - `1` in all of the examples
|
||||
that follow. More interesting values for `index_vector_dim` do not change the
|
||||
operation fundamentally, but make the visual representation more cumbersome.
|
||||
|
||||
To get an intuition on how all of the above fits together, let's look at an
|
||||
example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array. The
|
||||
@ -1526,12 +1533,12 @@ As a final example, we use (2) and (3) to implement `tf.gather_nd`:
|
||||
|
||||
`G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index
|
||||
from the gather indices array as usual, except the starting index has only one
|
||||
element, `X`. Similarly, there is only one output offset index with the value
|
||||
`O`<sub>`0`</sub>. However, before being used as indices into the input array,
|
||||
element, `X`. Similarly, there is only one output offset index with the value
|
||||
`O`<sub>`0`</sub>. However, before being used as indices into the input array,
|
||||
these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in
|
||||
the formal description) and "Offset Mapping" (`expand_offset_dims` in the formal
|
||||
description) into [`X`,`0`] and [`0`,`O`<sub>`0`</sub>] respectively, adding up
|
||||
to [`X`,`O`<sub>`0`</sub>]. In other words, the output index
|
||||
the formal description) and "Offset Mapping" (`remapped_offset_dims` in the
|
||||
formal description) into [`X`,`0`] and [`0`,`O`<sub>`0`</sub>] respectively,
|
||||
adding up to [`X`,`O`<sub>`0`</sub>]. In other words, the output index
|
||||
[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index
|
||||
[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us
|
||||
the semantics for `tf.gather_nd`.
|
||||
|
Loading…
Reference in New Issue
Block a user