[TF:XLA] Handle more patterns in ArCrsCombiner, and handle sequences of patterns.
Now, we optimize any sequence of the form: AR [Bitcast|Transpose|Reshape|Convert|Multiply|Add|Subtract]* CRS PiperOrigin-RevId: 225090998
This commit is contained in:
parent
9748092a5d
commit
184223ec16
@ -36,24 +36,40 @@ namespace {
|
||||
|
||||
namespace m = match;
|
||||
|
||||
// If the argument instruction is a CRS in the sequence
|
||||
// AR -> Convert -> Add -> CRS
|
||||
// then return the AR in the sequence.
|
||||
// TODO(b/117554291): Rewrite this to recognize more general patterns,
|
||||
// not just the specific one of AR -> Add -> Convert -> CRS.
|
||||
absl::optional<HloInstruction*> MatchesArCrsPattern(
|
||||
HloInstruction* instruction) {
|
||||
HloInstruction *ar, *convert, *add, *crs;
|
||||
if (Match(instruction,
|
||||
m::CrossReplicaSum(
|
||||
&crs, m::Add(&add, m::Op(),
|
||||
m::Convert(&convert,
|
||||
m::CrossReplicaSum(&ar, m::Op()))))) &&
|
||||
ar->users().size() == 1 && ar->shape().element_type() == BF16 &&
|
||||
convert->shape().element_type() == F32 && !crs->all_reduce_id()) {
|
||||
return ar;
|
||||
// Returns true iff the argument instruction is an AllReduce, followed by a
|
||||
// certain sequence of instructions and then a CRS. It must be possible to move
|
||||
// the AR past each instruction in the sequence.
|
||||
bool MatchesArCrsPattern(HloInstruction* instruction) {
|
||||
auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool {
|
||||
if (instruction->user_count() != 1) {
|
||||
return false;
|
||||
}
|
||||
auto opcode = instruction->opcode();
|
||||
return opcode == HloOpcode::kBitcast || opcode == HloOpcode::kTranspose ||
|
||||
opcode == HloOpcode::kReshape || opcode == HloOpcode::kConvert ||
|
||||
opcode == HloOpcode::kAdd || opcode == HloOpcode::kSubtract ||
|
||||
opcode == HloOpcode::kMultiply;
|
||||
};
|
||||
|
||||
auto computation_is_addition = [](HloComputation* c) {
|
||||
return c->instruction_count() == 3 &&
|
||||
Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter()));
|
||||
};
|
||||
|
||||
if (!instruction->IsCrossModuleAllReduce() ||
|
||||
!computation_is_addition(instruction->called_computations()[0]) ||
|
||||
instruction->user_count() != 1) {
|
||||
return false;
|
||||
}
|
||||
return absl::optional<HloInstruction*>();
|
||||
auto next = instruction->users()[0];
|
||||
while (!next->IsCrossReplicaAllReduce()) {
|
||||
if (can_ar_move_past_instruction(next)) {
|
||||
next = next->users()[0];
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return computation_is_addition(next->called_computations()[0]);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -195,9 +211,8 @@ bool ArCrsCombiner::InstructionsComputeSameValue(
|
||||
void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
auto ar = MatchesArCrsPattern(instruction);
|
||||
if (ar) {
|
||||
all_reduce_map_[*((*ar)->all_reduce_id())].push_back(*ar);
|
||||
if (MatchesArCrsPattern(instruction)) {
|
||||
all_reduce_map_[*(instruction->all_reduce_id())].push_back(instruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -205,21 +220,23 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
|
||||
|
||||
void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
|
||||
for (auto it : all_reduce_map_) {
|
||||
auto all_reduce_id = it.first;
|
||||
auto instruction_vec = it.second;
|
||||
CHECK_EQ(instruction_vec.size(), num_spatial_partitions_);
|
||||
|
||||
auto instr_0 = instruction_vec[0];
|
||||
auto add_0 = instr_0->users()[0]->users()[0];
|
||||
CHECK_EQ(HloOpcode::kAdd, add_0->opcode());
|
||||
|
||||
for (int i = 1; i < instruction_vec.size(); ++i) {
|
||||
auto instr_i = instruction_vec[i];
|
||||
auto add_i = instr_i->users()[0]->users()[0];
|
||||
CHECK_EQ(HloOpcode::kAdd, add_i->opcode());
|
||||
auto next_0 = instr_0->users()[0];
|
||||
auto next_i = instr_i->users()[0];
|
||||
absl::flat_hash_map<int64, int64> visited_pairs;
|
||||
if (!InstructionsComputeSameValue(add_0, add_i, &visited_pairs)) {
|
||||
all_reduce_map_.erase(it.first);
|
||||
}
|
||||
do {
|
||||
if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) {
|
||||
all_reduce_map_.erase(all_reduce_id);
|
||||
break;
|
||||
}
|
||||
next_0 = next_0->users()[0];
|
||||
next_i = next_i->users()[0];
|
||||
} while (!next_0->IsCrossReplicaAllReduce());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -228,47 +245,51 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
|
||||
if (all_reduce_map_.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto computation_is_addition = [](HloComputation* c) {
|
||||
return c->instruction_count() == 3 &&
|
||||
Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter()));
|
||||
};
|
||||
|
||||
for (auto it : all_reduce_map_) {
|
||||
auto instruction_vec = it.second;
|
||||
for (auto all_reduce : instruction_vec) {
|
||||
auto parent_computation = all_reduce->parent();
|
||||
auto convert = all_reduce->users()[0];
|
||||
auto add = convert->users()[0];
|
||||
auto crs = add->users()[0];
|
||||
|
||||
if (!computation_is_addition(all_reduce->called_computations()[0]) ||
|
||||
!computation_is_addition(crs->called_computations()[0])) {
|
||||
continue;
|
||||
}
|
||||
HloInstruction* other_summand = (add->operands()[0] == convert)
|
||||
? add->operands()[1]
|
||||
: add->operands()[0];
|
||||
// To move the AR past the addition, we need to divide other_summand by
|
||||
// the number of spatial partitions.
|
||||
CHECK_EQ(all_reduce->user_count(), 1);
|
||||
TF_CHECK_OK(
|
||||
all_reduce->ReplaceAllUsesWith(all_reduce->mutable_operand(0)));
|
||||
auto shape = other_summand->shape();
|
||||
Literal lit(shape);
|
||||
lit.PopulateWithValue<float>(num_spatial_partitions_);
|
||||
auto divisor = parent_computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(lit.Clone()));
|
||||
auto division =
|
||||
parent_computation->AddInstruction(HloInstruction::CreateBinary(
|
||||
shape, HloOpcode::kDivide, other_summand, divisor));
|
||||
TF_CHECK_OK(other_summand->ReplaceUseWith(add, division));
|
||||
// The AllReduce and the CRS are combined to an all-core AllReduce.
|
||||
crs->set_all_reduce_id(all_reduce->all_reduce_id());
|
||||
auto all_reduce_id = all_reduce->all_reduce_id();
|
||||
auto prev = all_reduce->mutable_operand(0);
|
||||
auto next = all_reduce->users()[0];
|
||||
TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev));
|
||||
TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce));
|
||||
while (!next->IsCrossReplicaAllReduce()) {
|
||||
switch (next->opcode()) {
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kTranspose:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kMultiply:
|
||||
break;
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kSubtract: {
|
||||
auto other_operand = (next->operands()[0] == prev)
|
||||
? next->operands()[1]
|
||||
: next->operands()[0];
|
||||
// To move the AR past the addition/subtraction, we need to divide
|
||||
// other_operand by the number of spatial partitions.
|
||||
auto shape = other_operand->shape();
|
||||
Literal lit(shape);
|
||||
lit.PopulateWithValue<float>(num_spatial_partitions_);
|
||||
auto divisor = parent_computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(lit.Clone()));
|
||||
auto division =
|
||||
parent_computation->AddInstruction(HloInstruction::CreateBinary(
|
||||
shape, HloOpcode::kDivide, other_operand, divisor));
|
||||
TF_CHECK_OK(other_operand->ReplaceUseWith(next, division));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Unexpected instruction: " << next->ToShortString();
|
||||
}
|
||||
prev = next;
|
||||
next = next->users()[0];
|
||||
}
|
||||
// The AllReduce and the CRS are combined to an all-core AllReduce.
|
||||
next->set_all_reduce_id(all_reduce_id);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -25,9 +25,12 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Combine an AllReduce and a CrossReplicaSum when they are close to each other
|
||||
// in the graph, to use an efficient CrossReplicaSum implementation that
|
||||
// fully utilizes the interconnect bandwidth.
|
||||
// When the HLO graph contains an AllReduce, followed by some simple linear
|
||||
// operations, followed by a CrossReplicaSum, we can combine the AR and the CRS,
|
||||
// to use an efficient CrossReplicaSum implementation that fully utilizes the
|
||||
// interconnect bandwidth.
|
||||
// Such sequences appear in spatially partitioned models.
|
||||
// This pass must run right after spatial partitioning.
|
||||
class ArCrsCombiner : public HloModulePass {
|
||||
public:
|
||||
ArCrsCombiner(int num_spatial_partitions)
|
||||
|
@ -326,11 +326,27 @@ ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
|
||||
EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewritePatternArConvertAddCrs) {
|
||||
void CompareReplicaGroups(const std::vector<ReplicaGroup>& groups_before,
|
||||
const std::vector<ReplicaGroup>& groups_after) {
|
||||
ASSERT_EQ(groups_before.size(), groups_after.size());
|
||||
for (int i = 0; i < groups_before.size(); ++i) {
|
||||
// Somewhat verbose way to compare the replica_ids, because EqualsProto
|
||||
// is not available in the open-source build.
|
||||
auto group_before = groups_before[i];
|
||||
std::vector<int64> ids_before(group_before.replica_ids().begin(),
|
||||
group_before.replica_ids().end());
|
||||
auto group_after = groups_after[i];
|
||||
std::vector<int64> ids_after(group_after.replica_ids().begin(),
|
||||
group_after.replica_ids().end());
|
||||
EXPECT_EQ(ids_before, ids_after);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteArConvertCrs) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%binary_add (a: bf16[], b: bf16[]) -> bf16[] {
|
||||
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
|
||||
%a = bf16[] parameter(0)
|
||||
%b = bf16[] parameter(1)
|
||||
ROOT %add = bf16[] add(%a, %b)
|
||||
@ -342,48 +358,257 @@ HloModule foobar
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
|
||||
%p = f32[2,2] parameter(0)
|
||||
%constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}})
|
||||
%constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
|
||||
ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||
%p = bf16[] parameter(0)
|
||||
|
||||
%cross-replica-sum.ar.1 = bf16[2,2]
|
||||
%cross-replica-sum.ar.1 = bf16[]
|
||||
cross-replica-sum(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
to_apply=%sum.bf16,
|
||||
sharding={maximal device=0}
|
||||
%convert.1 = f32[]
|
||||
convert(%cross-replica-sum.ar.1),
|
||||
sharding={maximal device=0}
|
||||
%cross-replica-sum.1 = f32[]
|
||||
cross-replica-sum(%convert.1),
|
||||
replica_groups={{0,1}},
|
||||
to_apply=%sum.f32,
|
||||
sharding={maximal device=0}
|
||||
|
||||
%cross-replica-sum.ar.2 = bf16[]
|
||||
cross-replica-sum(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
to_apply=%sum.bf16,
|
||||
sharding={maximal device=1}
|
||||
%convert.2 = f32[]
|
||||
convert(%cross-replica-sum.ar.2),
|
||||
sharding={maximal device=1}
|
||||
%cross-replica-sum.2 = f32[]
|
||||
cross-replica-sum(%convert.2),
|
||||
replica_groups={{0,1}},
|
||||
to_apply=%sum.f32,
|
||||
sharding={maximal device=1}
|
||||
|
||||
ROOT %tuple = (f32[], f32[])
|
||||
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
|
||||
sharding={{maximal device=0}, {maximal device=1}}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(2);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::CrossReplicaSum(op::Convert(op::Parameter())),
|
||||
op::CrossReplicaSum(op::Convert(op::Parameter()))));
|
||||
auto crs_after =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_after = crs_after->replica_groups();
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%sum.1 (a: f32[2,1], b: f32[2,1]) -> f32[2,1] {
|
||||
%a = f32[2,1] parameter(0)
|
||||
%b = f32[2,1] parameter(1)
|
||||
ROOT %add = f32[2,1] add(%a, %b)
|
||||
}
|
||||
|
||||
%sum.2 (x: f32[2], y: f32[2]) -> f32[2] {
|
||||
%x = f32[2] parameter(0)
|
||||
%y = f32[2] parameter(1)
|
||||
ROOT %add = f32[2] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
|
||||
%p = f32[2,1] parameter(0)
|
||||
|
||||
%cross-replica-sum.ar.1 = f32[2,1]
|
||||
cross-replica-sum(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
to_apply=%sum.1,
|
||||
sharding={maximal device=0}
|
||||
%bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %cross-replica-sum.ar.1)
|
||||
%cross-replica-sum.1 = f32[2]
|
||||
cross-replica-sum(%bitcast.1),
|
||||
replica_groups={{0,1}},
|
||||
to_apply=%sum.2,
|
||||
sharding={maximal device=0}
|
||||
|
||||
%cross-replica-sum.ar.2 = f32[2,1]
|
||||
cross-replica-sum(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
to_apply=%sum.1,
|
||||
sharding={maximal device=1}
|
||||
%bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %cross-replica-sum.ar.2)
|
||||
%cross-replica-sum.2 = f32[2]
|
||||
cross-replica-sum(%bitcast.2),
|
||||
replica_groups={{0,1}},
|
||||
to_apply=%sum.2,
|
||||
sharding={maximal device=1}
|
||||
|
||||
ROOT %tuple = (f32[], f32[])
|
||||
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
|
||||
sharding={{maximal device=0}, {maximal device=1}}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(2);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::CrossReplicaSum(op::Bitcast(op::Parameter())),
|
||||
op::CrossReplicaSum(op::Bitcast(op::Parameter()))));
|
||||
auto crs_after =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_after = crs_after->replica_groups();
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrs) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%p = f32[] parameter(0)
|
||||
%constant.f32 = f32[] constant(123)
|
||||
|
||||
%cross-replica-sum.ar.1 = f32[]
|
||||
cross-replica-sum(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
to_apply=%sum.f32,
|
||||
sharding={maximal device=0}
|
||||
%multiply.1 = f32[]
|
||||
multiply(%cross-replica-sum.ar.1, %constant.f32),
|
||||
sharding={maximal device=0}
|
||||
%cross-replica-sum.1 = f32[]
|
||||
cross-replica-sum(%multiply.1),
|
||||
replica_groups={{0,1}},
|
||||
to_apply=%sum.f32,
|
||||
sharding={maximal device=0}
|
||||
|
||||
%cross-replica-sum.ar.2 = f32[]
|
||||
cross-replica-sum(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
to_apply=%sum.f32,
|
||||
sharding={maximal device=1}
|
||||
%multiply.2 = f32[]
|
||||
multiply(%cross-replica-sum.ar.2, %constant.f32),
|
||||
sharding={maximal device=1}
|
||||
%cross-replica-sum.2 = f32[]
|
||||
cross-replica-sum(%multiply.2),
|
||||
replica_groups={{0,1}},
|
||||
to_apply=%sum.f32,
|
||||
sharding={maximal device=1}
|
||||
|
||||
ROOT %tuple = (f32[], f32[])
|
||||
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
|
||||
sharding={{maximal device=0}, {maximal device=1}}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(2);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
op::Tuple(
|
||||
op::CrossReplicaSum(op::Multiply(op::Parameter(), op::Constant())),
|
||||
op::CrossReplicaSum(op::Multiply(op::Parameter(), op::Constant()))));
|
||||
auto crs_after =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_after = crs_after->replica_groups();
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
|
||||
%a = bf16[] parameter(0)
|
||||
%b = bf16[] parameter(1)
|
||||
ROOT %add = bf16[] add(%a, %b)
|
||||
}
|
||||
|
||||
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%p = f32[] parameter(0)
|
||||
%constant.bf16 = bf16[] constant(1)
|
||||
%constant.f32 = f32[] constant(2)
|
||||
|
||||
%cross-replica-sum.ar.1 = bf16[]
|
||||
cross-replica-sum(%constant.bf16),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
to_apply=%binary_add,
|
||||
to_apply=%sum.bf16,
|
||||
sharding={maximal device=0}
|
||||
%convert.1 = f32[2,2]
|
||||
%convert.1 = f32[]
|
||||
convert(%cross-replica-sum.ar.1),
|
||||
sharding={maximal device=0}
|
||||
%add.1 = f32[2,2]
|
||||
%add.1 = f32[]
|
||||
add(%constant.f32, %convert.1),
|
||||
sharding={maximal device=0}
|
||||
%cross-replica-sum.1 = f32[2,2]
|
||||
%cross-replica-sum.1 = f32[]
|
||||
cross-replica-sum(%add.1),
|
||||
replica_groups={{0,1}},
|
||||
to_apply=%sum.f32,
|
||||
sharding={maximal device=0}
|
||||
|
||||
%cross-replica-sum.ar.2 = bf16[2,2]
|
||||
%cross-replica-sum.ar.2 = bf16[]
|
||||
cross-replica-sum(%constant.bf16),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
to_apply=%binary_add,
|
||||
to_apply=%sum.bf16,
|
||||
sharding={maximal device=1}
|
||||
%convert.2 = f32[2,2]
|
||||
%convert.2 = f32[]
|
||||
convert(%cross-replica-sum.ar.2),
|
||||
sharding={maximal device=1}
|
||||
%add.2 = f32[2,2]
|
||||
%add.2 = f32[]
|
||||
add(%constant.f32, %convert.2),
|
||||
sharding={maximal device=1}
|
||||
%cross-replica-sum.2 = f32[2,2]
|
||||
%cross-replica-sum.2 = f32[]
|
||||
cross-replica-sum(%add.2),
|
||||
replica_groups={{0,1}},
|
||||
to_apply=%sum.f32,
|
||||
sharding={maximal device=1}
|
||||
|
||||
ROOT %tuple = (f32[2,2], f32[2,2])
|
||||
ROOT %tuple = (f32[], f32[])
|
||||
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
|
||||
sharding={{maximal device=0}, {maximal device=1}}
|
||||
}
|
||||
@ -407,25 +632,14 @@ ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
|
||||
auto crs_after =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_after = crs_after->replica_groups();
|
||||
ASSERT_EQ(replica_groups_before.size(), replica_groups_after.size());
|
||||
for (int i = 0; i < replica_groups_before.size(); ++i) {
|
||||
// Somewhat verbose way to compare the replica_ids, because EqualsProto
|
||||
// is not available in the open-source build.
|
||||
auto group_before = replica_groups_before[i];
|
||||
std::vector<int64> ids_before(group_before.replica_ids().begin(),
|
||||
group_before.replica_ids().end());
|
||||
auto group_after = replica_groups_after[i];
|
||||
std::vector<int64> ids_after(group_after.replica_ids().begin(),
|
||||
group_after.replica_ids().end());
|
||||
EXPECT_EQ(ids_before, ids_after);
|
||||
}
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%binary_add (a: bf16[], b: bf16[]) -> bf16[] {
|
||||
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
|
||||
%a = bf16[] parameter(0)
|
||||
%b = bf16[] parameter(1)
|
||||
ROOT %add = bf16[] add(%a, %b)
|
||||
@ -437,49 +651,49 @@ HloModule foobar
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
|
||||
%p = f32[2,2] parameter(0)
|
||||
%constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}})
|
||||
%constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
|
||||
%constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}})
|
||||
ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%p = f32[] parameter(0)
|
||||
%constant.bf16 = bf16[] constant(1)
|
||||
%constant.f32.1 = f32[] constant(2)
|
||||
%constant.f32.2 = f32[] constant(3)
|
||||
|
||||
%cross-replica-sum.ar.1 = bf16[2,2]
|
||||
%cross-replica-sum.ar.1 = bf16[]
|
||||
cross-replica-sum(%constant.bf16),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
to_apply=%binary_add,
|
||||
to_apply=%sum.bf16,
|
||||
sharding={maximal device=0}
|
||||
%convert.1 = f32[2,2]
|
||||
%convert.1 = f32[]
|
||||
convert(%cross-replica-sum.ar.1),
|
||||
sharding={maximal device=0}
|
||||
%add.1 = f32[2,2]
|
||||
%add.1 = f32[]
|
||||
add(%constant.f32.1, %convert.1),
|
||||
sharding={maximal device=0}
|
||||
%cross-replica-sum.1 = f32[2,2]
|
||||
%cross-replica-sum.1 = f32[]
|
||||
cross-replica-sum(%add.1),
|
||||
replica_groups={{0,1}},
|
||||
to_apply=%sum.f32,
|
||||
sharding={maximal device=0}
|
||||
|
||||
%cross-replica-sum.ar.2 = bf16[2,2]
|
||||
%cross-replica-sum.ar.2 = bf16[]
|
||||
cross-replica-sum(%constant.bf16),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
to_apply=%binary_add,
|
||||
to_apply=%sum.bf16,
|
||||
sharding={maximal device=1}
|
||||
%convert.2 = f32[2,2]
|
||||
%convert.2 = f32[]
|
||||
convert(%cross-replica-sum.ar.2),
|
||||
sharding={maximal device=1}
|
||||
%add.2 = f32[2,2]
|
||||
%add.2 = f32[]
|
||||
add(%constant.f32.2, %convert.2),
|
||||
sharding={maximal device=1}
|
||||
%cross-replica-sum.2 = f32[2,2]
|
||||
%cross-replica-sum.2 = f32[]
|
||||
cross-replica-sum(%add.2),
|
||||
replica_groups={{0,1}},
|
||||
to_apply=%sum.f32,
|
||||
sharding={maximal device=1}
|
||||
|
||||
ROOT %tuple = (f32[2,2], f32[2,2])
|
||||
ROOT %tuple = (f32[], f32[])
|
||||
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
|
||||
sharding={{maximal device=0}, {maximal device=1}}
|
||||
}
|
||||
|
@ -2060,6 +2060,10 @@ bool HloInstruction::IsCrossModuleAllReduce() const {
|
||||
return opcode() == HloOpcode::kCrossReplicaSum && all_reduce_id();
|
||||
}
|
||||
|
||||
bool HloInstruction::IsCrossReplicaAllReduce() const {
|
||||
return opcode() == HloOpcode::kCrossReplicaSum && !all_reduce_id();
|
||||
}
|
||||
|
||||
string HloInstruction::ToStringWithCanonicalNameMap(
|
||||
const HloPrintOptions& options,
|
||||
CanonicalNameMap* canonical_name_map) const {
|
||||
|
@ -1174,9 +1174,12 @@ class HloInstruction {
|
||||
// Returns true if this instruction is elementwise on all its operands.
|
||||
bool IsElementwise() const;
|
||||
|
||||
// Returns true if this is an cross module all-reduce instrucion.
|
||||
// Returns true if this is a cross module all-reduce instruction.
|
||||
bool IsCrossModuleAllReduce() const;
|
||||
|
||||
// Returns true if this is a cross-replica all-reduce instruction.
|
||||
bool IsCrossReplicaAllReduce() const;
|
||||
|
||||
// Returns true if this elementwise instruction implicitly broadcasts operand
|
||||
// `operand_idx`.
|
||||
//
|
||||
|
Loading…
Reference in New Issue
Block a user