From dd1ce23e9397d597663a41ecfea8b640da115169 Mon Sep 17 00:00:00 2001 From: Yunxing Dai <yunxing@google.com> Date: Thu, 4 Mar 2021 19:04:49 -0800 Subject: [PATCH] Skip space-to-batch optimization on convs that are used by a different rank reduce-window or select-and-scatter. PiperOrigin-RevId: 361053054 Change-Id: Idf82848912ceea722aebbd07f0c87a1b14499673 --- .../xla/service/space_to_batch_converter.cc | 4 ++- .../service/space_to_batch_converter_test.cc | 35 ++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc index 491835d8a0c..1f269240bf8 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter.cc +++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc @@ -2798,7 +2798,9 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( auto reduce_window_or_select_and_scatter = DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution); - if (reduce_window_or_select_and_scatter != nullptr) { + if (reduce_window_or_select_and_scatter != nullptr && + reduce_window_or_select_and_scatter->shape().rank() == + convolution->shape().rank()) { VLOG(2) << "DoesConvolutionFeedReduceWindowOrSelectAndScatter returned true"; // Take into account the stride of the reduce window while choosing the diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc index 96cfe553b53..ac399bccded 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc +++ b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc @@ -64,13 +64,46 @@ ENTRY computation { EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1); } +TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch1WithReduceWindow) { + string hlo_string = R"( + HloModule module + adder (lhs: bf16[], rhs: bf16[]) -> bf16[] { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) + } + + ENTRY computation { + %p0 = bf16[1,258,258,32] parameter(0) + %p1 = bf16[3,3,32,32] parameter(1) + %convolution = bf16[1,256,256,32] convolution(%p0, %p1), window={size=3x3}, + dim_labels=b01f_01io->b01f + %constant = bf16[3] constant({1.0, 2.0, 3.0}) + %tuple = (bf16[1,256,256,32], bf16[3])tuple(%convolution, %constant) + ROOT %gte = bf16[1,256,256,32] get-tuple-element(%tuple), index=0 + %gte2 = bf16[3]get-tuple-element(%tuple), index=1 + %init = bf16[] constant(1.0) + %reduce-window = bf16[3] reduce-window(bf16[3] %gte2, bf16[] %init), + window={size=1}, to_apply=%adder + } + + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseAndReturnVerifiedModule(hlo_string)); + + ConvolutionSpaceToBatchConverter converter; + // Test that a reduce window consumer with different rank won't freeze the + // compiler. + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); +} + TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch2) { string hlo_string = R"( HloModule module ENTRY computation { %p0 = bf16[2,258,258,32] parameter(0) %p1 = bf16[3,3,32,32] parameter(1) - ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3}, + ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3}, dim_labels=b01f_01io->b01f }