[XLA:CPU/GPU] Implement all-reduce and/or of pred

This uses the fact that we store pred as 8-bit integers, so we just forward to
the implementation of uint8 min/max.

PiperOrigin-RevId: 303779071
Change-Id: I899e787d0d0db023219d270cdc81f965781ff814
This commit is contained in:
Benjamin Kramer 2020-03-30 11:07:08 -07:00 committed by TensorFlower Gardener
parent 322cfa2290
commit 60d69096d2
5 changed files with 92 additions and 2 deletions

View File

@ -30,13 +30,18 @@ absl::optional<ReductionKind> MatchReductionComputation(
.WithShape(m::Shape().IsEffectiveScalar()));
};
// Match the operation to a reduction kind. We can represent and/or of pred as
// min/max. This works because pred is stored as an 8-bit int of value 0 or 1.
PrimitiveType type = computation->root_instruction()->shape().element_type();
if (match_opcode(HloOpcode::kAdd)) {
return ReductionKind::SUM;
} else if (match_opcode(HloOpcode::kMultiply)) {
return ReductionKind::PRODUCT;
} else if (match_opcode(HloOpcode::kMinimum)) {
} else if (match_opcode(HloOpcode::kMinimum) ||
(type == PRED && match_opcode(HloOpcode::kAnd))) {
return ReductionKind::MIN;
} else if (match_opcode(HloOpcode::kMaximum)) {
} else if (match_opcode(HloOpcode::kMaximum) ||
(type == PRED && match_opcode(HloOpcode::kOr))) {
return ReductionKind::MAX;
} else {
return absl::nullopt;

View File

@ -278,6 +278,7 @@ class CpuAllReduceRendezvous : public xla::Rendezvous<std::nullptr_t> {
case xla::S8:
DoAllReduce<xla::S8>(participant);
break;
case xla::PRED:
case xla::U8:
DoAllReduce<xla::U8>(participant);
break;

View File

@ -1416,6 +1416,7 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
bool is_datatype_supported = [&] {
// TODO(cheshire): Fix duplication wrt. cpu_runtime
switch (datatype) {
case PRED:
case S8:
case U8:
case S32:

View File

@ -173,6 +173,7 @@ absl::optional<ncclDataType_t> DatatypeToNccl(PrimitiveType element_type) {
switch (element_type) {
case S8:
return ncclInt8;
case PRED:
case U8:
return ncclUint8;
case S32:

View File

@ -218,6 +218,88 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) {
TestAllOps<Eigen::half>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) {
// Test with equal elements.
TestTwoReplicasOneOperand<bool>(
"and",
/*input_value=*/LiteralUtil::CreateR1<bool>({true, false}),
/*expected_value=*/LiteralUtil::CreateR1<bool>({true, false}));
// Test with {true, false}.
const char* hlo_module = R"(
HloModule test
apply_op {
x = pred[] parameter(0)
y = pred[] parameter(1)
ROOT apply_op = pred[] and(x, y)
}
ENTRY test_computation {
id = u32[] replica-id()
c = u32[] constant(0)
p = pred[] compare(id, c), direction=EQ
p2 = pred[1] bitcast(p)
crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op
copy = pred[1] copy(crs)
ROOT out = pred[1] bitcast(copy)
}
)";
auto config = GetModuleConfigForTest();
config.set_replica_count(2);
auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
ExecuteReplicated(std::move(module), {},
/*num_replicas=*/2,
/*use_threads=*/true));
for (int replica_idx = 0; replica_idx < 2; replica_idx++) {
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<bool>({false}),
results[replica_idx]));
}
}
XLA_TEST_F(CollectiveOpsTest, AllReduceOr_Pred) {
// Test with equal elements.
TestTwoReplicasOneOperand<bool>(
"or",
/*input_value=*/LiteralUtil::CreateR1<bool>({true, false}),
/*expected_value=*/LiteralUtil::CreateR1<bool>({true, false}));
// Test with {true, false}.
const char* hlo_module = R"(
HloModule test
apply_op {
x = pred[] parameter(0)
y = pred[] parameter(1)
ROOT apply_op = pred[] or(x, y)
}
ENTRY test_computation {
id = u32[] replica-id()
c = u32[] constant(0)
p = pred[] compare(id, c), direction=EQ
p2 = pred[1] bitcast(p)
crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op
copy = pred[1] copy(crs)
ROOT out = pred[1] bitcast(copy)
}
)";
auto config = GetModuleConfigForTest();
config.set_replica_count(2);
auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
ExecuteReplicated(std::move(module), {},
/*num_replicas=*/2,
/*use_threads=*/true));
for (int replica_idx = 0; replica_idx < 2; replica_idx++) {
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<bool>({true}),
results[replica_idx]));
}
}
// Tries all-to-all operations across all 2^kNumDevices - 1 combinations of
// devices in sequence.
XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) {