Handle batch group conv - backward filter grouped conv
This commit is contained in:
parent
abaffb8ad1
commit
cc4a431906
@ -411,6 +411,10 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
|
||||
}
|
||||
|
||||
Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
||||
if (is_cost_viable_(convolution)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (convert_batch_groups_only_) {
|
||||
return HandleBatchGroupCount(convolution);
|
||||
}
|
||||
|
@ -190,6 +190,9 @@ Status ConvolutionVisitor::HandleBackwardFilterBatchGroupConvolution(
|
||||
}
|
||||
|
||||
Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
||||
if (is_cost_viable_(convolution)) {
|
||||
return Status::OK();
|
||||
}
|
||||
return HandleBackwardFilterBatchGroupConvolution(convolution);
|
||||
}
|
||||
|
||||
|
@ -152,17 +152,12 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
|
||||
pipeline.AddPass<Convolution4DExpander>();
|
||||
|
||||
auto cost_model = [](HloInstruction* conv) {
|
||||
auto operand = conv->operand(0);
|
||||
return operand->shape().dimensions(conv->convolution_dimension_numbers()
|
||||
.input_batch_dimension()) ==
|
||||
conv->batch_group_count();
|
||||
};
|
||||
auto cost_model = [](HloInstruction* conv) { return true; };
|
||||
pipeline.AddPass<DepthwiseConvolutionConverter>(cost_model);
|
||||
|
||||
// We use the ConvolutionGroupConverter to convert backprops of filter
|
||||
// grouped convolutions into non-grouped equivalents.
|
||||
auto batch_group_cost_model = [](HloInstruction*) { return false; };
|
||||
auto batch_group_cost_model = [](HloInstruction*) { return true; };
|
||||
|
||||
pipeline.AddPass<ConvolutionGroupConverter>(
|
||||
batch_group_cost_model,
|
||||
|
@ -64,6 +64,62 @@ HloInstruction* CreateGpuConv(const char* call_target, const Shape& shape,
|
||||
return custom_call;
|
||||
}
|
||||
|
||||
HloInstruction* ConvertBatchGroupedToFeatureGroupedConvolution(HloInstruction* conv) {
|
||||
CHECK_EQ(conv->feature_group_count(), 1);
|
||||
int64 num_groups = conv->batch_group_count();
|
||||
auto dim_numbers = conv->convolution_dimension_numbers();
|
||||
auto lhs = conv->mutable_operand(0);
|
||||
auto rhs = conv->mutable_operand(1);
|
||||
|
||||
int64 input_batch_dimension = dim_numbers.input_batch_dimension();
|
||||
int64 input_batch = lhs->shape().dimensions(input_batch_dimension);
|
||||
|
||||
Shape output_shape = conv->shape();
|
||||
int64 input_feature_dimension = dim_numbers.input_feature_dimension();
|
||||
int64 input_feature = lhs->shape().dimensions(input_feature_dimension);
|
||||
|
||||
HloComputation* computation = lhs->parent();
|
||||
auto add = [&](std::unique_ptr<HloInstruction> inst) {
|
||||
return computation->AddInstruction(std::move(inst));
|
||||
};
|
||||
// Reshape batch_dim N -> [G, N/G]
|
||||
std::vector<int64> reshape_dims = SpanToVector(lhs->shape().dimensions());
|
||||
reshape_dims[input_batch_dimension] =
|
||||
reshape_dims[input_batch_dimension] / num_groups;
|
||||
reshape_dims.insert(reshape_dims.begin() + input_batch_dimension,
|
||||
num_groups);
|
||||
lhs = add(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), lhs));
|
||||
|
||||
// Transpose G to the axis before C, For eg: [G, N/G, H, W, C ] -> [N/G, H,
|
||||
// W, G, C]
|
||||
std::vector<int64> transpose_dims(lhs->shape().dimensions_size());
|
||||
std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
|
||||
transpose_dims.erase(transpose_dims.begin() + input_batch_dimension);
|
||||
transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
|
||||
input_batch_dimension);
|
||||
std::vector<int64> transpose_reshape_dims =
|
||||
ComposePermutations(lhs->shape().dimensions(), transpose_dims);
|
||||
lhs = add(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::MakeShape(lhs->shape().element_type(),
|
||||
transpose_reshape_dims),
|
||||
lhs, transpose_dims));
|
||||
|
||||
// Merge [G,C] -> [C*G]
|
||||
Shape new_shape = lhs->shape();
|
||||
new_shape.DeleteDimension(input_feature_dimension);
|
||||
new_shape.set_dimensions(input_feature_dimension,
|
||||
input_feature * num_groups);
|
||||
lhs = add(HloInstruction::CreateReshape(new_shape, lhs));
|
||||
|
||||
std::vector<HloInstruction*> new_operands = {lhs, rhs};
|
||||
auto new_conv = conv->CloneWithNewOperands(output_shape, new_operands);
|
||||
new_conv->set_feature_group_count(num_groups);
|
||||
new_conv->set_batch_group_count(1);
|
||||
new_conv->set_convolution_dimension_numbers(dim_numbers);
|
||||
return computation->AddInstruction(std::move(new_conv));
|
||||
}
|
||||
|
||||
bool CanImplementAsGpuForwardConv(HloInstruction* conv) {
|
||||
const ConvolutionDimensionNumbers& dnums =
|
||||
conv->convolution_dimension_numbers();
|
||||
@ -91,9 +147,19 @@ bool CanImplementAsGpuForwardConv(HloInstruction* conv) {
|
||||
// Precondition: "conv" is a kConvolution.
|
||||
std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
|
||||
MatchBackwardFilter(HloInstruction* conv) {
|
||||
VLOG(2) << "Trying to match convolution backward filter.";
|
||||
const auto no_match_result =
|
||||
std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
|
||||
|
||||
if (conv->feature_group_count() > 1) {
|
||||
VLOG(1) << conv->ToString()
|
||||
<< " is a forward convolution. All grouped backward filters are "
|
||||
"mapped to batch grouped convolutions in tf2xla bridge. Hence backward filter "
|
||||
"convolutions cannot have feature groups greater than 1 at this "
|
||||
"point. No need to fold to backward filter.";
|
||||
return no_match_result;
|
||||
}
|
||||
|
||||
// Step 1: match the instruction pattern without considering the paddings and
|
||||
// dimension numbers just yet. We may need some generic pattern matcher
|
||||
// similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
|
||||
@ -122,7 +188,6 @@ MatchBackwardFilter(HloInstruction* conv) {
|
||||
auto output_batch_dim = conv_dnums.output_batch_dimension();
|
||||
auto output_feature_dim = conv_dnums.output_feature_dimension();
|
||||
auto output_spatial_dims = conv_dnums.output_spatial_dimensions();
|
||||
|
||||
for (const WindowDimension& window_dim : conv->window().dimensions()) {
|
||||
if (window_dim.stride() != 1) {
|
||||
VLOG(1) << "Forward convolution's window "
|
||||
@ -150,16 +215,7 @@ MatchBackwardFilter(HloInstruction* conv) {
|
||||
!window_util::HasWindowDilation(conv->window())) {
|
||||
VLOG(1) << conv->ToString()
|
||||
<< " is a regular forward convolution. No need "
|
||||
"to fold it to a backward filter convolution.";
|
||||
return no_match_result;
|
||||
}
|
||||
auto rhs_in =
|
||||
conv->mutable_operand(1)->shape().dimensions(kernel_input_feature_dim);
|
||||
if (conv->feature_group_count() > 1 && rhs_in == 1 &&
|
||||
input_batch_dim == output_batch_dim) {
|
||||
VLOG(1) << conv->ToString()
|
||||
<< " is a depthwise forward convolution. No need to fold to "
|
||||
"backward filter.";
|
||||
"to fold it to a backward filter convolution....";
|
||||
return no_match_result;
|
||||
}
|
||||
|
||||
@ -256,67 +312,14 @@ MatchBackwardFilter(HloInstruction* conv) {
|
||||
}
|
||||
|
||||
HloInstruction* lhs = conv->mutable_operand(0);
|
||||
if (conv->feature_group_count() == 1) {
|
||||
return std::make_tuple(true, backward_conv_window, backward_conv_dnums,
|
||||
lhs);
|
||||
}
|
||||
|
||||
int64 input_batch_dimension = backward_conv_dnums.input_batch_dimension();
|
||||
int64 input_feature_dimension = backward_conv_dnums.input_feature_dimension();
|
||||
|
||||
int64 input_batch = lhs->shape().dimensions(input_batch_dimension);
|
||||
int64 input_feature = lhs->shape().dimensions(input_feature_dimension);
|
||||
|
||||
// Reshape batch_dim G*N -> [G,N]
|
||||
std::vector<int64> reshape_dims = SpanToVector(lhs->shape().dimensions());
|
||||
auto num_groups = conv->feature_group_count();
|
||||
CHECK_EQ(input_batch % num_groups, 0)
|
||||
<< "Input batch should be an exact multiple of feature group count";
|
||||
reshape_dims[input_batch_dimension] =
|
||||
reshape_dims[input_batch_dimension] / num_groups;
|
||||
reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups);
|
||||
|
||||
HloComputation* c = conv->parent();
|
||||
HloInstruction* lhs_reshape_1 =
|
||||
c->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims),
|
||||
lhs));
|
||||
|
||||
// Transpose G to the axis before C/G, For eg: [G, N, C/G, H, W] -> [N, G,
|
||||
// C/G, H, W]
|
||||
std::vector<int64> transpose_dims(lhs_reshape_1->shape().dimensions_size());
|
||||
std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
|
||||
transpose_dims.erase(transpose_dims.begin() + input_batch_dimension);
|
||||
transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
|
||||
input_batch_dimension);
|
||||
std::vector<int64> transpose_reshape_dims =
|
||||
SpanToVector(lhs_reshape_1->shape().dimensions());
|
||||
transpose_reshape_dims.erase(transpose_reshape_dims.begin() +
|
||||
input_batch_dimension);
|
||||
transpose_reshape_dims.insert(
|
||||
transpose_reshape_dims.begin() + input_feature_dimension, num_groups);
|
||||
|
||||
HloInstruction* lhs_transpose =
|
||||
c->AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::MakeShape(lhs_reshape_1->shape().element_type(),
|
||||
transpose_reshape_dims),
|
||||
lhs_reshape_1, transpose_dims));
|
||||
|
||||
// Merge [G,C/G] -> [C]
|
||||
Shape new_shape = lhs_transpose->shape();
|
||||
new_shape.DeleteDimension(input_feature_dimension);
|
||||
new_shape.set_dimensions(input_feature_dimension,
|
||||
input_feature * conv->feature_group_count());
|
||||
HloInstruction* lhs_reshape_2 = c->AddInstruction(
|
||||
HloInstruction::CreateReshape(new_shape, lhs_transpose));
|
||||
return std::make_tuple(true, backward_conv_window, backward_conv_dnums,
|
||||
lhs_reshape_2);
|
||||
return std::make_tuple(true, backward_conv_window, backward_conv_dnums, lhs);
|
||||
}
|
||||
|
||||
// Try to match a backward input pattern that contains "conv".
|
||||
// Precondition: "conv" is a kConvolution.
|
||||
std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
|
||||
MatchBackwardInput(HloInstruction* conv) {
|
||||
VLOG(2) << "Trying to match convolution backward input.";
|
||||
const auto no_match_result =
|
||||
std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
|
||||
|
||||
@ -639,9 +642,12 @@ static StatusOr<HloInstruction*> CreateCustomCallHelper(HloInstruction* conv) {
|
||||
if (match) {
|
||||
return CreateGpuConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), lhs,
|
||||
conv->mutable_operand(1), window, dnums,
|
||||
conv->feature_group_count(), conv->metadata());
|
||||
conv->batch_group_count(), conv->metadata());
|
||||
}
|
||||
|
||||
if (conv->batch_group_count() > 1) {
|
||||
conv = ConvertBatchGroupedToFeatureGroupedConvolution(conv);
|
||||
}
|
||||
// If all else fails, try a forward convolution.
|
||||
if (CanImplementAsGpuForwardConv(conv)) {
|
||||
if (primitive_util::IsIntegralType(
|
||||
@ -736,11 +742,13 @@ StatusOr<bool> RunOnComputation(HloComputation* computation) {
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> GpuConvRewriter::Run(HloModule* module) {
|
||||
XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), before:\n" + module->ToString());
|
||||
bool changed = false;
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
|
||||
changed |= result;
|
||||
}
|
||||
XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), after:\n" + module->ToString());
|
||||
return changed;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user