Speed up backprop grouped convolutions
PiperOrigin-RevId: 287216138 Change-Id: I275d92e286d9d5114dfe8d1ba614a5e63aa6b062
This commit is contained in:
parent
5078cab51c
commit
44310d275b
@ -68,21 +68,21 @@ def ConfigsToTest():
|
|||||||
Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
|
Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
|
||||||
convolution parameters.
|
convolution parameters.
|
||||||
"""
|
"""
|
||||||
input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 9, 27, 8],
|
input_sizes = [[4, 5, 5, 48], [2, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48],
|
||||||
[4, 31, 31, 7], [4, 35, 35, 2], [4, 147, 147, 2],
|
[4, 9, 27, 8], [4, 31, 31, 7], [4, 35, 35, 2],
|
||||||
[3, 299, 299, 3], [5, 183, 183, 1]]
|
[4, 147, 147, 2], [3, 299, 299, 3], [5, 183, 183, 1]]
|
||||||
filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [3, 3, 8, 1],
|
filter_sizes = [[1, 1, 48, 2], [2, 2, 48, 8], [1, 3, 84, 1], [3, 1, 48, 4],
|
||||||
[3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3,
|
[3, 3, 8, 1], [3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8],
|
||||||
8], [5, 5, 1, 2]]
|
[2, 2, 3, 8], [5, 5, 1, 2]]
|
||||||
out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 9, 27, 8],
|
out_sizes = [[4, 5, 5, 96], [2, 5, 5, 384], [4, 8, 8, 84], [4, 17, 17, 192],
|
||||||
[4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16],
|
[4, 9, 27, 8], [4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16],
|
||||||
[3, 150, 150, 24], [5, 92, 92, 2]]
|
[3, 150, 150, 24], [5, 92, 92, 2]]
|
||||||
strides = [1, 1, 1, 1, 1, 1, 3, 2, 2]
|
strides = [1, 1, 1, 1, 1, 1, 1, 3, 2, 2]
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
VALID = "VALID"
|
VALID = "VALID"
|
||||||
SAME = "SAME"
|
SAME = "SAME"
|
||||||
# pylint: enable=invalid-name
|
# pylint: enable=invalid-name
|
||||||
paddings = [SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME]
|
paddings = [SAME, SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME]
|
||||||
for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
|
for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
|
||||||
paddings):
|
paddings):
|
||||||
yield i, f, o, s, p
|
yield i, f, o, s, p
|
||||||
|
@ -512,22 +512,26 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
|||||||
filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
|
filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
|
||||||
feature_group_count = in_depth / filter_in_depth;
|
feature_group_count = in_depth / filter_in_depth;
|
||||||
|
|
||||||
|
// In the case of depthwise convolutions, the computation can be done by the
|
||||||
|
// batch_group_count parameter.
|
||||||
|
bool use_batch_group_count = in_depth > 1 && in_depth == filter_in_depth &&
|
||||||
|
(feature_group_count != 1 || attrs.depthwise);
|
||||||
|
|
||||||
|
if (use_batch_group_count) {
|
||||||
|
feature_group_count = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// The activations (inputs) form the LHS of the convolution.
|
// The activations (inputs) form the LHS of the convolution.
|
||||||
// Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
|
// Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
|
||||||
// For the gradient computation, we need to:
|
// For the gradient computation, we need to:
|
||||||
// 1. In the case of group convolution, move the num_groups dimension before
|
// 1. In the case of group convolution, move the num_groups dimension before
|
||||||
// the batch dimension
|
// the batch dimension
|
||||||
// 2. Swap the roles of the batch and feature dimensions.
|
// 2. Swap the roles of the batch and feature dimensions.
|
||||||
if (feature_group_count != 1 && !attrs.depthwise) {
|
if (!use_batch_group_count && feature_group_count != 1 && !attrs.depthwise) {
|
||||||
activations = TransposeInputForGroupConvolutionBackpropFilter(
|
activations = TransposeInputForGroupConvolutionBackpropFilter(
|
||||||
activations, input_shape, feature_group_count, n_dim, c_dim);
|
activations, input_shape, feature_group_count, n_dim, c_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
// In the case of depthwise convolution with no multiplier,
|
|
||||||
// the computation can be done by the batch_group_count parameter.
|
|
||||||
bool use_batch_group_count =
|
|
||||||
filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise;
|
|
||||||
|
|
||||||
std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
|
std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
|
||||||
std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
|
std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
|
||||||
std::vector<int64> window_strides(attrs.num_spatial_dims);
|
std::vector<int64> window_strides(attrs.num_spatial_dims);
|
||||||
|
@ -218,14 +218,127 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
|
|||||||
|
|
||||||
int64 input_batch_dimension = dim_numbers.input_batch_dimension();
|
int64 input_batch_dimension = dim_numbers.input_batch_dimension();
|
||||||
int64 output_batch_dimension = dim_numbers.output_batch_dimension();
|
int64 output_batch_dimension = dim_numbers.output_batch_dimension();
|
||||||
|
const int64 kernel_output_feature_dimension =
|
||||||
|
dim_numbers.kernel_output_feature_dimension();
|
||||||
int64 output_feature_dimension = dim_numbers.output_feature_dimension();
|
int64 output_feature_dimension = dim_numbers.output_feature_dimension();
|
||||||
|
|
||||||
int64 input_batch = activation->shape().dimensions(input_batch_dimension);
|
int64 input_batch = activation->shape().dimensions(input_batch_dimension);
|
||||||
|
|
||||||
|
const int64 output_feature =
|
||||||
|
filter->shape().dimensions(kernel_output_feature_dimension);
|
||||||
|
|
||||||
|
VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution);
|
||||||
|
const bool cost_too_high = !is_cost_viable_(convolution);
|
||||||
|
|
||||||
|
if (output_feature != batch_group_count) {
|
||||||
|
const int64 group_size = output_feature / batch_group_count;
|
||||||
|
|
||||||
|
VLOG(2) << "Need to insert a spatial dimension in activations and in the "
|
||||||
|
"kernel to deal with backprop of grouped convolutions "
|
||||||
|
<< " group size " << group_size;
|
||||||
|
|
||||||
|
// Add spatial dimension to the activation, and reshape.
|
||||||
|
Shape reshaped_activation_shape = activation->shape();
|
||||||
|
ShapeUtil::AppendMajorDimension(1, &reshaped_activation_shape);
|
||||||
|
const int64 new_spatial_dim =
|
||||||
|
reshaped_activation_shape.dimensions().size() - 1;
|
||||||
|
|
||||||
|
activation = add(
|
||||||
|
HloInstruction::CreateReshape(reshaped_activation_shape, activation));
|
||||||
|
|
||||||
|
// Insert new spatial dimension after the output feature dimension on the
|
||||||
|
// kernel.
|
||||||
|
auto dims = filter->shape().dimensions();
|
||||||
|
std::vector<int64> new_dims;
|
||||||
|
for (int i = 0; i < dims.size(); i++) {
|
||||||
|
if (i == kernel_output_feature_dimension) {
|
||||||
|
new_dims.push_back(batch_group_count);
|
||||||
|
new_dims.push_back(group_size);
|
||||||
|
} else {
|
||||||
|
new_dims.push_back(dims[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape reshaped_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||||
|
filter->shape().element_type(), new_dims);
|
||||||
|
|
||||||
|
filter = add(HloInstruction::CreateReshape(reshaped_filter_shape, filter));
|
||||||
|
|
||||||
|
Shape new_output_shape = convolution->shape();
|
||||||
|
ShapeUtil::AppendMajorDimension(1, &new_output_shape);
|
||||||
|
|
||||||
|
// Edit convolution dimension numbers. Note that kernel_input_feature_dim
|
||||||
|
// now becomes a spatial dimension, and the newly added dimension of size
|
||||||
|
// 1 is the new kernel_input_feature_dim.
|
||||||
|
dim_numbers.add_input_spatial_dimensions(new_spatial_dim);
|
||||||
|
|
||||||
|
// Update spatial dimension numbers if they show up after the newly added
|
||||||
|
// spatial dimension.
|
||||||
|
for (auto& d : *dim_numbers.mutable_kernel_spatial_dimensions()) {
|
||||||
|
if (d > kernel_output_feature_dimension) {
|
||||||
|
++d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same for input feature dimension.
|
||||||
|
if (dim_numbers.kernel_input_feature_dimension() >
|
||||||
|
kernel_output_feature_dimension) {
|
||||||
|
dim_numbers.set_kernel_input_feature_dimension(
|
||||||
|
dim_numbers.kernel_input_feature_dimension() + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
dim_numbers.add_kernel_spatial_dimensions(kernel_output_feature_dimension +
|
||||||
|
1);
|
||||||
|
|
||||||
|
dim_numbers.add_output_spatial_dimensions(output_batch_dimension);
|
||||||
|
|
||||||
|
dim_numbers.set_output_batch_dimension(new_spatial_dim);
|
||||||
|
|
||||||
|
// Add window for the new spatial dimension.
|
||||||
|
Window new_window = convolution->window();
|
||||||
|
auto* dim = new_window.add_dimensions();
|
||||||
|
dim->set_window_dilation(1);
|
||||||
|
dim->set_base_dilation(1);
|
||||||
|
dim->set_stride(1);
|
||||||
|
dim->set_size(group_size);
|
||||||
|
dim->set_padding_high(group_size - 1);
|
||||||
|
dim->set_padding_low(group_size - 1);
|
||||||
|
dim->set_window_reversal(false);
|
||||||
|
|
||||||
|
auto new_convolution = add(HloInstruction::CreateConvolve(
|
||||||
|
new_output_shape, activation, filter, /*feature_group_count=*/1,
|
||||||
|
batch_group_count, new_window, dim_numbers,
|
||||||
|
convolution->precision_config()));
|
||||||
|
|
||||||
|
VLOG(2) << "New convolution " << new_convolution->ToString();
|
||||||
|
|
||||||
|
// This reversal is not done via set_window_reversal because GPUs don't
|
||||||
|
// support it.
|
||||||
|
auto rev = add(HloInstruction::CreateReverse(
|
||||||
|
new_output_shape, new_convolution, {output_batch_dimension}));
|
||||||
|
|
||||||
|
// Delete the extra spatial dimension, and reshape.
|
||||||
|
Shape reshaped_convolution_shape =
|
||||||
|
ShapeUtil::DeleteDimension(new_spatial_dim, rev->shape());
|
||||||
|
auto reshaped_convolution =
|
||||||
|
HloInstruction::CreateReshape(reshaped_convolution_shape, rev);
|
||||||
|
|
||||||
|
VLOG(2) << "Reshaped convolution " << reshaped_convolution->ToString();
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
|
||||||
|
convolution, std::move(reshaped_convolution)));
|
||||||
|
|
||||||
|
changed_ = true;
|
||||||
|
|
||||||
|
convolution = new_convolution;
|
||||||
|
dim_numbers = convolution->convolution_dimension_numbers();
|
||||||
|
output_batch_dimension = new_spatial_dim;
|
||||||
|
}
|
||||||
|
|
||||||
// We are not yet supporting batch_group of sizes greater than 1.
|
// We are not yet supporting batch_group of sizes greater than 1.
|
||||||
TF_RET_CHECK(input_batch == batch_group_count);
|
TF_RET_CHECK(input_batch == batch_group_count);
|
||||||
|
|
||||||
if (!is_cost_viable_(convolution) || filter_expansion_) {
|
if (cost_too_high || filter_expansion_) {
|
||||||
// We first obtain the expanded the filter (which is the convolution
|
// We first obtain the expanded the filter (which is the convolution
|
||||||
// output). The batch dimension is the expanded one (which originally
|
// output). The batch dimension is the expanded one (which originally
|
||||||
// represents kernel input feature dimension). We mask the filter to zero
|
// represents kernel input feature dimension). We mask the filter to zero
|
||||||
@ -238,11 +351,17 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
|
|||||||
auto expanded_filter_shape = ExpandedFilterShape(
|
auto expanded_filter_shape = ExpandedFilterShape(
|
||||||
convolution->shape(), batch_group_count, output_batch_dimension);
|
convolution->shape(), batch_group_count, output_batch_dimension);
|
||||||
|
|
||||||
|
VLOG(2) << "output_batch_dimension " << output_batch_dimension;
|
||||||
|
VLOG(2) << "New output shape of convolution "
|
||||||
|
<< expanded_filter_shape.ToString();
|
||||||
|
|
||||||
auto new_convolution = add(HloInstruction::CreateConvolve(
|
auto new_convolution = add(HloInstruction::CreateConvolve(
|
||||||
expanded_filter_shape, activation, filter,
|
expanded_filter_shape, activation, filter,
|
||||||
/*feature_group_count=*/1, /*batch_group_count=*/1,
|
/*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||||
convolution->window(), dim_numbers, convolution->precision_config()));
|
convolution->window(), dim_numbers, convolution->precision_config()));
|
||||||
|
|
||||||
|
VLOG(2) << "Expanded convolution " << new_convolution->ToString();
|
||||||
|
|
||||||
auto zero = add(HloInstruction::CreateConstant(
|
auto zero = add(HloInstruction::CreateConstant(
|
||||||
LiteralUtil::Zero(expanded_filter_shape.element_type())));
|
LiteralUtil::Zero(expanded_filter_shape.element_type())));
|
||||||
auto zero_filter =
|
auto zero_filter =
|
||||||
@ -354,6 +473,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
|||||||
changed_ = false;
|
changed_ = false;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution);
|
||||||
// We want to repeat 'filter' in the 'input_feature_dim' dimension
|
// We want to repeat 'filter' in the 'input_feature_dim' dimension
|
||||||
// 'group_count' times.
|
// 'group_count' times.
|
||||||
if (!is_cost_viable_(convolution) || filter_expansion_) {
|
if (!is_cost_viable_(convolution) || filter_expansion_) {
|
||||||
|
@ -1116,6 +1116,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||||
"//tensorflow/compiler/xla/service:call_inliner",
|
"//tensorflow/compiler/xla/service:call_inliner",
|
||||||
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
||||||
|
"//tensorflow/compiler/xla/service:convolution_group_converter",
|
||||||
"//tensorflow/compiler/xla/service:depthwise_convolution_converter",
|
"//tensorflow/compiler/xla/service:depthwise_convolution_converter",
|
||||||
"//tensorflow/compiler/xla/service:dot_decomposer",
|
"//tensorflow/compiler/xla/service:dot_decomposer",
|
||||||
"//tensorflow/compiler/xla/service:dump",
|
"//tensorflow/compiler/xla/service:dump",
|
||||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
||||||
#include "tensorflow/compiler/xla/service/depthwise_convolution_converter.h"
|
#include "tensorflow/compiler/xla/service/depthwise_convolution_converter.h"
|
||||||
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
||||||
#include "tensorflow/compiler/xla/service/dump.h"
|
#include "tensorflow/compiler/xla/service/dump.h"
|
||||||
@ -138,11 +139,28 @@ Status GpuCompiler::OptimizeHloModule(
|
|||||||
|
|
||||||
// TODO(b/64094172): make Call work on GPU instead of inlining.
|
// TODO(b/64094172): make Call work on GPU instead of inlining.
|
||||||
pipeline.AddPass<CallInliner>();
|
pipeline.AddPass<CallInliner>();
|
||||||
|
|
||||||
|
pipeline.AddPass<DotDecomposer>();
|
||||||
|
|
||||||
|
// We use the ConvolutionGroupConverter to convert backprops of filter
|
||||||
|
// grouped convolutions into non-grouped equivalents.
|
||||||
|
auto batch_group_cost_model = [](HloInstruction* conv) {
|
||||||
|
auto dim_numbers = conv->convolution_dimension_numbers();
|
||||||
|
const int64 input_batch_size = conv->operand(0)->shape().dimensions(
|
||||||
|
dim_numbers.input_batch_dimension());
|
||||||
|
return conv->batch_group_count() != input_batch_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
pipeline.AddPass<ConvolutionGroupConverter>(
|
||||||
|
batch_group_cost_model,
|
||||||
|
/*convert_batch_groups_only=*/true,
|
||||||
|
/*canonicalize_depthwise_filter=*/false);
|
||||||
|
|
||||||
auto cost_model = [](HloInstruction* conv) {
|
auto cost_model = [](HloInstruction* conv) {
|
||||||
// We need a cost model for GPUs. Currently, do nothing.
|
// We need a cost model for GPUs. Currently, do nothing.
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
pipeline.AddPass<DotDecomposer>();
|
|
||||||
pipeline.AddPass<DepthwiseConvolutionConverter>(cost_model);
|
pipeline.AddPass<DepthwiseConvolutionConverter>(cost_model);
|
||||||
// Expand the sort op to support stable sorting if required.
|
// Expand the sort op to support stable sorting if required.
|
||||||
pipeline.AddPass<StableSortExpander>();
|
pipeline.AddPass<StableSortExpander>();
|
||||||
|
@ -1720,7 +1720,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
|||||||
const int64 kernel_output_features =
|
const int64 kernel_output_features =
|
||||||
rhs.dimensions(dnums.kernel_output_feature_dimension());
|
rhs.dimensions(dnums.kernel_output_feature_dimension());
|
||||||
|
|
||||||
if (batch_group_count > 1 && kernel_output_features != batch_group_count) {
|
if (batch_group_count > 1 &&
|
||||||
|
kernel_output_features % batch_group_count != 0) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Expected output feature dimension size (value %d) to be equal to "
|
"Expected output feature dimension size (value %d) to be equal to "
|
||||||
"batch group count %d; got <conv>(%s, %s)\n"
|
"batch group count %d; got <conv>(%s, %s)\n"
|
||||||
@ -1759,7 +1760,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
|||||||
dnums.DebugString());
|
dnums.DebugString());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (input_batch % batch_group_count > 0) {
|
if (input_batch % batch_group_count != 0) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Expected input batch dimension (value %d) to be divisible by "
|
"Expected input batch dimension (value %d) to be divisible by "
|
||||||
"batch_group_count (value %d); "
|
"batch_group_count (value %d); "
|
||||||
@ -1793,6 +1794,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
|||||||
std::vector<int64> dimensions(num_dims);
|
std::vector<int64> dimensions(num_dims);
|
||||||
dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count;
|
dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count;
|
||||||
dimensions[dnums.output_feature_dimension()] = kernel_output_features;
|
dimensions[dnums.output_feature_dimension()] = kernel_output_features;
|
||||||
|
|
||||||
|
if (batch_group_count > 1) {
|
||||||
|
dimensions[dnums.output_batch_dimension()] =
|
||||||
|
kernel_output_features / batch_group_count;
|
||||||
|
dimensions[dnums.output_feature_dimension()] = batch_group_count;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||||
dimensions[dnums.output_spatial_dimensions(i)] =
|
dimensions[dnums.output_spatial_dimensions(i)] =
|
||||||
window_output_shape.dimensions(i);
|
window_output_shape.dimensions(i);
|
||||||
|
Loading…
Reference in New Issue
Block a user