Handle batch group conv - backward filter grouped conv

This commit is contained in:
amoitra 2020-05-11 17:21:23 -07:00
parent abaffb8ad1
commit cc4a431906
4 changed files with 84 additions and 74 deletions

View File

@ -411,6 +411,10 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
}
Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
if (is_cost_viable_(convolution)) {
return Status::OK();
}
if (convert_batch_groups_only_) {
return HandleBatchGroupCount(convolution);
}

View File

@ -190,6 +190,9 @@ Status ConvolutionVisitor::HandleBackwardFilterBatchGroupConvolution(
}
Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
if (is_cost_viable_(convolution)) {
return Status::OK();
}
return HandleBackwardFilterBatchGroupConvolution(convolution);
}

View File

@ -152,17 +152,12 @@ Status GpuCompiler::OptimizeHloModule(
pipeline.AddPass<Convolution4DExpander>();
auto cost_model = [](HloInstruction* conv) {
auto operand = conv->operand(0);
return operand->shape().dimensions(conv->convolution_dimension_numbers()
.input_batch_dimension()) ==
conv->batch_group_count();
};
auto cost_model = [](HloInstruction* conv) { return true; };
pipeline.AddPass<DepthwiseConvolutionConverter>(cost_model);
// We use the ConvolutionGroupConverter to convert backprops of filter
// grouped convolutions into non-grouped equivalents.
auto batch_group_cost_model = [](HloInstruction*) { return false; };
auto batch_group_cost_model = [](HloInstruction*) { return true; };
pipeline.AddPass<ConvolutionGroupConverter>(
batch_group_cost_model,

View File

@ -64,6 +64,62 @@ HloInstruction* CreateGpuConv(const char* call_target, const Shape& shape,
return custom_call;
}
HloInstruction* ConvertBatchGroupedToFeatureGroupedConvolution(HloInstruction* conv) {
CHECK_EQ(conv->feature_group_count(), 1);
int64 num_groups = conv->batch_group_count();
auto dim_numbers = conv->convolution_dimension_numbers();
auto lhs = conv->mutable_operand(0);
auto rhs = conv->mutable_operand(1);
int64 input_batch_dimension = dim_numbers.input_batch_dimension();
int64 input_batch = lhs->shape().dimensions(input_batch_dimension);
Shape output_shape = conv->shape();
int64 input_feature_dimension = dim_numbers.input_feature_dimension();
int64 input_feature = lhs->shape().dimensions(input_feature_dimension);
HloComputation* computation = lhs->parent();
auto add = [&](std::unique_ptr<HloInstruction> inst) {
return computation->AddInstruction(std::move(inst));
};
// Reshape batch_dim N -> [G, N/G]
std::vector<int64> reshape_dims = SpanToVector(lhs->shape().dimensions());
reshape_dims[input_batch_dimension] =
reshape_dims[input_batch_dimension] / num_groups;
reshape_dims.insert(reshape_dims.begin() + input_batch_dimension,
num_groups);
lhs = add(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), lhs));
// Transpose G to the axis before C, For eg: [G, N/G, H, W, C ] -> [N/G, H,
// W, G, C]
std::vector<int64> transpose_dims(lhs->shape().dimensions_size());
std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
transpose_dims.erase(transpose_dims.begin() + input_batch_dimension);
transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
input_batch_dimension);
std::vector<int64> transpose_reshape_dims =
ComposePermutations(lhs->shape().dimensions(), transpose_dims);
lhs = add(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(lhs->shape().element_type(),
transpose_reshape_dims),
lhs, transpose_dims));
// Merge [G,C] -> [C*G]
Shape new_shape = lhs->shape();
new_shape.DeleteDimension(input_feature_dimension);
new_shape.set_dimensions(input_feature_dimension,
input_feature * num_groups);
lhs = add(HloInstruction::CreateReshape(new_shape, lhs));
std::vector<HloInstruction*> new_operands = {lhs, rhs};
auto new_conv = conv->CloneWithNewOperands(output_shape, new_operands);
new_conv->set_feature_group_count(num_groups);
new_conv->set_batch_group_count(1);
new_conv->set_convolution_dimension_numbers(dim_numbers);
return computation->AddInstruction(std::move(new_conv));
}
bool CanImplementAsGpuForwardConv(HloInstruction* conv) {
const ConvolutionDimensionNumbers& dnums =
conv->convolution_dimension_numbers();
@ -91,9 +147,19 @@ bool CanImplementAsGpuForwardConv(HloInstruction* conv) {
// Precondition: "conv" is a kConvolution.
std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
MatchBackwardFilter(HloInstruction* conv) {
VLOG(2) << "Trying to match convolution backward filter.";
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
if (conv->feature_group_count() > 1) {
VLOG(1) << conv->ToString()
<< " is a forward convolution. All grouped backward filters are "
"mapped to batch grouped convolutions in tf2xla bridge. Hence backward filter "
"convolutions cannot have feature groups greater than 1 at this "
"point. No need to fold to backward filter.";
return no_match_result;
}
// Step 1: match the instruction pattern without considering the paddings and
// dimension numbers just yet. We may need some generic pattern matcher
// similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
@ -122,7 +188,6 @@ MatchBackwardFilter(HloInstruction* conv) {
auto output_batch_dim = conv_dnums.output_batch_dimension();
auto output_feature_dim = conv_dnums.output_feature_dimension();
auto output_spatial_dims = conv_dnums.output_spatial_dimensions();
for (const WindowDimension& window_dim : conv->window().dimensions()) {
if (window_dim.stride() != 1) {
VLOG(1) << "Forward convolution's window "
@ -150,16 +215,7 @@ MatchBackwardFilter(HloInstruction* conv) {
!window_util::HasWindowDilation(conv->window())) {
VLOG(1) << conv->ToString()
<< " is a regular forward convolution. No need "
"to fold it to a backward filter convolution.";
return no_match_result;
}
auto rhs_in =
conv->mutable_operand(1)->shape().dimensions(kernel_input_feature_dim);
if (conv->feature_group_count() > 1 && rhs_in == 1 &&
input_batch_dim == output_batch_dim) {
VLOG(1) << conv->ToString()
<< " is a depthwise forward convolution. No need to fold to "
"backward filter.";
"to fold it to a backward filter convolution....";
return no_match_result;
}
@ -256,67 +312,14 @@ MatchBackwardFilter(HloInstruction* conv) {
}
HloInstruction* lhs = conv->mutable_operand(0);
if (conv->feature_group_count() == 1) {
return std::make_tuple(true, backward_conv_window, backward_conv_dnums,
lhs);
}
int64 input_batch_dimension = backward_conv_dnums.input_batch_dimension();
int64 input_feature_dimension = backward_conv_dnums.input_feature_dimension();
int64 input_batch = lhs->shape().dimensions(input_batch_dimension);
int64 input_feature = lhs->shape().dimensions(input_feature_dimension);
// Reshape batch_dim G*N -> [G,N]
std::vector<int64> reshape_dims = SpanToVector(lhs->shape().dimensions());
auto num_groups = conv->feature_group_count();
CHECK_EQ(input_batch % num_groups, 0)
<< "Input batch should be an exact multiple of feature group count";
reshape_dims[input_batch_dimension] =
reshape_dims[input_batch_dimension] / num_groups;
reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups);
HloComputation* c = conv->parent();
HloInstruction* lhs_reshape_1 =
c->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims),
lhs));
// Transpose G to the axis before C/G, For eg: [G, N, C/G, H, W] -> [N, G,
// C/G, H, W]
std::vector<int64> transpose_dims(lhs_reshape_1->shape().dimensions_size());
std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
transpose_dims.erase(transpose_dims.begin() + input_batch_dimension);
transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
input_batch_dimension);
std::vector<int64> transpose_reshape_dims =
SpanToVector(lhs_reshape_1->shape().dimensions());
transpose_reshape_dims.erase(transpose_reshape_dims.begin() +
input_batch_dimension);
transpose_reshape_dims.insert(
transpose_reshape_dims.begin() + input_feature_dimension, num_groups);
HloInstruction* lhs_transpose =
c->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(lhs_reshape_1->shape().element_type(),
transpose_reshape_dims),
lhs_reshape_1, transpose_dims));
// Merge [G,C/G] -> [C]
Shape new_shape = lhs_transpose->shape();
new_shape.DeleteDimension(input_feature_dimension);
new_shape.set_dimensions(input_feature_dimension,
input_feature * conv->feature_group_count());
HloInstruction* lhs_reshape_2 = c->AddInstruction(
HloInstruction::CreateReshape(new_shape, lhs_transpose));
return std::make_tuple(true, backward_conv_window, backward_conv_dnums,
lhs_reshape_2);
return std::make_tuple(true, backward_conv_window, backward_conv_dnums, lhs);
}
// Try to match a backward input pattern that contains "conv".
// Precondition: "conv" is a kConvolution.
std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
MatchBackwardInput(HloInstruction* conv) {
VLOG(2) << "Trying to match convolution backward input.";
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
@ -639,9 +642,12 @@ static StatusOr<HloInstruction*> CreateCustomCallHelper(HloInstruction* conv) {
if (match) {
return CreateGpuConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), lhs,
conv->mutable_operand(1), window, dnums,
conv->feature_group_count(), conv->metadata());
conv->batch_group_count(), conv->metadata());
}
if (conv->batch_group_count() > 1) {
conv = ConvertBatchGroupedToFeatureGroupedConvolution(conv);
}
// If all else fails, try a forward convolution.
if (CanImplementAsGpuForwardConv(conv)) {
if (primitive_util::IsIntegralType(
@ -736,11 +742,13 @@ StatusOr<bool> RunOnComputation(HloComputation* computation) {
} // namespace
StatusOr<bool> GpuConvRewriter::Run(HloModule* module) {
XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), before:\n" + module->ToString());
bool changed = false;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
changed |= result;
}
XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), after:\n" + module->ToString());
return changed;
}