Correctly handle grouped backprop conv conversion to depthwise convs.
PiperOrigin-RevId: 268135215
This commit is contained in:
parent
ae0c14b8fd
commit
dd9b975b03
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -474,8 +475,6 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
|||||||
new_convolution)));
|
new_convolution)));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int64 activation_input_feature_dim = dim_numbers.input_feature_dimension();
|
|
||||||
|
|
||||||
int64 output_feature =
|
int64 output_feature =
|
||||||
filter->shape().dimensions(kernel_output_feature_dim);
|
filter->shape().dimensions(kernel_output_feature_dim);
|
||||||
|
|
||||||
@ -487,11 +486,62 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
|||||||
// [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the
|
// [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the
|
||||||
// additional spatial dimension. The generated convolution output will be
|
// additional spatial dimension. The generated convolution output will be
|
||||||
// [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}.
|
// [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}.
|
||||||
|
// We only do this for b0..0f or f0..0b dimension labels on activations.
|
||||||
if (group_count == output_feature && !filter_expansion_) {
|
const int64 input_feature_dim = dim_numbers.input_feature_dimension();
|
||||||
|
const int64 input_batch_dim = dim_numbers.input_batch_dimension();
|
||||||
|
const int64 activations_dimension_count =
|
||||||
|
convolution->operand(0)->shape().dimensions().size();
|
||||||
|
if (group_count == output_feature && !filter_expansion_ &&
|
||||||
|
((input_feature_dim == 0 &&
|
||||||
|
input_batch_dim == activations_dimension_count - 1) ||
|
||||||
|
(input_batch_dim == 0 &&
|
||||||
|
input_feature_dim == activations_dimension_count - 1))) {
|
||||||
auto filter = convolution->mutable_operand(1);
|
auto filter = convolution->mutable_operand(1);
|
||||||
auto activation = convolution->mutable_operand(0);
|
auto activation = convolution->mutable_operand(0);
|
||||||
|
|
||||||
|
// We want b0..0f logical dimensions on activations. If they are f0..0b
|
||||||
|
// instead, we transpose the activations to have the right dimension
|
||||||
|
// ordering.
|
||||||
|
if (input_feature_dim < input_batch_dim) {
|
||||||
|
// Generate the required shape for activations by swapping batch and
|
||||||
|
// feature dimension sizes.
|
||||||
|
Shape new_act_shape = activation->shape();
|
||||||
|
new_act_shape.set_dimensions(dim_numbers.input_feature_dimension(),
|
||||||
|
activation->shape().dimensions(
|
||||||
|
dim_numbers.input_batch_dimension()));
|
||||||
|
new_act_shape.set_dimensions(
|
||||||
|
dim_numbers.input_batch_dimension(),
|
||||||
|
activation->shape().dimensions(
|
||||||
|
dim_numbers.input_feature_dimension()));
|
||||||
|
|
||||||
|
// Generate dimension mapping.
|
||||||
|
std::vector<int64> transpose_dims(new_act_shape.dimensions_size());
|
||||||
|
std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
|
||||||
|
std::iter_swap(transpose_dims.begin(), transpose_dims.end() - 1);
|
||||||
|
|
||||||
|
// Transpose the activations. Change the convolution input.
|
||||||
|
auto transposed_activations =
|
||||||
|
computation_->AddInstruction(HloInstruction::CreateTranspose(
|
||||||
|
new_act_shape, activation, transpose_dims));
|
||||||
|
TF_CHECK_OK(convolution->ReplaceOperandWithDifferentShape(
|
||||||
|
0, transposed_activations));
|
||||||
|
|
||||||
|
const int64 old_feature_dim = dim_numbers.input_feature_dimension();
|
||||||
|
const int64 old_batch_dim = dim_numbers.input_batch_dimension();
|
||||||
|
|
||||||
|
// Rectify the convolution dimension numbers.
|
||||||
|
dim_numbers.set_input_feature_dimension(old_batch_dim);
|
||||||
|
dim_numbers.set_input_batch_dimension(old_feature_dim);
|
||||||
|
convolution->set_convolution_dimension_numbers(dim_numbers);
|
||||||
|
|
||||||
|
// Update the data structures we'd use.
|
||||||
|
dim_numbers = convolution->convolution_dimension_numbers();
|
||||||
|
activation = convolution->mutable_operand(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64 activation_input_feature_dim =
|
||||||
|
dim_numbers.input_feature_dimension();
|
||||||
|
|
||||||
// Add spatial dimension to the activation, and reshape.
|
// Add spatial dimension to the activation, and reshape.
|
||||||
Shape reshaped_activation_shape = activation->shape();
|
Shape reshaped_activation_shape = activation->shape();
|
||||||
ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape);
|
ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape);
|
||||||
@ -534,12 +584,16 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
|||||||
/*batch_group_count=*/1, new_window, dim_numbers,
|
/*batch_group_count=*/1, new_window, dim_numbers,
|
||||||
convolution->precision_config()));
|
convolution->precision_config()));
|
||||||
|
|
||||||
|
VLOG(2) << "New convolution " << new_convolution->ToString();
|
||||||
|
|
||||||
// Delete the extra spatial dimension, and reshape.
|
// Delete the extra spatial dimension, and reshape.
|
||||||
Shape reshaped_convolution_shape =
|
Shape reshaped_convolution_shape =
|
||||||
ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape());
|
ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape());
|
||||||
auto reshaped_convolution = HloInstruction::CreateReshape(
|
auto reshaped_convolution = HloInstruction::CreateReshape(
|
||||||
reshaped_convolution_shape, new_convolution);
|
reshaped_convolution_shape, new_convolution);
|
||||||
|
|
||||||
|
VLOG(2) << "Reshaped convolution " << reshaped_convolution->ToString();
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
|
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
|
||||||
convolution, std::move(reshaped_convolution)));
|
convolution, std::move(reshaped_convolution)));
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user