Merge pull request #30370 from AyanmoI:amoitra/use_cudnn_backprop_filter_grouped_conv
PiperOrigin-RevId: 259548643
This commit is contained in:
commit
bb27ef8604
84
tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc
Normal file → Executable file
84
tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc
Normal file → Executable file
@ -89,13 +89,11 @@ bool CanImplementAsCudnnForwardConv(HloInstruction* conv) {
|
||||
|
||||
// Try to match a backward filter pattern that contains "conv".
|
||||
// Precondition: "conv" is a kConvolution.
|
||||
std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
|
||||
HloInstruction* conv) {
|
||||
std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
|
||||
MatchBackwardFilter(HloInstruction* conv) {
|
||||
const auto no_match_result =
|
||||
std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
|
||||
if (conv->feature_group_count() > 1) {
|
||||
return no_match_result;
|
||||
}
|
||||
std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
|
||||
|
||||
// Step 1: match the instruction pattern without considering the paddings and
|
||||
// dimension numbers just yet. We may need some generic pattern matcher
|
||||
// similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
|
||||
@ -155,6 +153,15 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
|
||||
"to fold it to a backward filter convolution.";
|
||||
return no_match_result;
|
||||
}
|
||||
auto rhs_in =
|
||||
conv->mutable_operand(1)->shape().dimensions(kernel_input_feature_dim);
|
||||
if (conv->feature_group_count() > 1 && rhs_in == 1 &&
|
||||
input_batch_dim == output_batch_dim) {
|
||||
VLOG(1) << conv->ToString()
|
||||
<< " is a depthwise forward convolution. No need to fold to "
|
||||
"backward filter.";
|
||||
return no_match_result;
|
||||
}
|
||||
|
||||
// Step 3: fuse the matched HLOs into a backward convolution instruction.
|
||||
//
|
||||
@ -248,7 +255,62 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
|
||||
backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]);
|
||||
}
|
||||
|
||||
return std::make_tuple(true, backward_conv_window, backward_conv_dnums);
|
||||
HloInstruction* lhs = conv->mutable_operand(0);
|
||||
if (conv->feature_group_count() == 1) {
|
||||
return std::make_tuple(true, backward_conv_window, backward_conv_dnums,
|
||||
lhs);
|
||||
}
|
||||
|
||||
int64 input_batch_dimension = backward_conv_dnums.input_batch_dimension();
|
||||
int64 input_feature_dimension = backward_conv_dnums.input_feature_dimension();
|
||||
|
||||
int64 input_batch = lhs->shape().dimensions(input_batch_dimension);
|
||||
int64 input_feature = lhs->shape().dimensions(input_feature_dimension);
|
||||
|
||||
// Reshape batch_dim G*N -> [G,N]
|
||||
std::vector<int64> reshape_dims = lhs->shape().dimensions();
|
||||
auto num_groups = conv->feature_group_count();
|
||||
CHECK_EQ(input_batch % num_groups, 0)
|
||||
<< "Input batch should be an exact multiple of feature group count";
|
||||
reshape_dims[input_batch_dimension] =
|
||||
reshape_dims[input_batch_dimension] / num_groups;
|
||||
reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups);
|
||||
|
||||
HloComputation* c = conv->parent();
|
||||
HloInstruction* lhs_reshape_1 =
|
||||
c->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims),
|
||||
lhs));
|
||||
|
||||
// Transpose G to the axis before C/G, For eg: [G, N, C/G, H, W] -> [N, G,
|
||||
// C/G, H, W]
|
||||
std::vector<int64> transpose_dims(lhs_reshape_1->shape().dimensions_size());
|
||||
std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
|
||||
transpose_dims.erase(transpose_dims.begin() + input_batch_dimension);
|
||||
transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
|
||||
input_batch_dimension);
|
||||
std::vector<int64> transpose_reshape_dims =
|
||||
lhs_reshape_1->shape().dimensions();
|
||||
transpose_reshape_dims.erase(transpose_reshape_dims.begin() +
|
||||
input_batch_dimension);
|
||||
transpose_reshape_dims.insert(
|
||||
transpose_reshape_dims.begin() + input_feature_dimension, num_groups);
|
||||
|
||||
HloInstruction* lhs_transpose =
|
||||
c->AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::MakeShape(lhs_reshape_1->shape().element_type(),
|
||||
transpose_reshape_dims),
|
||||
lhs_reshape_1, transpose_dims));
|
||||
|
||||
// Merge [G,C/G] -> [C]
|
||||
Shape new_shape = lhs_transpose->shape();
|
||||
new_shape.DeleteDimension(input_feature_dimension);
|
||||
new_shape.set_dimensions(input_feature_dimension,
|
||||
input_feature * conv->feature_group_count());
|
||||
HloInstruction* lhs_reshape_2 = c->AddInstruction(
|
||||
HloInstruction::CreateReshape(new_shape, lhs_transpose));
|
||||
return std::make_tuple(true, backward_conv_window, backward_conv_dnums,
|
||||
lhs_reshape_2);
|
||||
}
|
||||
|
||||
// Try to match a backward input pattern that contains "conv".
|
||||
@ -503,6 +565,7 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
|
||||
Window window;
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
HloInstruction* rhs;
|
||||
HloInstruction* lhs;
|
||||
|
||||
std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
|
||||
if (match) {
|
||||
@ -511,12 +574,11 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
|
||||
conv->feature_group_count(), conv->metadata());
|
||||
}
|
||||
|
||||
std::tie(match, window, dnums) = MatchBackwardFilter(conv);
|
||||
std::tie(match, window, dnums, lhs) = MatchBackwardFilter(conv);
|
||||
if (match) {
|
||||
return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(),
|
||||
conv->mutable_operand(0), conv->mutable_operand(1),
|
||||
window, dnums, conv->feature_group_count(),
|
||||
conv->metadata());
|
||||
lhs, conv->mutable_operand(1), window, dnums,
|
||||
conv->feature_group_count(), conv->metadata());
|
||||
}
|
||||
|
||||
// If all else fails, try a forward convolution.
|
||||
|
@ -135,6 +135,86 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolve) {
|
||||
<< md_after_opt.DebugString() << " vs " << metadata.DebugString();
|
||||
}
|
||||
|
||||
TEST_F(CudnnConvRewriterTest, BackwardFilterGroupConvolve) {
|
||||
// In a nutshell, before pass:
|
||||
// Input->batch_dim: 3 input_shape(3) = 4
|
||||
// Input->feature_dim: 0 input_shape(0) = 32
|
||||
// Kernel(gradient)->kernel_input_feature_dim (gradient_batch_dimension): 0
|
||||
// Kernel(gradient)->kernel_output_feature_dim (gradient_feature_dimension): 3
|
||||
// Output(dkernel)->output_batch_dim (dkernel_input_feature_dim): 2
|
||||
// Output(dkernel)->output_feature_dim (dkernel_output_feature_dim): 3
|
||||
|
||||
// After pass: All shapes and dimension layout is brought
|
||||
// back to normal as would be acceptable by cudnn
|
||||
// Input->batch_dim: 0 input_shape(0) = 8
|
||||
// Input->feature_dim: 3 input_shape(3) = 16
|
||||
// Kernel(gradient)->kernel_input_feature_dim (gradient_batch_dimension): 2
|
||||
// Kernel(gradient)->kernel_output_feature_dim (gradient_feature_dimension): 3
|
||||
// Output(dkernel)->output_batch_dim (dkernel_input_feature_dim): 0
|
||||
// Output(dkernel)->output_feature_dim (dkernel_output_feature_dim): 3
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* activations =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {32, 1, 3, 4}), "activations"));
|
||||
HloInstruction* gradients =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(F32, {8, 1, 2, 16}), "gradients"));
|
||||
Window conv_window = default_conv_window_;
|
||||
conv_window.mutable_dimensions(1)->set_size(2);
|
||||
conv_window.mutable_dimensions(1)->set_window_dilation(2);
|
||||
auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
|
||||
ShapeInference::InferConvolveShape(
|
||||
activations->shape(), gradients->shape(), /*feature_group_count=*/4,
|
||||
/*batch_group_count=*/1, conv_window,
|
||||
tf_default_dnums_for_backward_filter_)
|
||||
.ConsumeValueOrDie(),
|
||||
activations, gradients, /*feature_group_count=*/4,
|
||||
/*batch_group_count=*/1, conv_window,
|
||||
tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
|
||||
OpMetadata metadata;
|
||||
metadata.set_op_name("bar");
|
||||
conv->set_metadata(metadata);
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation* entry_computation =
|
||||
module->AddEntryComputation(builder.Build());
|
||||
EXPECT_TRUE(RunPass(module.get()));
|
||||
ASSERT_THAT(entry_computation->root_instruction(),
|
||||
op::GetTupleElement(
|
||||
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
|
||||
// Check that metadata was preserved.
|
||||
const auto& md_after_opt =
|
||||
entry_computation->root_instruction()->operand(0)->metadata();
|
||||
EXPECT_TRUE(protobuf_util::ProtobufEquals(md_after_opt, metadata))
|
||||
<< md_after_opt.DebugString() << " vs " << metadata.DebugString();
|
||||
const HloInstruction* custom_call =
|
||||
entry_computation->root_instruction()->operand(0);
|
||||
const ConvolutionDimensionNumbers conv_dim =
|
||||
custom_call->convolution_dimension_numbers();
|
||||
const auto lhs_a = custom_call->operand(0);
|
||||
const auto input_shape = lhs_a->shape();
|
||||
// The input (lhs) batch_dim(dim 0 in the original NHWC layout) gets mapped to
|
||||
// be the feature_dim(dim 3) with a value of N*g = 32 in tf2xla. As described
|
||||
// in conv_grad_ops.h, this swap is required to implement backprop using fwd
|
||||
// conv. After the pass the batch_dim gets remapped to dim 0. The batch_dim
|
||||
// value gets scaled to N = N*g/g = 32/4 = 8 to be compatible with cudnn
|
||||
EXPECT_EQ(0, conv_dim.input_batch_dimension());
|
||||
EXPECT_EQ(8, input_shape.dimensions(conv_dim.input_batch_dimension()));
|
||||
// Similarly, the input (lhs) feature_dim(dim 3 in the original NHWC layout)
|
||||
// gets mapped to be the batch_dim(dim 0) with a value of C/g = 4 in tf2xla.
|
||||
// After the pass the batch_dim gets remapped to dim 0. The feature_dim value
|
||||
// gets scaled to C = C/g*g = 4*4 = 16 to be compatible with cudnn
|
||||
EXPECT_EQ(3, conv_dim.input_feature_dimension());
|
||||
EXPECT_EQ(16, input_shape.dimensions(conv_dim.input_feature_dimension()));
|
||||
// Similarly, the feature and batch dims of the incoming gradients (used as
|
||||
// rhs) and the in/out dims of the output of convolution i.e, dgrad have been
|
||||
// been modified in tf2xla (as described in conv_grad_ops.h). This pass remaps
|
||||
// everything back for the layout to be compatible with cudnn backprop APIs.
|
||||
EXPECT_EQ(2, conv_dim.kernel_input_feature_dimension());
|
||||
EXPECT_EQ(3, conv_dim.kernel_output_feature_dimension());
|
||||
EXPECT_EQ(0, conv_dim.output_batch_dimension());
|
||||
EXPECT_EQ(3, conv_dim.output_feature_dimension());
|
||||
}
|
||||
|
||||
TEST_F(CudnnConvRewriterTest,
|
||||
BackwardFilterConvolveEquivalentToForwardConvolution) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
Loading…
Reference in New Issue
Block a user