[XLA:CPU/GPU] Implement all-reduce and/or of pred
This uses the fact that we store pred as 8-bit integers, so we just forward to the implementation of uint8 min/max. PiperOrigin-RevId: 303779071 Change-Id: I899e787d0d0db023219d270cdc81f965781ff814
This commit is contained in:
parent
322cfa2290
commit
60d69096d2
@ -30,13 +30,18 @@ absl::optional<ReductionKind> MatchReductionComputation(
|
||||
.WithShape(m::Shape().IsEffectiveScalar()));
|
||||
};
|
||||
|
||||
// Match the operation to a reduction kind. We can represent and/or of pred as
|
||||
// min/max. This works because pred is stored as an 8-bit int of value 0 or 1.
|
||||
PrimitiveType type = computation->root_instruction()->shape().element_type();
|
||||
if (match_opcode(HloOpcode::kAdd)) {
|
||||
return ReductionKind::SUM;
|
||||
} else if (match_opcode(HloOpcode::kMultiply)) {
|
||||
return ReductionKind::PRODUCT;
|
||||
} else if (match_opcode(HloOpcode::kMinimum)) {
|
||||
} else if (match_opcode(HloOpcode::kMinimum) ||
|
||||
(type == PRED && match_opcode(HloOpcode::kAnd))) {
|
||||
return ReductionKind::MIN;
|
||||
} else if (match_opcode(HloOpcode::kMaximum)) {
|
||||
} else if (match_opcode(HloOpcode::kMaximum) ||
|
||||
(type == PRED && match_opcode(HloOpcode::kOr))) {
|
||||
return ReductionKind::MAX;
|
||||
} else {
|
||||
return absl::nullopt;
|
||||
|
||||
@ -278,6 +278,7 @@ class CpuAllReduceRendezvous : public xla::Rendezvous<std::nullptr_t> {
|
||||
case xla::S8:
|
||||
DoAllReduce<xla::S8>(participant);
|
||||
break;
|
||||
case xla::PRED:
|
||||
case xla::U8:
|
||||
DoAllReduce<xla::U8>(participant);
|
||||
break;
|
||||
|
||||
@ -1416,6 +1416,7 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
|
||||
bool is_datatype_supported = [&] {
|
||||
// TODO(cheshire): Fix duplication wrt. cpu_runtime
|
||||
switch (datatype) {
|
||||
case PRED:
|
||||
case S8:
|
||||
case U8:
|
||||
case S32:
|
||||
|
||||
@ -173,6 +173,7 @@ absl::optional<ncclDataType_t> DatatypeToNccl(PrimitiveType element_type) {
|
||||
switch (element_type) {
|
||||
case S8:
|
||||
return ncclInt8;
|
||||
case PRED:
|
||||
case U8:
|
||||
return ncclUint8;
|
||||
case S32:
|
||||
|
||||
@ -218,6 +218,88 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) {
|
||||
TestAllOps<Eigen::half>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) {
|
||||
// Test with equal elements.
|
||||
TestTwoReplicasOneOperand<bool>(
|
||||
"and",
|
||||
/*input_value=*/LiteralUtil::CreateR1<bool>({true, false}),
|
||||
/*expected_value=*/LiteralUtil::CreateR1<bool>({true, false}));
|
||||
|
||||
// Test with {true, false}.
|
||||
const char* hlo_module = R"(
|
||||
HloModule test
|
||||
|
||||
apply_op {
|
||||
x = pred[] parameter(0)
|
||||
y = pred[] parameter(1)
|
||||
ROOT apply_op = pred[] and(x, y)
|
||||
}
|
||||
|
||||
ENTRY test_computation {
|
||||
id = u32[] replica-id()
|
||||
c = u32[] constant(0)
|
||||
p = pred[] compare(id, c), direction=EQ
|
||||
p2 = pred[1] bitcast(p)
|
||||
crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op
|
||||
copy = pred[1] copy(crs)
|
||||
ROOT out = pred[1] bitcast(copy)
|
||||
}
|
||||
)";
|
||||
|
||||
auto config = GetModuleConfigForTest();
|
||||
config.set_replica_count(2);
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie();
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
|
||||
ExecuteReplicated(std::move(module), {},
|
||||
/*num_replicas=*/2,
|
||||
/*use_threads=*/true));
|
||||
for (int replica_idx = 0; replica_idx < 2; replica_idx++) {
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<bool>({false}),
|
||||
results[replica_idx]));
|
||||
}
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduceOr_Pred) {
|
||||
// Test with equal elements.
|
||||
TestTwoReplicasOneOperand<bool>(
|
||||
"or",
|
||||
/*input_value=*/LiteralUtil::CreateR1<bool>({true, false}),
|
||||
/*expected_value=*/LiteralUtil::CreateR1<bool>({true, false}));
|
||||
|
||||
// Test with {true, false}.
|
||||
const char* hlo_module = R"(
|
||||
HloModule test
|
||||
|
||||
apply_op {
|
||||
x = pred[] parameter(0)
|
||||
y = pred[] parameter(1)
|
||||
ROOT apply_op = pred[] or(x, y)
|
||||
}
|
||||
|
||||
ENTRY test_computation {
|
||||
id = u32[] replica-id()
|
||||
c = u32[] constant(0)
|
||||
p = pred[] compare(id, c), direction=EQ
|
||||
p2 = pred[1] bitcast(p)
|
||||
crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op
|
||||
copy = pred[1] copy(crs)
|
||||
ROOT out = pred[1] bitcast(copy)
|
||||
}
|
||||
)";
|
||||
|
||||
auto config = GetModuleConfigForTest();
|
||||
config.set_replica_count(2);
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie();
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
|
||||
ExecuteReplicated(std::move(module), {},
|
||||
/*num_replicas=*/2,
|
||||
/*use_threads=*/true));
|
||||
for (int replica_idx = 0; replica_idx < 2; replica_idx++) {
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<bool>({true}),
|
||||
results[replica_idx]));
|
||||
}
|
||||
}
|
||||
|
||||
// Tries all-to-all operations across all 2^kNumDevices - 1 combinations of
|
||||
// devices in sequence.
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user