[XLA] Modify operation semantics description for XLA Gather. NFC.

PiperOrigin-RevId: 257298693
This commit is contained in:
A. Unique TensorFlower 2019-07-09 16:47:51 -07:00 committed by TensorFlower Gardener
parent 7a6aa853e1
commit e70b1dd882

View File

@ -1367,12 +1367,12 @@ For a more intuitive description, see the "Informal Description" section below.
: : : detailed description. :
| `offset_dims` | `ArraySlice<int64>` | The set of dimensions in the |
: : : output shape that offset into :
: : : a array sliced from operand. :
: : : an array sliced from operand. :
| `slice_sizes` | `ArraySlice<int64>` | `slice_sizes[i]` is the |
: : : bounds for the slice on :
: : : dimension `i`. :
| `collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each |
: : : \: slice that are collapsed :
: : : slice that are collapsed :
: : : away. These dimensions must :
: : : have size 1. :
| `start_index_map` | `ArraySlice<int64>` | A map that describes how to |
@ -1383,8 +1383,11 @@ For a more intuitive description, see the "Informal Description" section below.
For convenience, we label dimensions in the output array not in `offset_dims`
as `batch_dims`.
The output is an array of rank `batch_dims.size` + `operand.rank` -
`collapsed_slice_dims`.size.
The output is an array of rank `batch_dims.size` + `offset_dims.size`.
The `operand.rank` must equal the sume of `offset_dims.size` and
`collapsed_slice_dims`. Also, `slice_sizes.size` has to be equal to
`operand.rank`.
If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider
`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of
@ -1405,61 +1408,65 @@ accounting for `collapsed_slice_dims` (i.e. we pick
`adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes`
with the bounds at indices `collapsed_slice_dims` removed).
Formally, the operand index `In` corresponding to an output index `Out` is
computed as follows:
Formally, the operand index `In` corresponding to a given output index `Out` is
calculated as follows:
1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out
vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where
Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
this is well defined even if `G` is empty -- if `G` is empty then `S` =
`start_indices`.
1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out a
vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where
Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
this is well defined even if `G` is empty -- if `G` is empty then `S` =
`start_indices`.
2. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by
scattering `S` using `start_index_map`. More precisely:
1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
`start_index_map.size`.
2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
2. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by
scattering `S` using `start_index_map`. More precisely:
3. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices
at the offset dimensions in `Out` according to the `collapsed_slice_dims`
set. More precisely:
1. `O`<sub>`in`</sub>[`expand_offset_dims`(`k`)] =
`Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
(`expand_offset_dims` is defined below).
2. `O`<sub>`in`</sub>[`_`] = `0` otherwise.
4. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
addition.
1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
`start_index_map.size`.
`expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`)
and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g.,
2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
3. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices
at the offset dimensions in `Out` according to the `collapsed_slice_dims`
set. More precisely:
1. `O`<sub>`in`</sub>[`remapped_offset_dims`(`k`)] =
`Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
(`remapped_offset_dims` is defined below).
2. `O`<sub>`in`</sub>[`_`] = `0` otherwise.
4. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
addition.
`remapped_offset_dims` is a monotonic function with domain [`0`, `offset.size`)
and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g.,
`offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`,
`2`} then `expand_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
`2`} then `remapped_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
### Informal Description and Examples
Informally, every index `Out` in the output array corresponds to an element `E`
in the operand array, computed as follows:
- We use the batch dimensions in `Out` to look up a starting index from
`start_indices`.
- We use the batch dimensions in `Out` to look up a starting index from
`start_indices`.
- We use `start_index_map` to map the starting index (which may have size less
than operand.rank) to a "full" starting index into operand.
- We use `start_index_map` to map the starting index (whose size may be less
than operand.rank) to a "full" starting index into the `operand`.
- We dynamic-slice out a slice with size `slice_sizes` using the full starting
index.
- We dynamic-slice out a slice with size `slice_sizes` using the full starting
index.
- We reshape the slice by collapsing the `collapsed_slice_dims` dimensions.
Since all collapsed slice dimensions have to have bound 1 this reshape is
always legal.
- We reshape the slice by collapsing the `collapsed_slice_dims` dimensions.
Since all collapsed slice dimensions must have a bound of 1, this reshape is
always legal.
- We use the offset dimensions in `Out` to index into this slice to get the
input element, `E`, corresponding to output index `Out`.
- We use the offset dimensions in `Out` to index into this slice to get the
input element, `E`, corresponding to output index `Out`.
`index_vector_dim` is set to `start_indices.rank` - `1` in all of the
examples that follow. More interesting values for `index_vector_dim` does not
change the operation fundamentally, but makes the visual representation more
cumbersome.
`index_vector_dim` is set to `start_indices.rank` - `1` in all of the examples
that follow. More interesting values for `index_vector_dim` do not change the
operation fundamentally, but make the visual representation more cumbersome.
To get an intuition on how all of the above fits together, let's look at an
example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array. The
@ -1526,12 +1533,12 @@ As a final example, we use (2) and (3) to implement `tf.gather_nd`:
`G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index
from the gather indices array as usual, except the starting index has only one
element, `X`. Similarly, there is only one output offset index with the value
`O`<sub>`0`</sub>. However, before being used as indices into the input array,
element, `X`. Similarly, there is only one output offset index with the value
`O`<sub>`0`</sub>. However, before being used as indices into the input array,
these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in
the formal description) and "Offset Mapping" (`expand_offset_dims` in the formal
description) into [`X`,`0`] and [`0`,`O`<sub>`0`</sub>] respectively, adding up
to [`X`,`O`<sub>`0`</sub>]. In other words, the output index
the formal description) and "Offset Mapping" (`remapped_offset_dims` in the
formal description) into [`X`,`0`] and [`0`,`O`<sub>`0`</sub>] respectively,
adding up to [`X`,`O`<sub>`0`</sub>]. In other words, the output index
[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index
[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us
the semantics for `tf.gather_nd`.