From 41b6bae3d1b0c103baa331036debc92de9422a7e Mon Sep 17 00:00:00 2001
From: Blake Hechtman <blakehechtman@google.com>
Date: Thu, 20 Feb 2020 21:16:59 -0800
Subject: [PATCH] [XLA] Add some more slice of pad optimizations.

PiperOrigin-RevId: 296361878
Change-Id: I4dbef5e94d95f3337c1004e8c3f09c7a94148075
---
 .../xla/service/algebraic_simplifier.cc       | 91 ++++++++-----------
 .../xla/service/algebraic_simplifier_test.cc  | 34 ++++++-
 2 files changed, 68 insertions(+), 57 deletions(-)

diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index cfbcb5a4fe2..fd373671b97 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -3204,53 +3204,6 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
     return false;
   }
 
-  if (slice->operand(0)->opcode() == HloOpcode::kPad) {
-    VLOG(10) << "Trying to simplify scalar slice of pad";
-    // Check there's no internal padding. Again, we could handle that too, since
-    // everything is statically known, but it's not worth it.
-    auto pad = Cast<HloPadInstruction>(slice->mutable_operand(0));
-    auto padding_config = pad->padding_config();
-    int64 rank = padding_config.dimensions_size();
-    if (HasInteriorPadding(padding_config)) {
-      VLOG(10) << "Not folding scalar slice of pad, pad has interior padding";
-      return false;
-    }
-
-    // Check whether the scalar we're slicing out falls into the padding.
-    bool in_padding = [&]() {
-      for (int64 i = 0; i < rank; ++i) {
-        int64 start = slice->slice_starts(i);
-        int64 low = padding_config.dimensions(i).edge_padding_low();
-        int64 data = pad->operand(0)->shape().dimensions(i);
-        if (start < low || start >= low + data) {
-          return true;
-        }
-      }
-      return false;
-    }();
-
-    if (in_padding) {
-      VLOG(10) << "Folding scalar slice of pad into padding value";
-      TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
-          slice, HloInstruction::CreateReshape(slice->shape(),
-                                               pad->mutable_padding_value())));
-      return true;
-    } else {
-      // We already know the output of the slice is scalar. If the padded
-      // value is scalar, and it's not in the padding, then it's exactly the
-      // output value.
-      bool replaced =
-          ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0));
-      if (replaced) {
-        VLOG(10) << "Folding scalar slice of pad into padded value";
-      } else {
-        VLOG(10) << "Not folding scalar slice of pad into padded value as they "
-                    "have different shapes.";
-      }
-      return replaced;
-    }
-  }
-
   if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
     VLOG(10) << "Trying to simplify scalar slice of concat";
     // Only do this for R1, there's no chance of this being useful otherwise.
@@ -3356,20 +3309,54 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
   HloInstruction* pad;
   HloInstruction* pad_operand;
   if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
+    // Is the result of the slice the pad operand.
     bool slice_undoes_pad = true;
+    // Can the slice be moved to the pad_operand without any padding being read.
+    bool slice_inside_pad = true;
+    // Does this slice slice out pading only.
+    bool slice_in_padding = false;
+    std::vector<int64> new_starts = slice->slice_starts();
+    std::vector<int64> new_limits = slice->slice_limits();
     for (int64 i = 0; i < slice->shape().rank(); ++i) {
-      if (slice->slice_starts(i) !=
-          pad->padding_config().dimensions(i).edge_padding_low()) {
+      const int64 start = slice->slice_starts(i);
+      const int64 stride = slice->slice_strides(i);
+      const int64 limit = slice->slice_limits(i);
+      const int64 size = pad->shape().dimensions(i);
+
+      const auto& dim = pad->padding_config().dimensions(i);
+      const int64 low = dim.edge_padding_low();
+      const int64 high = dim.edge_padding_high();
+      const int64 interior = dim.interior_padding();
+      const int64 edge = size - high;
+
+      if (limit <= low || start >= edge) {
+        slice_in_padding = true;
+        break;
+      }
+
+      if (start != low || stride - 1 != interior) {
         slice_undoes_pad = false;
       }
-      if (slice->slice_strides(i) - 1 !=
-          pad->padding_config().dimensions(i).interior_padding()) {
-        slice_undoes_pad = false;
+
+      if (start < low || limit > edge || interior != 0 || stride != 1) {
+        slice_inside_pad = false;
       }
+      new_starts[i] -= low;
+      new_limits[i] -= low;
+    }
+    if (slice_in_padding) {
+      return ReplaceInstruction(
+          slice, MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape()));
     }
     if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) {
       return Status::OK();
     }
+    if (slice_inside_pad) {
+      TF_ASSIGN_OR_RETURN(HloInstruction * new_slice,
+                          MakeSliceHlo(pad_operand, new_starts, new_limits,
+                                       slice->slice_strides()));
+      return ReplaceInstruction(slice, new_slice);
+    }
   }
 
   if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 8f66f8084f3..31fa125b3e1 100755
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -4389,7 +4389,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) {
   AlgebraicSimplifier simplifier(options);
   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
   auto root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
+  EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
 }
 
 TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
@@ -4410,7 +4410,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
   AlgebraicSimplifier simplifier(options);
   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
   auto root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
+  EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
 }
 
 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
@@ -4429,7 +4429,31 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
 
   AlgebraicSimplifierOptions options;
   AlgebraicSimplifier simplifier(options);
-  EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Slice(m::Parameter(0))));
+}
+
+TEST_F(AlgebraicSimplifierTest, SliceOfPad) {
+  const char* hlo_string = R"(
+    HloModule module
+
+    ENTRY test {
+      param = f32[3,4] parameter(0)
+      constant = f32[] constant(0.0)
+      pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
+      ROOT slice = f32[2,3] slice(f32[8,10] pad), slice={[4:6],[2:5]}
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  AlgebraicSimplifierOptions options;
+  AlgebraicSimplifier simplifier(options);
+  EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(0))));
+  EXPECT_THAT(root->slice_starts(), ElementsAre(1, 1));
 }
 
 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
@@ -4450,7 +4474,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
   AlgebraicSimplifier simplifier(options);
   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
   auto root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
+  EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
 }
 
 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) {
@@ -4494,7 +4518,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) {
   AlgebraicSimplifier simplifier(options);
   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
   auto root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root, GmockMatch(m::Reshape(m::ConstantScalar(-7.0))));
+  EXPECT_THAT(root, GmockMatch(m::Broadcast(m::ConstantScalar(-7.0))));
 }
 
 TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) {