[XLA] Make the canonicalize all-gather pass add a new/exclusive all-gather dimension.
Make the canonicalizer always add a completely new all gather dimension to the shape of the all-gather operand to help layout assignment. PiperOrigin-RevId: 360612458 Change-Id: I8299301219f704f14661f6fb2619ef604bb4d7c4
This commit is contained in:
parent
8e5ab11c65
commit
71fd1ab1d5
@ -71,9 +71,19 @@ StatusOr<bool> CanonicalizeAllGatherForCSE::RunOnComputation(
|
||||
// adding the dimension the all-gather is operating on then perform the
|
||||
// canonicalization.
|
||||
if (real_data != ag->operand(0)) {
|
||||
std::vector<int64> new_dimensions(real_data->shape().dimensions().begin(),
|
||||
real_data->shape().dimensions().end());
|
||||
new_dimensions[0] *= all_gather_participants;
|
||||
std::vector<int64> new_dimensions;
|
||||
new_dimensions.reserve(real_data->shape().dimensions_size() + 1);
|
||||
new_dimensions.push_back(1);
|
||||
new_dimensions.insert(new_dimensions.end(),
|
||||
real_data->shape().dimensions().begin(),
|
||||
real_data->shape().dimensions().end());
|
||||
// Adding specialized all-gather dimension.
|
||||
HloInstruction* ag_input =
|
||||
comp->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(real_data->shape().element_type(),
|
||||
new_dimensions),
|
||||
real_data));
|
||||
new_dimensions[0] = all_gather_participants;
|
||||
absl::optional<int64> new_channel_id =
|
||||
ag->channel_id() ? absl::make_optional(this->NextChannelId())
|
||||
: absl::nullopt;
|
||||
@ -81,7 +91,7 @@ StatusOr<bool> CanonicalizeAllGatherForCSE::RunOnComputation(
|
||||
comp->AddInstruction(HloInstruction::CreateAllGather(
|
||||
ShapeUtil::MakeShape(real_data->shape().element_type(),
|
||||
new_dimensions),
|
||||
real_data, /*all_gather_dimension=*/0, ag->replica_groups(),
|
||||
ag_input, /*all_gather_dimension=*/0, ag->replica_groups(),
|
||||
ag->constrain_layout(), new_channel_id,
|
||||
ag->use_global_device_ids()));
|
||||
HloInstruction* new_formatting = comp->AddInstruction(
|
||||
|
@ -85,8 +85,9 @@ ENTRY entry {
|
||||
auto module = module_status.ConsumeValueOrDie();
|
||||
const HloInstruction* const reshape =
|
||||
module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(reshape,
|
||||
AllOf(op::Reshape(op::AllGather(_)), op::Shape("s32[2,8,1,1]")));
|
||||
EXPECT_THAT(reshape, AllOf(op::Reshape(op::AllGather(
|
||||
AllOf(op::Reshape(_), op::Shape("s32[1,8]")))),
|
||||
op::Shape("s32[2,8,1,1]")));
|
||||
}
|
||||
|
||||
TEST_F(AllGatherCanonicalizeTest, MultipleDegenerateReshapes2) {
|
||||
@ -105,8 +106,9 @@ ENTRY entry {
|
||||
auto module = module_status.ConsumeValueOrDie();
|
||||
const HloInstruction* const reshape =
|
||||
module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(reshape,
|
||||
AllOf(op::Reshape(op::AllGather(_)), op::Shape("s32[2,8,1,1]")));
|
||||
EXPECT_THAT(reshape, AllOf(op::Reshape(op::AllGather(
|
||||
AllOf(op::Reshape(_), op::Shape("s32[1,8]")))),
|
||||
op::Shape("s32[2,8,1,1]")));
|
||||
}
|
||||
|
||||
TEST_F(AllGatherCanonicalizeTest, MultipleDegenerateReshapesNoDim0) {
|
||||
|
@ -25,6 +25,30 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
// Returns if an instructions adds only degenerate dimensions to the shape of
|
||||
// the input, like going from [X,Y] to [1,X,Y,1].
|
||||
bool IsAddingOnlyDegenerateDimensions(const HloInstruction* inst) {
|
||||
if (inst->opcode() != HloOpcode::kBitcast &&
|
||||
inst->opcode() != HloOpcode::kReshape) {
|
||||
return false;
|
||||
}
|
||||
const Shape& in_shape = inst->operand(0)->shape();
|
||||
const Shape& out_shape = inst->shape();
|
||||
return ShapeUtil::ElementsIn(in_shape) == ShapeUtil::ElementsIn(out_shape) &&
|
||||
ShapeUtil::DimensionsUnmodifiedByReshape(in_shape, out_shape).size() ==
|
||||
in_shape.rank();
|
||||
}
|
||||
|
||||
// Passthrough reshapes or bitcasts adding only degenerate hdimensions to some
|
||||
// shape.
|
||||
const HloInstruction* PassthroughDegenerateAddingReshapes(
|
||||
const HloInstruction* inst) {
|
||||
while (IsAddingOnlyDegenerateDimensions(inst)) {
|
||||
inst = inst->operand(0);
|
||||
}
|
||||
return inst;
|
||||
}
|
||||
|
||||
HloCollectiveInstruction* MayConsiderAsAllGather(HloInstruction* hlo,
|
||||
bool for_replicas) {
|
||||
auto coll = DynCast<HloCollectiveInstruction>(hlo);
|
||||
@ -85,16 +109,23 @@ StatusOr<bool> RunOnComputation(HloComputation* comp, bool for_replicas,
|
||||
if (!ag) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& earlier_ags = operand_to_ag[ag->operand(0)];
|
||||
auto& earlier_ags =
|
||||
operand_to_ag[PassthroughDegenerateAddingReshapes(ag->operand(0))];
|
||||
bool found = false;
|
||||
int64 ag_height = height[ag];
|
||||
for (auto& eag : earlier_ags) {
|
||||
if (!ShapeUtil::Equal(eag->shape(), ag->shape())) {
|
||||
continue;
|
||||
}
|
||||
HloInstruction* ag_operand = ag->mutable_operand(0);
|
||||
TF_RETURN_IF_ERROR(ag->ReplaceOperandWith(0, eag->mutable_operand(0)));
|
||||
if (!eag->IdenticalIgnoringChannelIdValues(*ag)) {
|
||||
TF_RETURN_IF_ERROR(ag->ReplaceOperandWith(0, ag_operand));
|
||||
continue;
|
||||
}
|
||||
found = true;
|
||||
if (lowest_user_height(eag) > ag_height + distance_threshold) {
|
||||
TF_RETURN_IF_ERROR(ag->ReplaceOperandWith(0, ag_operand));
|
||||
eag = ag;
|
||||
continue;
|
||||
}
|
||||
|
@ -63,6 +63,52 @@ ENTRY entry {
|
||||
EXPECT_EQ(tuple->operand(0), tuple->operand(1));
|
||||
}
|
||||
|
||||
TEST_F(AllGatherCseTest, SimpleCseReshapeLookthrough) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
param0 = s32[8]{0} parameter(0)
|
||||
rshp = s32[1,8]{1,0} reshape(param0)
|
||||
rshp2 = s32[1,8]{1,0} reshape(param0)
|
||||
ag1 = s32[2,8]{1,0} all-gather(rshp), replica_groups={{0,1}}, dimensions={0},
|
||||
channel_id=0, use_global_device_ids=true
|
||||
ag2 = s32[2,8]{1,0} all-gather(rshp2), replica_groups={{0,1}}, dimensions={0},
|
||||
channel_id=1, use_global_device_ids=true
|
||||
ROOT tuple = (s32[2,8]{1,0}, s32[2,8]{1,0}) tuple(ag1, ag2)
|
||||
})";
|
||||
auto module_status = RunPass(hlo_string);
|
||||
EXPECT_TRUE(module_status.status().ok());
|
||||
auto module = module_status.ConsumeValueOrDie();
|
||||
HloInstruction* tuple = module->entry_computation()->root_instruction();
|
||||
EXPECT_EQ(tuple->opcode(), HloOpcode::kTuple);
|
||||
EXPECT_EQ(tuple->operand_count(), 2);
|
||||
EXPECT_EQ(tuple->operand(0), tuple->operand(1));
|
||||
}
|
||||
|
||||
TEST_F(AllGatherCseTest, SimpleNoCseInvalidReshapes) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
param0 = s32[8]{0} parameter(0)
|
||||
rshp = s32[2,4]{1,0} reshape(param0)
|
||||
rshp2 = s32[2,4]{1,0} reshape(param0)
|
||||
ag1 = s32[4,4]{1,0} all-gather(rshp), replica_groups={{0,1}}, dimensions={0},
|
||||
channel_id=0, use_global_device_ids=true
|
||||
ag2 = s32[4,4]{1,0} all-gather(rshp2), replica_groups={{0,1}}, dimensions={0},
|
||||
channel_id=1, use_global_device_ids=true
|
||||
ROOT tuple = (s32[4,4]{1,0}, s32[4,4]{1,0}) tuple(ag1, ag2)
|
||||
})";
|
||||
auto module_status = RunPass(hlo_string);
|
||||
EXPECT_TRUE(module_status.status().ok());
|
||||
auto module = module_status.ConsumeValueOrDie();
|
||||
HloInstruction* tuple = module->entry_computation()->root_instruction();
|
||||
EXPECT_EQ(tuple->opcode(), HloOpcode::kTuple);
|
||||
EXPECT_EQ(tuple->operand_count(), 2);
|
||||
EXPECT_NE(tuple->operand(0), tuple->operand(1));
|
||||
}
|
||||
|
||||
TEST_F(AllGatherCseTest, SimpleCseDifferentDim) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
@ -84,6 +130,29 @@ ENTRY entry {
|
||||
EXPECT_EQ(tuple->operand(0), tuple->operand(1));
|
||||
}
|
||||
|
||||
TEST_F(AllGatherCseTest, SimpleCseDifferentDimReshapeLookthrough) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
param0 = s32[8]{0} parameter(0)
|
||||
rshp = s32[1,8]{1,0} reshape(param0)
|
||||
rshp2 = s32[1,8]{1,0} reshape(param0)
|
||||
ag1 = s32[1,16]{1,0} all-gather(rshp), replica_groups={{0,1}}, dimensions={1},
|
||||
channel_id=0, use_global_device_ids=true
|
||||
ag2 = s32[1,16]{1,0} all-gather(rshp2), replica_groups={{0,1}},
|
||||
dimensions={1}, channel_id=1, use_global_device_ids=true
|
||||
ROOT tuple = (s32[1,16]{1,0}, s32[2,8,1,1]{3,2,1,0}) tuple(ag1, ag2)
|
||||
})";
|
||||
auto module_status = RunPass(hlo_string);
|
||||
EXPECT_TRUE(module_status.status().ok());
|
||||
auto module = module_status.ConsumeValueOrDie();
|
||||
HloInstruction* tuple = module->entry_computation()->root_instruction();
|
||||
EXPECT_EQ(tuple->opcode(), HloOpcode::kTuple);
|
||||
EXPECT_EQ(tuple->operand_count(), 2);
|
||||
EXPECT_EQ(tuple->operand(0), tuple->operand(1));
|
||||
}
|
||||
|
||||
TEST_F(AllGatherCseTest, NoCseGlobalDevice) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
Loading…
Reference in New Issue
Block a user