[XLA:SPMD] Detect shardable reshape in the other direction
PiperOrigin-RevId: 359835067 Change-Id: I9153dd2691e6539b836e63516a4649eb608eaf7c
This commit is contained in:
parent
e10f611ee2
commit
03e82caf2e
@ -1975,8 +1975,9 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
|
|||||||
auto operand = GetPartitionedHlo(hlo->operand(0));
|
auto operand = GetPartitionedHlo(hlo->operand(0));
|
||||||
// The output shape is the source and the operand shape is the target to get
|
// The output shape is the source and the operand shape is the target to get
|
||||||
// the aligned sharding for the operand.
|
// the aligned sharding for the operand.
|
||||||
auto desired_operand_sharding = hlo_sharding_util::ReshapeSharding(
|
absl::optional<HloSharding> desired_operand_sharding =
|
||||||
hlo->shape(), hlo->operand(0)->shape(), hlo->sharding());
|
hlo_sharding_util::ReshapeSharding(hlo->shape(), hlo->operand(0)->shape(),
|
||||||
|
hlo->sharding());
|
||||||
if (desired_operand_sharding.has_value()) {
|
if (desired_operand_sharding.has_value()) {
|
||||||
auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo();
|
auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo();
|
||||||
SetPartitionedHlo(hlo, [&] {
|
SetPartitionedHlo(hlo, [&] {
|
||||||
@ -1985,6 +1986,21 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
|
|||||||
});
|
});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
absl::optional<HloSharding> desired_output_sharding =
|
||||||
|
hlo_sharding_util::ReshapeSharding(hlo->operand(0)->shape(), hlo->shape(),
|
||||||
|
operand.sharding());
|
||||||
|
if (desired_output_sharding.has_value()) {
|
||||||
|
auto reshape = b_.AddInstruction(hlo->CloneWithNewOperands(
|
||||||
|
MakePartitionedShape(hlo->shape(), *desired_output_sharding),
|
||||||
|
{operand.hlo()}));
|
||||||
|
reshape->set_sharding(*desired_output_sharding);
|
||||||
|
SetPartitionedHlo(hlo, [&] {
|
||||||
|
return PartitionedHlo(reshape, hlo->shape(), MakePartitioningState())
|
||||||
|
.Reshard(sharding)
|
||||||
|
.hlo();
|
||||||
|
});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Check if operand sharding and sharding are both tiled or partial replicate.
|
// Check if operand sharding and sharding are both tiled or partial replicate.
|
||||||
// If both of them are partial replicate, check num_replications are the same.
|
// If both of them are partial replicate, check num_replications are the same.
|
||||||
|
@ -2925,6 +2925,49 @@ ENTRY entry {
|
|||||||
EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
|
EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(SpmdPartitioningTest, ReshapeWithReshard) {
|
||||||
|
absl::string_view hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
%param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1}
|
||||||
|
ROOT %reshape = f32[38,38,4,81] reshape(%param0),
|
||||||
|
sharding={devices=[1,2,1,1]0,1}
|
||||||
|
})";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||||
|
VLOG(1) << module->ToString();
|
||||||
|
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
auto input_reshard =
|
||||||
|
op::Reshape(op::Transpose(op::AllToAll(op::Reshape(op::Parameter(0)))));
|
||||||
|
EXPECT_THAT(root,
|
||||||
|
AllOf(op::Reshape(input_reshard), op::Shape("f32[38,19,4,81]")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SpmdPartitioningTest, ReshapeWithReshard2) {
|
||||||
|
absl::string_view hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
%param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1}
|
||||||
|
ROOT %reshape = f32[38,38,2,162] reshape(%param0),
|
||||||
|
sharding={devices=[1,1,1,2]0,1}
|
||||||
|
})";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||||
|
VLOG(1) << module->ToString();
|
||||||
|
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
auto local_reshape =
|
||||||
|
AllOf(op::Reshape(op::Parameter(0)), op::Shape("f32[19,38,2,162]"));
|
||||||
|
EXPECT_THAT(root, AllOf(op::Shape("f32[38,38,2,81]"),
|
||||||
|
op::Reshape(op::Transpose(
|
||||||
|
op::AllToAll(op::Reshape(local_reshape))))));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(SpmdPartitioningTest, PartialReplicateShardableReshape) {
|
TEST_F(SpmdPartitioningTest, PartialReplicateShardableReshape) {
|
||||||
absl::string_view hlo_string = R"(
|
absl::string_view hlo_string = R"(
|
||||||
HloModule module
|
HloModule module
|
||||||
@ -2949,35 +2992,6 @@ ENTRY entry {
|
|||||||
EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
|
EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SpmdPartitioningTest, NonShardableReshape) {
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule module
|
|
||||||
|
|
||||||
ENTRY entry {
|
|
||||||
%param0 = f32[38,38,324] parameter(0)
|
|
||||||
%param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[1,1,2]0,1}
|
|
||||||
ROOT %transpose = f32[38,38,4,81] reshape(%param0.copy),
|
|
||||||
sharding={devices=[1,1,1,2]0,1}
|
|
||||||
})";
|
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
|
||||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
|
||||||
VLOG(1) << module->ToString();
|
|
||||||
|
|
||||||
auto root = module->entry_computation()->root_instruction();
|
|
||||||
EXPECT_THAT(
|
|
||||||
root,
|
|
||||||
AllOf(op::DynamicSlice(
|
|
||||||
AllOf(op::Pad(
|
|
||||||
AllOf(op::Reshape(AllOf(op::AllReduce(),
|
|
||||||
op::Shape("f32[38,38,324]"))),
|
|
||||||
op::Shape("f32[38,38,4,81]")),
|
|
||||||
op::Constant()),
|
|
||||||
op::Shape("f32[38,38,4,82]")),
|
|
||||||
op::Constant(), op::Constant(), op::Constant(), op::Reshape()),
|
|
||||||
op::Shape("f32[38,38,4,41]")));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(SpmdPartitioningTest, ReshapeMergeDimsWithHaloExchange) {
|
TEST_F(SpmdPartitioningTest, ReshapeMergeDimsWithHaloExchange) {
|
||||||
absl::string_view hlo_string = R"(
|
absl::string_view hlo_string = R"(
|
||||||
HloModule module
|
HloModule module
|
||||||
|
Loading…
x
Reference in New Issue
Block a user