[XLA] Make the canonicalize all-gather pass add a new/exclusive all-gather dimension.

Make the canonicalizer always add a completely new all gather dimension to the shape
of the all-gather operand to help layout assignment.

PiperOrigin-RevId: 360612458
Change-Id: I8299301219f704f14661f6fb2619ef604bb4d7c4
This commit is contained in:
Marcello Maggioni 2021-03-03 01:01:54 -08:00 committed by TensorFlower Gardener
parent 8e5ab11c65
commit 71fd1ab1d5
4 changed files with 122 additions and 10 deletions

View File

@ -71,9 +71,19 @@ StatusOr<bool> CanonicalizeAllGatherForCSE::RunOnComputation(
// adding the dimension the all-gather is operating on then perform the
// canonicalization.
if (real_data != ag->operand(0)) {
std::vector<int64> new_dimensions(real_data->shape().dimensions().begin(),
real_data->shape().dimensions().end());
new_dimensions[0] *= all_gather_participants;
std::vector<int64> new_dimensions;
new_dimensions.reserve(real_data->shape().dimensions_size() + 1);
new_dimensions.push_back(1);
new_dimensions.insert(new_dimensions.end(),
real_data->shape().dimensions().begin(),
real_data->shape().dimensions().end());
// Adding specialized all-gather dimension.
HloInstruction* ag_input =
comp->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(real_data->shape().element_type(),
new_dimensions),
real_data));
new_dimensions[0] = all_gather_participants;
absl::optional<int64> new_channel_id =
ag->channel_id() ? absl::make_optional(this->NextChannelId())
: absl::nullopt;
@ -81,7 +91,7 @@ StatusOr<bool> CanonicalizeAllGatherForCSE::RunOnComputation(
comp->AddInstruction(HloInstruction::CreateAllGather(
ShapeUtil::MakeShape(real_data->shape().element_type(),
new_dimensions),
real_data, /*all_gather_dimension=*/0, ag->replica_groups(),
ag_input, /*all_gather_dimension=*/0, ag->replica_groups(),
ag->constrain_layout(), new_channel_id,
ag->use_global_device_ids()));
HloInstruction* new_formatting = comp->AddInstruction(

View File

@ -85,8 +85,9 @@ ENTRY entry {
auto module = module_status.ConsumeValueOrDie();
const HloInstruction* const reshape =
module->entry_computation()->root_instruction();
EXPECT_THAT(reshape,
AllOf(op::Reshape(op::AllGather(_)), op::Shape("s32[2,8,1,1]")));
EXPECT_THAT(reshape, AllOf(op::Reshape(op::AllGather(
AllOf(op::Reshape(_), op::Shape("s32[1,8]")))),
op::Shape("s32[2,8,1,1]")));
}
TEST_F(AllGatherCanonicalizeTest, MultipleDegenerateReshapes2) {
@ -105,8 +106,9 @@ ENTRY entry {
auto module = module_status.ConsumeValueOrDie();
const HloInstruction* const reshape =
module->entry_computation()->root_instruction();
EXPECT_THAT(reshape,
AllOf(op::Reshape(op::AllGather(_)), op::Shape("s32[2,8,1,1]")));
EXPECT_THAT(reshape, AllOf(op::Reshape(op::AllGather(
AllOf(op::Reshape(_), op::Shape("s32[1,8]")))),
op::Shape("s32[2,8,1,1]")));
}
TEST_F(AllGatherCanonicalizeTest, MultipleDegenerateReshapesNoDim0) {

View File

@ -25,6 +25,30 @@ limitations under the License.
namespace xla {
namespace {
// Returns if an instructions adds only degenerate dimensions to the shape of
// the input, like going from [X,Y] to [1,X,Y,1].
bool IsAddingOnlyDegenerateDimensions(const HloInstruction* inst) {
if (inst->opcode() != HloOpcode::kBitcast &&
inst->opcode() != HloOpcode::kReshape) {
return false;
}
const Shape& in_shape = inst->operand(0)->shape();
const Shape& out_shape = inst->shape();
return ShapeUtil::ElementsIn(in_shape) == ShapeUtil::ElementsIn(out_shape) &&
ShapeUtil::DimensionsUnmodifiedByReshape(in_shape, out_shape).size() ==
in_shape.rank();
}
// Passthrough reshapes or bitcasts adding only degenerate hdimensions to some
// shape.
const HloInstruction* PassthroughDegenerateAddingReshapes(
const HloInstruction* inst) {
while (IsAddingOnlyDegenerateDimensions(inst)) {
inst = inst->operand(0);
}
return inst;
}
HloCollectiveInstruction* MayConsiderAsAllGather(HloInstruction* hlo,
bool for_replicas) {
auto coll = DynCast<HloCollectiveInstruction>(hlo);
@ -85,16 +109,23 @@ StatusOr<bool> RunOnComputation(HloComputation* comp, bool for_replicas,
if (!ag) {
continue;
}
auto& earlier_ags = operand_to_ag[ag->operand(0)];
auto& earlier_ags =
operand_to_ag[PassthroughDegenerateAddingReshapes(ag->operand(0))];
bool found = false;
int64 ag_height = height[ag];
for (auto& eag : earlier_ags) {
if (!ShapeUtil::Equal(eag->shape(), ag->shape())) {
continue;
}
HloInstruction* ag_operand = ag->mutable_operand(0);
TF_RETURN_IF_ERROR(ag->ReplaceOperandWith(0, eag->mutable_operand(0)));
if (!eag->IdenticalIgnoringChannelIdValues(*ag)) {
TF_RETURN_IF_ERROR(ag->ReplaceOperandWith(0, ag_operand));
continue;
}
found = true;
if (lowest_user_height(eag) > ag_height + distance_threshold) {
TF_RETURN_IF_ERROR(ag->ReplaceOperandWith(0, ag_operand));
eag = ag;
continue;
}

View File

@ -63,6 +63,52 @@ ENTRY entry {
EXPECT_EQ(tuple->operand(0), tuple->operand(1));
}
TEST_F(AllGatherCseTest, SimpleCseReshapeLookthrough) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY entry {
param0 = s32[8]{0} parameter(0)
rshp = s32[1,8]{1,0} reshape(param0)
rshp2 = s32[1,8]{1,0} reshape(param0)
ag1 = s32[2,8]{1,0} all-gather(rshp), replica_groups={{0,1}}, dimensions={0},
channel_id=0, use_global_device_ids=true
ag2 = s32[2,8]{1,0} all-gather(rshp2), replica_groups={{0,1}}, dimensions={0},
channel_id=1, use_global_device_ids=true
ROOT tuple = (s32[2,8]{1,0}, s32[2,8]{1,0}) tuple(ag1, ag2)
})";
auto module_status = RunPass(hlo_string);
EXPECT_TRUE(module_status.status().ok());
auto module = module_status.ConsumeValueOrDie();
HloInstruction* tuple = module->entry_computation()->root_instruction();
EXPECT_EQ(tuple->opcode(), HloOpcode::kTuple);
EXPECT_EQ(tuple->operand_count(), 2);
EXPECT_EQ(tuple->operand(0), tuple->operand(1));
}
TEST_F(AllGatherCseTest, SimpleNoCseInvalidReshapes) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY entry {
param0 = s32[8]{0} parameter(0)
rshp = s32[2,4]{1,0} reshape(param0)
rshp2 = s32[2,4]{1,0} reshape(param0)
ag1 = s32[4,4]{1,0} all-gather(rshp), replica_groups={{0,1}}, dimensions={0},
channel_id=0, use_global_device_ids=true
ag2 = s32[4,4]{1,0} all-gather(rshp2), replica_groups={{0,1}}, dimensions={0},
channel_id=1, use_global_device_ids=true
ROOT tuple = (s32[4,4]{1,0}, s32[4,4]{1,0}) tuple(ag1, ag2)
})";
auto module_status = RunPass(hlo_string);
EXPECT_TRUE(module_status.status().ok());
auto module = module_status.ConsumeValueOrDie();
HloInstruction* tuple = module->entry_computation()->root_instruction();
EXPECT_EQ(tuple->opcode(), HloOpcode::kTuple);
EXPECT_EQ(tuple->operand_count(), 2);
EXPECT_NE(tuple->operand(0), tuple->operand(1));
}
TEST_F(AllGatherCseTest, SimpleCseDifferentDim) {
absl::string_view hlo_string = R"(
HloModule module
@ -84,6 +130,29 @@ ENTRY entry {
EXPECT_EQ(tuple->operand(0), tuple->operand(1));
}
TEST_F(AllGatherCseTest, SimpleCseDifferentDimReshapeLookthrough) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY entry {
param0 = s32[8]{0} parameter(0)
rshp = s32[1,8]{1,0} reshape(param0)
rshp2 = s32[1,8]{1,0} reshape(param0)
ag1 = s32[1,16]{1,0} all-gather(rshp), replica_groups={{0,1}}, dimensions={1},
channel_id=0, use_global_device_ids=true
ag2 = s32[1,16]{1,0} all-gather(rshp2), replica_groups={{0,1}},
dimensions={1}, channel_id=1, use_global_device_ids=true
ROOT tuple = (s32[1,16]{1,0}, s32[2,8,1,1]{3,2,1,0}) tuple(ag1, ag2)
})";
auto module_status = RunPass(hlo_string);
EXPECT_TRUE(module_status.status().ok());
auto module = module_status.ConsumeValueOrDie();
HloInstruction* tuple = module->entry_computation()->root_instruction();
EXPECT_EQ(tuple->opcode(), HloOpcode::kTuple);
EXPECT_EQ(tuple->operand_count(), 2);
EXPECT_EQ(tuple->operand(0), tuple->operand(1));
}
TEST_F(AllGatherCseTest, NoCseGlobalDevice) {
absl::string_view hlo_string = R"(
HloModule module