[XLA] Fix the condition for rewriting batch group convolutions to include when the input batch is not equal to the batch group count.

PiperOrigin-RevId: 305581009
Change-Id: Ifb9d90a684e9de571fc6d4e03320a2f3438b36b5
This commit is contained in:
Blake Hechtman 2020-04-08 16:47:04 -07:00 committed by TensorFlower Gardener
parent 1e2c8c6873
commit 287cacfb99
2 changed files with 26 additions and 1 deletions

View File

@ -225,10 +225,12 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
const int64 kernel_output_feature_dimension =
dim_numbers.kernel_output_feature_dimension();
const int64 input_batch =
activation->shape().dimensions(input_batch_dimension);
const int64 output_feature =
filter->shape().dimensions(kernel_output_feature_dimension);
if (output_feature != batch_group_count) {
if (output_feature != batch_group_count || input_batch != batch_group_count) {
// Insert a spatial dimension to the activation before the input batch
// dimension to represent the batch group.
std::vector<int64> input_sizes(activation->shape().dimensions().begin(),

View File

@ -119,5 +119,28 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[16,19,19,512]{3,2,1,0}, filter: f32[16
EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kReduceWindow);
}
TEST_F(ConvolutionGroupConverterTest,
ConvertBatchGroupCountNotEqualToInputBatchDim) {
string hlo_string = R"(HloModule m
ENTRY main {
%input = f32[1,1,1,4] parameter(0)
%filter = f32[1,1,1,2] parameter(1)
ROOT %convolution = f32[1,1,2,2] convolution(%input,%filter),
window={size=1x1}, dim_labels=f01b_i01o->01fb, batch_group_count=2
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
auto computation = module->entry_computation();
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
auto cost_model = [](HloInstruction* conv) { return false; };
ConvolutionGroupConverter converter(cost_model, /*convert_batch_groups_only=*/
true);
// Make sure that batch group count is rewritten even if
// batch_group_count == output_feature but not input_batch
ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
}
} // namespace
} // namespace xla