[XLA] Fix OOB dynamic-slice/pad simplification

PiperOrigin-RevId: 356177919
Change-Id: I3f69c1043e4fa55f2fb0ce10ea0383240847ea51
This commit is contained in:
David Majnemer 2021-02-07 19:34:19 -08:00 committed by TensorFlower Gardener
parent 93b319bd07
commit 712be8cb4b
2 changed files with 39 additions and 4 deletions

View File

@ -4434,19 +4434,24 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
compatible = false;
break;
}
VLOG(2) << "slice :" << slice_dim_start->ToString();
VLOG(2) << "slice: " << slice_dim_start->ToString();
absl::optional<int64> beg =
slice_dim_start->literal().GetFirstInteger();
if (!beg) {
compatible = false;
break;
}
VLOG(2) << "beg value:" << *beg;
VLOG(2) << "beg value: " << *beg;
auto update_width = ShapeUtil::GetDimension(update_shape, dim);
auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim);
// Clamp beg so that it is non-negative.
*beg = std::max<int64>(0, *beg);
// Clamp beg so that it is in-bounds.
*beg = std::min<int64>(bcast_width - update_width, *beg);
VLOG(2) << "adjusted beg value: " << *beg;
padding_config_dim->set_edge_padding_low(*beg);
padding_config_dim->set_edge_padding_high(
std::max(bcast_width - (*beg + update_width), int64{0}));
padding_config_dim->set_edge_padding_high(bcast_width -
(*beg + update_width));
// dynamic_update_slice does not specify a stride
padding_config_dim->set_interior_padding(0);
}

View File

@ -7161,5 +7161,35 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithNonScalar) {
GmockMatch(m::Tuple(m::Broadcast(
m::Pad(m::Broadcast(m::Parameter()), m::Constant())))));
}
// Test that dynamic-update-slice with a scalar broadcast becomes a pad when the
// start_indices are too big.
TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceOfBroadcastToPadOob) {
const char* hlo_string = R"(
HloModule module
ENTRY f {
constant.546 = f32[] constant(0)
broadcast.467 = f32[2]{0} broadcast(constant.546), dimensions={}
parameter.1 = f32[1]{0} parameter(0)
constant.551 = s32[] constant(2)
ROOT dynamic-update-slice.44 = f32[2]{0} dynamic-update-slice(broadcast.467, parameter.1, constant.551)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
VLOG(2) << "Before rewrite dus->pad\n" << module->ToString();
AlgebraicSimplifier simplifier(default_options_);
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
VLOG(2) << "After rewrite dus->pad\n" << module->ToString();
auto* pad = module->entry_computation()->root_instruction();
EXPECT_THAT(pad,
GmockMatch(m::Pad(m::Parameter(0), m::ConstantScalar(0.0f))));
EXPECT_FALSE(HasInteriorPadding(pad->padding_config()));
ASSERT_EQ(pad->padding_config().dimensions_size(), 1);
EXPECT_EQ(pad->padding_config().dimensions(0).edge_padding_low(), 1);
EXPECT_EQ(pad->padding_config().dimensions(0).edge_padding_high(), 0);
}
} // namespace
} // namespace xla