From c05dd7b074a011df79cc781e165da7c767ab5b0f Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Sat, 7 Mar 2020 00:33:38 -0800 Subject: [PATCH] [XLA] Add the ability to disable the Reduce(Reshape(X)) to Reduce(X) rewrite and have TransposeFolding default to folding all operands of Convolution and Dot PiperOrigin-RevId: 299518278 Change-Id: I295cc8ec91d5059b6f760bd4edd945fc10996820 --- .../compiler/xla/service/algebraic_simplifier.cc | 3 ++- .../compiler/xla/service/algebraic_simplifier.h | 7 +++++++ tensorflow/compiler/xla/service/transpose_folding.h | 13 +++++++++++-- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index fd373671b97..1f36d906e73 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -3647,7 +3647,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { // A reshape that collapses multiple dimensions into a dimension being // reduced can just reduce all of those dimensions instead of doing a // collapsing reshape before a reduction. - if (arg->opcode() == HloOpcode::kReshape) { + if (options_.enable_reduce_of_reshape() && + arg->opcode() == HloOpcode::kReshape) { std::vector> unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(), arg->shape()); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index ce364a16134..4251e7eb846 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -107,6 +107,12 @@ class AlgebraicSimplifierOptions { return metadata_.cudnn_batchnorm_forward_training_metadata; } + void set_enable_reduce_of_reshape(bool enable_reduce_of_reshape) { + enable_reduce_of_reshape_ = enable_reduce_of_reshape; + } + + bool enable_reduce_of_reshape() const { return enable_reduce_of_reshape_; } + private: // Metadata struct can be used to store any metadata information encapsulated // with the AlgebraicSimplierOptions that can be later used in an @@ -126,6 +132,7 @@ class AlgebraicSimplifierOptions { bool enable_dot_to_multiply_rewrite_{true}; bool enable_conv_simplification_{true}; bool enable_window_reduce_to_reduce_replacement_{true}; + bool enable_reduce_of_reshape_{true}; int64 very_small_gather_size_{4}; Metadata metadata_; }; diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h index f95f982eb89..ac5e1b80651 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.h +++ b/tensorflow/compiler/xla/service/transpose_folding.h @@ -39,6 +39,13 @@ class TransposeFolding : public HloModulePass { const OperandIndices&) { return {}; } + + // Helper function to always fold transposes. + static OperandIndices AlwaysFoldTranspose(const HloInstruction&, + const OperandIndices& ids) { + return ids; + } + // transposable_gemm_operands returns the set of operands it wants to fold if // the instruction argument is implemented as a GEMM kernel that supports // transposing its arguments. @@ -47,8 +54,10 @@ class TransposeFolding : public HloModulePass { // the instruction argument is implemented as a convolution that supports // transposing its arguments. explicit TransposeFolding( - TransposableGemmOperandsFn transposable_gemm_operands, - TransposableConvOperandsFn transposable_conv_operands); + TransposableGemmOperandsFn transposable_gemm_operands = + AlwaysFoldTranspose, + TransposableConvOperandsFn transposable_conv_operands = + AlwaysFoldTranspose); absl::string_view name() const override { return "transpose-folding"; } StatusOr Run(HloModule* module) override;