Correctly handle grouped backprop conv conversion to depthwise convs.

PiperOrigin-RevId: 268135215
This commit is contained in:
A. Unique TensorFlower 2019-09-09 20:04:26 -07:00 committed by TensorFlower Gardener
parent ae0c14b8fd
commit dd9b975b03

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
#include <algorithm>
#include <memory> #include <memory>
#include <vector> #include <vector>
@ -474,8 +475,6 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
new_convolution))); new_convolution)));
} }
} else { } else {
int64 activation_input_feature_dim = dim_numbers.input_feature_dimension();
int64 output_feature = int64 output_feature =
filter->shape().dimensions(kernel_output_feature_dim); filter->shape().dimensions(kernel_output_feature_dim);
@ -487,11 +486,62 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
// [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the // [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the
// additional spatial dimension. The generated convolution output will be // additional spatial dimension. The generated convolution output will be
// [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}. // [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}.
// We only do this for b0..0f or f0..0b dimension labels on activations.
if (group_count == output_feature && !filter_expansion_) { const int64 input_feature_dim = dim_numbers.input_feature_dimension();
const int64 input_batch_dim = dim_numbers.input_batch_dimension();
const int64 activations_dimension_count =
convolution->operand(0)->shape().dimensions().size();
if (group_count == output_feature && !filter_expansion_ &&
((input_feature_dim == 0 &&
input_batch_dim == activations_dimension_count - 1) ||
(input_batch_dim == 0 &&
input_feature_dim == activations_dimension_count - 1))) {
auto filter = convolution->mutable_operand(1); auto filter = convolution->mutable_operand(1);
auto activation = convolution->mutable_operand(0); auto activation = convolution->mutable_operand(0);
// We want b0..0f logical dimensions on activations. If they are f0..0b
// instead, we transpose the activations to have the right dimension
// ordering.
if (input_feature_dim < input_batch_dim) {
// Generate the required shape for activations by swapping batch and
// feature dimension sizes.
Shape new_act_shape = activation->shape();
new_act_shape.set_dimensions(dim_numbers.input_feature_dimension(),
activation->shape().dimensions(
dim_numbers.input_batch_dimension()));
new_act_shape.set_dimensions(
dim_numbers.input_batch_dimension(),
activation->shape().dimensions(
dim_numbers.input_feature_dimension()));
// Generate dimension mapping.
std::vector<int64> transpose_dims(new_act_shape.dimensions_size());
std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
std::iter_swap(transpose_dims.begin(), transpose_dims.end() - 1);
// Transpose the activations. Change the convolution input.
auto transposed_activations =
computation_->AddInstruction(HloInstruction::CreateTranspose(
new_act_shape, activation, transpose_dims));
TF_CHECK_OK(convolution->ReplaceOperandWithDifferentShape(
0, transposed_activations));
const int64 old_feature_dim = dim_numbers.input_feature_dimension();
const int64 old_batch_dim = dim_numbers.input_batch_dimension();
// Rectify the convolution dimension numbers.
dim_numbers.set_input_feature_dimension(old_batch_dim);
dim_numbers.set_input_batch_dimension(old_feature_dim);
convolution->set_convolution_dimension_numbers(dim_numbers);
// Update the data structures we'd use.
dim_numbers = convolution->convolution_dimension_numbers();
activation = convolution->mutable_operand(0);
}
const int64 activation_input_feature_dim =
dim_numbers.input_feature_dimension();
// Add spatial dimension to the activation, and reshape. // Add spatial dimension to the activation, and reshape.
Shape reshaped_activation_shape = activation->shape(); Shape reshaped_activation_shape = activation->shape();
ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape); ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape);
@ -534,12 +584,16 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
/*batch_group_count=*/1, new_window, dim_numbers, /*batch_group_count=*/1, new_window, dim_numbers,
convolution->precision_config())); convolution->precision_config()));
VLOG(2) << "New convolution " << new_convolution->ToString();
// Delete the extra spatial dimension, and reshape. // Delete the extra spatial dimension, and reshape.
Shape reshaped_convolution_shape = Shape reshaped_convolution_shape =
ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape()); ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape());
auto reshaped_convolution = HloInstruction::CreateReshape( auto reshaped_convolution = HloInstruction::CreateReshape(
reshaped_convolution_shape, new_convolution); reshaped_convolution_shape, new_convolution);
VLOG(2) << "Reshaped convolution " << reshaped_convolution->ToString();
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
convolution, std::move(reshaped_convolution))); convolution, std::move(reshaped_convolution)));