[XLA:TPU] Implement 2D AllGather algorithm with use_global_device_ids = true.

PiperOrigin-RevId: 348684836
Change-Id: I31339c215221cb3aff281852096d2bd25933795c
This commit is contained in:
A. Unique TensorFlower 2020-12-22 13:29:49 -08:00 committed by TensorFlower Gardener
parent 9d2d8ca23d
commit 13329f5be2
3 changed files with 20 additions and 10 deletions

View File

@ -2704,7 +2704,8 @@ XlaOp XlaBuilder::AllGather(XlaOp operand, int64 all_gather_dimension,
int64 shard_count,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Layout>& layout) {
const absl::optional<Layout>& layout,
const absl::optional<bool> use_global_device_ids) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
@ -2725,6 +2726,9 @@ XlaOp XlaBuilder::AllGather(XlaOp operand, int64 all_gather_dimension,
if (channel_id.has_value()) {
instr.set_channel_id(channel_id->handle());
}
if (use_global_device_ids.has_value()) {
instr.set_use_global_device_ids(use_global_device_ids.value());
}
TF_ASSIGN_OR_RETURN(
auto all_gather,
@ -4549,10 +4553,11 @@ XlaOp AllGather(const XlaOp operand, int64 all_gather_dimension,
int64 shard_count,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Layout>& layout) {
const absl::optional<Layout>& layout,
const absl::optional<bool> use_global_device_ids) {
return operand.builder()->AllGather(operand, all_gather_dimension,
shard_count, replica_groups, channel_id,
layout);
layout, use_global_device_ids);
}
XlaOp CrossReplicaSum(const XlaOp operand,

View File

@ -731,7 +731,8 @@ class XlaBuilder {
XlaOp operand, int64 all_gather_dimension, int64 shard_count,
absl::Span<const ReplicaGroup> replica_groups = {},
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
const absl::optional<Layout>& layout = absl::nullopt);
const absl::optional<Layout>& layout = absl::nullopt,
const absl::optional<bool> use_global_device_ids = absl::nullopt);
XlaOp AllReduce(
XlaOp operand, const XlaComputation& computation,
@ -1286,7 +1287,8 @@ class XlaBuilder {
int64 shard_count,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Layout>& layout);
const absl::optional<Layout>& layout,
const absl::optional<bool> use_global_device_ids);
friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id,
@ -2161,10 +2163,12 @@ XlaOp ReduceWindowWithGeneralPadding(
XlaOp CrossReplicaSum(XlaOp operand,
absl::Span<const ReplicaGroup> replica_groups = {});
XlaOp AllGather(XlaOp operand, int64 all_gather_dimension, int64 shard_count,
absl::Span<const ReplicaGroup> replica_groups = {},
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
const absl::optional<Layout>& layout = absl::nullopt);
XlaOp AllGather(
XlaOp operand, int64 all_gather_dimension, int64 shard_count,
absl::Span<const ReplicaGroup> replica_groups = {},
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
const absl::optional<Layout>& layout = absl::nullopt,
const absl::optional<bool> use_global_device_ids = absl::nullopt);
// Enqueues an operation that do an AllReduce of the operand cross cores. Here
// AllReduce means doing a reduction on the input operand cross cores and then

View File

@ -55,7 +55,8 @@ void BuildOpsSubmodule(py::module* m) {
py::arg("all_gather_dimension"), py::arg("shard_count"),
py::arg("replica_groups") = py::list(),
py::arg("channel_id") = absl::nullopt,
py::arg("shape_with_layout") = absl::nullopt);
py::arg("shape_with_layout") = absl::nullopt,
py::arg("use_global_device_ids") = absl::nullopt);
ops.def(
"AllReduce",
static_cast<XlaOp (*)(