From cc4a4319066374fec1bbb79238bff1076a79534c Mon Sep 17 00:00:00 2001 From: amoitra Date: Mon, 11 May 2020 17:21:23 -0700 Subject: [PATCH] Handle batch group conv - backward filter grouped conv --- .../service/convolution_group_converter.cc | 4 + .../depthwise_convolution_converter.cc | 3 + .../compiler/xla/service/gpu/gpu_compiler.cc | 9 +- .../xla/service/gpu/gpu_conv_rewriter.cc | 142 +++++++++--------- 4 files changed, 84 insertions(+), 74 deletions(-) diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index 323bf44dcd3..254d2ea2429 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -411,6 +411,10 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { } Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { + if (is_cost_viable_(convolution)) { + return Status::OK(); + } + if (convert_batch_groups_only_) { return HandleBatchGroupCount(convolution); } diff --git a/tensorflow/compiler/xla/service/depthwise_convolution_converter.cc b/tensorflow/compiler/xla/service/depthwise_convolution_converter.cc index ad4d8118835..7d72707380a 100644 --- a/tensorflow/compiler/xla/service/depthwise_convolution_converter.cc +++ b/tensorflow/compiler/xla/service/depthwise_convolution_converter.cc @@ -190,6 +190,9 @@ Status ConvolutionVisitor::HandleBackwardFilterBatchGroupConvolution( } Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { + if (is_cost_viable_(convolution)) { + return Status::OK(); + } return HandleBackwardFilterBatchGroupConvolution(convolution); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 5f6dfd7d3a5..0b587a762a6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -152,17 +152,12 @@ Status GpuCompiler::OptimizeHloModule( pipeline.AddPass(); - auto cost_model = [](HloInstruction* conv) { - auto operand = conv->operand(0); - return operand->shape().dimensions(conv->convolution_dimension_numbers() - .input_batch_dimension()) == - conv->batch_group_count(); - }; + auto cost_model = [](HloInstruction* conv) { return true; }; pipeline.AddPass(cost_model); // We use the ConvolutionGroupConverter to convert backprops of filter // grouped convolutions into non-grouped equivalents. - auto batch_group_cost_model = [](HloInstruction*) { return false; }; + auto batch_group_cost_model = [](HloInstruction*) { return true; }; pipeline.AddPass( batch_group_cost_model, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc index 4a4448f668c..3856f1cd150 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc @@ -64,6 +64,62 @@ HloInstruction* CreateGpuConv(const char* call_target, const Shape& shape, return custom_call; } +HloInstruction* ConvertBatchGroupedToFeatureGroupedConvolution(HloInstruction* conv) { + CHECK_EQ(conv->feature_group_count(), 1); + int64 num_groups = conv->batch_group_count(); + auto dim_numbers = conv->convolution_dimension_numbers(); + auto lhs = conv->mutable_operand(0); + auto rhs = conv->mutable_operand(1); + + int64 input_batch_dimension = dim_numbers.input_batch_dimension(); + int64 input_batch = lhs->shape().dimensions(input_batch_dimension); + + Shape output_shape = conv->shape(); + int64 input_feature_dimension = dim_numbers.input_feature_dimension(); + int64 input_feature = lhs->shape().dimensions(input_feature_dimension); + + HloComputation* computation = lhs->parent(); + auto add = [&](std::unique_ptr inst) { + return computation->AddInstruction(std::move(inst)); + }; + // Reshape batch_dim N -> [G, N/G] + std::vector reshape_dims = SpanToVector(lhs->shape().dimensions()); + reshape_dims[input_batch_dimension] = + reshape_dims[input_batch_dimension] / num_groups; + reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, + num_groups); + lhs = add(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), lhs)); + + // Transpose G to the axis before C, For eg: [G, N/G, H, W, C ] -> [N/G, H, + // W, G, C] + std::vector transpose_dims(lhs->shape().dimensions_size()); + std::iota(transpose_dims.begin(), transpose_dims.end(), 0); + transpose_dims.erase(transpose_dims.begin() + input_batch_dimension); + transpose_dims.insert(transpose_dims.begin() + input_feature_dimension, + input_batch_dimension); + std::vector transpose_reshape_dims = + ComposePermutations(lhs->shape().dimensions(), transpose_dims); + lhs = add(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(lhs->shape().element_type(), + transpose_reshape_dims), + lhs, transpose_dims)); + + // Merge [G,C] -> [C*G] + Shape new_shape = lhs->shape(); + new_shape.DeleteDimension(input_feature_dimension); + new_shape.set_dimensions(input_feature_dimension, + input_feature * num_groups); + lhs = add(HloInstruction::CreateReshape(new_shape, lhs)); + + std::vector new_operands = {lhs, rhs}; + auto new_conv = conv->CloneWithNewOperands(output_shape, new_operands); + new_conv->set_feature_group_count(num_groups); + new_conv->set_batch_group_count(1); + new_conv->set_convolution_dimension_numbers(dim_numbers); + return computation->AddInstruction(std::move(new_conv)); +} + bool CanImplementAsGpuForwardConv(HloInstruction* conv) { const ConvolutionDimensionNumbers& dnums = conv->convolution_dimension_numbers(); @@ -91,9 +147,19 @@ bool CanImplementAsGpuForwardConv(HloInstruction* conv) { // Precondition: "conv" is a kConvolution. std::tuple MatchBackwardFilter(HloInstruction* conv) { + VLOG(2) << "Trying to match convolution backward filter."; const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); + if (conv->feature_group_count() > 1) { + VLOG(1) << conv->ToString() + << " is a forward convolution. All grouped backward filters are " + "mapped to batch grouped convolutions in tf2xla bridge. Hence backward filter " + "convolutions cannot have feature groups greater than 1 at this " + "point. No need to fold to backward filter."; + return no_match_result; + } + // Step 1: match the instruction pattern without considering the paddings and // dimension numbers just yet. We may need some generic pattern matcher // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h @@ -122,7 +188,6 @@ MatchBackwardFilter(HloInstruction* conv) { auto output_batch_dim = conv_dnums.output_batch_dimension(); auto output_feature_dim = conv_dnums.output_feature_dimension(); auto output_spatial_dims = conv_dnums.output_spatial_dimensions(); - for (const WindowDimension& window_dim : conv->window().dimensions()) { if (window_dim.stride() != 1) { VLOG(1) << "Forward convolution's window " @@ -150,16 +215,7 @@ MatchBackwardFilter(HloInstruction* conv) { !window_util::HasWindowDilation(conv->window())) { VLOG(1) << conv->ToString() << " is a regular forward convolution. No need " - "to fold it to a backward filter convolution."; - return no_match_result; - } - auto rhs_in = - conv->mutable_operand(1)->shape().dimensions(kernel_input_feature_dim); - if (conv->feature_group_count() > 1 && rhs_in == 1 && - input_batch_dim == output_batch_dim) { - VLOG(1) << conv->ToString() - << " is a depthwise forward convolution. No need to fold to " - "backward filter."; + "to fold it to a backward filter convolution...."; return no_match_result; } @@ -256,67 +312,14 @@ MatchBackwardFilter(HloInstruction* conv) { } HloInstruction* lhs = conv->mutable_operand(0); - if (conv->feature_group_count() == 1) { - return std::make_tuple(true, backward_conv_window, backward_conv_dnums, - lhs); - } - - int64 input_batch_dimension = backward_conv_dnums.input_batch_dimension(); - int64 input_feature_dimension = backward_conv_dnums.input_feature_dimension(); - - int64 input_batch = lhs->shape().dimensions(input_batch_dimension); - int64 input_feature = lhs->shape().dimensions(input_feature_dimension); - - // Reshape batch_dim G*N -> [G,N] - std::vector reshape_dims = SpanToVector(lhs->shape().dimensions()); - auto num_groups = conv->feature_group_count(); - CHECK_EQ(input_batch % num_groups, 0) - << "Input batch should be an exact multiple of feature group count"; - reshape_dims[input_batch_dimension] = - reshape_dims[input_batch_dimension] / num_groups; - reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups); - - HloComputation* c = conv->parent(); - HloInstruction* lhs_reshape_1 = - c->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), - lhs)); - - // Transpose G to the axis before C/G, For eg: [G, N, C/G, H, W] -> [N, G, - // C/G, H, W] - std::vector transpose_dims(lhs_reshape_1->shape().dimensions_size()); - std::iota(transpose_dims.begin(), transpose_dims.end(), 0); - transpose_dims.erase(transpose_dims.begin() + input_batch_dimension); - transpose_dims.insert(transpose_dims.begin() + input_feature_dimension, - input_batch_dimension); - std::vector transpose_reshape_dims = - SpanToVector(lhs_reshape_1->shape().dimensions()); - transpose_reshape_dims.erase(transpose_reshape_dims.begin() + - input_batch_dimension); - transpose_reshape_dims.insert( - transpose_reshape_dims.begin() + input_feature_dimension, num_groups); - - HloInstruction* lhs_transpose = - c->AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(lhs_reshape_1->shape().element_type(), - transpose_reshape_dims), - lhs_reshape_1, transpose_dims)); - - // Merge [G,C/G] -> [C] - Shape new_shape = lhs_transpose->shape(); - new_shape.DeleteDimension(input_feature_dimension); - new_shape.set_dimensions(input_feature_dimension, - input_feature * conv->feature_group_count()); - HloInstruction* lhs_reshape_2 = c->AddInstruction( - HloInstruction::CreateReshape(new_shape, lhs_transpose)); - return std::make_tuple(true, backward_conv_window, backward_conv_dnums, - lhs_reshape_2); + return std::make_tuple(true, backward_conv_window, backward_conv_dnums, lhs); } // Try to match a backward input pattern that contains "conv". // Precondition: "conv" is a kConvolution. std::tuple MatchBackwardInput(HloInstruction* conv) { + VLOG(2) << "Trying to match convolution backward input."; const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); @@ -639,9 +642,12 @@ static StatusOr CreateCustomCallHelper(HloInstruction* conv) { if (match) { return CreateGpuConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), lhs, conv->mutable_operand(1), window, dnums, - conv->feature_group_count(), conv->metadata()); + conv->batch_group_count(), conv->metadata()); } + if (conv->batch_group_count() > 1) { + conv = ConvertBatchGroupedToFeatureGroupedConvolution(conv); + } // If all else fails, try a forward convolution. if (CanImplementAsGpuForwardConv(conv)) { if (primitive_util::IsIntegralType( @@ -736,11 +742,13 @@ StatusOr RunOnComputation(HloComputation* computation) { } // namespace StatusOr GpuConvRewriter::Run(HloModule* module) { + XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), before:\n" + module->ToString()); bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations()) { TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); changed |= result; } + XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), after:\n" + module->ToString()); return changed; }