From a7440c393a7b700fee1e3d16d9b7ce91a8471766 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Tue, 9 Jul 2019 16:20:43 -0700 Subject: [PATCH] [XLA] Extract a common visitor for rewriting instructions. PiperOrigin-RevId: 257293820 --- tensorflow/compiler/xla/service/BUILD | 1 + .../xla/service/algebraic_simplifier.cc | 42 ++------------- .../xla/service/batchnorm_expander.cc | 37 +------------- .../service/bfloat16_conversion_folding.cc | 1 + .../xla/service/bfloat16_normalization.cc | 1 + .../compiler/xla/service/call_inliner.cc | 1 + .../service/dfs_hlo_visitor_with_default.h | 51 +++++++++++++++++-- .../service/gpu/cudnn_batchnorm_rewriter.cc | 2 + .../compiler/xla/service/gpu/gemm_rewriter.cc | 23 ++------- .../compiler/xla/service/hlo_computation.h | 1 - .../compiler/xla/service/hlo_verifier.cc | 1 + .../xla/service/logical_buffer_analysis.h | 1 + .../compiler/xla/tools/hlo_extractor.cc | 1 + 13 files changed, 65 insertions(+), 98 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 2f74a1378cb..880ea14fed5 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -644,6 +644,7 @@ cc_library( hdrs = ["call_inliner.h"], deps = [ ":call_graph", + ":hlo", ":hlo_dce", ":hlo_pass", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 10030ce6491..53a2a57617c 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -168,13 +168,8 @@ bool IsUnstridedSlice(const HloInstruction* hlo) { // algebraic expressions to simplified forms. Note: This only supports // simplifications that simply look at the operands of an instruction. For the // more general case a worklist based approach would be needed. -class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { +class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { public: - // Default visitor action is to do nothing and return OK. - Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { - return Status::OK(); - } - Status HandleAdd(HloInstruction* add) override; Status HandleAnd(HloInstruction* logical_and) override; @@ -250,9 +245,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleMap(HloInstruction* map) override; - // Returns whether algebraic simplification has occurred. - const bool changed() const { return changed_; } - // Runs the visitor on a computation. static bool Run(HloComputation* computation, const AlgebraicSimplifierOptions& options, @@ -350,35 +342,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { StatusOr TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( HloInstruction* broadcast); - // Replaces the existing HLO instruction old_instruction, with - // new_instruction, and marks the optimizer status as changed. - // Returns the Status representing the result of the replace operation. - Status ReplaceWithNewInstruction( - HloInstruction* old_instruction, - std::unique_ptr new_instruction) { - VLOG(3) << "Replacing instruction:"; - VLOG(3) << " old: " << old_instruction->ToString(); - VLOG(3) << " new: " << new_instruction->ToString(); - TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( - old_instruction, std::move(new_instruction))); - changed_ = true; - return Status::OK(); - } - - // Replaces the existing HLO instruction old_instruction, with - // new_instruction, and marks the optimizer status as changed. - // Returns the Status representing the result of the replace operation. - Status ReplaceInstruction(HloInstruction* old_instruction, - HloInstruction* new_instruction) { - VLOG(3) << "Replacing instruction:"; - VLOG(3) << " old: " << old_instruction->ToString(); - VLOG(3) << " new: " << new_instruction->ToString(); - TF_RETURN_IF_ERROR( - computation_->ReplaceInstruction(old_instruction, new_instruction)); - changed_ = true; - return Status::OK(); - } - StatusOr OptimizeDotOfConcat(HloInstruction* dot); StatusOr OptimizeDotOfConcatHelper( const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim, @@ -445,7 +408,7 @@ bool AlgebraicSimplifierVisitor::Run(HloComputation* computation, AlgebraicSimplifier* simplifier) { AlgebraicSimplifierVisitor visitor(computation, options, simplifier); TF_CHECK_OK(computation->Accept(&visitor)); - return visitor.changed_; + return visitor.changed_ || visitor.changed(); } bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, @@ -1723,6 +1686,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( } Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { + CHECK(computation_ == dot->parent()); HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); if (options_.is_layout_sensitive()) { diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index d14e803be6a..131b50efc9c 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -46,13 +46,8 @@ using absl::optional; // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. -class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { +class BatchNormExpanderVisitor : public DfsHloRewriteVisitor { public: - // Default visitor action is to do nothing and return OK. - Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { - return Status::OK(); - } - Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormInference(HloInstruction* batch_norm) override; @@ -63,9 +58,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { static bool Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, bool rewrite_grad_op); - // Returns whether any batch norm ops were rewritten. - const bool changed() const { return changed_; } - ~BatchNormExpanderVisitor() override = default; private: @@ -133,28 +125,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { elements_per_feature_u32); } - // Replaces the existing HLO instruction old_instruction, with - // new_instruction, and marks the optimizer status as changed. - // Returns the Status representing the result of the replace operation. - Status ReplaceWithNewInstruction( - HloInstruction* old_instruction, - std::unique_ptr new_instruction) { - TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( - old_instruction, std::move(new_instruction))); - changed_ = true; - return Status::OK(); - } - - // Replaces the existing HLO instruction old_instruction, with - // new_instruction, and marks the optimizer status as changed. - // Returns the Status representing the result of the replace operation. - Status ReplaceInstruction(HloInstruction* old_instruction, - HloInstruction* new_instruction) { - TF_RETURN_IF_ERROR( - computation_->ReplaceInstruction(old_instruction, new_instruction)); - changed_ = true; - return Status::OK(); - } // Current HloComputation instance the BatchNormExpander is // traversing. HloComputation* computation_; @@ -162,9 +132,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { bool rewrite_training_op_; bool rewrite_inference_op_; bool rewrite_grad_op_; - - // Whether rewrite has occurred. - bool changed_ = false; }; } // namespace @@ -179,7 +146,7 @@ bool BatchNormExpanderVisitor::Run(HloComputation* computation, /*rewrite_inference_op=*/rewrite_inference_op, /*rewrite_grad_op=*/rewrite_grad_op); TF_CHECK_OK(computation->Accept(&visitor)); - return visitor.changed_; + return visitor.changed(); } Status BatchNormExpanderVisitor::HandleBatchNormTraining( diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 430172b474c..23d2a9225a8 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 85e1113bf77..f1ab34d6141 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 1d421404440..062110af867 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 7f900d9fc55..fcb68a200d9 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -20,6 +20,8 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -29,9 +31,6 @@ limitations under the License. namespace xla { -class HloComputation; -class HloInstruction; - // DfsHloVisitor with default action based on the HloInstruction being visited. // Users should not use this class directly, but use the type aliases // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead. @@ -246,6 +245,52 @@ using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase; using ConstDfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase; +// A common base class for visitors performing rewriting operation. +// +// Subclasses call ReplaceWithNewInstruction and ReplaceInstruction while +// visiting. +class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault { + public: + // Default visitor action is to do nothing and return OK. + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + bool changed() const { return changed_; } + + protected: + // Replaces the existing HLO instruction old_instruction, with + // new_instruction, and marks the optimizer status as changed. + // Returns the Status representing the result of the replace operation. + Status ReplaceWithNewInstruction( + HloInstruction* old_instruction, + std::unique_ptr new_instruction) { + VLOG(3) << "Replacing instruction:"; + VLOG(3) << " old: " << old_instruction->ToString(); + VLOG(3) << " new: " << new_instruction->ToString(); + TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction( + old_instruction, std::move(new_instruction))); + changed_ = true; + return Status::OK(); + } + + // Replaces the existing HLO instruction old_instruction, with + // new_instruction, and marks the optimizer status as changed. + // Returns the Status representing the result of the replace operation. + Status ReplaceInstruction(HloInstruction* old_instruction, + HloInstruction* new_instruction) { + VLOG(3) << "Replacing instruction:"; + VLOG(3) << " old: " << old_instruction->ToString(); + VLOG(3) << " new: " << new_instruction->ToString(); + TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceInstruction( + old_instruction, new_instruction)); + changed_ = true; + return Status::OK(); + } + + bool changed_ = false; +}; + // (Const)FunctionVisitor lets you transform an // std::function into a (Const)DfsHloVisitor. // diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc index 2cceb0422d0..4d61f09a7a9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc @@ -14,8 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" + #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index 36098cbfb72..df7ee3cdc69 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -61,7 +61,7 @@ static complex128 GetScalarConstantAsComplex(const Literal &literal) { // and provided C has no other users). // We then guide the buffer assignment to alias the buffer of the custom call // and C. -class GemmRewriterVisitor : public DfsHloVisitorWithDefault { +class GemmRewriterVisitor : public DfsHloRewriteVisitor { public: Status HandleDot(HloInstruction *instr) override { if (IsMatrixMultiplication(*instr)) { @@ -107,9 +107,7 @@ class GemmRewriterVisitor : public DfsHloVisitorWithDefault { config.set_alpha_real(new_alpha.real()); config.set_alpha_imag(new_alpha.imag()); TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(config)); - TF_RETURN_IF_ERROR( - instr->parent()->ReplaceInstruction(instr, existing_gemm)); - changed_ = true; + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, existing_gemm)); } } return Status::OK(); @@ -141,27 +139,12 @@ class GemmRewriterVisitor : public DfsHloVisitorWithDefault { } return Status::OK(); } - - Status DefaultAction(HloInstruction *) override { return Status::OK(); } - - bool IsChanged() { return changed_; } - - private: - Status ReplaceWithNewInstruction( - HloInstruction *instr, std::unique_ptr replacement) { - TF_RETURN_IF_ERROR(instr->parent()->ReplaceWithNewInstruction( - instr, std::move(replacement))); - changed_ = true; - return Status::OK(); - } - - bool changed_ = false; }; static StatusOr RunOnComputation(HloComputation *computation) { GemmRewriterVisitor visitor; TF_RETURN_IF_ERROR(computation->Accept(&visitor)); - return visitor.IsChanged(); + return visitor.changed(); } StatusOr GemmRewriter::Run(HloModule *module) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index fc4aaedde15..316c3514aeb 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 1de9a66adcb..fa2631bc364 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index 276a157a15a..5f774bb25a6 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" diff --git a/tensorflow/compiler/xla/tools/hlo_extractor.cc b/tensorflow/compiler/xla/tools/hlo_extractor.cc index f3ce5f99b0c..5d681f61ff6 100644 --- a/tensorflow/compiler/xla/tools/hlo_extractor.cc +++ b/tensorflow/compiler/xla/tools/hlo_extractor.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h"