[XLA] Fix the condition for rewriting batch group convolutions to include when the input batch is not equal to the batch group count.
PiperOrigin-RevId: 305581009 Change-Id: Ifb9d90a684e9de571fc6d4e03320a2f3438b36b5
This commit is contained in:
parent
1e2c8c6873
commit
287cacfb99
@ -225,10 +225,12 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
|
||||
const int64 kernel_output_feature_dimension =
|
||||
dim_numbers.kernel_output_feature_dimension();
|
||||
|
||||
const int64 input_batch =
|
||||
activation->shape().dimensions(input_batch_dimension);
|
||||
const int64 output_feature =
|
||||
filter->shape().dimensions(kernel_output_feature_dimension);
|
||||
|
||||
if (output_feature != batch_group_count) {
|
||||
if (output_feature != batch_group_count || input_batch != batch_group_count) {
|
||||
// Insert a spatial dimension to the activation before the input batch
|
||||
// dimension to represent the batch group.
|
||||
std::vector<int64> input_sizes(activation->shape().dimensions().begin(),
|
||||
|
||||
@ -119,5 +119,28 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[16,19,19,512]{3,2,1,0}, filter: f32[16
|
||||
EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kReduceWindow);
|
||||
}
|
||||
|
||||
TEST_F(ConvolutionGroupConverterTest,
|
||||
ConvertBatchGroupCountNotEqualToInputBatchDim) {
|
||||
string hlo_string = R"(HloModule m
|
||||
ENTRY main {
|
||||
%input = f32[1,1,1,4] parameter(0)
|
||||
%filter = f32[1,1,1,2] parameter(1)
|
||||
ROOT %convolution = f32[1,1,2,2] convolution(%input,%filter),
|
||||
window={size=1x1}, dim_labels=f01b_i01o->01fb, batch_group_count=2
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
auto computation = module->entry_computation();
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
|
||||
auto cost_model = [](HloInstruction* conv) { return false; };
|
||||
ConvolutionGroupConverter converter(cost_model, /*convert_batch_groups_only=*/
|
||||
true);
|
||||
// Make sure that batch group count is rewritten even if
|
||||
// batch_group_count == output_feature but not input_batch
|
||||
ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user