From 712be8cb4b002bbc1b067eca79de8c01180e864c Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Sun, 7 Feb 2021 19:34:19 -0800 Subject: [PATCH] [XLA] Fix OOB dynamic-slice/pad simplification PiperOrigin-RevId: 356177919 Change-Id: I3f69c1043e4fa55f2fb0ce10ea0383240847ea51 --- .../xla/service/algebraic_simplifier.cc | 13 +++++--- .../xla/service/algebraic_simplifier_test.cc | 30 +++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 8ce7d811c51..cddad120875 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -4434,19 +4434,24 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( compatible = false; break; } - VLOG(2) << "slice :" << slice_dim_start->ToString(); + VLOG(2) << "slice: " << slice_dim_start->ToString(); absl::optional beg = slice_dim_start->literal().GetFirstInteger(); if (!beg) { compatible = false; break; } - VLOG(2) << "beg value:" << *beg; + VLOG(2) << "beg value: " << *beg; auto update_width = ShapeUtil::GetDimension(update_shape, dim); auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim); + // Clamp beg so that it is non-negative. + *beg = std::max(0, *beg); + // Clamp beg so that it is in-bounds. + *beg = std::min(bcast_width - update_width, *beg); + VLOG(2) << "adjusted beg value: " << *beg; padding_config_dim->set_edge_padding_low(*beg); - padding_config_dim->set_edge_padding_high( - std::max(bcast_width - (*beg + update_width), int64{0})); + padding_config_dim->set_edge_padding_high(bcast_width - + (*beg + update_width)); // dynamic_update_slice does not specify a stride padding_config_dim->set_interior_padding(0); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index a900008a59a..88f45e817fd 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -7161,5 +7161,35 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithNonScalar) { GmockMatch(m::Tuple(m::Broadcast( m::Pad(m::Broadcast(m::Parameter()), m::Constant()))))); } + +// Test that dynamic-update-slice with a scalar broadcast becomes a pad when the +// start_indices are too big. +TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceOfBroadcastToPadOob) { + const char* hlo_string = R"( +HloModule module + +ENTRY f { + constant.546 = f32[] constant(0) + broadcast.467 = f32[2]{0} broadcast(constant.546), dimensions={} + parameter.1 = f32[1]{0} parameter(0) + constant.551 = s32[] constant(2) + ROOT dynamic-update-slice.44 = f32[2]{0} dynamic-update-slice(broadcast.467, parameter.1, constant.551) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + VLOG(2) << "Before rewrite dus->pad\n" << module->ToString(); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + VLOG(2) << "After rewrite dus->pad\n" << module->ToString(); + auto* pad = module->entry_computation()->root_instruction(); + EXPECT_THAT(pad, + GmockMatch(m::Pad(m::Parameter(0), m::ConstantScalar(0.0f)))); + EXPECT_FALSE(HasInteriorPadding(pad->padding_config())); + ASSERT_EQ(pad->padding_config().dimensions_size(), 1); + EXPECT_EQ(pad->padding_config().dimensions(0).edge_padding_low(), 1); + EXPECT_EQ(pad->padding_config().dimensions(0).edge_padding_high(), 0); +} + } // namespace } // namespace xla