From 287cacfb9971f43000b72c9badcbbe44e00575a8 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Wed, 8 Apr 2020 16:47:04 -0700 Subject: [PATCH] [XLA] Fix the condition for rewriting batch group convolutions to include when the input batch is not equal to the batch group count. PiperOrigin-RevId: 305581009 Change-Id: Ifb9d90a684e9de571fc6d4e03320a2f3438b36b5 --- .../service/convolution_group_converter.cc | 4 +++- .../convolution_group_converter_test.cc | 23 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index ab959cb0087..323bf44dcd3 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -225,10 +225,12 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { const int64 kernel_output_feature_dimension = dim_numbers.kernel_output_feature_dimension(); + const int64 input_batch = + activation->shape().dimensions(input_batch_dimension); const int64 output_feature = filter->shape().dimensions(kernel_output_feature_dimension); - if (output_feature != batch_group_count) { + if (output_feature != batch_group_count || input_batch != batch_group_count) { // Insert a spatial dimension to the activation before the input batch // dimension to represent the batch group. std::vector input_sizes(activation->shape().dimensions().begin(), diff --git a/tensorflow/compiler/xla/service/convolution_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc index fea37130c6d..143e071dc3c 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter_test.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc @@ -119,5 +119,28 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[16,19,19,512]{3,2,1,0}, filter: f32[16 EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kReduceWindow); } +TEST_F(ConvolutionGroupConverterTest, + ConvertBatchGroupCountNotEqualToInputBatchDim) { + string hlo_string = R"(HloModule m + ENTRY main { + %input = f32[1,1,1,4] parameter(0) + %filter = f32[1,1,1,2] parameter(1) + ROOT %convolution = f32[1,1,2,2] convolution(%input,%filter), + window={size=1x1}, dim_labels=f01b_i01o->01fb, batch_group_count=2 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + auto cost_model = [](HloInstruction* conv) { return false; }; + ConvolutionGroupConverter converter(cost_model, /*convert_batch_groups_only=*/ + true); + // Make sure that batch group count is rewritten even if + // batch_group_count == output_feature but not input_batch + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace xla