[XLA:TPU] Implement 2D AllGather algorithm with use_global_device_ids = true.
PiperOrigin-RevId: 348684836 Change-Id: I31339c215221cb3aff281852096d2bd25933795c
This commit is contained in:
parent
9d2d8ca23d
commit
13329f5be2
@ -2704,7 +2704,8 @@ XlaOp XlaBuilder::AllGather(XlaOp operand, int64 all_gather_dimension,
|
||||
int64 shard_count,
|
||||
absl::Span<const ReplicaGroup> replica_groups,
|
||||
const absl::optional<ChannelHandle>& channel_id,
|
||||
const absl::optional<Layout>& layout) {
|
||||
const absl::optional<Layout>& layout,
|
||||
const absl::optional<bool> use_global_device_ids) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
@ -2725,6 +2726,9 @@ XlaOp XlaBuilder::AllGather(XlaOp operand, int64 all_gather_dimension,
|
||||
if (channel_id.has_value()) {
|
||||
instr.set_channel_id(channel_id->handle());
|
||||
}
|
||||
if (use_global_device_ids.has_value()) {
|
||||
instr.set_use_global_device_ids(use_global_device_ids.value());
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto all_gather,
|
||||
@ -4549,10 +4553,11 @@ XlaOp AllGather(const XlaOp operand, int64 all_gather_dimension,
|
||||
int64 shard_count,
|
||||
absl::Span<const ReplicaGroup> replica_groups,
|
||||
const absl::optional<ChannelHandle>& channel_id,
|
||||
const absl::optional<Layout>& layout) {
|
||||
const absl::optional<Layout>& layout,
|
||||
const absl::optional<bool> use_global_device_ids) {
|
||||
return operand.builder()->AllGather(operand, all_gather_dimension,
|
||||
shard_count, replica_groups, channel_id,
|
||||
layout);
|
||||
layout, use_global_device_ids);
|
||||
}
|
||||
|
||||
XlaOp CrossReplicaSum(const XlaOp operand,
|
||||
|
||||
@ -731,7 +731,8 @@ class XlaBuilder {
|
||||
XlaOp operand, int64 all_gather_dimension, int64 shard_count,
|
||||
absl::Span<const ReplicaGroup> replica_groups = {},
|
||||
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
|
||||
const absl::optional<Layout>& layout = absl::nullopt);
|
||||
const absl::optional<Layout>& layout = absl::nullopt,
|
||||
const absl::optional<bool> use_global_device_ids = absl::nullopt);
|
||||
|
||||
XlaOp AllReduce(
|
||||
XlaOp operand, const XlaComputation& computation,
|
||||
@ -1286,7 +1287,8 @@ class XlaBuilder {
|
||||
int64 shard_count,
|
||||
absl::Span<const ReplicaGroup> replica_groups,
|
||||
const absl::optional<ChannelHandle>& channel_id,
|
||||
const absl::optional<Layout>& layout);
|
||||
const absl::optional<Layout>& layout,
|
||||
const absl::optional<bool> use_global_device_ids);
|
||||
friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
|
||||
absl::Span<const ReplicaGroup> replica_groups,
|
||||
const absl::optional<ChannelHandle>& channel_id,
|
||||
@ -2161,10 +2163,12 @@ XlaOp ReduceWindowWithGeneralPadding(
|
||||
XlaOp CrossReplicaSum(XlaOp operand,
|
||||
absl::Span<const ReplicaGroup> replica_groups = {});
|
||||
|
||||
XlaOp AllGather(XlaOp operand, int64 all_gather_dimension, int64 shard_count,
|
||||
absl::Span<const ReplicaGroup> replica_groups = {},
|
||||
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
|
||||
const absl::optional<Layout>& layout = absl::nullopt);
|
||||
XlaOp AllGather(
|
||||
XlaOp operand, int64 all_gather_dimension, int64 shard_count,
|
||||
absl::Span<const ReplicaGroup> replica_groups = {},
|
||||
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
|
||||
const absl::optional<Layout>& layout = absl::nullopt,
|
||||
const absl::optional<bool> use_global_device_ids = absl::nullopt);
|
||||
|
||||
// Enqueues an operation that do an AllReduce of the operand cross cores. Here
|
||||
// AllReduce means doing a reduction on the input operand cross cores and then
|
||||
|
||||
@ -55,7 +55,8 @@ void BuildOpsSubmodule(py::module* m) {
|
||||
py::arg("all_gather_dimension"), py::arg("shard_count"),
|
||||
py::arg("replica_groups") = py::list(),
|
||||
py::arg("channel_id") = absl::nullopt,
|
||||
py::arg("shape_with_layout") = absl::nullopt);
|
||||
py::arg("shape_with_layout") = absl::nullopt,
|
||||
py::arg("use_global_device_ids") = absl::nullopt);
|
||||
ops.def(
|
||||
"AllReduce",
|
||||
static_cast<XlaOp (*)(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user