[XLA] Preserve layout when sinking broadcasts through slice/dynamic slice
AlgebraicSimplifier could be run after layout assignment passes so we want to preserve the layout chosen for the input of operations that use slice or dynamic slice. Using MakeBroadcastHlo ignores the layout of the shape we pass as input. Use CreateBroadcast instead that preserves the shape as-is. PiperOrigin-RevId: 315619861 Change-Id: I7b16196b2de03f709cf395f60708cfa26ccf1cb9
This commit is contained in:
parent
3d53fd6875
commit
3cdb06cbab
@ -3623,12 +3623,17 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
|
||||
new_slice_strides.push_back(slice->slice_strides(dim));
|
||||
new_slice_limits.push_back(slice->slice_limits(dim));
|
||||
}
|
||||
VLOG(3) << "Sink broadcast through slice";
|
||||
VLOG(3) << "Original slice: " << slice->ToString();
|
||||
VLOG(3) << "Original broadcast: " << broadcast->ToString();
|
||||
TF_ASSIGN_OR_RETURN(auto new_slice,
|
||||
MakeSliceHlo(broadcast_operand, new_slice_starts,
|
||||
new_slice_limits, new_slice_strides));
|
||||
return ReplaceInstruction(
|
||||
slice,
|
||||
MakeBroadcastHlo(new_slice, broadcast->dimensions(), slice->shape()));
|
||||
auto new_broadcast = HloInstruction::CreateBroadcast(
|
||||
slice->shape(), new_slice, broadcast->dimensions());
|
||||
VLOG(3) << "New slice: " << slice->ToString();
|
||||
VLOG(3) << "New broadcast: " << new_broadcast->ToString();
|
||||
return ReplaceWithNewInstruction(slice, std::move(new_broadcast));
|
||||
}
|
||||
|
||||
// Try to simplify concat -> slice to an operand of concat.
|
||||
@ -3708,16 +3713,21 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
|
||||
new_indices.push_back(dynamic_slice->mutable_operand(1 + dim));
|
||||
new_slice_sizes.push_back(dynamic_slice->slice_sizes(dim));
|
||||
}
|
||||
|
||||
VLOG(3) << "Sink broadcast through dynamic slice";
|
||||
VLOG(3) << "Original dynamic slice: " << dynamic_slice->ToString();
|
||||
VLOG(3) << "Original broadcast: " << operand->ToString();
|
||||
HloInstruction* new_dynamic_slice = broadcast_operand;
|
||||
if (!new_slice_sizes.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
new_dynamic_slice,
|
||||
MakeDynamicSliceHlo(broadcast_operand, new_indices, new_slice_sizes));
|
||||
}
|
||||
return ReplaceInstruction(
|
||||
dynamic_slice,
|
||||
MakeBroadcastHlo(new_dynamic_slice, operand->dimensions(),
|
||||
dynamic_slice->shape()));
|
||||
auto new_broadcast = HloInstruction::CreateBroadcast(
|
||||
dynamic_slice->shape(), new_dynamic_slice, operand->dimensions());
|
||||
VLOG(3) << "New dynamic slice: " << dynamic_slice->ToString();
|
||||
VLOG(3) << "New broadcast: " << new_broadcast->ToString();
|
||||
return ReplaceWithNewInstruction(dynamic_slice, std::move(new_broadcast));
|
||||
}
|
||||
|
||||
// Convert a dynamic slice into a slice if all offsets are constant and the
|
||||
|
@ -2539,6 +2539,28 @@ TEST_F(AlgebraicSimplifierTest, SliceOfBroadcast) {
|
||||
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastPreserveLayout) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[10,20] parameter(0)
|
||||
b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
|
||||
ROOT s = f32[5,5,5]{2,0,1:T(256)} slice(b), slice={[0:5:1], [5:25:4], [5:15:2]}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
const Shape original_slice_shape =
|
||||
module->entry_computation()->root_instruction()->shape();
|
||||
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_slice_shape));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcast) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
@ -2562,6 +2584,32 @@ TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcast) {
|
||||
m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcastPreserveLayout) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[10,20] parameter(0)
|
||||
i0 = s32[] parameter(1)
|
||||
i1 = s32[] parameter(2)
|
||||
i2 = s32[] parameter(3)
|
||||
b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
|
||||
ROOT ds = f32[5,5,5]{2,0,1:T(256)} dynamic-slice(b, i0, i1, i2), dynamic_slice_sizes={5,5,5}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
const Shape original_dynslice_shape =
|
||||
module->entry_computation()->root_instruction()->shape();
|
||||
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::DynamicSlice(
|
||||
m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_dynslice_shape));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
Loading…
Reference in New Issue
Block a user