From dd9b975b03de035ac3ecbcf7113669f1e3c69a54 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 9 Sep 2019 20:04:26 -0700 Subject: [PATCH] Correctly handle grouped backprop conv conversion to depthwise convs. PiperOrigin-RevId: 268135215 --- .../service/convolution_group_converter.cc | 62 +++++++++++++++++-- 1 file changed, 58 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index 20ebafcf780..cfcf059ba5f 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/convolution_group_converter.h" +#include #include #include @@ -474,8 +475,6 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { new_convolution))); } } else { - int64 activation_input_feature_dim = dim_numbers.input_feature_dimension(); - int64 output_feature = filter->shape().dimensions(kernel_output_feature_dim); @@ -487,11 +486,62 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { // [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the // additional spatial dimension. The generated convolution output will be // [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}. - - if (group_count == output_feature && !filter_expansion_) { + // We only do this for b0..0f or f0..0b dimension labels on activations. + const int64 input_feature_dim = dim_numbers.input_feature_dimension(); + const int64 input_batch_dim = dim_numbers.input_batch_dimension(); + const int64 activations_dimension_count = + convolution->operand(0)->shape().dimensions().size(); + if (group_count == output_feature && !filter_expansion_ && + ((input_feature_dim == 0 && + input_batch_dim == activations_dimension_count - 1) || + (input_batch_dim == 0 && + input_feature_dim == activations_dimension_count - 1))) { auto filter = convolution->mutable_operand(1); auto activation = convolution->mutable_operand(0); + // We want b0..0f logical dimensions on activations. If they are f0..0b + // instead, we transpose the activations to have the right dimension + // ordering. + if (input_feature_dim < input_batch_dim) { + // Generate the required shape for activations by swapping batch and + // feature dimension sizes. + Shape new_act_shape = activation->shape(); + new_act_shape.set_dimensions(dim_numbers.input_feature_dimension(), + activation->shape().dimensions( + dim_numbers.input_batch_dimension())); + new_act_shape.set_dimensions( + dim_numbers.input_batch_dimension(), + activation->shape().dimensions( + dim_numbers.input_feature_dimension())); + + // Generate dimension mapping. + std::vector transpose_dims(new_act_shape.dimensions_size()); + std::iota(transpose_dims.begin(), transpose_dims.end(), 0); + std::iter_swap(transpose_dims.begin(), transpose_dims.end() - 1); + + // Transpose the activations. Change the convolution input. + auto transposed_activations = + computation_->AddInstruction(HloInstruction::CreateTranspose( + new_act_shape, activation, transpose_dims)); + TF_CHECK_OK(convolution->ReplaceOperandWithDifferentShape( + 0, transposed_activations)); + + const int64 old_feature_dim = dim_numbers.input_feature_dimension(); + const int64 old_batch_dim = dim_numbers.input_batch_dimension(); + + // Rectify the convolution dimension numbers. + dim_numbers.set_input_feature_dimension(old_batch_dim); + dim_numbers.set_input_batch_dimension(old_feature_dim); + convolution->set_convolution_dimension_numbers(dim_numbers); + + // Update the data structures we'd use. + dim_numbers = convolution->convolution_dimension_numbers(); + activation = convolution->mutable_operand(0); + } + + const int64 activation_input_feature_dim = + dim_numbers.input_feature_dimension(); + // Add spatial dimension to the activation, and reshape. Shape reshaped_activation_shape = activation->shape(); ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape); @@ -534,12 +584,16 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { /*batch_group_count=*/1, new_window, dim_numbers, convolution->precision_config())); + VLOG(2) << "New convolution " << new_convolution->ToString(); + // Delete the extra spatial dimension, and reshape. Shape reshaped_convolution_shape = ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape()); auto reshaped_convolution = HloInstruction::CreateReshape( reshaped_convolution_shape, new_convolution); + VLOG(2) << "Reshaped convolution " << reshaped_convolution->ToString(); + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( convolution, std::move(reshaped_convolution)));