[XLA] Fix OOB dynamic-slice/pad simplification
PiperOrigin-RevId: 356177919 Change-Id: I3f69c1043e4fa55f2fb0ce10ea0383240847ea51
This commit is contained in:
parent
93b319bd07
commit
712be8cb4b
@ -4434,19 +4434,24 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
|
||||
compatible = false;
|
||||
break;
|
||||
}
|
||||
VLOG(2) << "slice :" << slice_dim_start->ToString();
|
||||
VLOG(2) << "slice: " << slice_dim_start->ToString();
|
||||
absl::optional<int64> beg =
|
||||
slice_dim_start->literal().GetFirstInteger();
|
||||
if (!beg) {
|
||||
compatible = false;
|
||||
break;
|
||||
}
|
||||
VLOG(2) << "beg value:" << *beg;
|
||||
VLOG(2) << "beg value: " << *beg;
|
||||
auto update_width = ShapeUtil::GetDimension(update_shape, dim);
|
||||
auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim);
|
||||
// Clamp beg so that it is non-negative.
|
||||
*beg = std::max<int64>(0, *beg);
|
||||
// Clamp beg so that it is in-bounds.
|
||||
*beg = std::min<int64>(bcast_width - update_width, *beg);
|
||||
VLOG(2) << "adjusted beg value: " << *beg;
|
||||
padding_config_dim->set_edge_padding_low(*beg);
|
||||
padding_config_dim->set_edge_padding_high(
|
||||
std::max(bcast_width - (*beg + update_width), int64{0}));
|
||||
padding_config_dim->set_edge_padding_high(bcast_width -
|
||||
(*beg + update_width));
|
||||
// dynamic_update_slice does not specify a stride
|
||||
padding_config_dim->set_interior_padding(0);
|
||||
}
|
||||
|
@ -7161,5 +7161,35 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithNonScalar) {
|
||||
GmockMatch(m::Tuple(m::Broadcast(
|
||||
m::Pad(m::Broadcast(m::Parameter()), m::Constant())))));
|
||||
}
|
||||
|
||||
// Test that dynamic-update-slice with a scalar broadcast becomes a pad when the
|
||||
// start_indices are too big.
|
||||
TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceOfBroadcastToPadOob) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY f {
|
||||
constant.546 = f32[] constant(0)
|
||||
broadcast.467 = f32[2]{0} broadcast(constant.546), dimensions={}
|
||||
parameter.1 = f32[1]{0} parameter(0)
|
||||
constant.551 = s32[] constant(2)
|
||||
ROOT dynamic-update-slice.44 = f32[2]{0} dynamic-update-slice(broadcast.467, parameter.1, constant.551)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
VLOG(2) << "Before rewrite dus->pad\n" << module->ToString();
|
||||
AlgebraicSimplifier simplifier(default_options_);
|
||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
VLOG(2) << "After rewrite dus->pad\n" << module->ToString();
|
||||
auto* pad = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(pad,
|
||||
GmockMatch(m::Pad(m::Parameter(0), m::ConstantScalar(0.0f))));
|
||||
EXPECT_FALSE(HasInteriorPadding(pad->padding_config()));
|
||||
ASSERT_EQ(pad->padding_config().dimensions_size(), 1);
|
||||
EXPECT_EQ(pad->padding_config().dimensions(0).edge_padding_low(), 1);
|
||||
EXPECT_EQ(pad->padding_config().dimensions(0).edge_padding_high(), 0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user