diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index ab959cb0087..323bf44dcd3 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -225,10 +225,12 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { const int64 kernel_output_feature_dimension = dim_numbers.kernel_output_feature_dimension(); + const int64 input_batch = + activation->shape().dimensions(input_batch_dimension); const int64 output_feature = filter->shape().dimensions(kernel_output_feature_dimension); - if (output_feature != batch_group_count) { + if (output_feature != batch_group_count || input_batch != batch_group_count) { // Insert a spatial dimension to the activation before the input batch // dimension to represent the batch group. std::vector input_sizes(activation->shape().dimensions().begin(), diff --git a/tensorflow/compiler/xla/service/convolution_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc index fea37130c6d..143e071dc3c 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter_test.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc @@ -119,5 +119,28 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[16,19,19,512]{3,2,1,0}, filter: f32[16 EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kReduceWindow); } +TEST_F(ConvolutionGroupConverterTest, + ConvertBatchGroupCountNotEqualToInputBatchDim) { + string hlo_string = R"(HloModule m + ENTRY main { + %input = f32[1,1,1,4] parameter(0) + %filter = f32[1,1,1,2] parameter(1) + ROOT %convolution = f32[1,1,2,2] convolution(%input,%filter), + window={size=1x1}, dim_labels=f01b_i01o->01fb, batch_group_count=2 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + auto cost_model = [](HloInstruction* conv) { return false; }; + ConvolutionGroupConverter converter(cost_model, /*convert_batch_groups_only=*/ + true); + // Make sure that batch group count is rewritten even if + // batch_group_count == output_feature but not input_batch + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace xla