STT-tensorflow/tensorflow/compiler/xla/service/convolution_4d_expander.cc
Adrian Kuegel 682cef5c04 Convert 4D convolutions with trivial dimensions to lower-dimensional convolutions.
cuDNN cannot handle convolutions with 4 or more spatial dimensions. However we can
work around this in some cases by removing trivial dimensions.

PiperOrigin-RevId: 291894511
Change-Id: Ic0e3fa4f4181e105ca62f92a235a55413acd7253
2020-01-28 02:40:48 -08:00

176 lines
7.8 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/convolution_4d_expander.h"
#include <algorithm>
#include <functional>
#include <vector>
#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
bool Convolution4DExpander::InstructionMatchesPattern(
HloInstruction* instruction) {
if (instruction->opcode() != HloOpcode::kConvolution) {
return false;
}
// Check whether it is a 4D convolution and whether there is at least one
// trivial dimension.
const ConvolutionDimensionNumbers& dim_nums =
instruction->convolution_dimension_numbers();
if (dim_nums.input_spatial_dimensions().size() != 4) {
return false;
}
Shape input = instruction->operand(0)->shape();
for (int64 i = 0; i < dim_nums.input_spatial_dimensions().size(); ++i) {
int64 spatial_dim = dim_nums.input_spatial_dimensions(i);
if (input.dimensions(spatial_dim) == 1 &&
instruction->window().dimensions(i).padding_low() == 0 &&
instruction->window().dimensions(i).padding_high() == 0) {
return true;
}
}
return false;
}
StatusOr<HloInstruction*> Convolution4DExpander::ExpandInstruction(
HloInstruction* instruction) {
HloComputation* computation = instruction->parent();
ConvolutionDimensionNumbers dim_nums =
instruction->convolution_dimension_numbers();
ConvolutionDimensionNumbers new_dim_nums = dim_nums;
std::vector<int64> removed_input_dimensions;
std::vector<int64> removed_kernel_dimensions;
std::vector<int64> removed_output_dimensions;
new_dim_nums.clear_input_spatial_dimensions();
new_dim_nums.clear_output_spatial_dimensions();
new_dim_nums.clear_kernel_spatial_dimensions();
Window new_window;
HloInstruction* input = instruction->mutable_operand(0);
// Collect all trivial input spatial dimensions, and the corresponding
// dimensions of the kernel and the output. Those will be removed.
for (int64 i = 0; i < dim_nums.input_spatial_dimensions().size(); ++i) {
int64 input_spatial_dim = dim_nums.input_spatial_dimensions(i);
int64 output_spatial_dim = dim_nums.output_spatial_dimensions(i);
int64 kernel_spatial_dim = dim_nums.kernel_spatial_dimensions(i);
if (input->shape().dimensions(input_spatial_dim) == 1 &&
instruction->window().dimensions(i).padding_low() == 0 &&
instruction->window().dimensions(i).padding_high() == 0) {
removed_input_dimensions.push_back(input_spatial_dim);
removed_output_dimensions.push_back(output_spatial_dim);
removed_kernel_dimensions.push_back(kernel_spatial_dim);
} else {
*new_window.add_dimensions() = instruction->window().dimensions(i);
new_dim_nums.add_input_spatial_dimensions(input_spatial_dim);
new_dim_nums.add_output_spatial_dimensions(output_spatial_dim);
new_dim_nums.add_kernel_spatial_dimensions(kernel_spatial_dim);
}
}
// We sort the removed dimensions into descending order, because we need to
// delete higher dimensions first, otherwise we would have to adjust dimension
// indices.
std::sort(removed_input_dimensions.begin(), removed_input_dimensions.end(),
std::greater<>());
std::sort(removed_output_dimensions.begin(), removed_output_dimensions.end(),
std::greater<>());
std::sort(removed_kernel_dimensions.begin(), removed_kernel_dimensions.end(),
std::greater<>());
// Compute the new shapes.
Shape new_input_shape = input->shape();
for (int64 dim : removed_input_dimensions) {
new_input_shape.DeleteDimension(dim);
}
HloInstruction* kernel = instruction->mutable_operand(1);
Shape new_kernel_shape = kernel->shape();
for (int64 dim : removed_kernel_dimensions) {
new_kernel_shape.DeleteDimension(dim);
}
Shape new_output_shape = instruction->shape();
for (int64 dim : removed_output_dimensions) {
new_output_shape.DeleteDimension(dim);
}
// Relabel the dimension numbers to account for the deleted dimensions. For
// each dimension number, we need to reduce its value by the number of removed
// smaller dimensions.
auto compute_new_dimension = [](const std::vector<int64>& removed_dimensions,
int64 old_dimension) {
int64 num_smaller = absl::c_count_if(
removed_dimensions, [old_dimension](int64 removed_dimension) {
return removed_dimension < old_dimension;
});
return old_dimension - num_smaller;
};
new_dim_nums.set_input_batch_dimension(compute_new_dimension(
removed_input_dimensions, new_dim_nums.input_batch_dimension()));
new_dim_nums.set_input_feature_dimension(compute_new_dimension(
removed_input_dimensions, new_dim_nums.input_feature_dimension()));
for (int64 i = 0; i < new_dim_nums.input_spatial_dimensions().size(); ++i) {
new_dim_nums.set_input_spatial_dimensions(
i, compute_new_dimension(removed_input_dimensions,
new_dim_nums.input_spatial_dimensions(i)));
}
new_dim_nums.set_output_batch_dimension(compute_new_dimension(
removed_output_dimensions, new_dim_nums.output_batch_dimension()));
new_dim_nums.set_output_feature_dimension(compute_new_dimension(
removed_output_dimensions, new_dim_nums.output_feature_dimension()));
for (int64 i = 0; i < new_dim_nums.output_spatial_dimensions().size(); ++i) {
new_dim_nums.set_output_spatial_dimensions(
i, compute_new_dimension(removed_output_dimensions,
new_dim_nums.output_spatial_dimensions(i)));
}
new_dim_nums.set_kernel_input_feature_dimension(
compute_new_dimension(removed_kernel_dimensions,
new_dim_nums.kernel_input_feature_dimension()));
new_dim_nums.set_kernel_output_feature_dimension(
compute_new_dimension(removed_kernel_dimensions,
new_dim_nums.kernel_output_feature_dimension()));
for (int64 i = 0; i < new_dim_nums.kernel_spatial_dimensions().size(); ++i) {
new_dim_nums.set_kernel_spatial_dimensions(
i, compute_new_dimension(removed_kernel_dimensions,
new_dim_nums.kernel_spatial_dimensions(i)));
}
// Reshape the input and the kernel.
HloInstruction* reshaped_input = computation->AddInstruction(
HloInstruction::CreateReshape(new_input_shape, input));
HloInstruction* reshaped_kernel = computation->AddInstruction(
HloInstruction::CreateReshape(new_kernel_shape, kernel));
// We want to use CloneWithNewOperands, but that doesn't support substituting
// the window and the ConvolutionDimensionNumbers. So we set this on the old
// instruction (which is going to be removed anyway) before cloning it.
instruction->set_convolution_dimension_numbers(new_dim_nums);
instruction->set_window(new_window);
HloInstruction* new_convolution =
computation->AddInstruction(instruction->CloneWithNewOperands(
new_output_shape, {reshaped_input, reshaped_kernel}));
return computation->AddInstruction(
HloInstruction::CreateReshape(instruction->shape(), new_convolution));
}
} // namespace xla