[XLA:SPMD] Detect shardable reshape in the other direction

PiperOrigin-RevId: 359835067
Change-Id: I9153dd2691e6539b836e63516a4649eb608eaf7c
This commit is contained in:
Yuanzhong Xu 2021-02-26 14:06:43 -08:00 committed by TensorFlower Gardener
parent e10f611ee2
commit 03e82caf2e
2 changed files with 61 additions and 31 deletions

View File

@ -1975,8 +1975,9 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
auto operand = GetPartitionedHlo(hlo->operand(0));
// The output shape is the source and the operand shape is the target to get
// the aligned sharding for the operand.
auto desired_operand_sharding = hlo_sharding_util::ReshapeSharding(
hlo->shape(), hlo->operand(0)->shape(), hlo->sharding());
absl::optional<HloSharding> desired_operand_sharding =
hlo_sharding_util::ReshapeSharding(hlo->shape(), hlo->operand(0)->shape(),
hlo->sharding());
if (desired_operand_sharding.has_value()) {
auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo();
SetPartitionedHlo(hlo, [&] {
@ -1985,6 +1986,21 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
});
return Status::OK();
}
absl::optional<HloSharding> desired_output_sharding =
hlo_sharding_util::ReshapeSharding(hlo->operand(0)->shape(), hlo->shape(),
operand.sharding());
if (desired_output_sharding.has_value()) {
auto reshape = b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), *desired_output_sharding),
{operand.hlo()}));
reshape->set_sharding(*desired_output_sharding);
SetPartitionedHlo(hlo, [&] {
return PartitionedHlo(reshape, hlo->shape(), MakePartitioningState())
.Reshard(sharding)
.hlo();
});
return Status::OK();
}
// Check if operand sharding and sharding are both tiled or partial replicate.
// If both of them are partial replicate, check num_replications are the same.

View File

@ -2925,6 +2925,49 @@ ENTRY entry {
EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
}
TEST_F(SpmdPartitioningTest, ReshapeWithReshard) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1}
ROOT %reshape = f32[38,38,4,81] reshape(%param0),
sharding={devices=[1,2,1,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto input_reshard =
op::Reshape(op::Transpose(op::AllToAll(op::Reshape(op::Parameter(0)))));
EXPECT_THAT(root,
AllOf(op::Reshape(input_reshard), op::Shape("f32[38,19,4,81]")));
}
TEST_F(SpmdPartitioningTest, ReshapeWithReshard2) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1}
ROOT %reshape = f32[38,38,2,162] reshape(%param0),
sharding={devices=[1,1,1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto local_reshape =
AllOf(op::Reshape(op::Parameter(0)), op::Shape("f32[19,38,2,162]"));
EXPECT_THAT(root, AllOf(op::Shape("f32[38,38,2,81]"),
op::Reshape(op::Transpose(
op::AllToAll(op::Reshape(local_reshape))))));
}
TEST_F(SpmdPartitioningTest, PartialReplicateShardableReshape) {
absl::string_view hlo_string = R"(
HloModule module
@ -2949,35 +2992,6 @@ ENTRY entry {
EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
}
TEST_F(SpmdPartitioningTest, NonShardableReshape) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[38,38,324] parameter(0)
%param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[1,1,2]0,1}
ROOT %transpose = f32[38,38,4,81] reshape(%param0.copy),
sharding={devices=[1,1,1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::DynamicSlice(
AllOf(op::Pad(
AllOf(op::Reshape(AllOf(op::AllReduce(),
op::Shape("f32[38,38,324]"))),
op::Shape("f32[38,38,4,81]")),
op::Constant()),
op::Shape("f32[38,38,4,82]")),
op::Constant(), op::Constant(), op::Constant(), op::Reshape()),
op::Shape("f32[38,38,4,41]")));
}
TEST_F(SpmdPartitioningTest, ReshapeMergeDimsWithHaloExchange) {
absl::string_view hlo_string = R"(
HloModule module