[XLA] Respect layout of outfeed when SPMD partioning.
PiperOrigin-RevId: 351281979 Change-Id: I56af3d07ed4a2181de31b924a40d7cef13a360d7
This commit is contained in:
parent
4517ba2893
commit
741b16bd0f
@ -3329,9 +3329,12 @@ Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
|
||||
auto token = GetPartitionedHlo(hlo->operand(1)).hlo();
|
||||
|
||||
if (EvenlyPartitions(shape, sharding)) {
|
||||
Shape outfeed_shape = operand->shape();
|
||||
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(hlo->outfeed_shape(),
|
||||
&outfeed_shape));
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
return b_.AddInstruction(HloInstruction::CreateOutfeed(
|
||||
operand->shape(), operand, token, hlo->outfeed_config()));
|
||||
outfeed_shape, operand, token, hlo->outfeed_config()));
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
@ -3440,6 +3443,8 @@ Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
|
||||
};
|
||||
outfeed_data = slice_outfeed({}, outfeed_data);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
|
||||
hlo->outfeed_shape(), &per_branch_partitioned_shapes[i]));
|
||||
branch_b.AddInstruction(HloInstruction::CreateOutfeed(
|
||||
per_branch_partitioned_shapes[i], outfeed_data, outfeed_token,
|
||||
hlo->outfeed_config()));
|
||||
|
@ -708,7 +708,8 @@ ENTRY entry {
|
||||
token.0 = token[] after-all()
|
||||
data = (f32[1024,2]{1,0}, f32[2]{0}) parameter(0), sharding={{devices=[2,1]0,1},
|
||||
{devices=[2]0,1}}
|
||||
ROOT outfeed = token[] outfeed(data, token.0), sharding={{devices=[2,1]0,1},
|
||||
ROOT outfeed = token[] outfeed(data, token.0),
|
||||
outfeed_shape=(f32[1024,2]{0,1}, f32[2]{0}), sharding={{devices=[2,1]0,1},
|
||||
{devices=[2]0,1}}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
@ -717,6 +718,12 @@ ENTRY entry {
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Shape("token[]"),
|
||||
op::Outfeed(op::Parameter(), op::AfterAll())));
|
||||
auto expected_layout0 = LayoutUtil::MakeLayout({0, 1});
|
||||
auto expected_layout1 = LayoutUtil::MakeLayout({0});
|
||||
EXPECT_TRUE(LayoutUtil::Equal(root->outfeed_shape().tuple_shapes(0).layout(),
|
||||
expected_layout0));
|
||||
EXPECT_TRUE(LayoutUtil::Equal(root->outfeed_shape().tuple_shapes(1).layout(),
|
||||
expected_layout1));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, OutfeedReplicated) {
|
||||
@ -746,7 +753,8 @@ ENTRY entry {
|
||||
token.0 = token[] after-all()
|
||||
data = (f32[1023,2]{1,0}, f32[3]{0}) parameter(0), sharding={{devices=[2,1]0,1},
|
||||
{devices=[2]0,1}}
|
||||
outfeed = token[] outfeed(data, token.0), sharding={{devices=[2,1]0,1},
|
||||
outfeed = token[] outfeed(data, token.0),
|
||||
outfeed_shape=(f32[1023,2]{0,1}, f32[3]{0}), sharding={{devices=[2,1]0,1},
|
||||
{devices=[2]0,1}}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
@ -770,6 +778,24 @@ ENTRY entry {
|
||||
EXPECT_THAT(root->called_computations()[1]->root_instruction(),
|
||||
AllOf(op::Shape("token[]"),
|
||||
op::Outfeed(second_outfeed, op::GetTupleElement())));
|
||||
|
||||
auto expected_layout0 = LayoutUtil::MakeLayout({0, 1});
|
||||
auto expected_layout1 = LayoutUtil::MakeLayout({0});
|
||||
auto first_outfeed_instr = root->called_computations()[0]->root_instruction();
|
||||
auto second_outfeed_instr =
|
||||
root->called_computations()[1]->root_instruction();
|
||||
EXPECT_TRUE(LayoutUtil::Equal(
|
||||
first_outfeed_instr->outfeed_shape().tuple_shapes(0).layout(),
|
||||
expected_layout0));
|
||||
EXPECT_TRUE(LayoutUtil::Equal(
|
||||
first_outfeed_instr->outfeed_shape().tuple_shapes(1).layout(),
|
||||
expected_layout1));
|
||||
EXPECT_TRUE(LayoutUtil::Equal(
|
||||
second_outfeed_instr->outfeed_shape().tuple_shapes(0).layout(),
|
||||
expected_layout0));
|
||||
EXPECT_TRUE(LayoutUtil::Equal(
|
||||
second_outfeed_instr->outfeed_shape().tuple_shapes(1).layout(),
|
||||
expected_layout1));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, ReduceWindowReplicatedInput) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user