Merge pull request #30370 from AyanmoI:amoitra/use_cudnn_backprop_filter_grouped_conv

PiperOrigin-RevId: 259548643
This commit is contained in:
TensorFlower Gardener 2019-07-23 09:17:32 -07:00
commit bb27ef8604
2 changed files with 153 additions and 11 deletions

View File

@ -89,13 +89,11 @@ bool CanImplementAsCudnnForwardConv(HloInstruction* conv) {
// Try to match a backward filter pattern that contains "conv".
// Precondition: "conv" is a kConvolution.
std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
HloInstruction* conv) {
std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
MatchBackwardFilter(HloInstruction* conv) {
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
if (conv->feature_group_count() > 1) {
return no_match_result;
}
std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
// Step 1: match the instruction pattern without considering the paddings and
// dimension numbers just yet. We may need some generic pattern matcher
// similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
@ -155,6 +153,15 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
"to fold it to a backward filter convolution.";
return no_match_result;
}
auto rhs_in =
conv->mutable_operand(1)->shape().dimensions(kernel_input_feature_dim);
if (conv->feature_group_count() > 1 && rhs_in == 1 &&
input_batch_dim == output_batch_dim) {
VLOG(1) << conv->ToString()
<< " is a depthwise forward convolution. No need to fold to "
"backward filter.";
return no_match_result;
}
// Step 3: fuse the matched HLOs into a backward convolution instruction.
//
@ -248,7 +255,62 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]);
}
return std::make_tuple(true, backward_conv_window, backward_conv_dnums);
HloInstruction* lhs = conv->mutable_operand(0);
if (conv->feature_group_count() == 1) {
return std::make_tuple(true, backward_conv_window, backward_conv_dnums,
lhs);
}
int64 input_batch_dimension = backward_conv_dnums.input_batch_dimension();
int64 input_feature_dimension = backward_conv_dnums.input_feature_dimension();
int64 input_batch = lhs->shape().dimensions(input_batch_dimension);
int64 input_feature = lhs->shape().dimensions(input_feature_dimension);
// Reshape batch_dim G*N -> [G,N]
std::vector<int64> reshape_dims = lhs->shape().dimensions();
auto num_groups = conv->feature_group_count();
CHECK_EQ(input_batch % num_groups, 0)
<< "Input batch should be an exact multiple of feature group count";
reshape_dims[input_batch_dimension] =
reshape_dims[input_batch_dimension] / num_groups;
reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups);
HloComputation* c = conv->parent();
HloInstruction* lhs_reshape_1 =
c->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims),
lhs));
// Transpose G to the axis before C/G, For eg: [G, N, C/G, H, W] -> [N, G,
// C/G, H, W]
std::vector<int64> transpose_dims(lhs_reshape_1->shape().dimensions_size());
std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
transpose_dims.erase(transpose_dims.begin() + input_batch_dimension);
transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
input_batch_dimension);
std::vector<int64> transpose_reshape_dims =
lhs_reshape_1->shape().dimensions();
transpose_reshape_dims.erase(transpose_reshape_dims.begin() +
input_batch_dimension);
transpose_reshape_dims.insert(
transpose_reshape_dims.begin() + input_feature_dimension, num_groups);
HloInstruction* lhs_transpose =
c->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(lhs_reshape_1->shape().element_type(),
transpose_reshape_dims),
lhs_reshape_1, transpose_dims));
// Merge [G,C/G] -> [C]
Shape new_shape = lhs_transpose->shape();
new_shape.DeleteDimension(input_feature_dimension);
new_shape.set_dimensions(input_feature_dimension,
input_feature * conv->feature_group_count());
HloInstruction* lhs_reshape_2 = c->AddInstruction(
HloInstruction::CreateReshape(new_shape, lhs_transpose));
return std::make_tuple(true, backward_conv_window, backward_conv_dnums,
lhs_reshape_2);
}
// Try to match a backward input pattern that contains "conv".
@ -503,6 +565,7 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
Window window;
ConvolutionDimensionNumbers dnums;
HloInstruction* rhs;
HloInstruction* lhs;
std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
if (match) {
@ -511,12 +574,11 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
conv->feature_group_count(), conv->metadata());
}
std::tie(match, window, dnums) = MatchBackwardFilter(conv);
std::tie(match, window, dnums, lhs) = MatchBackwardFilter(conv);
if (match) {
return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(),
conv->mutable_operand(0), conv->mutable_operand(1),
window, dnums, conv->feature_group_count(),
conv->metadata());
lhs, conv->mutable_operand(1), window, dnums,
conv->feature_group_count(), conv->metadata());
}
// If all else fails, try a forward convolution.

View File

@ -135,6 +135,86 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolve) {
<< md_after_opt.DebugString() << " vs " << metadata.DebugString();
}
TEST_F(CudnnConvRewriterTest, BackwardFilterGroupConvolve) {
// In a nutshell, before pass:
// Input->batch_dim: 3 input_shape(3) = 4
// Input->feature_dim: 0 input_shape(0) = 32
// Kernel(gradient)->kernel_input_feature_dim (gradient_batch_dimension): 0
// Kernel(gradient)->kernel_output_feature_dim (gradient_feature_dimension): 3
// Output(dkernel)->output_batch_dim (dkernel_input_feature_dim): 2
// Output(dkernel)->output_feature_dim (dkernel_output_feature_dim): 3
// After pass: All shapes and dimension layout is brought
// back to normal as would be acceptable by cudnn
// Input->batch_dim: 0 input_shape(0) = 8
// Input->feature_dim: 3 input_shape(3) = 16
// Kernel(gradient)->kernel_input_feature_dim (gradient_batch_dimension): 2
// Kernel(gradient)->kernel_output_feature_dim (gradient_feature_dimension): 3
// Output(dkernel)->output_batch_dim (dkernel_input_feature_dim): 0
// Output(dkernel)->output_feature_dim (dkernel_output_feature_dim): 3
HloComputation::Builder builder(TestName());
HloInstruction* activations =
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {32, 1, 3, 4}), "activations"));
HloInstruction* gradients =
builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {8, 1, 2, 16}), "gradients"));
Window conv_window = default_conv_window_;
conv_window.mutable_dimensions(1)->set_size(2);
conv_window.mutable_dimensions(1)->set_window_dilation(2);
auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeInference::InferConvolveShape(
activations->shape(), gradients->shape(), /*feature_group_count=*/4,
/*batch_group_count=*/1, conv_window,
tf_default_dnums_for_backward_filter_)
.ConsumeValueOrDie(),
activations, gradients, /*feature_group_count=*/4,
/*batch_group_count=*/1, conv_window,
tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
OpMetadata metadata;
metadata.set_op_name("bar");
conv->set_metadata(metadata);
auto module = CreateNewVerifiedModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
EXPECT_TRUE(RunPass(module.get()));
ASSERT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
// Check that metadata was preserved.
const auto& md_after_opt =
entry_computation->root_instruction()->operand(0)->metadata();
EXPECT_TRUE(protobuf_util::ProtobufEquals(md_after_opt, metadata))
<< md_after_opt.DebugString() << " vs " << metadata.DebugString();
const HloInstruction* custom_call =
entry_computation->root_instruction()->operand(0);
const ConvolutionDimensionNumbers conv_dim =
custom_call->convolution_dimension_numbers();
const auto lhs_a = custom_call->operand(0);
const auto input_shape = lhs_a->shape();
// The input (lhs) batch_dim(dim 0 in the original NHWC layout) gets mapped to
// be the feature_dim(dim 3) with a value of N*g = 32 in tf2xla. As described
// in conv_grad_ops.h, this swap is required to implement backprop using fwd
// conv. After the pass the batch_dim gets remapped to dim 0. The batch_dim
// value gets scaled to N = N*g/g = 32/4 = 8 to be compatible with cudnn
EXPECT_EQ(0, conv_dim.input_batch_dimension());
EXPECT_EQ(8, input_shape.dimensions(conv_dim.input_batch_dimension()));
// Similarly, the input (lhs) feature_dim(dim 3 in the original NHWC layout)
// gets mapped to be the batch_dim(dim 0) with a value of C/g = 4 in tf2xla.
// After the pass the batch_dim gets remapped to dim 0. The feature_dim value
// gets scaled to C = C/g*g = 4*4 = 16 to be compatible with cudnn
EXPECT_EQ(3, conv_dim.input_feature_dimension());
EXPECT_EQ(16, input_shape.dimensions(conv_dim.input_feature_dimension()));
// Similarly, the feature and batch dims of the incoming gradients (used as
// rhs) and the in/out dims of the output of convolution i.e, dgrad have been
// been modified in tf2xla (as described in conv_grad_ops.h). This pass remaps
// everything back for the layout to be compatible with cudnn backprop APIs.
EXPECT_EQ(2, conv_dim.kernel_input_feature_dimension());
EXPECT_EQ(3, conv_dim.kernel_output_feature_dimension());
EXPECT_EQ(0, conv_dim.output_batch_dimension());
EXPECT_EQ(3, conv_dim.output_feature_dimension());
}
TEST_F(CudnnConvRewriterTest,
BackwardFilterConvolveEquivalentToForwardConvolution) {
HloComputation::Builder builder(TestName());