[XLA] Respect layout of outfeed when SPMD partioning.

PiperOrigin-RevId: 351281979
Change-Id: I56af3d07ed4a2181de31b924a40d7cef13a360d7
This commit is contained in:
Marcello Maggioni 2021-01-11 18:57:48 -08:00 committed by TensorFlower Gardener
parent 4517ba2893
commit 741b16bd0f
2 changed files with 34 additions and 3 deletions

View File

@ -3329,9 +3329,12 @@ Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
auto token = GetPartitionedHlo(hlo->operand(1)).hlo();
if (EvenlyPartitions(shape, sharding)) {
Shape outfeed_shape = operand->shape();
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(hlo->outfeed_shape(),
&outfeed_shape));
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateOutfeed(
operand->shape(), operand, token, hlo->outfeed_config()));
outfeed_shape, operand, token, hlo->outfeed_config()));
});
return Status::OK();
}
@ -3440,6 +3443,8 @@ Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
};
outfeed_data = slice_outfeed({}, outfeed_data);
}
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
hlo->outfeed_shape(), &per_branch_partitioned_shapes[i]));
branch_b.AddInstruction(HloInstruction::CreateOutfeed(
per_branch_partitioned_shapes[i], outfeed_data, outfeed_token,
hlo->outfeed_config()));

View File

@ -708,7 +708,8 @@ ENTRY entry {
token.0 = token[] after-all()
data = (f32[1024,2]{1,0}, f32[2]{0}) parameter(0), sharding={{devices=[2,1]0,1},
{devices=[2]0,1}}
ROOT outfeed = token[] outfeed(data, token.0), sharding={{devices=[2,1]0,1},
ROOT outfeed = token[] outfeed(data, token.0),
outfeed_shape=(f32[1024,2]{0,1}, f32[2]{0}), sharding={{devices=[2,1]0,1},
{devices=[2]0,1}}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
@ -717,6 +718,12 @@ ENTRY entry {
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("token[]"),
op::Outfeed(op::Parameter(), op::AfterAll())));
auto expected_layout0 = LayoutUtil::MakeLayout({0, 1});
auto expected_layout1 = LayoutUtil::MakeLayout({0});
EXPECT_TRUE(LayoutUtil::Equal(root->outfeed_shape().tuple_shapes(0).layout(),
expected_layout0));
EXPECT_TRUE(LayoutUtil::Equal(root->outfeed_shape().tuple_shapes(1).layout(),
expected_layout1));
}
TEST_F(SpmdPartitioningTest, OutfeedReplicated) {
@ -746,7 +753,8 @@ ENTRY entry {
token.0 = token[] after-all()
data = (f32[1023,2]{1,0}, f32[3]{0}) parameter(0), sharding={{devices=[2,1]0,1},
{devices=[2]0,1}}
outfeed = token[] outfeed(data, token.0), sharding={{devices=[2,1]0,1},
outfeed = token[] outfeed(data, token.0),
outfeed_shape=(f32[1023,2]{0,1}, f32[3]{0}), sharding={{devices=[2,1]0,1},
{devices=[2]0,1}}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
@ -770,6 +778,24 @@ ENTRY entry {
EXPECT_THAT(root->called_computations()[1]->root_instruction(),
AllOf(op::Shape("token[]"),
op::Outfeed(second_outfeed, op::GetTupleElement())));
auto expected_layout0 = LayoutUtil::MakeLayout({0, 1});
auto expected_layout1 = LayoutUtil::MakeLayout({0});
auto first_outfeed_instr = root->called_computations()[0]->root_instruction();
auto second_outfeed_instr =
root->called_computations()[1]->root_instruction();
EXPECT_TRUE(LayoutUtil::Equal(
first_outfeed_instr->outfeed_shape().tuple_shapes(0).layout(),
expected_layout0));
EXPECT_TRUE(LayoutUtil::Equal(
first_outfeed_instr->outfeed_shape().tuple_shapes(1).layout(),
expected_layout1));
EXPECT_TRUE(LayoutUtil::Equal(
second_outfeed_instr->outfeed_shape().tuple_shapes(0).layout(),
expected_layout0));
EXPECT_TRUE(LayoutUtil::Equal(
second_outfeed_instr->outfeed_shape().tuple_shapes(1).layout(),
expected_layout1));
}
TEST_F(SpmdPartitioningTest, ReduceWindowReplicatedInput) {