[XLA:TPU] Implement 2D AllGather algorithm with use_global_device_ids = true.
PiperOrigin-RevId: 348684836 Change-Id: I31339c215221cb3aff281852096d2bd25933795c
This commit is contained in:
parent
9d2d8ca23d
commit
13329f5be2
@ -2704,7 +2704,8 @@ XlaOp XlaBuilder::AllGather(XlaOp operand, int64 all_gather_dimension,
|
|||||||
int64 shard_count,
|
int64 shard_count,
|
||||||
absl::Span<const ReplicaGroup> replica_groups,
|
absl::Span<const ReplicaGroup> replica_groups,
|
||||||
const absl::optional<ChannelHandle>& channel_id,
|
const absl::optional<ChannelHandle>& channel_id,
|
||||||
const absl::optional<Layout>& layout) {
|
const absl::optional<Layout>& layout,
|
||||||
|
const absl::optional<bool> use_global_device_ids) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
HloInstructionProto instr;
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||||
@ -2725,6 +2726,9 @@ XlaOp XlaBuilder::AllGather(XlaOp operand, int64 all_gather_dimension,
|
|||||||
if (channel_id.has_value()) {
|
if (channel_id.has_value()) {
|
||||||
instr.set_channel_id(channel_id->handle());
|
instr.set_channel_id(channel_id->handle());
|
||||||
}
|
}
|
||||||
|
if (use_global_device_ids.has_value()) {
|
||||||
|
instr.set_use_global_device_ids(use_global_device_ids.value());
|
||||||
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto all_gather,
|
auto all_gather,
|
||||||
@ -4549,10 +4553,11 @@ XlaOp AllGather(const XlaOp operand, int64 all_gather_dimension,
|
|||||||
int64 shard_count,
|
int64 shard_count,
|
||||||
absl::Span<const ReplicaGroup> replica_groups,
|
absl::Span<const ReplicaGroup> replica_groups,
|
||||||
const absl::optional<ChannelHandle>& channel_id,
|
const absl::optional<ChannelHandle>& channel_id,
|
||||||
const absl::optional<Layout>& layout) {
|
const absl::optional<Layout>& layout,
|
||||||
|
const absl::optional<bool> use_global_device_ids) {
|
||||||
return operand.builder()->AllGather(operand, all_gather_dimension,
|
return operand.builder()->AllGather(operand, all_gather_dimension,
|
||||||
shard_count, replica_groups, channel_id,
|
shard_count, replica_groups, channel_id,
|
||||||
layout);
|
layout, use_global_device_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp CrossReplicaSum(const XlaOp operand,
|
XlaOp CrossReplicaSum(const XlaOp operand,
|
||||||
|
|||||||
@ -731,7 +731,8 @@ class XlaBuilder {
|
|||||||
XlaOp operand, int64 all_gather_dimension, int64 shard_count,
|
XlaOp operand, int64 all_gather_dimension, int64 shard_count,
|
||||||
absl::Span<const ReplicaGroup> replica_groups = {},
|
absl::Span<const ReplicaGroup> replica_groups = {},
|
||||||
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
|
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
|
||||||
const absl::optional<Layout>& layout = absl::nullopt);
|
const absl::optional<Layout>& layout = absl::nullopt,
|
||||||
|
const absl::optional<bool> use_global_device_ids = absl::nullopt);
|
||||||
|
|
||||||
XlaOp AllReduce(
|
XlaOp AllReduce(
|
||||||
XlaOp operand, const XlaComputation& computation,
|
XlaOp operand, const XlaComputation& computation,
|
||||||
@ -1286,7 +1287,8 @@ class XlaBuilder {
|
|||||||
int64 shard_count,
|
int64 shard_count,
|
||||||
absl::Span<const ReplicaGroup> replica_groups,
|
absl::Span<const ReplicaGroup> replica_groups,
|
||||||
const absl::optional<ChannelHandle>& channel_id,
|
const absl::optional<ChannelHandle>& channel_id,
|
||||||
const absl::optional<Layout>& layout);
|
const absl::optional<Layout>& layout,
|
||||||
|
const absl::optional<bool> use_global_device_ids);
|
||||||
friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
|
friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
|
||||||
absl::Span<const ReplicaGroup> replica_groups,
|
absl::Span<const ReplicaGroup> replica_groups,
|
||||||
const absl::optional<ChannelHandle>& channel_id,
|
const absl::optional<ChannelHandle>& channel_id,
|
||||||
@ -2161,10 +2163,12 @@ XlaOp ReduceWindowWithGeneralPadding(
|
|||||||
XlaOp CrossReplicaSum(XlaOp operand,
|
XlaOp CrossReplicaSum(XlaOp operand,
|
||||||
absl::Span<const ReplicaGroup> replica_groups = {});
|
absl::Span<const ReplicaGroup> replica_groups = {});
|
||||||
|
|
||||||
XlaOp AllGather(XlaOp operand, int64 all_gather_dimension, int64 shard_count,
|
XlaOp AllGather(
|
||||||
absl::Span<const ReplicaGroup> replica_groups = {},
|
XlaOp operand, int64 all_gather_dimension, int64 shard_count,
|
||||||
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
|
absl::Span<const ReplicaGroup> replica_groups = {},
|
||||||
const absl::optional<Layout>& layout = absl::nullopt);
|
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
|
||||||
|
const absl::optional<Layout>& layout = absl::nullopt,
|
||||||
|
const absl::optional<bool> use_global_device_ids = absl::nullopt);
|
||||||
|
|
||||||
// Enqueues an operation that do an AllReduce of the operand cross cores. Here
|
// Enqueues an operation that do an AllReduce of the operand cross cores. Here
|
||||||
// AllReduce means doing a reduction on the input operand cross cores and then
|
// AllReduce means doing a reduction on the input operand cross cores and then
|
||||||
|
|||||||
@ -55,7 +55,8 @@ void BuildOpsSubmodule(py::module* m) {
|
|||||||
py::arg("all_gather_dimension"), py::arg("shard_count"),
|
py::arg("all_gather_dimension"), py::arg("shard_count"),
|
||||||
py::arg("replica_groups") = py::list(),
|
py::arg("replica_groups") = py::list(),
|
||||||
py::arg("channel_id") = absl::nullopt,
|
py::arg("channel_id") = absl::nullopt,
|
||||||
py::arg("shape_with_layout") = absl::nullopt);
|
py::arg("shape_with_layout") = absl::nullopt,
|
||||||
|
py::arg("use_global_device_ids") = absl::nullopt);
|
||||||
ops.def(
|
ops.def(
|
||||||
"AllReduce",
|
"AllReduce",
|
||||||
static_cast<XlaOp (*)(
|
static_cast<XlaOp (*)(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user