[XLA:SPMD] Detect shardable reshape in the other direction
PiperOrigin-RevId: 359835067 Change-Id: I9153dd2691e6539b836e63516a4649eb608eaf7c
This commit is contained in:
parent
e10f611ee2
commit
03e82caf2e
@ -1975,8 +1975,9 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
auto operand = GetPartitionedHlo(hlo->operand(0));
|
||||
// The output shape is the source and the operand shape is the target to get
|
||||
// the aligned sharding for the operand.
|
||||
auto desired_operand_sharding = hlo_sharding_util::ReshapeSharding(
|
||||
hlo->shape(), hlo->operand(0)->shape(), hlo->sharding());
|
||||
absl::optional<HloSharding> desired_operand_sharding =
|
||||
hlo_sharding_util::ReshapeSharding(hlo->shape(), hlo->operand(0)->shape(),
|
||||
hlo->sharding());
|
||||
if (desired_operand_sharding.has_value()) {
|
||||
auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo();
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
@ -1985,6 +1986,21 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
absl::optional<HloSharding> desired_output_sharding =
|
||||
hlo_sharding_util::ReshapeSharding(hlo->operand(0)->shape(), hlo->shape(),
|
||||
operand.sharding());
|
||||
if (desired_output_sharding.has_value()) {
|
||||
auto reshape = b_.AddInstruction(hlo->CloneWithNewOperands(
|
||||
MakePartitionedShape(hlo->shape(), *desired_output_sharding),
|
||||
{operand.hlo()}));
|
||||
reshape->set_sharding(*desired_output_sharding);
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
return PartitionedHlo(reshape, hlo->shape(), MakePartitioningState())
|
||||
.Reshard(sharding)
|
||||
.hlo();
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Check if operand sharding and sharding are both tiled or partial replicate.
|
||||
// If both of them are partial replicate, check num_replications are the same.
|
||||
|
@ -2925,6 +2925,49 @@ ENTRY entry {
|
||||
EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, ReshapeWithReshard) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1}
|
||||
ROOT %reshape = f32[38,38,4,81] reshape(%param0),
|
||||
sharding={devices=[1,2,1,1]0,1}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
auto input_reshard =
|
||||
op::Reshape(op::Transpose(op::AllToAll(op::Reshape(op::Parameter(0)))));
|
||||
EXPECT_THAT(root,
|
||||
AllOf(op::Reshape(input_reshard), op::Shape("f32[38,19,4,81]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, ReshapeWithReshard2) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1}
|
||||
ROOT %reshape = f32[38,38,2,162] reshape(%param0),
|
||||
sharding={devices=[1,1,1,2]0,1}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
auto local_reshape =
|
||||
AllOf(op::Reshape(op::Parameter(0)), op::Shape("f32[19,38,2,162]"));
|
||||
EXPECT_THAT(root, AllOf(op::Shape("f32[38,38,2,81]"),
|
||||
op::Reshape(op::Transpose(
|
||||
op::AllToAll(op::Reshape(local_reshape))))));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, PartialReplicateShardableReshape) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
@ -2949,35 +2992,6 @@ ENTRY entry {
|
||||
EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, NonShardableReshape) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%param0 = f32[38,38,324] parameter(0)
|
||||
%param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[1,1,2]0,1}
|
||||
ROOT %transpose = f32[38,38,4,81] reshape(%param0.copy),
|
||||
sharding={devices=[1,1,1,2]0,1}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(
|
||||
root,
|
||||
AllOf(op::DynamicSlice(
|
||||
AllOf(op::Pad(
|
||||
AllOf(op::Reshape(AllOf(op::AllReduce(),
|
||||
op::Shape("f32[38,38,324]"))),
|
||||
op::Shape("f32[38,38,4,81]")),
|
||||
op::Constant()),
|
||||
op::Shape("f32[38,38,4,82]")),
|
||||
op::Constant(), op::Constant(), op::Constant(), op::Reshape()),
|
||||
op::Shape("f32[38,38,4,41]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, ReshapeMergeDimsWithHaloExchange) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
Loading…
x
Reference in New Issue
Block a user