[XLA] Preserve layout when sinking broadcasts through slice/dynamic slice

AlgebraicSimplifier could be run after layout assignment passes so we want
to preserve the layout chosen for the input of operations that use slice
or dynamic slice.
Using MakeBroadcastHlo ignores the layout of the shape we pass as input.
Use CreateBroadcast instead that preserves the shape as-is.

PiperOrigin-RevId: 315619861
Change-Id: I7b16196b2de03f709cf395f60708cfa26ccf1cb9
This commit is contained in:
A. Unique TensorFlower 2020-06-09 21:01:48 -07:00 committed by TensorFlower Gardener
parent 3d53fd6875
commit 3cdb06cbab
2 changed files with 65 additions and 7 deletions

View File

@ -3623,12 +3623,17 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
new_slice_strides.push_back(slice->slice_strides(dim));
new_slice_limits.push_back(slice->slice_limits(dim));
}
VLOG(3) << "Sink broadcast through slice";
VLOG(3) << "Original slice: " << slice->ToString();
VLOG(3) << "Original broadcast: " << broadcast->ToString();
TF_ASSIGN_OR_RETURN(auto new_slice,
MakeSliceHlo(broadcast_operand, new_slice_starts,
new_slice_limits, new_slice_strides));
return ReplaceInstruction(
slice,
MakeBroadcastHlo(new_slice, broadcast->dimensions(), slice->shape()));
auto new_broadcast = HloInstruction::CreateBroadcast(
slice->shape(), new_slice, broadcast->dimensions());
VLOG(3) << "New slice: " << slice->ToString();
VLOG(3) << "New broadcast: " << new_broadcast->ToString();
return ReplaceWithNewInstruction(slice, std::move(new_broadcast));
}
// Try to simplify concat -> slice to an operand of concat.
@ -3708,16 +3713,21 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
new_indices.push_back(dynamic_slice->mutable_operand(1 + dim));
new_slice_sizes.push_back(dynamic_slice->slice_sizes(dim));
}
VLOG(3) << "Sink broadcast through dynamic slice";
VLOG(3) << "Original dynamic slice: " << dynamic_slice->ToString();
VLOG(3) << "Original broadcast: " << operand->ToString();
HloInstruction* new_dynamic_slice = broadcast_operand;
if (!new_slice_sizes.empty()) {
TF_ASSIGN_OR_RETURN(
new_dynamic_slice,
MakeDynamicSliceHlo(broadcast_operand, new_indices, new_slice_sizes));
}
return ReplaceInstruction(
dynamic_slice,
MakeBroadcastHlo(new_dynamic_slice, operand->dimensions(),
dynamic_slice->shape()));
auto new_broadcast = HloInstruction::CreateBroadcast(
dynamic_slice->shape(), new_dynamic_slice, operand->dimensions());
VLOG(3) << "New dynamic slice: " << dynamic_slice->ToString();
VLOG(3) << "New broadcast: " << new_broadcast->ToString();
return ReplaceWithNewInstruction(dynamic_slice, std::move(new_broadcast));
}
// Convert a dynamic slice into a slice if all offsets are constant and the

View File

@ -2539,6 +2539,28 @@ TEST_F(AlgebraicSimplifierTest, SliceOfBroadcast) {
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
}
TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastPreserveLayout) {
const char* hlo_string = R"(
HloModule module
ENTRY test {
p0 = f32[10,20] parameter(0)
b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
ROOT s = f32[5,5,5]{2,0,1:T(256)} slice(b), slice={[0:5:1], [5:25:4], [5:15:2]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
const Shape original_slice_shape =
module->entry_computation()->root_instruction()->shape();
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_slice_shape));
}
TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcast) {
const char* hlo_string = R"(
HloModule module
@ -2562,6 +2584,32 @@ TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcast) {
m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
}
TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcastPreserveLayout) {
const char* hlo_string = R"(
HloModule module
ENTRY test {
p0 = f32[10,20] parameter(0)
i0 = s32[] parameter(1)
i1 = s32[] parameter(2)
i2 = s32[] parameter(3)
b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
ROOT ds = f32[5,5,5]{2,0,1:T(256)} dynamic-slice(b, i0, i1, i2), dynamic_slice_sizes={5,5,5}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
const Shape original_dynslice_shape =
module->entry_computation()->root_instruction()->shape();
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::DynamicSlice(
m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_dynslice_shape));
}
TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) {
const char* hlo_string = R"(
HloModule module