Speed up backprop grouped convolutions
PiperOrigin-RevId: 287216138 Change-Id: I275d92e286d9d5114dfe8d1ba614a5e63aa6b062
This commit is contained in:
parent
5078cab51c
commit
44310d275b
@ -68,21 +68,21 @@ def ConfigsToTest():
|
||||
Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
|
||||
convolution parameters.
|
||||
"""
|
||||
input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 9, 27, 8],
|
||||
[4, 31, 31, 7], [4, 35, 35, 2], [4, 147, 147, 2],
|
||||
[3, 299, 299, 3], [5, 183, 183, 1]]
|
||||
filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [3, 3, 8, 1],
|
||||
[3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3,
|
||||
8], [5, 5, 1, 2]]
|
||||
out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 9, 27, 8],
|
||||
[4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16],
|
||||
input_sizes = [[4, 5, 5, 48], [2, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48],
|
||||
[4, 9, 27, 8], [4, 31, 31, 7], [4, 35, 35, 2],
|
||||
[4, 147, 147, 2], [3, 299, 299, 3], [5, 183, 183, 1]]
|
||||
filter_sizes = [[1, 1, 48, 2], [2, 2, 48, 8], [1, 3, 84, 1], [3, 1, 48, 4],
|
||||
[3, 3, 8, 1], [3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8],
|
||||
[2, 2, 3, 8], [5, 5, 1, 2]]
|
||||
out_sizes = [[4, 5, 5, 96], [2, 5, 5, 384], [4, 8, 8, 84], [4, 17, 17, 192],
|
||||
[4, 9, 27, 8], [4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16],
|
||||
[3, 150, 150, 24], [5, 92, 92, 2]]
|
||||
strides = [1, 1, 1, 1, 1, 1, 3, 2, 2]
|
||||
strides = [1, 1, 1, 1, 1, 1, 1, 3, 2, 2]
|
||||
# pylint: disable=invalid-name
|
||||
VALID = "VALID"
|
||||
SAME = "SAME"
|
||||
# pylint: enable=invalid-name
|
||||
paddings = [SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME]
|
||||
paddings = [SAME, SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME]
|
||||
for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
|
||||
paddings):
|
||||
yield i, f, o, s, p
|
||||
|
@ -512,22 +512,26 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
||||
filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
|
||||
feature_group_count = in_depth / filter_in_depth;
|
||||
|
||||
// In the case of depthwise convolutions, the computation can be done by the
|
||||
// batch_group_count parameter.
|
||||
bool use_batch_group_count = in_depth > 1 && in_depth == filter_in_depth &&
|
||||
(feature_group_count != 1 || attrs.depthwise);
|
||||
|
||||
if (use_batch_group_count) {
|
||||
feature_group_count = 1;
|
||||
}
|
||||
|
||||
// The activations (inputs) form the LHS of the convolution.
|
||||
// Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
|
||||
// For the gradient computation, we need to:
|
||||
// 1. In the case of group convolution, move the num_groups dimension before
|
||||
// the batch dimension
|
||||
// 2. Swap the roles of the batch and feature dimensions.
|
||||
if (feature_group_count != 1 && !attrs.depthwise) {
|
||||
if (!use_batch_group_count && feature_group_count != 1 && !attrs.depthwise) {
|
||||
activations = TransposeInputForGroupConvolutionBackpropFilter(
|
||||
activations, input_shape, feature_group_count, n_dim, c_dim);
|
||||
}
|
||||
|
||||
// In the case of depthwise convolution with no multiplier,
|
||||
// the computation can be done by the batch_group_count parameter.
|
||||
bool use_batch_group_count =
|
||||
filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise;
|
||||
|
||||
std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
|
||||
std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
|
||||
std::vector<int64> window_strides(attrs.num_spatial_dims);
|
||||
|
@ -218,14 +218,127 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
|
||||
|
||||
int64 input_batch_dimension = dim_numbers.input_batch_dimension();
|
||||
int64 output_batch_dimension = dim_numbers.output_batch_dimension();
|
||||
const int64 kernel_output_feature_dimension =
|
||||
dim_numbers.kernel_output_feature_dimension();
|
||||
int64 output_feature_dimension = dim_numbers.output_feature_dimension();
|
||||
|
||||
int64 input_batch = activation->shape().dimensions(input_batch_dimension);
|
||||
|
||||
const int64 output_feature =
|
||||
filter->shape().dimensions(kernel_output_feature_dimension);
|
||||
|
||||
VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution);
|
||||
const bool cost_too_high = !is_cost_viable_(convolution);
|
||||
|
||||
if (output_feature != batch_group_count) {
|
||||
const int64 group_size = output_feature / batch_group_count;
|
||||
|
||||
VLOG(2) << "Need to insert a spatial dimension in activations and in the "
|
||||
"kernel to deal with backprop of grouped convolutions "
|
||||
<< " group size " << group_size;
|
||||
|
||||
// Add spatial dimension to the activation, and reshape.
|
||||
Shape reshaped_activation_shape = activation->shape();
|
||||
ShapeUtil::AppendMajorDimension(1, &reshaped_activation_shape);
|
||||
const int64 new_spatial_dim =
|
||||
reshaped_activation_shape.dimensions().size() - 1;
|
||||
|
||||
activation = add(
|
||||
HloInstruction::CreateReshape(reshaped_activation_shape, activation));
|
||||
|
||||
// Insert new spatial dimension after the output feature dimension on the
|
||||
// kernel.
|
||||
auto dims = filter->shape().dimensions();
|
||||
std::vector<int64> new_dims;
|
||||
for (int i = 0; i < dims.size(); i++) {
|
||||
if (i == kernel_output_feature_dimension) {
|
||||
new_dims.push_back(batch_group_count);
|
||||
new_dims.push_back(group_size);
|
||||
} else {
|
||||
new_dims.push_back(dims[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Shape reshaped_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
filter->shape().element_type(), new_dims);
|
||||
|
||||
filter = add(HloInstruction::CreateReshape(reshaped_filter_shape, filter));
|
||||
|
||||
Shape new_output_shape = convolution->shape();
|
||||
ShapeUtil::AppendMajorDimension(1, &new_output_shape);
|
||||
|
||||
// Edit convolution dimension numbers. Note that kernel_input_feature_dim
|
||||
// now becomes a spatial dimension, and the newly added dimension of size
|
||||
// 1 is the new kernel_input_feature_dim.
|
||||
dim_numbers.add_input_spatial_dimensions(new_spatial_dim);
|
||||
|
||||
// Update spatial dimension numbers if they show up after the newly added
|
||||
// spatial dimension.
|
||||
for (auto& d : *dim_numbers.mutable_kernel_spatial_dimensions()) {
|
||||
if (d > kernel_output_feature_dimension) {
|
||||
++d;
|
||||
}
|
||||
}
|
||||
|
||||
// Same for input feature dimension.
|
||||
if (dim_numbers.kernel_input_feature_dimension() >
|
||||
kernel_output_feature_dimension) {
|
||||
dim_numbers.set_kernel_input_feature_dimension(
|
||||
dim_numbers.kernel_input_feature_dimension() + 1);
|
||||
}
|
||||
|
||||
dim_numbers.add_kernel_spatial_dimensions(kernel_output_feature_dimension +
|
||||
1);
|
||||
|
||||
dim_numbers.add_output_spatial_dimensions(output_batch_dimension);
|
||||
|
||||
dim_numbers.set_output_batch_dimension(new_spatial_dim);
|
||||
|
||||
// Add window for the new spatial dimension.
|
||||
Window new_window = convolution->window();
|
||||
auto* dim = new_window.add_dimensions();
|
||||
dim->set_window_dilation(1);
|
||||
dim->set_base_dilation(1);
|
||||
dim->set_stride(1);
|
||||
dim->set_size(group_size);
|
||||
dim->set_padding_high(group_size - 1);
|
||||
dim->set_padding_low(group_size - 1);
|
||||
dim->set_window_reversal(false);
|
||||
|
||||
auto new_convolution = add(HloInstruction::CreateConvolve(
|
||||
new_output_shape, activation, filter, /*feature_group_count=*/1,
|
||||
batch_group_count, new_window, dim_numbers,
|
||||
convolution->precision_config()));
|
||||
|
||||
VLOG(2) << "New convolution " << new_convolution->ToString();
|
||||
|
||||
// This reversal is not done via set_window_reversal because GPUs don't
|
||||
// support it.
|
||||
auto rev = add(HloInstruction::CreateReverse(
|
||||
new_output_shape, new_convolution, {output_batch_dimension}));
|
||||
|
||||
// Delete the extra spatial dimension, and reshape.
|
||||
Shape reshaped_convolution_shape =
|
||||
ShapeUtil::DeleteDimension(new_spatial_dim, rev->shape());
|
||||
auto reshaped_convolution =
|
||||
HloInstruction::CreateReshape(reshaped_convolution_shape, rev);
|
||||
|
||||
VLOG(2) << "Reshaped convolution " << reshaped_convolution->ToString();
|
||||
|
||||
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
|
||||
convolution, std::move(reshaped_convolution)));
|
||||
|
||||
changed_ = true;
|
||||
|
||||
convolution = new_convolution;
|
||||
dim_numbers = convolution->convolution_dimension_numbers();
|
||||
output_batch_dimension = new_spatial_dim;
|
||||
}
|
||||
|
||||
// We are not yet supporting batch_group of sizes greater than 1.
|
||||
TF_RET_CHECK(input_batch == batch_group_count);
|
||||
|
||||
if (!is_cost_viable_(convolution) || filter_expansion_) {
|
||||
if (cost_too_high || filter_expansion_) {
|
||||
// We first obtain the expanded the filter (which is the convolution
|
||||
// output). The batch dimension is the expanded one (which originally
|
||||
// represents kernel input feature dimension). We mask the filter to zero
|
||||
@ -238,11 +351,17 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
|
||||
auto expanded_filter_shape = ExpandedFilterShape(
|
||||
convolution->shape(), batch_group_count, output_batch_dimension);
|
||||
|
||||
VLOG(2) << "output_batch_dimension " << output_batch_dimension;
|
||||
VLOG(2) << "New output shape of convolution "
|
||||
<< expanded_filter_shape.ToString();
|
||||
|
||||
auto new_convolution = add(HloInstruction::CreateConvolve(
|
||||
expanded_filter_shape, activation, filter,
|
||||
/*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||
convolution->window(), dim_numbers, convolution->precision_config()));
|
||||
|
||||
VLOG(2) << "Expanded convolution " << new_convolution->ToString();
|
||||
|
||||
auto zero = add(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(expanded_filter_shape.element_type())));
|
||||
auto zero_filter =
|
||||
@ -354,6 +473,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
||||
changed_ = false;
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution);
|
||||
// We want to repeat 'filter' in the 'input_feature_dim' dimension
|
||||
// 'group_count' times.
|
||||
if (!is_cost_viable_(convolution) || filter_expansion_) {
|
||||
|
@ -1116,6 +1116,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:call_inliner",
|
||||
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
||||
"//tensorflow/compiler/xla/service:convolution_group_converter",
|
||||
"//tensorflow/compiler/xla/service:depthwise_convolution_converter",
|
||||
"//tensorflow/compiler/xla/service:dot_decomposer",
|
||||
"//tensorflow/compiler/xla/service:dump",
|
||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
||||
#include "tensorflow/compiler/xla/service/depthwise_convolution_converter.h"
|
||||
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
||||
#include "tensorflow/compiler/xla/service/dump.h"
|
||||
@ -138,11 +139,28 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
|
||||
// TODO(b/64094172): make Call work on GPU instead of inlining.
|
||||
pipeline.AddPass<CallInliner>();
|
||||
|
||||
pipeline.AddPass<DotDecomposer>();
|
||||
|
||||
// We use the ConvolutionGroupConverter to convert backprops of filter
|
||||
// grouped convolutions into non-grouped equivalents.
|
||||
auto batch_group_cost_model = [](HloInstruction* conv) {
|
||||
auto dim_numbers = conv->convolution_dimension_numbers();
|
||||
const int64 input_batch_size = conv->operand(0)->shape().dimensions(
|
||||
dim_numbers.input_batch_dimension());
|
||||
return conv->batch_group_count() != input_batch_size;
|
||||
};
|
||||
|
||||
pipeline.AddPass<ConvolutionGroupConverter>(
|
||||
batch_group_cost_model,
|
||||
/*convert_batch_groups_only=*/true,
|
||||
/*canonicalize_depthwise_filter=*/false);
|
||||
|
||||
auto cost_model = [](HloInstruction* conv) {
|
||||
// We need a cost model for GPUs. Currently, do nothing.
|
||||
return false;
|
||||
};
|
||||
pipeline.AddPass<DotDecomposer>();
|
||||
|
||||
pipeline.AddPass<DepthwiseConvolutionConverter>(cost_model);
|
||||
// Expand the sort op to support stable sorting if required.
|
||||
pipeline.AddPass<StableSortExpander>();
|
||||
|
@ -1720,7 +1720,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
const int64 kernel_output_features =
|
||||
rhs.dimensions(dnums.kernel_output_feature_dimension());
|
||||
|
||||
if (batch_group_count > 1 && kernel_output_features != batch_group_count) {
|
||||
if (batch_group_count > 1 &&
|
||||
kernel_output_features % batch_group_count != 0) {
|
||||
return InvalidArgument(
|
||||
"Expected output feature dimension size (value %d) to be equal to "
|
||||
"batch group count %d; got <conv>(%s, %s)\n"
|
||||
@ -1759,7 +1760,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
dnums.DebugString());
|
||||
}
|
||||
|
||||
if (input_batch % batch_group_count > 0) {
|
||||
if (input_batch % batch_group_count != 0) {
|
||||
return InvalidArgument(
|
||||
"Expected input batch dimension (value %d) to be divisible by "
|
||||
"batch_group_count (value %d); "
|
||||
@ -1793,6 +1794,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
std::vector<int64> dimensions(num_dims);
|
||||
dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count;
|
||||
dimensions[dnums.output_feature_dimension()] = kernel_output_features;
|
||||
|
||||
if (batch_group_count > 1) {
|
||||
dimensions[dnums.output_batch_dimension()] =
|
||||
kernel_output_features / batch_group_count;
|
||||
dimensions[dnums.output_feature_dimension()] = batch_group_count;
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
dimensions[dnums.output_spatial_dimensions(i)] =
|
||||
window_output_shape.dimensions(i);
|
||||
|
Loading…
Reference in New Issue
Block a user