[XLA] Swap convolution operands if a naive convolution implementation would do less flops.

PiperOrigin-RevId: 306657217
Change-Id: I1e48bb858da585fbc421092da0e6172ea3cf5d56
This commit is contained in:
Blake Hechtman 2020-04-15 09:19:45 -07:00 committed by TensorFlower Gardener
parent 0592ae692b
commit 902bd615ea
6 changed files with 191 additions and 0 deletions

View File

@ -59,6 +59,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/lib/statusor.h"
@ -494,6 +495,10 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
StatusOr<bool> FoldConvInputPad(HloInstruction* convolution);
StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution);
// Tries to swap convolution operands if they would result in a more efficient
// convolution.
StatusOr<bool> SwapConvOperands(HloInstruction* convolution);
// Tries to use a kDot in place of the given convolution.
StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
@ -4481,6 +4486,107 @@ StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
return true;
}
StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
HloInstruction* convolution) {
if (!options_.enable_conv_operand_swap() || options_.is_layout_sensitive()) {
return false;
}
if (convolution->feature_group_count() > 1 ||
convolution->batch_group_count() > 1) {
return false;
}
const auto& dnums = convolution->convolution_dimension_numbers();
const auto& window_dims = convolution->window().dimensions();
Window swapped_window;
HloInstruction *input = convolution->mutable_operand(0),
*kernel = convolution->mutable_operand(1);
int64 kernel_product = 1;
int64 swapped_kernel_product = 1;
DimensionVector reverse_dimensions;
for (int64 spatial_dim = 0;
spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
const int64 kernel_size = window_dims[spatial_dim].size();
kernel_product *= kernel_size;
const int64 dilated_kernel_size =
1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
const int64 input_size =
input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
swapped_kernel_product *= input_size;
const int64 dilated_input_size =
1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
auto new_dim = swapped_window.add_dimensions();
new_dim->set_size(input_size);
// If the kernel is not reversed, the activations must be manually reversed.
if (!window_dims[spatial_dim].window_reversal()) {
reverse_dimensions.push_back(
dnums.kernel_spatial_dimensions(spatial_dim));
}
// The input is not originally reversed so it must be reversed to move the
// kernel.
new_dim->set_window_reversal(true);
// Base dilation and window dilation switch places.
new_dim->set_base_dilation(window_dims[spatial_dim].window_dilation());
new_dim->set_window_dilation(window_dims[spatial_dim].base_dilation());
new_dim->set_stride(window_dims[spatial_dim].stride());
new_dim->set_padding_low(dilated_input_size +
window_dims[spatial_dim].padding_low() -
dilated_kernel_size);
new_dim->set_padding_high(dilated_input_size +
window_dims[spatial_dim].padding_high() -
dilated_kernel_size);
}
// Don't transform if a naive convolution implementation would not have fewer
// flops.
if (kernel_product <= swapped_kernel_product) {
return false;
}
ConvolutionDimensionNumbers swapped_dnums;
*swapped_dnums.mutable_output_spatial_dimensions() =
dnums.output_spatial_dimensions();
// Swap batch and output feature of the output.
swapped_dnums.set_output_batch_dimension(dnums.output_feature_dimension());
swapped_dnums.set_output_feature_dimension(dnums.output_batch_dimension());
// Swap input dnums with kernel dnums
*swapped_dnums.mutable_input_spatial_dimensions() =
dnums.kernel_spatial_dimensions();
swapped_dnums.set_input_batch_dimension(
dnums.kernel_output_feature_dimension());
swapped_dnums.set_input_feature_dimension(
dnums.kernel_input_feature_dimension());
// Swap kernel dnums with input dnums
*swapped_dnums.mutable_kernel_spatial_dimensions() =
dnums.input_spatial_dimensions();
swapped_dnums.set_kernel_output_feature_dimension(
dnums.input_batch_dimension());
swapped_dnums.set_kernel_input_feature_dimension(
dnums.input_feature_dimension());
PrecisionConfig precision_config;
precision_config.add_operand_precision(
convolution->precision_config().operand_precision(1));
precision_config.add_operand_precision(
convolution->precision_config().operand_precision(0));
if (!reverse_dimensions.empty()) {
TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions));
}
TF_ASSIGN_OR_RETURN(
HloInstruction * new_convolution,
MakeConvolveHlo(kernel, input, /*feature_group_count=*/1, swapped_window,
swapped_dnums, precision_config));
convolution->SetupDerivedInstruction(new_convolution);
TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution));
return true;
}
StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
HloInstruction* convolution) {
auto* lhs = convolution->mutable_operand(0);
@ -4619,6 +4725,11 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
return Status::OK();
}
// Try to swap convolution operands.
TF_ASSIGN_OR_RETURN(bool swapped, SwapConvOperands(convolution));
if (swapped) {
return Status::OK();
}
// Try to replace the convolution with a kDot instruction.
TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
if (replaced_with_dot) {

View File

@ -80,6 +80,12 @@ class AlgebraicSimplifierOptions {
return enable_conv_simplification_;
}
// Enable convolution operand swapping on platforms where it is supported.
void set_enable_conv_operand_swap(bool enable_conv_operand_swap) {
enable_conv_operand_swap_ = enable_conv_operand_swap;
}
bool enable_conv_operand_swap() const { return enable_conv_operand_swap_; }
// If enable_window_reduce_replacement is true, the kReduceWindow instruction
// can be optimized by replacement with simpler operations.
void set_enable_window_reduce_to_reduce_replacement(
@ -139,6 +145,7 @@ class AlgebraicSimplifierOptions {
bool enable_dot_strength_reduction_{true};
bool enable_dot_to_multiply_rewrite_{true};
bool enable_conv_simplification_{true};
bool enable_conv_operand_swap_{true};
bool enable_window_reduce_to_reduce_replacement_{true};
bool enable_reduce_of_reshape_{true};
bool replace_transpose_with_bitcast_{true};

View File

@ -6414,5 +6414,33 @@ TEST_F(AlgebraicSimplifierTest, ScalarScatter) {
// Combine Scatters
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
}
TEST_F(AlgebraicSimplifierTest, SwapConvOperands) {
const char* hlo_string = R"(
HloModule m
test {
a = f32[3,3,160,160] parameter(0)
b = f32[128,32,32,160] parameter(1)
ROOT c = f32[128,32,32,160] convolution(a,b),
window={size=32x32 pad=30_30x30_30 rhs_reversal=1x1},
dim_labels=01bf_o01i->f01b
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
// Combine Scatters
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
const HloInstruction* conv = m->entry_computation()->root_instruction();
EXPECT_THAT(conv,
GmockMatch(m::Convolution(m::Parameter(1), m::Parameter(0))));
EXPECT_EQ(conv->window().dimensions(0).size(), 3);
EXPECT_EQ(conv->window().dimensions(1).size(), 3);
EXPECT_EQ(conv->window().dimensions(0).window_reversal(), true);
EXPECT_EQ(conv->window().dimensions(1).window_reversal(), true);
EXPECT_EQ(conv->window().dimensions(0).padding_low(), 1);
EXPECT_EQ(conv->window().dimensions(1).padding_low(), 1);
EXPECT_EQ(conv->window().dimensions(0).padding_high(), 1);
EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1);
}
} // namespace
} // namespace xla

View File

@ -216,6 +216,7 @@ Status GpuCompiler::OptimizeHloModule(
// bitcast. This leads to having to linearize and then delinearize the
// index.
options.set_replace_transpose_with_bitcast(false);
options.set_enable_conv_operand_swap(false);
pass.AddPass<AlgebraicSimplifier>(options);
// AlgebraicSimplifier may add contracting dimensions to a dot.
pass.AddPass<DotDecomposer>();
@ -321,6 +322,7 @@ Status GpuCompiler::OptimizeHloModule(
HloPassPipeline pipeline("final_algebraic_simplifier");
AlgebraicSimplifierOptions options;
options.set_is_layout_sensitive(true);
options.set_enable_conv_operand_swap(false);
pipeline.AddPass<AlgebraicSimplifier>(options);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
@ -399,6 +401,7 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment(
// bitcast. This leads to having to linearize and then delinearize the
// index.
options.set_replace_transpose_with_bitcast(false);
options.set_enable_conv_operand_swap(false);
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
if (RequireDeterminism() ||

View File

@ -141,6 +141,7 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(
// bitcast. This leads to having to linearize and then delinearize the
// index.
options.set_replace_transpose_with_bitcast(false);
options.set_enable_conv_operand_swap(false);
options.set_cudnn_batchnorm_forward_training_metadata(
kCudnnBatchNormForwardTrainingCallTarget);
pass.AddPass<AlgebraicSimplifier>(options);

View File

@ -2008,6 +2008,47 @@ ENTRY Test {
EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
}
XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolve) {
constexpr char kHlo[] = R"(
HloModule TestModule
ENTRY Test {
%lhs = f32[3,3,7,7] parameter(0)
%rhs = f32[5,11,11,7] parameter(1)
ROOT %convolution = f32[5,21,2,7] convolution(lhs, rhs),
window={size=11x11 pad=3_25x3_6},
dim_labels=01bf_o01i->f01b
})";
EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
}
XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolveWithStride) {
constexpr char kHlo[] = R"(
HloModule TestModule
ENTRY Test {
%lhs = f32[3,3,7,7] parameter(0)
%rhs = f32[5,11,11,7] parameter(1)
ROOT %convolution = f32[5,11,2,7] convolution(lhs, rhs),
window={size=11x11 pad=3_26x3_6 stride=2x1},
dim_labels=01bf_o01i->f01b
})";
EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
}
XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolve2) {
constexpr char kHlo[] = R"(
HloModule TestModule
ENTRY Test {
%lhs = f32[3,3,7,7] parameter(0)
%rhs = f32[5,11,11,7] parameter(1)
ROOT %convolution = f32[5,11,4,7] convolution(lhs, rhs),
window={size=11x11 pad=3_25x3_6 lhs_dilate=1x2 rhs_dilate=2x1},
dim_labels=01bf_o01i->f01b
})";
EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
}
XLA_TEST_F(ConvolutionHloTest, TestConv0D) {
constexpr char kHlo[] = R"(
HloModule TestModule