[TF:XLA] Handle more patterns in ArCrsCombiner, and handle sequences of patterns.
Now, we optimize any sequence of the form: AR [Bitcast|Transpose|Reshape|Convert|Multiply|Add|Subtract]* CRS PiperOrigin-RevId: 225090998
This commit is contained in:
parent
9748092a5d
commit
184223ec16
@ -36,24 +36,40 @@ namespace {
|
|||||||
|
|
||||||
namespace m = match;
|
namespace m = match;
|
||||||
|
|
||||||
// If the argument instruction is a CRS in the sequence
|
// Returns true iff the argument instruction is an AllReduce, followed by a
|
||||||
// AR -> Convert -> Add -> CRS
|
// certain sequence of instructions and then a CRS. It must be possible to move
|
||||||
// then return the AR in the sequence.
|
// the AR past each instruction in the sequence.
|
||||||
// TODO(b/117554291): Rewrite this to recognize more general patterns,
|
bool MatchesArCrsPattern(HloInstruction* instruction) {
|
||||||
// not just the specific one of AR -> Add -> Convert -> CRS.
|
auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool {
|
||||||
absl::optional<HloInstruction*> MatchesArCrsPattern(
|
if (instruction->user_count() != 1) {
|
||||||
HloInstruction* instruction) {
|
return false;
|
||||||
HloInstruction *ar, *convert, *add, *crs;
|
|
||||||
if (Match(instruction,
|
|
||||||
m::CrossReplicaSum(
|
|
||||||
&crs, m::Add(&add, m::Op(),
|
|
||||||
m::Convert(&convert,
|
|
||||||
m::CrossReplicaSum(&ar, m::Op()))))) &&
|
|
||||||
ar->users().size() == 1 && ar->shape().element_type() == BF16 &&
|
|
||||||
convert->shape().element_type() == F32 && !crs->all_reduce_id()) {
|
|
||||||
return ar;
|
|
||||||
}
|
}
|
||||||
return absl::optional<HloInstruction*>();
|
auto opcode = instruction->opcode();
|
||||||
|
return opcode == HloOpcode::kBitcast || opcode == HloOpcode::kTranspose ||
|
||||||
|
opcode == HloOpcode::kReshape || opcode == HloOpcode::kConvert ||
|
||||||
|
opcode == HloOpcode::kAdd || opcode == HloOpcode::kSubtract ||
|
||||||
|
opcode == HloOpcode::kMultiply;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto computation_is_addition = [](HloComputation* c) {
|
||||||
|
return c->instruction_count() == 3 &&
|
||||||
|
Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter()));
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!instruction->IsCrossModuleAllReduce() ||
|
||||||
|
!computation_is_addition(instruction->called_computations()[0]) ||
|
||||||
|
instruction->user_count() != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto next = instruction->users()[0];
|
||||||
|
while (!next->IsCrossReplicaAllReduce()) {
|
||||||
|
if (can_ar_move_past_instruction(next)) {
|
||||||
|
next = next->users()[0];
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return computation_is_addition(next->called_computations()[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -195,9 +211,8 @@ bool ArCrsCombiner::InstructionsComputeSameValue(
|
|||||||
void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
|
void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
|
||||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||||
for (HloInstruction* instruction : computation->instructions()) {
|
for (HloInstruction* instruction : computation->instructions()) {
|
||||||
auto ar = MatchesArCrsPattern(instruction);
|
if (MatchesArCrsPattern(instruction)) {
|
||||||
if (ar) {
|
all_reduce_map_[*(instruction->all_reduce_id())].push_back(instruction);
|
||||||
all_reduce_map_[*((*ar)->all_reduce_id())].push_back(*ar);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -205,21 +220,23 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
|
|||||||
|
|
||||||
void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
|
void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
|
||||||
for (auto it : all_reduce_map_) {
|
for (auto it : all_reduce_map_) {
|
||||||
|
auto all_reduce_id = it.first;
|
||||||
auto instruction_vec = it.second;
|
auto instruction_vec = it.second;
|
||||||
CHECK_EQ(instruction_vec.size(), num_spatial_partitions_);
|
CHECK_EQ(instruction_vec.size(), num_spatial_partitions_);
|
||||||
|
|
||||||
auto instr_0 = instruction_vec[0];
|
auto instr_0 = instruction_vec[0];
|
||||||
auto add_0 = instr_0->users()[0]->users()[0];
|
|
||||||
CHECK_EQ(HloOpcode::kAdd, add_0->opcode());
|
|
||||||
|
|
||||||
for (int i = 1; i < instruction_vec.size(); ++i) {
|
for (int i = 1; i < instruction_vec.size(); ++i) {
|
||||||
auto instr_i = instruction_vec[i];
|
auto instr_i = instruction_vec[i];
|
||||||
auto add_i = instr_i->users()[0]->users()[0];
|
auto next_0 = instr_0->users()[0];
|
||||||
CHECK_EQ(HloOpcode::kAdd, add_i->opcode());
|
auto next_i = instr_i->users()[0];
|
||||||
absl::flat_hash_map<int64, int64> visited_pairs;
|
absl::flat_hash_map<int64, int64> visited_pairs;
|
||||||
if (!InstructionsComputeSameValue(add_0, add_i, &visited_pairs)) {
|
do {
|
||||||
all_reduce_map_.erase(it.first);
|
if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) {
|
||||||
|
all_reduce_map_.erase(all_reduce_id);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
next_0 = next_0->users()[0];
|
||||||
|
next_i = next_i->users()[0];
|
||||||
|
} while (!next_0->IsCrossReplicaAllReduce());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -228,47 +245,51 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
|
|||||||
if (all_reduce_map_.empty()) {
|
if (all_reduce_map_.empty()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto computation_is_addition = [](HloComputation* c) {
|
|
||||||
return c->instruction_count() == 3 &&
|
|
||||||
Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter()));
|
|
||||||
};
|
|
||||||
|
|
||||||
for (auto it : all_reduce_map_) {
|
for (auto it : all_reduce_map_) {
|
||||||
auto instruction_vec = it.second;
|
auto instruction_vec = it.second;
|
||||||
for (auto all_reduce : instruction_vec) {
|
for (auto all_reduce : instruction_vec) {
|
||||||
auto parent_computation = all_reduce->parent();
|
auto parent_computation = all_reduce->parent();
|
||||||
auto convert = all_reduce->users()[0];
|
auto all_reduce_id = all_reduce->all_reduce_id();
|
||||||
auto add = convert->users()[0];
|
auto prev = all_reduce->mutable_operand(0);
|
||||||
auto crs = add->users()[0];
|
auto next = all_reduce->users()[0];
|
||||||
|
TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev));
|
||||||
if (!computation_is_addition(all_reduce->called_computations()[0]) ||
|
TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce));
|
||||||
!computation_is_addition(crs->called_computations()[0])) {
|
while (!next->IsCrossReplicaAllReduce()) {
|
||||||
continue;
|
switch (next->opcode()) {
|
||||||
}
|
case HloOpcode::kBitcast:
|
||||||
HloInstruction* other_summand = (add->operands()[0] == convert)
|
case HloOpcode::kTranspose:
|
||||||
? add->operands()[1]
|
case HloOpcode::kReshape:
|
||||||
: add->operands()[0];
|
case HloOpcode::kConvert:
|
||||||
// To move the AR past the addition, we need to divide other_summand by
|
case HloOpcode::kMultiply:
|
||||||
// the number of spatial partitions.
|
break;
|
||||||
CHECK_EQ(all_reduce->user_count(), 1);
|
case HloOpcode::kAdd:
|
||||||
TF_CHECK_OK(
|
case HloOpcode::kSubtract: {
|
||||||
all_reduce->ReplaceAllUsesWith(all_reduce->mutable_operand(0)));
|
auto other_operand = (next->operands()[0] == prev)
|
||||||
auto shape = other_summand->shape();
|
? next->operands()[1]
|
||||||
|
: next->operands()[0];
|
||||||
|
// To move the AR past the addition/subtraction, we need to divide
|
||||||
|
// other_operand by the number of spatial partitions.
|
||||||
|
auto shape = other_operand->shape();
|
||||||
Literal lit(shape);
|
Literal lit(shape);
|
||||||
lit.PopulateWithValue<float>(num_spatial_partitions_);
|
lit.PopulateWithValue<float>(num_spatial_partitions_);
|
||||||
auto divisor = parent_computation->AddInstruction(
|
auto divisor = parent_computation->AddInstruction(
|
||||||
HloInstruction::CreateConstant(lit.Clone()));
|
HloInstruction::CreateConstant(lit.Clone()));
|
||||||
auto division =
|
auto division =
|
||||||
parent_computation->AddInstruction(HloInstruction::CreateBinary(
|
parent_computation->AddInstruction(HloInstruction::CreateBinary(
|
||||||
shape, HloOpcode::kDivide, other_summand, divisor));
|
shape, HloOpcode::kDivide, other_operand, divisor));
|
||||||
TF_CHECK_OK(other_summand->ReplaceUseWith(add, division));
|
TF_CHECK_OK(other_operand->ReplaceUseWith(next, division));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Unexpected instruction: " << next->ToShortString();
|
||||||
|
}
|
||||||
|
prev = next;
|
||||||
|
next = next->users()[0];
|
||||||
|
}
|
||||||
// The AllReduce and the CRS are combined to an all-core AllReduce.
|
// The AllReduce and the CRS are combined to an all-core AllReduce.
|
||||||
crs->set_all_reduce_id(all_reduce->all_reduce_id());
|
next->set_all_reduce_id(all_reduce_id);
|
||||||
TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,9 +25,12 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// Combine an AllReduce and a CrossReplicaSum when they are close to each other
|
// When the HLO graph contains an AllReduce, followed by some simple linear
|
||||||
// in the graph, to use an efficient CrossReplicaSum implementation that
|
// operations, followed by a CrossReplicaSum, we can combine the AR and the CRS,
|
||||||
// fully utilizes the interconnect bandwidth.
|
// to use an efficient CrossReplicaSum implementation that fully utilizes the
|
||||||
|
// interconnect bandwidth.
|
||||||
|
// Such sequences appear in spatially partitioned models.
|
||||||
|
// This pass must run right after spatial partitioning.
|
||||||
class ArCrsCombiner : public HloModulePass {
|
class ArCrsCombiner : public HloModulePass {
|
||||||
public:
|
public:
|
||||||
ArCrsCombiner(int num_spatial_partitions)
|
ArCrsCombiner(int num_spatial_partitions)
|
||||||
|
@ -326,11 +326,27 @@ ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
|
|||||||
EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
|
EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ArCrsCombinerTest, RewritePatternArConvertAddCrs) {
|
void CompareReplicaGroups(const std::vector<ReplicaGroup>& groups_before,
|
||||||
|
const std::vector<ReplicaGroup>& groups_after) {
|
||||||
|
ASSERT_EQ(groups_before.size(), groups_after.size());
|
||||||
|
for (int i = 0; i < groups_before.size(); ++i) {
|
||||||
|
// Somewhat verbose way to compare the replica_ids, because EqualsProto
|
||||||
|
// is not available in the open-source build.
|
||||||
|
auto group_before = groups_before[i];
|
||||||
|
std::vector<int64> ids_before(group_before.replica_ids().begin(),
|
||||||
|
group_before.replica_ids().end());
|
||||||
|
auto group_after = groups_after[i];
|
||||||
|
std::vector<int64> ids_after(group_after.replica_ids().begin(),
|
||||||
|
group_after.replica_ids().end());
|
||||||
|
EXPECT_EQ(ids_before, ids_after);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArCrsCombinerTest, RewriteArConvertCrs) {
|
||||||
const char* module_str = R"(
|
const char* module_str = R"(
|
||||||
HloModule foobar
|
HloModule foobar
|
||||||
|
|
||||||
%binary_add (a: bf16[], b: bf16[]) -> bf16[] {
|
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
|
||||||
%a = bf16[] parameter(0)
|
%a = bf16[] parameter(0)
|
||||||
%b = bf16[] parameter(1)
|
%b = bf16[] parameter(1)
|
||||||
ROOT %add = bf16[] add(%a, %b)
|
ROOT %add = bf16[] add(%a, %b)
|
||||||
@ -342,48 +358,257 @@ HloModule foobar
|
|||||||
ROOT %add = f32[] add(%x, %y)
|
ROOT %add = f32[] add(%x, %y)
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
|
ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||||
%p = f32[2,2] parameter(0)
|
%p = bf16[] parameter(0)
|
||||||
%constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}})
|
|
||||||
%constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
|
|
||||||
|
|
||||||
%cross-replica-sum.ar.1 = bf16[2,2]
|
%cross-replica-sum.ar.1 = bf16[]
|
||||||
|
cross-replica-sum(%p),
|
||||||
|
replica_groups={{0},{1}},
|
||||||
|
all_reduce_id=1,
|
||||||
|
to_apply=%sum.bf16,
|
||||||
|
sharding={maximal device=0}
|
||||||
|
%convert.1 = f32[]
|
||||||
|
convert(%cross-replica-sum.ar.1),
|
||||||
|
sharding={maximal device=0}
|
||||||
|
%cross-replica-sum.1 = f32[]
|
||||||
|
cross-replica-sum(%convert.1),
|
||||||
|
replica_groups={{0,1}},
|
||||||
|
to_apply=%sum.f32,
|
||||||
|
sharding={maximal device=0}
|
||||||
|
|
||||||
|
%cross-replica-sum.ar.2 = bf16[]
|
||||||
|
cross-replica-sum(%p),
|
||||||
|
replica_groups={{0},{1}},
|
||||||
|
all_reduce_id=1,
|
||||||
|
to_apply=%sum.bf16,
|
||||||
|
sharding={maximal device=1}
|
||||||
|
%convert.2 = f32[]
|
||||||
|
convert(%cross-replica-sum.ar.2),
|
||||||
|
sharding={maximal device=1}
|
||||||
|
%cross-replica-sum.2 = f32[]
|
||||||
|
cross-replica-sum(%convert.2),
|
||||||
|
replica_groups={{0,1}},
|
||||||
|
to_apply=%sum.f32,
|
||||||
|
sharding={maximal device=1}
|
||||||
|
|
||||||
|
ROOT %tuple = (f32[], f32[])
|
||||||
|
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
|
||||||
|
sharding={{maximal device=0}, {maximal device=1}}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str));
|
||||||
|
auto crs_before =
|
||||||
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
|
ArCrsCombiner combiner(2);
|
||||||
|
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
|
op::Tuple(op::CrossReplicaSum(op::Convert(op::Parameter())),
|
||||||
|
op::CrossReplicaSum(op::Convert(op::Parameter()))));
|
||||||
|
auto crs_after =
|
||||||
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
|
auto replica_groups_after = crs_after->replica_groups();
|
||||||
|
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) {
|
||||||
|
const char* module_str = R"(
|
||||||
|
HloModule foobar
|
||||||
|
|
||||||
|
%sum.1 (a: f32[2,1], b: f32[2,1]) -> f32[2,1] {
|
||||||
|
%a = f32[2,1] parameter(0)
|
||||||
|
%b = f32[2,1] parameter(1)
|
||||||
|
ROOT %add = f32[2,1] add(%a, %b)
|
||||||
|
}
|
||||||
|
|
||||||
|
%sum.2 (x: f32[2], y: f32[2]) -> f32[2] {
|
||||||
|
%x = f32[2] parameter(0)
|
||||||
|
%y = f32[2] parameter(1)
|
||||||
|
ROOT %add = f32[2] add(%x, %y)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
|
||||||
|
%p = f32[2,1] parameter(0)
|
||||||
|
|
||||||
|
%cross-replica-sum.ar.1 = f32[2,1]
|
||||||
|
cross-replica-sum(%p),
|
||||||
|
replica_groups={{0},{1}},
|
||||||
|
all_reduce_id=1,
|
||||||
|
to_apply=%sum.1,
|
||||||
|
sharding={maximal device=0}
|
||||||
|
%bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %cross-replica-sum.ar.1)
|
||||||
|
%cross-replica-sum.1 = f32[2]
|
||||||
|
cross-replica-sum(%bitcast.1),
|
||||||
|
replica_groups={{0,1}},
|
||||||
|
to_apply=%sum.2,
|
||||||
|
sharding={maximal device=0}
|
||||||
|
|
||||||
|
%cross-replica-sum.ar.2 = f32[2,1]
|
||||||
|
cross-replica-sum(%p),
|
||||||
|
replica_groups={{0},{1}},
|
||||||
|
all_reduce_id=1,
|
||||||
|
to_apply=%sum.1,
|
||||||
|
sharding={maximal device=1}
|
||||||
|
%bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %cross-replica-sum.ar.2)
|
||||||
|
%cross-replica-sum.2 = f32[2]
|
||||||
|
cross-replica-sum(%bitcast.2),
|
||||||
|
replica_groups={{0,1}},
|
||||||
|
to_apply=%sum.2,
|
||||||
|
sharding={maximal device=1}
|
||||||
|
|
||||||
|
ROOT %tuple = (f32[], f32[])
|
||||||
|
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
|
||||||
|
sharding={{maximal device=0}, {maximal device=1}}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str));
|
||||||
|
auto crs_before =
|
||||||
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
|
ArCrsCombiner combiner(2);
|
||||||
|
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
|
op::Tuple(op::CrossReplicaSum(op::Bitcast(op::Parameter())),
|
||||||
|
op::CrossReplicaSum(op::Bitcast(op::Parameter()))));
|
||||||
|
auto crs_after =
|
||||||
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
|
auto replica_groups_after = crs_after->replica_groups();
|
||||||
|
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrs) {
|
||||||
|
const char* module_str = R"(
|
||||||
|
HloModule foobar
|
||||||
|
|
||||||
|
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
|
||||||
|
%x = f32[] parameter(0)
|
||||||
|
%y = f32[] parameter(1)
|
||||||
|
ROOT %add = f32[] add(%x, %y)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||||
|
%p = f32[] parameter(0)
|
||||||
|
%constant.f32 = f32[] constant(123)
|
||||||
|
|
||||||
|
%cross-replica-sum.ar.1 = f32[]
|
||||||
|
cross-replica-sum(%p),
|
||||||
|
replica_groups={{0},{1}},
|
||||||
|
all_reduce_id=1,
|
||||||
|
to_apply=%sum.f32,
|
||||||
|
sharding={maximal device=0}
|
||||||
|
%multiply.1 = f32[]
|
||||||
|
multiply(%cross-replica-sum.ar.1, %constant.f32),
|
||||||
|
sharding={maximal device=0}
|
||||||
|
%cross-replica-sum.1 = f32[]
|
||||||
|
cross-replica-sum(%multiply.1),
|
||||||
|
replica_groups={{0,1}},
|
||||||
|
to_apply=%sum.f32,
|
||||||
|
sharding={maximal device=0}
|
||||||
|
|
||||||
|
%cross-replica-sum.ar.2 = f32[]
|
||||||
|
cross-replica-sum(%p),
|
||||||
|
replica_groups={{0},{1}},
|
||||||
|
all_reduce_id=1,
|
||||||
|
to_apply=%sum.f32,
|
||||||
|
sharding={maximal device=1}
|
||||||
|
%multiply.2 = f32[]
|
||||||
|
multiply(%cross-replica-sum.ar.2, %constant.f32),
|
||||||
|
sharding={maximal device=1}
|
||||||
|
%cross-replica-sum.2 = f32[]
|
||||||
|
cross-replica-sum(%multiply.2),
|
||||||
|
replica_groups={{0,1}},
|
||||||
|
to_apply=%sum.f32,
|
||||||
|
sharding={maximal device=1}
|
||||||
|
|
||||||
|
ROOT %tuple = (f32[], f32[])
|
||||||
|
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
|
||||||
|
sharding={{maximal device=0}, {maximal device=1}}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str));
|
||||||
|
auto crs_before =
|
||||||
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
|
ArCrsCombiner combiner(2);
|
||||||
|
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
EXPECT_THAT(
|
||||||
|
module->entry_computation()->root_instruction(),
|
||||||
|
op::Tuple(
|
||||||
|
op::CrossReplicaSum(op::Multiply(op::Parameter(), op::Constant())),
|
||||||
|
op::CrossReplicaSum(op::Multiply(op::Parameter(), op::Constant()))));
|
||||||
|
auto crs_after =
|
||||||
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
|
auto replica_groups_after = crs_after->replica_groups();
|
||||||
|
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) {
|
||||||
|
const char* module_str = R"(
|
||||||
|
HloModule foobar
|
||||||
|
|
||||||
|
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
|
||||||
|
%a = bf16[] parameter(0)
|
||||||
|
%b = bf16[] parameter(1)
|
||||||
|
ROOT %add = bf16[] add(%a, %b)
|
||||||
|
}
|
||||||
|
|
||||||
|
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
|
||||||
|
%x = f32[] parameter(0)
|
||||||
|
%y = f32[] parameter(1)
|
||||||
|
ROOT %add = f32[] add(%x, %y)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||||
|
%p = f32[] parameter(0)
|
||||||
|
%constant.bf16 = bf16[] constant(1)
|
||||||
|
%constant.f32 = f32[] constant(2)
|
||||||
|
|
||||||
|
%cross-replica-sum.ar.1 = bf16[]
|
||||||
cross-replica-sum(%constant.bf16),
|
cross-replica-sum(%constant.bf16),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
all_reduce_id=1,
|
||||||
to_apply=%binary_add,
|
to_apply=%sum.bf16,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%convert.1 = f32[2,2]
|
%convert.1 = f32[]
|
||||||
convert(%cross-replica-sum.ar.1),
|
convert(%cross-replica-sum.ar.1),
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%add.1 = f32[2,2]
|
%add.1 = f32[]
|
||||||
add(%constant.f32, %convert.1),
|
add(%constant.f32, %convert.1),
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%cross-replica-sum.1 = f32[2,2]
|
%cross-replica-sum.1 = f32[]
|
||||||
cross-replica-sum(%add.1),
|
cross-replica-sum(%add.1),
|
||||||
replica_groups={{0,1}},
|
replica_groups={{0,1}},
|
||||||
to_apply=%sum.f32,
|
to_apply=%sum.f32,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
|
|
||||||
%cross-replica-sum.ar.2 = bf16[2,2]
|
%cross-replica-sum.ar.2 = bf16[]
|
||||||
cross-replica-sum(%constant.bf16),
|
cross-replica-sum(%constant.bf16),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
all_reduce_id=1,
|
||||||
to_apply=%binary_add,
|
to_apply=%sum.bf16,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%convert.2 = f32[2,2]
|
%convert.2 = f32[]
|
||||||
convert(%cross-replica-sum.ar.2),
|
convert(%cross-replica-sum.ar.2),
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%add.2 = f32[2,2]
|
%add.2 = f32[]
|
||||||
add(%constant.f32, %convert.2),
|
add(%constant.f32, %convert.2),
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%cross-replica-sum.2 = f32[2,2]
|
%cross-replica-sum.2 = f32[]
|
||||||
cross-replica-sum(%add.2),
|
cross-replica-sum(%add.2),
|
||||||
replica_groups={{0,1}},
|
replica_groups={{0,1}},
|
||||||
to_apply=%sum.f32,
|
to_apply=%sum.f32,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
|
|
||||||
ROOT %tuple = (f32[2,2], f32[2,2])
|
ROOT %tuple = (f32[], f32[])
|
||||||
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
|
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
|
||||||
sharding={{maximal device=0}, {maximal device=1}}
|
sharding={{maximal device=0}, {maximal device=1}}
|
||||||
}
|
}
|
||||||
@ -407,25 +632,14 @@ ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
|
|||||||
auto crs_after =
|
auto crs_after =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_after = crs_after->replica_groups();
|
auto replica_groups_after = crs_after->replica_groups();
|
||||||
ASSERT_EQ(replica_groups_before.size(), replica_groups_after.size());
|
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||||
for (int i = 0; i < replica_groups_before.size(); ++i) {
|
|
||||||
// Somewhat verbose way to compare the replica_ids, because EqualsProto
|
|
||||||
// is not available in the open-source build.
|
|
||||||
auto group_before = replica_groups_before[i];
|
|
||||||
std::vector<int64> ids_before(group_before.replica_ids().begin(),
|
|
||||||
group_before.replica_ids().end());
|
|
||||||
auto group_after = replica_groups_after[i];
|
|
||||||
std::vector<int64> ids_after(group_after.replica_ids().begin(),
|
|
||||||
group_after.replica_ids().end());
|
|
||||||
EXPECT_EQ(ids_before, ids_after);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) {
|
TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) {
|
||||||
const char* module_str = R"(
|
const char* module_str = R"(
|
||||||
HloModule foobar
|
HloModule foobar
|
||||||
|
|
||||||
%binary_add (a: bf16[], b: bf16[]) -> bf16[] {
|
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
|
||||||
%a = bf16[] parameter(0)
|
%a = bf16[] parameter(0)
|
||||||
%b = bf16[] parameter(1)
|
%b = bf16[] parameter(1)
|
||||||
ROOT %add = bf16[] add(%a, %b)
|
ROOT %add = bf16[] add(%a, %b)
|
||||||
@ -437,49 +651,49 @@ HloModule foobar
|
|||||||
ROOT %add = f32[] add(%x, %y)
|
ROOT %add = f32[] add(%x, %y)
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
|
ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||||
%p = f32[2,2] parameter(0)
|
%p = f32[] parameter(0)
|
||||||
%constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}})
|
%constant.bf16 = bf16[] constant(1)
|
||||||
%constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
|
%constant.f32.1 = f32[] constant(2)
|
||||||
%constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}})
|
%constant.f32.2 = f32[] constant(3)
|
||||||
|
|
||||||
%cross-replica-sum.ar.1 = bf16[2,2]
|
%cross-replica-sum.ar.1 = bf16[]
|
||||||
cross-replica-sum(%constant.bf16),
|
cross-replica-sum(%constant.bf16),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
all_reduce_id=1,
|
||||||
to_apply=%binary_add,
|
to_apply=%sum.bf16,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%convert.1 = f32[2,2]
|
%convert.1 = f32[]
|
||||||
convert(%cross-replica-sum.ar.1),
|
convert(%cross-replica-sum.ar.1),
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%add.1 = f32[2,2]
|
%add.1 = f32[]
|
||||||
add(%constant.f32.1, %convert.1),
|
add(%constant.f32.1, %convert.1),
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%cross-replica-sum.1 = f32[2,2]
|
%cross-replica-sum.1 = f32[]
|
||||||
cross-replica-sum(%add.1),
|
cross-replica-sum(%add.1),
|
||||||
replica_groups={{0,1}},
|
replica_groups={{0,1}},
|
||||||
to_apply=%sum.f32,
|
to_apply=%sum.f32,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
|
|
||||||
%cross-replica-sum.ar.2 = bf16[2,2]
|
%cross-replica-sum.ar.2 = bf16[]
|
||||||
cross-replica-sum(%constant.bf16),
|
cross-replica-sum(%constant.bf16),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
all_reduce_id=1,
|
||||||
to_apply=%binary_add,
|
to_apply=%sum.bf16,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%convert.2 = f32[2,2]
|
%convert.2 = f32[]
|
||||||
convert(%cross-replica-sum.ar.2),
|
convert(%cross-replica-sum.ar.2),
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%add.2 = f32[2,2]
|
%add.2 = f32[]
|
||||||
add(%constant.f32.2, %convert.2),
|
add(%constant.f32.2, %convert.2),
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%cross-replica-sum.2 = f32[2,2]
|
%cross-replica-sum.2 = f32[]
|
||||||
cross-replica-sum(%add.2),
|
cross-replica-sum(%add.2),
|
||||||
replica_groups={{0,1}},
|
replica_groups={{0,1}},
|
||||||
to_apply=%sum.f32,
|
to_apply=%sum.f32,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
|
|
||||||
ROOT %tuple = (f32[2,2], f32[2,2])
|
ROOT %tuple = (f32[], f32[])
|
||||||
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
|
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
|
||||||
sharding={{maximal device=0}, {maximal device=1}}
|
sharding={{maximal device=0}, {maximal device=1}}
|
||||||
}
|
}
|
||||||
|
@ -2060,6 +2060,10 @@ bool HloInstruction::IsCrossModuleAllReduce() const {
|
|||||||
return opcode() == HloOpcode::kCrossReplicaSum && all_reduce_id();
|
return opcode() == HloOpcode::kCrossReplicaSum && all_reduce_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HloInstruction::IsCrossReplicaAllReduce() const {
|
||||||
|
return opcode() == HloOpcode::kCrossReplicaSum && !all_reduce_id();
|
||||||
|
}
|
||||||
|
|
||||||
string HloInstruction::ToStringWithCanonicalNameMap(
|
string HloInstruction::ToStringWithCanonicalNameMap(
|
||||||
const HloPrintOptions& options,
|
const HloPrintOptions& options,
|
||||||
CanonicalNameMap* canonical_name_map) const {
|
CanonicalNameMap* canonical_name_map) const {
|
||||||
|
@ -1174,9 +1174,12 @@ class HloInstruction {
|
|||||||
// Returns true if this instruction is elementwise on all its operands.
|
// Returns true if this instruction is elementwise on all its operands.
|
||||||
bool IsElementwise() const;
|
bool IsElementwise() const;
|
||||||
|
|
||||||
// Returns true if this is an cross module all-reduce instrucion.
|
// Returns true if this is a cross module all-reduce instruction.
|
||||||
bool IsCrossModuleAllReduce() const;
|
bool IsCrossModuleAllReduce() const;
|
||||||
|
|
||||||
|
// Returns true if this is a cross-replica all-reduce instruction.
|
||||||
|
bool IsCrossReplicaAllReduce() const;
|
||||||
|
|
||||||
// Returns true if this elementwise instruction implicitly broadcasts operand
|
// Returns true if this elementwise instruction implicitly broadcasts operand
|
||||||
// `operand_idx`.
|
// `operand_idx`.
|
||||||
//
|
//
|
||||||
|
Loading…
Reference in New Issue
Block a user