diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 4c0dcbbd2ad..fcf5c9c7952 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -59,6 +59,7 @@ limitations under the License. #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -494,6 +495,10 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { StatusOr FoldConvInputPad(HloInstruction* convolution); StatusOr FoldConvFilterPad(HloInstruction* convolution); + // Tries to swap convolution operands if they would result in a more efficient + // convolution. + StatusOr SwapConvOperands(HloInstruction* convolution); + // Tries to use a kDot in place of the given convolution. StatusOr SimplifyConvToDot(HloInstruction* convolution); @@ -4481,6 +4486,107 @@ StatusOr AlgebraicSimplifierVisitor::FoldConvFilterPad( return true; } +StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( + HloInstruction* convolution) { + if (!options_.enable_conv_operand_swap() || options_.is_layout_sensitive()) { + return false; + } + if (convolution->feature_group_count() > 1 || + convolution->batch_group_count() > 1) { + return false; + } + + const auto& dnums = convolution->convolution_dimension_numbers(); + const auto& window_dims = convolution->window().dimensions(); + Window swapped_window; + + HloInstruction *input = convolution->mutable_operand(0), + *kernel = convolution->mutable_operand(1); + int64 kernel_product = 1; + int64 swapped_kernel_product = 1; + DimensionVector reverse_dimensions; + for (int64 spatial_dim = 0; + spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) { + const int64 kernel_size = window_dims[spatial_dim].size(); + kernel_product *= kernel_size; + const int64 dilated_kernel_size = + 1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation(); + + const int64 input_size = + input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim)); + swapped_kernel_product *= input_size; + const int64 dilated_input_size = + 1 + (input_size - 1) * window_dims[spatial_dim].base_dilation(); + + auto new_dim = swapped_window.add_dimensions(); + new_dim->set_size(input_size); + // If the kernel is not reversed, the activations must be manually reversed. + if (!window_dims[spatial_dim].window_reversal()) { + reverse_dimensions.push_back( + dnums.kernel_spatial_dimensions(spatial_dim)); + } + // The input is not originally reversed so it must be reversed to move the + // kernel. + new_dim->set_window_reversal(true); + // Base dilation and window dilation switch places. + new_dim->set_base_dilation(window_dims[spatial_dim].window_dilation()); + new_dim->set_window_dilation(window_dims[spatial_dim].base_dilation()); + new_dim->set_stride(window_dims[spatial_dim].stride()); + new_dim->set_padding_low(dilated_input_size + + window_dims[spatial_dim].padding_low() - + dilated_kernel_size); + new_dim->set_padding_high(dilated_input_size + + window_dims[spatial_dim].padding_high() - + dilated_kernel_size); + } + + // Don't transform if a naive convolution implementation would not have fewer + // flops. + if (kernel_product <= swapped_kernel_product) { + return false; + } + ConvolutionDimensionNumbers swapped_dnums; + *swapped_dnums.mutable_output_spatial_dimensions() = + dnums.output_spatial_dimensions(); + // Swap batch and output feature of the output. + swapped_dnums.set_output_batch_dimension(dnums.output_feature_dimension()); + swapped_dnums.set_output_feature_dimension(dnums.output_batch_dimension()); + + // Swap input dnums with kernel dnums + *swapped_dnums.mutable_input_spatial_dimensions() = + dnums.kernel_spatial_dimensions(); + swapped_dnums.set_input_batch_dimension( + dnums.kernel_output_feature_dimension()); + swapped_dnums.set_input_feature_dimension( + dnums.kernel_input_feature_dimension()); + + // Swap kernel dnums with input dnums + *swapped_dnums.mutable_kernel_spatial_dimensions() = + dnums.input_spatial_dimensions(); + swapped_dnums.set_kernel_output_feature_dimension( + dnums.input_batch_dimension()); + swapped_dnums.set_kernel_input_feature_dimension( + dnums.input_feature_dimension()); + + PrecisionConfig precision_config; + precision_config.add_operand_precision( + convolution->precision_config().operand_precision(1)); + precision_config.add_operand_precision( + convolution->precision_config().operand_precision(0)); + if (!reverse_dimensions.empty()) { + TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions)); + } + TF_ASSIGN_OR_RETURN( + HloInstruction * new_convolution, + MakeConvolveHlo(kernel, input, /*feature_group_count=*/1, swapped_window, + swapped_dnums, precision_config)); + + convolution->SetupDerivedInstruction(new_convolution); + TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution)); + + return true; +} + StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( HloInstruction* convolution) { auto* lhs = convolution->mutable_operand(0); @@ -4619,6 +4725,11 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( return Status::OK(); } + // Try to swap convolution operands. + TF_ASSIGN_OR_RETURN(bool swapped, SwapConvOperands(convolution)); + if (swapped) { + return Status::OK(); + } // Try to replace the convolution with a kDot instruction. TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution)); if (replaced_with_dot) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index d3c276e9bc3..9f29df3c209 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -80,6 +80,12 @@ class AlgebraicSimplifierOptions { return enable_conv_simplification_; } + // Enable convolution operand swapping on platforms where it is supported. + void set_enable_conv_operand_swap(bool enable_conv_operand_swap) { + enable_conv_operand_swap_ = enable_conv_operand_swap; + } + bool enable_conv_operand_swap() const { return enable_conv_operand_swap_; } + // If enable_window_reduce_replacement is true, the kReduceWindow instruction // can be optimized by replacement with simpler operations. void set_enable_window_reduce_to_reduce_replacement( @@ -139,6 +145,7 @@ class AlgebraicSimplifierOptions { bool enable_dot_strength_reduction_{true}; bool enable_dot_to_multiply_rewrite_{true}; bool enable_conv_simplification_{true}; + bool enable_conv_operand_swap_{true}; bool enable_window_reduce_to_reduce_replacement_{true}; bool enable_reduce_of_reshape_{true}; bool replace_transpose_with_bitcast_{true}; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 10b437506b3..dde1bcbf785 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -6414,5 +6414,33 @@ TEST_F(AlgebraicSimplifierTest, ScalarScatter) { // Combine Scatters ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); } + +TEST_F(AlgebraicSimplifierTest, SwapConvOperands) { + const char* hlo_string = R"( + HloModule m + test { + a = f32[3,3,160,160] parameter(0) + b = f32[128,32,32,160] parameter(1) + ROOT c = f32[128,32,32,160] convolution(a,b), + window={size=32x32 pad=30_30x30_30 rhs_reversal=1x1}, + dim_labels=01bf_o01i->f01b + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + // Combine Scatters + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + const HloInstruction* conv = m->entry_computation()->root_instruction(); + EXPECT_THAT(conv, + GmockMatch(m::Convolution(m::Parameter(1), m::Parameter(0)))); + EXPECT_EQ(conv->window().dimensions(0).size(), 3); + EXPECT_EQ(conv->window().dimensions(1).size(), 3); + EXPECT_EQ(conv->window().dimensions(0).window_reversal(), true); + EXPECT_EQ(conv->window().dimensions(1).window_reversal(), true); + EXPECT_EQ(conv->window().dimensions(0).padding_low(), 1); + EXPECT_EQ(conv->window().dimensions(1).padding_low(), 1); + EXPECT_EQ(conv->window().dimensions(0).padding_high(), 1); + EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 767c34b3a99..b6c1e671986 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -216,6 +216,7 @@ Status GpuCompiler::OptimizeHloModule( // bitcast. This leads to having to linearize and then delinearize the // index. options.set_replace_transpose_with_bitcast(false); + options.set_enable_conv_operand_swap(false); pass.AddPass(options); // AlgebraicSimplifier may add contracting dimensions to a dot. pass.AddPass(); @@ -321,6 +322,7 @@ Status GpuCompiler::OptimizeHloModule( HloPassPipeline pipeline("final_algebraic_simplifier"); AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); + options.set_enable_conv_operand_swap(false); pipeline.AddPass(options); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -399,6 +401,7 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( // bitcast. This leads to having to linearize and then delinearize the // index. options.set_replace_transpose_with_bitcast(false); + options.set_enable_conv_operand_swap(false); pipeline.AddPass>(options); if (RequireDeterminism() || diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index d905e56b66f..b4ec6b92e45 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -141,6 +141,7 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( // bitcast. This leads to having to linearize and then delinearize the // index. options.set_replace_transpose_with_bitcast(false); + options.set_enable_conv_operand_swap(false); options.set_cudnn_batchnorm_forward_training_metadata( kCudnnBatchNormForwardTrainingCallTarget); pass.AddPass(options); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 15d3f7f1cbb..c63f1d0edf3 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -2008,6 +2008,47 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01})); } +XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolve) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %lhs = f32[3,3,7,7] parameter(0) + %rhs = f32[5,11,11,7] parameter(1) + ROOT %convolution = f32[5,21,2,7] convolution(lhs, rhs), + window={size=11x11 pad=3_25x3_6}, + dim_labels=01bf_o01i->f01b +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01})); +} + +XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolveWithStride) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %lhs = f32[3,3,7,7] parameter(0) + %rhs = f32[5,11,11,7] parameter(1) + ROOT %convolution = f32[5,11,2,7] convolution(lhs, rhs), + window={size=11x11 pad=3_26x3_6 stride=2x1}, + dim_labels=01bf_o01i->f01b +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01})); +} +XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolve2) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %lhs = f32[3,3,7,7] parameter(0) + %rhs = f32[5,11,11,7] parameter(1) + ROOT %convolution = f32[5,11,4,7] convolution(lhs, rhs), + window={size=11x11 pad=3_25x3_6 lhs_dilate=1x2 rhs_dilate=2x1}, + dim_labels=01bf_o01i->f01b +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01})); +} + XLA_TEST_F(ConvolutionHloTest, TestConv0D) { constexpr char kHlo[] = R"( HloModule TestModule