[XLA] Respect layout of outfeed when SPMD partioning.
PiperOrigin-RevId: 351281979 Change-Id: I56af3d07ed4a2181de31b924a40d7cef13a360d7
This commit is contained in:
parent
4517ba2893
commit
741b16bd0f
@ -3329,9 +3329,12 @@ Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
|
|||||||
auto token = GetPartitionedHlo(hlo->operand(1)).hlo();
|
auto token = GetPartitionedHlo(hlo->operand(1)).hlo();
|
||||||
|
|
||||||
if (EvenlyPartitions(shape, sharding)) {
|
if (EvenlyPartitions(shape, sharding)) {
|
||||||
|
Shape outfeed_shape = operand->shape();
|
||||||
|
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(hlo->outfeed_shape(),
|
||||||
|
&outfeed_shape));
|
||||||
SetPartitionedHlo(hlo, [&]() {
|
SetPartitionedHlo(hlo, [&]() {
|
||||||
return b_.AddInstruction(HloInstruction::CreateOutfeed(
|
return b_.AddInstruction(HloInstruction::CreateOutfeed(
|
||||||
operand->shape(), operand, token, hlo->outfeed_config()));
|
outfeed_shape, operand, token, hlo->outfeed_config()));
|
||||||
});
|
});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -3440,6 +3443,8 @@ Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
|
|||||||
};
|
};
|
||||||
outfeed_data = slice_outfeed({}, outfeed_data);
|
outfeed_data = slice_outfeed({}, outfeed_data);
|
||||||
}
|
}
|
||||||
|
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
|
||||||
|
hlo->outfeed_shape(), &per_branch_partitioned_shapes[i]));
|
||||||
branch_b.AddInstruction(HloInstruction::CreateOutfeed(
|
branch_b.AddInstruction(HloInstruction::CreateOutfeed(
|
||||||
per_branch_partitioned_shapes[i], outfeed_data, outfeed_token,
|
per_branch_partitioned_shapes[i], outfeed_data, outfeed_token,
|
||||||
hlo->outfeed_config()));
|
hlo->outfeed_config()));
|
||||||
|
@ -708,7 +708,8 @@ ENTRY entry {
|
|||||||
token.0 = token[] after-all()
|
token.0 = token[] after-all()
|
||||||
data = (f32[1024,2]{1,0}, f32[2]{0}) parameter(0), sharding={{devices=[2,1]0,1},
|
data = (f32[1024,2]{1,0}, f32[2]{0}) parameter(0), sharding={{devices=[2,1]0,1},
|
||||||
{devices=[2]0,1}}
|
{devices=[2]0,1}}
|
||||||
ROOT outfeed = token[] outfeed(data, token.0), sharding={{devices=[2,1]0,1},
|
ROOT outfeed = token[] outfeed(data, token.0),
|
||||||
|
outfeed_shape=(f32[1024,2]{0,1}, f32[2]{0}), sharding={{devices=[2,1]0,1},
|
||||||
{devices=[2]0,1}}
|
{devices=[2]0,1}}
|
||||||
})";
|
})";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
@ -717,6 +718,12 @@ ENTRY entry {
|
|||||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||||
EXPECT_THAT(root, AllOf(op::Shape("token[]"),
|
EXPECT_THAT(root, AllOf(op::Shape("token[]"),
|
||||||
op::Outfeed(op::Parameter(), op::AfterAll())));
|
op::Outfeed(op::Parameter(), op::AfterAll())));
|
||||||
|
auto expected_layout0 = LayoutUtil::MakeLayout({0, 1});
|
||||||
|
auto expected_layout1 = LayoutUtil::MakeLayout({0});
|
||||||
|
EXPECT_TRUE(LayoutUtil::Equal(root->outfeed_shape().tuple_shapes(0).layout(),
|
||||||
|
expected_layout0));
|
||||||
|
EXPECT_TRUE(LayoutUtil::Equal(root->outfeed_shape().tuple_shapes(1).layout(),
|
||||||
|
expected_layout1));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SpmdPartitioningTest, OutfeedReplicated) {
|
TEST_F(SpmdPartitioningTest, OutfeedReplicated) {
|
||||||
@ -746,7 +753,8 @@ ENTRY entry {
|
|||||||
token.0 = token[] after-all()
|
token.0 = token[] after-all()
|
||||||
data = (f32[1023,2]{1,0}, f32[3]{0}) parameter(0), sharding={{devices=[2,1]0,1},
|
data = (f32[1023,2]{1,0}, f32[3]{0}) parameter(0), sharding={{devices=[2,1]0,1},
|
||||||
{devices=[2]0,1}}
|
{devices=[2]0,1}}
|
||||||
outfeed = token[] outfeed(data, token.0), sharding={{devices=[2,1]0,1},
|
outfeed = token[] outfeed(data, token.0),
|
||||||
|
outfeed_shape=(f32[1023,2]{0,1}, f32[3]{0}), sharding={{devices=[2,1]0,1},
|
||||||
{devices=[2]0,1}}
|
{devices=[2]0,1}}
|
||||||
})";
|
})";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
@ -770,6 +778,24 @@ ENTRY entry {
|
|||||||
EXPECT_THAT(root->called_computations()[1]->root_instruction(),
|
EXPECT_THAT(root->called_computations()[1]->root_instruction(),
|
||||||
AllOf(op::Shape("token[]"),
|
AllOf(op::Shape("token[]"),
|
||||||
op::Outfeed(second_outfeed, op::GetTupleElement())));
|
op::Outfeed(second_outfeed, op::GetTupleElement())));
|
||||||
|
|
||||||
|
auto expected_layout0 = LayoutUtil::MakeLayout({0, 1});
|
||||||
|
auto expected_layout1 = LayoutUtil::MakeLayout({0});
|
||||||
|
auto first_outfeed_instr = root->called_computations()[0]->root_instruction();
|
||||||
|
auto second_outfeed_instr =
|
||||||
|
root->called_computations()[1]->root_instruction();
|
||||||
|
EXPECT_TRUE(LayoutUtil::Equal(
|
||||||
|
first_outfeed_instr->outfeed_shape().tuple_shapes(0).layout(),
|
||||||
|
expected_layout0));
|
||||||
|
EXPECT_TRUE(LayoutUtil::Equal(
|
||||||
|
first_outfeed_instr->outfeed_shape().tuple_shapes(1).layout(),
|
||||||
|
expected_layout1));
|
||||||
|
EXPECT_TRUE(LayoutUtil::Equal(
|
||||||
|
second_outfeed_instr->outfeed_shape().tuple_shapes(0).layout(),
|
||||||
|
expected_layout0));
|
||||||
|
EXPECT_TRUE(LayoutUtil::Equal(
|
||||||
|
second_outfeed_instr->outfeed_shape().tuple_shapes(1).layout(),
|
||||||
|
expected_layout1));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SpmdPartitioningTest, ReduceWindowReplicatedInput) {
|
TEST_F(SpmdPartitioningTest, ReduceWindowReplicatedInput) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user