diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index cfbcb5a4fe2..fd373671b97 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -3204,53 +3204,6 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( return false; } - if (slice->operand(0)->opcode() == HloOpcode::kPad) { - VLOG(10) << "Trying to simplify scalar slice of pad"; - // Check there's no internal padding. Again, we could handle that too, since - // everything is statically known, but it's not worth it. - auto pad = Cast(slice->mutable_operand(0)); - auto padding_config = pad->padding_config(); - int64 rank = padding_config.dimensions_size(); - if (HasInteriorPadding(padding_config)) { - VLOG(10) << "Not folding scalar slice of pad, pad has interior padding"; - return false; - } - - // Check whether the scalar we're slicing out falls into the padding. - bool in_padding = [&]() { - for (int64 i = 0; i < rank; ++i) { - int64 start = slice->slice_starts(i); - int64 low = padding_config.dimensions(i).edge_padding_low(); - int64 data = pad->operand(0)->shape().dimensions(i); - if (start < low || start >= low + data) { - return true; - } - } - return false; - }(); - - if (in_padding) { - VLOG(10) << "Folding scalar slice of pad into padding value"; - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( - slice, HloInstruction::CreateReshape(slice->shape(), - pad->mutable_padding_value()))); - return true; - } else { - // We already know the output of the slice is scalar. If the padded - // value is scalar, and it's not in the padding, then it's exactly the - // output value. - bool replaced = - ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0)); - if (replaced) { - VLOG(10) << "Folding scalar slice of pad into padded value"; - } else { - VLOG(10) << "Not folding scalar slice of pad into padded value as they " - "have different shapes."; - } - return replaced; - } - } - if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) { VLOG(10) << "Trying to simplify scalar slice of concat"; // Only do this for R1, there's no chance of this being useful otherwise. @@ -3356,20 +3309,54 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { HloInstruction* pad; HloInstruction* pad_operand; if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) { + // Is the result of the slice the pad operand. bool slice_undoes_pad = true; + // Can the slice be moved to the pad_operand without any padding being read. + bool slice_inside_pad = true; + // Does this slice slice out pading only. + bool slice_in_padding = false; + std::vector new_starts = slice->slice_starts(); + std::vector new_limits = slice->slice_limits(); for (int64 i = 0; i < slice->shape().rank(); ++i) { - if (slice->slice_starts(i) != - pad->padding_config().dimensions(i).edge_padding_low()) { + const int64 start = slice->slice_starts(i); + const int64 stride = slice->slice_strides(i); + const int64 limit = slice->slice_limits(i); + const int64 size = pad->shape().dimensions(i); + + const auto& dim = pad->padding_config().dimensions(i); + const int64 low = dim.edge_padding_low(); + const int64 high = dim.edge_padding_high(); + const int64 interior = dim.interior_padding(); + const int64 edge = size - high; + + if (limit <= low || start >= edge) { + slice_in_padding = true; + break; + } + + if (start != low || stride - 1 != interior) { slice_undoes_pad = false; } - if (slice->slice_strides(i) - 1 != - pad->padding_config().dimensions(i).interior_padding()) { - slice_undoes_pad = false; + + if (start < low || limit > edge || interior != 0 || stride != 1) { + slice_inside_pad = false; } + new_starts[i] -= low; + new_limits[i] -= low; + } + if (slice_in_padding) { + return ReplaceInstruction( + slice, MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape())); } if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) { return Status::OK(); } + if (slice_inside_pad) { + TF_ASSIGN_OR_RETURN(HloInstruction * new_slice, + MakeSliceHlo(pad_operand, new_starts, new_limits, + slice->slice_strides())); + return ReplaceInstruction(slice, new_slice); + } } if (slice->operand(0)->opcode() == HloOpcode::kSlice && diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 8f66f8084f3..31fa125b3e1 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -4389,7 +4389,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { @@ -4410,7 +4410,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { @@ -4429,7 +4429,31 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Slice(m::Parameter(0)))); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfPad) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[3,4] parameter(0) + constant = f32[] constant(0.0) + pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 + ROOT slice = f32[2,3] slice(f32[8,10] pad), slice={[4:6],[2:5]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(0)))); + EXPECT_THAT(root->slice_starts(), ElementsAre(1, 1)); } TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) { @@ -4450,7 +4474,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { @@ -4494,7 +4518,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, GmockMatch(m::Reshape(m::ConstantScalar(-7.0)))); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::ConstantScalar(-7.0)))); } TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) {