From dd1ce23e9397d597663a41ecfea8b640da115169 Mon Sep 17 00:00:00 2001
From: Yunxing Dai <yunxing@google.com>
Date: Thu, 4 Mar 2021 19:04:49 -0800
Subject: [PATCH] Skip space-to-batch optimization on convs that are used by a
 different rank reduce-window or select-and-scatter.

PiperOrigin-RevId: 361053054
Change-Id: Idf82848912ceea722aebbd07f0c87a1b14499673
---
 .../xla/service/space_to_batch_converter.cc   |  4 ++-
 .../service/space_to_batch_converter_test.cc  | 35 ++++++++++++++++++-
 2 files changed, 37 insertions(+), 2 deletions(-)

diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc
index 491835d8a0c..1f269240bf8 100644
--- a/tensorflow/compiler/xla/service/space_to_batch_converter.cc
+++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc
@@ -2798,7 +2798,9 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
   auto reduce_window_or_select_and_scatter =
       DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution);
 
-  if (reduce_window_or_select_and_scatter != nullptr) {
+  if (reduce_window_or_select_and_scatter != nullptr &&
+      reduce_window_or_select_and_scatter->shape().rank() ==
+          convolution->shape().rank()) {
     VLOG(2)
         << "DoesConvolutionFeedReduceWindowOrSelectAndScatter returned true";
     // Take into account the stride of the reduce window while choosing the
diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc
index 96cfe553b53..ac399bccded 100644
--- a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc
+++ b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc
@@ -64,13 +64,46 @@ ENTRY computation {
   EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1);
 }
 
+TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch1WithReduceWindow) {
+  string hlo_string = R"(
+  HloModule module  
+  adder (lhs: bf16[], rhs: bf16[]) -> bf16[] {
+    lhs = bf16[] parameter(0)
+    rhs = bf16[] parameter(1)
+    ROOT add = bf16[] add(lhs, rhs)
+  }
+
+  ENTRY computation {
+    %p0 = bf16[1,258,258,32] parameter(0)
+    %p1 = bf16[3,3,32,32] parameter(1)
+    %convolution = bf16[1,256,256,32] convolution(%p0, %p1), window={size=3x3},
+    dim_labels=b01f_01io->b01f
+    %constant = bf16[3] constant({1.0, 2.0, 3.0})
+    %tuple = (bf16[1,256,256,32], bf16[3])tuple(%convolution, %constant)
+    ROOT %gte = bf16[1,256,256,32] get-tuple-element(%tuple), index=0
+    %gte2 = bf16[3]get-tuple-element(%tuple), index=1
+    %init = bf16[] constant(1.0)
+    %reduce-window = bf16[3] reduce-window(bf16[3] %gte2, bf16[] %init),
+      window={size=1}, to_apply=%adder
+  }
+
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  ConvolutionSpaceToBatchConverter converter;
+  // Test that a reduce window consumer with different rank won't freeze the
+  // compiler.
+  ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
+}
+
 TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch2) {
   string hlo_string = R"(
   HloModule module
   ENTRY computation {
     %p0 = bf16[2,258,258,32] parameter(0)
     %p1 = bf16[3,3,32,32] parameter(1)
-    ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3}, 
+    ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3},
     dim_labels=b01f_01io->b01f
   }