[XLA] Add AllGather documentation
PiperOrigin-RevId: 352450526 Change-Id: Iaf4fd10cbdaf41c443793b3e853c722fc6b7469b
This commit is contained in:
parent
da0884c7d4
commit
673079ce30
@ -29,6 +29,45 @@ Arguments | Type | Semantics
|
||||
---------- | ------- | -------------------------
|
||||
`operands` | `XlaOp` | variadic number of tokens
|
||||
|
||||
## AllGather
|
||||
|
||||
See also
|
||||
[`XlaBuilder::AllGather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
|
||||
|
||||
Performs concatenation across replicas.
|
||||
|
||||
<b> `AllGather(operand, all_gather_dim, shard_count, replica_group_ids,
|
||||
channel_id)` </b>
|
||||
|
||||
| Arguments | Type | Semantics |
|
||||
| ---------------- | -------------------- | --------------------------- |
|
||||
| `operand` | `XlaOp` | Array to concatenate across |
|
||||
: : : replicas. :
|
||||
| `all_gather_dim` | `int64` | Concatenation dimension. |
|
||||
| `replica_groups` | vector of vectors of | Groups between which the |
|
||||
: : `int64` : concatenation is performed. :
|
||||
| `channel_id` | optional `int64` | Optional channel ID for |
|
||||
: : : cross-module communication. :
|
||||
|
||||
- `replica_groups` is a list of replica groups between which the concatenation
|
||||
is performed (replica id for the current replica can be retrieved using
|
||||
[`ReplicaId`](#replicaid)). The order of replicas in each group determines
|
||||
the order in which their inputs are located in the result. `replica_groups`
|
||||
must either be empty (in which case all replicas belong to a single group,
|
||||
ordered from `0` to `N - 1`), or contain the same number of elements as the
|
||||
number of replicas. For example, `replica_groups = {0, 2}, {1, 3}` performs
|
||||
concatenation between the replicas `0` and `2`, and `1` and `3`.
|
||||
- `shard_count` is the size of each replica group. We need this in cases where
|
||||
`replica_groups` are empty.
|
||||
- `channel_id` is used for cross-module communication: only `all-gather`
|
||||
operations with the same `channel_id` can communicate to each other.
|
||||
|
||||
The output shape is the input shape with the `all_gather_dim` made `shard_count`
|
||||
times larger. For example, if there are two replicas and the operand has the
|
||||
value `[1.0, 2.5]` and `[3.0, 5.25]` respectively on the two replicas, then the
|
||||
output value from this op where `all_gather_dim` is `0` will be `[1.0, 2.5, 3.0,
|
||||
5.25]` on both replicas.
|
||||
|
||||
## AllReduce
|
||||
|
||||
See also
|
||||
|
Loading…
Reference in New Issue
Block a user