From f8655c08cfe3bd99ec1703211e1c9154a14a6150 Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Wed, 19 Sep 2018 08:12:29 -0700 Subject: [PATCH] Add interface for HLO passes which run on HloModuleGroup. Derive HloModulePass and HloModuleGroupPass from HloPassInterface which run module-scoped and module-group-scoped respectively. Replace all existing uses of HloPassInterface with HloModulePass because all existing passes are module-scoped. Also rewrite HloPassPipeline to support both module-scoped and module-group-scoped passes. PiperOrigin-RevId: 213629604 --- tensorflow/compiler/xla/service/BUILD | 21 ++ .../xla/service/algebraic_simplifier.h | 2 +- .../xla/service/batch_dot_simplification.h | 2 +- .../compiler/xla/service/batchnorm_expander.h | 2 +- .../xla/service/bfloat16_conversion_folding.h | 2 +- .../xla/service/bfloat16_normalization.h | 4 +- .../xla/service/bfloat16_propagation.h | 2 +- .../compiler/xla/service/call_inliner.h | 2 +- .../xla/service/conditional_simplifier.h | 2 +- .../convolution_feature_group_converter.h | 2 +- .../compiler/xla/service/copy_insertion.h | 2 +- .../xla/service/cpu/conv_canonicalization.h | 2 +- .../xla/service/cpu/cpu_copy_insertion.h | 2 +- .../xla/service/cpu/cpu_hlo_support_checker.h | 2 +- .../service/cpu/parallel_task_assignment.h | 2 +- tensorflow/compiler/xla/service/defuser.h | 2 +- .../compiler/xla/service/despecializer.cc | 2 +- .../compiler/xla/service/despecializer.h | 2 +- .../compiler/xla/service/dot_decomposer.h | 2 +- .../compiler/xla/service/flatten_call_graph.h | 2 +- .../compiler/xla/service/gather_expander.h | 2 +- .../service/gpu/cudnn_batchnorm_rewriter.h | 2 +- .../gpu/cudnn_convolution_algorithm_picker.h | 2 +- .../service/gpu/cudnn_convolution_rewriter.h | 2 +- .../compiler/xla/service/gpu/fusion_merger.h | 2 +- .../xla/service/gpu/gpu_copy_insertion.cc | 9 - .../xla/service/gpu/gpu_copy_insertion.h | 11 +- .../xla/service/gpu/gpu_hlo_support_checker.h | 2 +- .../xla/service/gpu/pad_for_tensor_cores.h | 2 +- .../compiler/xla/service/gpu/pad_insertion.h | 2 +- .../xla/service/hlo_constant_folding.h | 2 +- tensorflow/compiler/xla/service/hlo_cse.h | 2 +- tensorflow/compiler/xla/service/hlo_dce.h | 2 +- .../xla/service/hlo_domain_isolator.h | 2 +- .../compiler/xla/service/hlo_domain_remover.h | 2 +- .../xla/service/hlo_domain_verifier.h | 2 +- .../xla/service/hlo_element_type_converter.h | 2 +- .../xla/service/hlo_memory_scheduler.h | 4 +- .../compiler/xla/service/hlo_module_dce.h | 2 +- .../compiler/xla/service/hlo_pass_interface.h | 35 ++- .../compiler/xla/service/hlo_pass_pipeline.cc | 195 +++++++------ .../compiler/xla/service/hlo_pass_pipeline.h | 38 ++- .../xla/service/hlo_pass_pipeline_test.cc | 259 ++++++++++++++++++ .../xla/service/hlo_rematerialization.cc | 6 + .../xla/service/hlo_rematerialization.h | 2 +- .../service/hlo_subcomputation_unification.h | 2 +- .../compiler/xla/service/hlo_verifier.h | 2 +- .../xla/service/implicit_broadcast_remover.h | 2 +- .../xla/service/indexed_array_analysis.h | 2 +- tensorflow/compiler/xla/service/inliner.h | 2 +- .../compiler/xla/service/instruction_fusion.h | 2 +- .../compiler/xla/service/layout_assignment.h | 2 +- .../xla/service/multi_output_fusion.h | 2 +- .../xla/service/reduce_precision_insertion.h | 2 +- .../compiler/xla/service/reshape_mover.h | 2 +- .../compiler/xla/service/scatter_expander.h | 2 +- .../compiler/xla/service/transpose_folding.h | 2 +- .../compiler/xla/service/tuple_simplifier.h | 2 +- .../xla/service/while_loop_constant_sinking.h | 2 +- .../while_loop_invariant_code_motion.h | 2 +- .../xla/service/while_loop_simplifier.h | 2 +- .../xla/service/zero_sized_hlo_elimination.h | 2 +- 62 files changed, 520 insertions(+), 166 deletions(-) create mode 100644 tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 68bf56c1b14..4c3208a2422 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2560,6 +2560,7 @@ cc_library( ], deps = [ ":hlo", + ":hlo_module_group", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -2591,6 +2592,26 @@ cc_library( ], ) +tf_cc_test( + name = "hlo_pass_pipeline_test", + srcs = ["hlo_pass_pipeline_test.cc"], + deps = [ + ":hlo", + ":hlo_parser", + ":hlo_pass_pipeline", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_cse", srcs = ["hlo_cse.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index b864c372fa5..9f8d0ee88bd 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -24,7 +24,7 @@ limitations under the License. namespace xla { // A pass which performs algebraic simplifications. -class AlgebraicSimplifier : public HloPassInterface { +class AlgebraicSimplifier : public HloModulePass { public: // Given shapes 'from_shape' and 'to_shape', determines if it is valid to // bitcast from 'from_shape' to 'to_shape' after considering platform diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h index 79d37f08d35..5b625bf3b98 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.h +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h @@ -25,7 +25,7 @@ namespace xla { // Normally these would live in the algebraic simplifier, but we want to run // this to fixpoint (this pass reaches fixed point in one execution) before we // run the DotDecomposer. -class BatchDotSimplification : public HloPassInterface { +class BatchDotSimplification : public HloModulePass { public: StatusOr Run(HloModule* module) override; absl::string_view name() const override; diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 76e32174f3e..147f3ae7b6d 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -26,7 +26,7 @@ namespace xla { // A pass which rewrites batch norm operations into more operations. Breaking a // big operation into smaller operations helps leverage our generic fusion // logic. -class BatchNormExpander : public HloPassInterface { +class BatchNormExpander : public HloModulePass { public: // When use_fusion is set, a multi-output fusion node is created. BatchNormExpander(bool rewrite_training_op = false, diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h index 5dcd31b83d2..cb3d12f0bfd 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h @@ -31,7 +31,7 @@ namespace xla { // optimization pipeline followed by a DCE pass. If other passes are needed // after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the // changed made by this pass. -class BFloat16ConversionFolding : public HloPassInterface { +class BFloat16ConversionFolding : public HloModulePass { public: explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support) : bfloat16_support_(bfloat16_support) {} diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h index 30b63463127..f48e925823c 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.h +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h @@ -25,7 +25,7 @@ namespace xla { // A pass which adds F32 <-> BF16 conversions for HLO instructions that do not // support BF16 input/output or mixed precision, according to the passed-in // backend-specific BF16 support rules. -class BFloat16Normalization : public HloPassInterface { +class BFloat16Normalization : public HloModulePass { public: explicit BFloat16Normalization(const BFloat16Support* bfloat16_support) : bfloat16_support_(bfloat16_support) {} @@ -48,7 +48,7 @@ class BFloat16Normalization : public HloPassInterface { // use mixed precision; it removes mixed precision even if the backend supports // it. This pass is used to make the HLO module valid for other HLO passes which // do not support mixed precision. -class BFloat16MixedPrecisionRemoval : public HloPassInterface { +class BFloat16MixedPrecisionRemoval : public HloModulePass { public: BFloat16MixedPrecisionRemoval() {} diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 1ee64971ab5..6a62439f887 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -58,7 +58,7 @@ namespace xla { // BFloat16ConversionFolding. If other passes are needed after this pass, run // BFloat16MixedPrecisionRemoval first to undo some of the changes made by this // pass. -class BFloat16Propagation : public HloPassInterface { +class BFloat16Propagation : public HloModulePass { public: explicit BFloat16Propagation(const BFloat16Support* bfloat16_support); diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index c5cd88b9ea2..08c4aff4f7f 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -25,7 +25,7 @@ namespace xla { // For every kCall operation in the main computation, we inline the body of the // called function, and proceed recursively. -class CallInliner : public HloPassInterface { +class CallInliner : public HloModulePass { public: using InlinedInstructionMap = std::unordered_map; diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h index 3de50cbd7ff..2223ad67534 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.h +++ b/tensorflow/compiler/xla/service/conditional_simplifier.h @@ -25,7 +25,7 @@ namespace xla { // HLO pass that removes kConditional with a constant predicate, replacing them // with their true or false computation as appropriate. -class ConditionalSimplifier : public HloPassInterface { +class ConditionalSimplifier : public HloModulePass { public: absl::string_view name() const override { return "simplify-conditional"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h index 498894737fa..ce0138e56fb 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h @@ -25,7 +25,7 @@ namespace xla { // A pass which rewrites convolutions with feature_group_count > 1 into // convolutions with feature_group_count = 1. -class ConvolutionFeatureGroupConverter : public HloPassInterface { +class ConvolutionFeatureGroupConverter : public HloModulePass { public: ConvolutionFeatureGroupConverter() {} diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index d308f6bc846..c097089e30d 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -43,7 +43,7 @@ namespace xla { // (3) The buffer set of the root instruction of the entry computation must be // unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and // InstructionAliasSet::IsDistinct return true. -class CopyInsertion : public HloPassInterface { +class CopyInsertion : public HloModulePass { public: absl::string_view name() const override { return "copy-insertion"; } diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h index 59437e88af2..becee3f81fc 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -31,7 +31,7 @@ namespace cpu { // called canonical convolutions). This pass expands non-canonical convolutions // into reshapes and canonical convolutions, so that these non-canonical // convolutions can run faster. -class ConvCanonicalization : public HloPassInterface { +class ConvCanonicalization : public HloModulePass { public: explicit ConvCanonicalization( const TargetMachineFeatures* target_machine_features) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h index d49f7d7cc2d..076235f8874 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h @@ -30,7 +30,7 @@ namespace xla { // // TODO(b/62548313): Remove this when buffer assignment is smarter // (module-scoped). -class CpuCopyInsertion : public HloPassInterface { +class CpuCopyInsertion : public HloModulePass { public: absl::string_view name() const override { return "copy-insertion"; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h index 6af724b2a5d..a39a9d47246 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h @@ -23,7 +23,7 @@ namespace xla { // This pass should run early in the HLO pipeline and checks for HLO constructs // which are not supported by the CPU backend and cannot be removed via HLO // transformations (eg, sparse layouts). -class CpuHloSupportChecker : public HloPassInterface { +class CpuHloSupportChecker : public HloModulePass { public: CpuHloSupportChecker() = default; ~CpuHloSupportChecker() override = default; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index a99cd99c14a..3822d5300e3 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -60,7 +60,7 @@ class ParallelTaskAssignment { // own embedded computation, which is compiled as a parallel compute function, // and which is invoked from a kCall instruction that is lowered in codegen to // a runtime parallel fork/join call. -class ParallelTaskAssigner : public HloPassInterface { +class ParallelTaskAssigner : public HloModulePass { public: // 'max_parallelism': the maximum parallel task count per instruction. // 'shape_size': shape size function used by HloCostAnalysis during parallel diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h index c326beb899f..aaa41fc4fe7 100644 --- a/tensorflow/compiler/xla/service/defuser.h +++ b/tensorflow/compiler/xla/service/defuser.h @@ -25,7 +25,7 @@ namespace xla { // A pass which replaces all fusion instructions with the equivalent un-fused // instructions. -class Defuser : public HloPassInterface { +class Defuser : public HloModulePass { public: Defuser() {} ~Defuser() override {} diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index ba2a674d9af..b3549acfc29 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -24,7 +24,7 @@ namespace xla { namespace { // Pass which strips control dependencies from all instructions in the module. -class ControlDepRemover : public HloPassInterface { +class ControlDepRemover : public HloModulePass { public: ControlDepRemover() = default; absl::string_view name() const override { return "control-dep-remover"; } diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h index 7be70add2f7..46dcc3a438c 100644 --- a/tensorflow/compiler/xla/service/despecializer.h +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -30,7 +30,7 @@ namespace xla { // // Current despecialization passes are Defuser, ImplicitBroadcastRemover, // and BFloat16MixedPrecisionRemoval. -class Despecializer : public HloPassInterface { +class Despecializer : public HloModulePass { public: Despecializer(); absl::string_view name() const override { return "despecializer"; } diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h index fc38e317001..40e7a3b4c25 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.h +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -23,7 +23,7 @@ namespace xla { // DotDecomposer is a pass which decomposes batch Dot operations into a // sequence of smaller (R2) Dot operations. -class DotDecomposer : public HloPassInterface { +class DotDecomposer : public HloModulePass { public: // Decomposes batch Dot operations when 'decompose_batch_dot' is true. DotDecomposer(bool decompose_batch_dot = true) diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h index 3cccec9862e..986970f8862 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.h +++ b/tensorflow/compiler/xla/service/flatten_call_graph.h @@ -26,7 +26,7 @@ namespace xla { // Flattening associates each call site with a unique computation (for // sequential calling contexts) This simplifies buffer assignment and // points-to analysis (see b/36865746 for details). -class FlattenCallGraph : public HloPassInterface { +class FlattenCallGraph : public HloModulePass { public: absl::string_view name() const override { return "flatten-call-graph"; } diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index 7bd9ea59841..2b39359aae9 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -23,7 +23,7 @@ namespace xla { // This pass rewrites gather operations into (roughly) while loops of dynamic // slices. This lets backends that don't support gather directly to // nevertheless have a minimum level of support. -class GatherExpander : public HloPassInterface { +class GatherExpander : public HloModulePass { public: absl::string_view name() const override { return "gather_expander"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h index 6e2e330edd4..c3f58508ddd 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h @@ -52,7 +52,7 @@ namespace gpu { // The GPU backend does not implement a lowering for the batchnorm HLOs -- it // expects them to be lowered to cudnn calls via this pass or to HLO soup via // BatchNormRewriter. -class CudnnBatchNormRewriter : public HloPassInterface { +class CudnnBatchNormRewriter : public HloModulePass { public: absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index f79b113f8fa..ce0189543c9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -30,7 +30,7 @@ namespace gpu { // Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for // each and adding explicit scratch space to the CustomCalls. -class CudnnConvolutionAlgorithmPicker : public HloPassInterface { +class CudnnConvolutionAlgorithmPicker : public HloModulePass { public: // If the `allocator` parameter is not null, we will use it to allocate temp // memory while timing the various convolution algorithms. If it's null, diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h index fbe7e984945..8d7c6fdab51 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h @@ -24,7 +24,7 @@ namespace gpu { // Rewrites plain convolutions, backwards-filter convolutions, and // backwards-input convolutions into CustomCall HLOs that call into cuDNN. -class CudnnConvolutionRewriter : public HloPassInterface { +class CudnnConvolutionRewriter : public HloModulePass { public: absl::string_view name() const override { return "cudnn-convolution-rewriter"; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h index 7e3f5775b8d..f19996edfe3 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h @@ -32,7 +32,7 @@ namespace gpu { // 2) The result of merging the fusion instruction into its users would not // increase bytes transferred. // -class FusionMerger : public HloPassInterface { +class FusionMerger : public HloModulePass { public: absl::string_view name() const override { return "fusion merger"; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 75f414e47fe..79c74e7e8bf 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -34,15 +34,6 @@ namespace xla { namespace gpu { -StatusOr GpuCopyInsertion::FindOrInsertCopy( - HloInstruction* hlo) { - HloInstruction*& copy = hlo_to_copy_map_[hlo]; - if (copy == nullptr) { - TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo)); - } - return copy; -} - StatusOr GpuCopyInsertion::Run(HloModule* module) { CopyInsertion generic_copy_insertion; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h index 8ffae18fe82..4c7e38ffeb6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h @@ -25,20 +25,11 @@ namespace gpu { // Besides the modifications made by the generic xla::CopyInsertion, this // GPU-specific copy insertion also materializes operands of library calls by // inserting kCopy instructions. -class GpuCopyInsertion : public HloPassInterface { +class GpuCopyInsertion : public HloModulePass { public: absl::string_view name() const override { return "copy-insertion"; } StatusOr Run(HloModule* module) override; - - protected: - // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making - // duplicate copies. - StatusOr FindOrInsertCopy(HloInstruction* hlo); - - // A map containing all copies inserted to materialize operands of library - // calls. The key is the copied instruction and the value is the copy. - tensorflow::gtl::FlatMap hlo_to_copy_map_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h index bbb3340760c..9c64b4d10c9 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -23,7 +23,7 @@ namespace xla { // his pass should run early in the HLO pipeline and checks for HLO constructs // which are not supported by the GPU backend and cannot be removed via HLO // transformations (eg, sparse layouts). -class GpuHloSupportChecker : public HloPassInterface { +class GpuHloSupportChecker : public HloModulePass { public: GpuHloSupportChecker() = default; ~GpuHloSupportChecker() override = default; diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h index 11dc56a64fd..e592a3774ec 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h @@ -30,7 +30,7 @@ namespace gpu { // targeting before running this pass. // // TODO(jlebar): Also pad dots. -class PadForTensorCores : public HloPassInterface { +class PadForTensorCores : public HloModulePass { public: absl::string_view name() const override { return "pad for tensor cores"; } diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h index a622e894ed9..25cdf64c4cf 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -24,7 +24,7 @@ namespace gpu { // An HLO pass that canonicalizes convolution instructions for GPU codegen. It // inserts Pad instructions before Convolution instructions with uncanonicalized // padding, so that they can be lowered to cuDNN convolution. -class PadInsertion : public HloPassInterface { +class PadInsertion : public HloModulePass { public: absl::string_view name() const override { return "pad insertion"; } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h index 4557983a9c0..4a624cc7b84 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.h +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h @@ -23,7 +23,7 @@ namespace xla { // A pass which performs constant folding in order to avoid unnecessary // computation on constants. -class HloConstantFolding : public HloPassInterface { +class HloConstantFolding : public HloModulePass { public: absl::string_view name() const override { return "constant_folding"; } diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h index a28c03599a8..e4857fd3fdd 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.h +++ b/tensorflow/compiler/xla/service/hlo_cse.h @@ -25,7 +25,7 @@ namespace xla { // and identical instructions with the same operands are commoned. The pass // iterates over the instructions in topological order which enables the pass to // find arbitrarily large common expressions. -class HloCSE : public HloPassInterface { +class HloCSE : public HloModulePass { public: // If is_layout_sensitive is true, then the simplifier preserves layout during // transformation. Otherwise, layout is ignored. diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h index 1fe69b13957..40120426728 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.h +++ b/tensorflow/compiler/xla/service/hlo_dce.h @@ -33,7 +33,7 @@ namespace xla { // // This pass does not remove dead parameter instructions, as parameter // instructions cannot be deleted. -class HloDCE : public HloPassInterface { +class HloDCE : public HloModulePass { public: ~HloDCE() override {} absl::string_view name() const override { return "dce"; } diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index d36631fc2f1..c0bf1b9e16b 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -30,7 +30,7 @@ namespace xla { // used to break an HLO graph edge connecting two instructions with different // sharding. If a set of connected instructions have all the same sharding, no // kDomain instruction will be placed. -class HloDomainIsolator : public HloPassInterface { +class HloDomainIsolator : public HloModulePass { public: // Creates a new kDomain instruction for the edge between the use instruction // (the first HloInstruction argument), and the operand instruction (the diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h index 97bc8ef6040..0fc30fb86c3 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.h +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -26,7 +26,7 @@ namespace xla { // Removes all the kDomain instructions of a given kind from the input module, // and calls the normalizer to propagate the properties on the possibly new born // instructions. -class HloDomainRemover : public HloPassInterface { +class HloDomainRemover : public HloModulePass { public: // Creates a new HloDomainRemover object tasked at removing all the kDomain // instructions of a given kind. diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h index 81d6d69a8c5..bea5cba38d0 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h @@ -29,7 +29,7 @@ namespace xla { // Verifies that the domain instructions are consistent, and the each domain is // surrounded by the same metadata. -class HloDomainVerifier : public HloPassInterface { +class HloDomainVerifier : public HloModulePass { public: HloDomainVerifier(std::vector kinds) : kinds_(std::move(kinds)) {} diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h index 44ded2c2faf..4d2a9429252 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -25,7 +25,7 @@ namespace xla { // inserting Convert ops. This allows a backend to support an element type while // only actually implementing the Convert op for that element type. This is // generally not the fastest approach, but it works. -class HloElementTypeConverter : public HloPassInterface { +class HloElementTypeConverter : public HloModulePass { public: // eliminate_type is the type to eliminate as the input or output of ops, // using Convert ops to replace it with replace_with_type. diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index 5e02868ebad..9964c6fdd7c 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -90,7 +90,7 @@ StatusOr ScheduleComputation( // A pass which schedules the HLO instructions in a module. The HloModule's // schedule field is set to the resulting HloSchedule using // HloModule::set_schedule. -class HloMemoryScheduler : public HloPassInterface { +class HloMemoryScheduler : public HloModulePass { public: // size_function is the function returning the number of bytes required for a // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not @@ -109,7 +109,7 @@ class HloMemoryScheduler : public HloPassInterface { // A trivial pass which clears the schedule currently set on the // HloModule. After this pass runs HloModudle::has_schedule will return false. -class HloDescheduler : public HloPassInterface { +class HloDescheduler : public HloModulePass { public: HloDescheduler() = default; ~HloDescheduler() override = default; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h index 12ca2340a6c..d472211d2af 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.h +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -28,7 +28,7 @@ namespace xla { // Sweeps through live instructions which cross computation boundaries (kWhile), // and removes code at dead shape indices. // -class HloModuleDCE : public HloPassInterface { +class HloModuleDCE : public HloModulePass { public: ~HloModuleDCE() override {} absl::string_view name() const override { return "hlo-module-dce"; } diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h index f1ad0f9b014..fdaac34386c 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_interface.h +++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -25,15 +26,45 @@ limitations under the License. namespace xla { // Base class for HLO passes. These are used with the HloPassPipeline to -// organize a sequence of passes. +// organize a sequence of passes. An HLO pass should not extend this class +// directly; it should extend HloModulePass or HloModuleGroupPass. class HloPassInterface { public: virtual ~HloPassInterface() = default; virtual absl::string_view name() const = 0; - // Run the pass on the given HLO module. Return whether it modified the + // Run the pass on the given HLO module. Returns whether it modified the // module. virtual StatusOr Run(HloModule* module) = 0; + + // Run the pass on the given HLO module group. Returns whether it modified the + // module group. Ideally, the module group variant would be named "Run" as + // well, but C++ does not handle overloaded virtual methods well. + virtual StatusOr RunOnModuleGroup(HloModuleGroup* module_group) = 0; +}; + +// Base class for passes which are module-scoped. +class HloModulePass : public HloPassInterface { + public: + // Runs the pass on a module group by iterating through each module in the + // group. + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override { + bool changed = false; + for (HloModule* module : module_group->modules()) { + TF_ASSIGN_OR_RETURN(bool module_changed, Run(module)); + changed |= module_changed; + } + return changed; + }; +}; + +// Base class for passes which are module-group scoped. These passes cannot run +// on an HLO module. +class HloModuleGroupPass : public HloPassInterface { + public: + StatusOr Run(HloModule* module) override { + return InternalError("Module group pass cannot be run on a module"); + } }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 6e4ed0de626..8c2f928ca10 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" @@ -29,108 +28,128 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { -namespace { -using absl::StrAppend; -using absl::StrCat; +template +Status HloPassPipeline::RunInvariantCheckers( + HloT* hlo, absl::string_view after_pass_name) { + for (auto& invariant_checker : invariant_checkers_) { + VLOG(1) << " Invariant checker " << invariant_checker->name(); + StatusOr changed_status = RunHelper(invariant_checker.get(), hlo); + VLOG(1) << " Invariant checker done " << invariant_checker->name(); + if (!changed_status.ok()) { + VLOG(2) << "Failed invariant check:"; + XLA_VLOG_LINES(2, hlo->ToString()); + return Status(changed_status.status().code(), + absl::StrCat(changed_status.status().error_message(), + "\n\nFailed after ", after_pass_name)); + } + TF_RET_CHECK(!changed_status.ValueOrDie()) + << "invariant checkers must not change the graph"; + } + return Status::OK(); +} -void DumpModuleGraph(const HloModule& module, const string& message) { +template +StatusOr HloPassPipeline::RunPassesInternal( + HloT* hlo, absl::Span passes) { + string last_pass_name = "pipeline-start"; + TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name)); + bool changed = false; + for (HloPassInterface* pass : passes) { + VLOG(1) << " HLO pass " << pass->name(); + MaybeDumpHlo(*hlo, + /*after_pass_name=*/last_pass_name, + /*before_pass_name=*/pass->name()); + TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo)); + changed |= pass_changed; + TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name())); + last_pass_name = string(pass->name()); + } + MaybeDumpHlo(*hlo, + /*after_pass_name=*/last_pass_name, + /*before_pass_name=*/"pipeline-end"); + return changed; +} + +std::vector HloPassPipeline::GetEnabledPasses( + const DebugOptions& debug_options) { + auto repeated_field = debug_options.xla_disable_hlo_passes(); + tensorflow::gtl::FlatSet disabled_pass_names(repeated_field.begin(), + repeated_field.end()); + if (!disabled_pass_names.empty()) { + VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " + << absl::StrJoin(disabled_pass_names, ", "); + } + + std::vector enabled_passes; + for (auto& pass : passes_) { + if (disabled_pass_names.count(string(pass->name())) == 0) { + enabled_passes.push_back(pass.get()); + } + } + return enabled_passes; +} + +void HloPassPipeline::MaybeDumpHlo(const HloModule& module, + absl::string_view after_pass_name, + absl::string_view before_pass_name) { + const string& proto_dump_path = + module.config().debug_options().xla_dump_per_pass_hlo_proto_to(); + if (!proto_dump_path.empty()) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static auto* const module_id_to_pass_number = + new tensorflow::gtl::FlatMap(); + + tensorflow::mutex_lock lock(mu); + const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; + + const string filename = SanitizeFileName( + absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(), + pass_number, name(), after_pass_name)); + + TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory( + MakeHloProto(module), proto_dump_path, filename)); + } + + const string message = + StrCat("after ", after_pass_name, ", before ", before_pass_name); hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; XLA_VLOG_LINES(3, module.ToString()); } -void DumpModuleProto(const HloModule& module, const string& dump_to, - const string& pipeline_name, const string& pass_name) { - static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - static auto* const module_id_to_pass_number = - new tensorflow::gtl::FlatMap(); - - tensorflow::mutex_lock lock(mu); - const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; - - const string mod_name = SanitizeFileName( - absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(), - pass_number, pipeline_name, pass_name)); - - TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module), - dump_to, mod_name)); +void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group, + absl::string_view after_pass_name, + absl::string_view before_pass_name) { + for (const HloModule* module : module_group.modules()) { + MaybeDumpHlo(*module, after_pass_name, before_pass_name); + } } -} // namespace StatusOr HloPassPipeline::Run(HloModule* module) { run_called_ = true; - VLOG(1) << "Running HLO pass pipeline " << name(); + VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": " + << name(); - auto repeated_field = - module->config().debug_options().xla_disable_hlo_passes(); - tensorflow::gtl::FlatSet disabled_passes(repeated_field.begin(), - repeated_field.end()); - if (!disabled_passes.empty()) { - VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " - << absl::StrJoin(disabled_passes, ", "); + return RunPassesInternal(module, + GetEnabledPasses(module->config().debug_options())); +} + +StatusOr HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) { + run_called_ = true; + + VLOG(1) << "Running HLO pass pipeline on module group " + << module_group->name() << ": " << name(); + + if (module_group->modules().empty()) { + VLOG(1) << "Module group is empty. Nothing to do."; + return false; } - auto run_invariant_checkers = [this, - module](const string& message) -> Status { - for (auto& invariant_checker : invariant_checkers_) { - VLOG(1) << " Invariant checker " << invariant_checker->name(); - StatusOr changed_status = invariant_checker->Run(module); - VLOG(1) << " Invariant checker done " << invariant_checker->name(); - if (!changed_status.ok()) { - VLOG(2) << "Module failed invariant check:"; - XLA_VLOG_LINES(2, module->ToString()); - return Status(changed_status.status().code(), - StrCat(changed_status.status().error_message(), - "\n\nFailed ", message)); - } - TF_RET_CHECK(!changed_status.ValueOrDie()) - << "invariant checkers must not change the graph"; - } - return Status::OK(); - }; - - string prefix = StrCat(name(), ": pipeline start"); - bool changed = false; - string message; - TF_RETURN_IF_ERROR( - run_invariant_checkers(StrCat("before running pipeline: ", name()))); - const string xla_dump_per_pass_hlo_proto_to = - module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); - if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()), - "pipeline_start"); - } - - for (auto& pass : passes_) { - if (disabled_passes.count(string(pass->name())) > 0) { - VLOG(1) << " Skipping HLO pass " << pass->name() - << ", disabled by --xla_disable_hlo_passes"; - continue; - } - - VLOG(1) << " HLO pass " << pass->name(); - - // Emit label containing: "after foo-pass, before bar-pass". - message.clear(); - StrAppend(&message, prefix, ", before ", pass->name()); - DumpModuleGraph(*module, message); - - TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); - TF_RETURN_IF_ERROR( - run_invariant_checkers(StrCat("after running pass: ", pass->name()))); - if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()), - string(pass->name())); - } - - changed |= changed_this_pass; - prefix.clear(); - StrAppend(&prefix, name(), ": after ", pass->name()); - } - DumpModuleGraph(*module, prefix + ", pipeline end"); - return changed; + return RunPassesInternal( + module_group, + GetEnabledPasses(module_group->module(0).config().debug_options())); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index 1d41a4dac1d..09e7033ea4e 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -61,10 +62,45 @@ class HloPassPipeline : public HloPassInterface { return *pass; } - // Run all passes on the given HLO module. StatusOr Run(HloModule* module) override; + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override; private: + // Returns the set of passes which are enabled. DebugOptions can selectively + // disable passes via --xla_disable_hlo_passes flag. + std::vector GetEnabledPasses( + const DebugOptions& debug_options); + + // Maybe dumps the given module or module group depending on flag values + // contained in DebugOptions of module config. + void MaybeDumpHlo(const HloModuleGroup& module_group, + absl::string_view after_pass_name, + absl::string_view before_pass_name); + void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name, + absl::string_view before_pass_name); + + // Runs the invariant checker on the given HLO. HloT can be either HloModule + // or HloModuleGroup. + template + Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name); + + // Helper which runs the given pass on the given HLO. HloT can be either + // HloModule or HloModuleGroup. + template + StatusOr RunPassesInternal(HloT* hlo, + absl::Span passes); + + // Helpers which run the given passes on the given HLO construct. These + // helpers enable templating of the core of the pipeline logic by providing + // HloModule and HloModuleGroup specific methods with the same name. + static StatusOr RunHelper(HloPassInterface* pass, HloModule* module) { + return pass->Run(module); + } + static StatusOr RunHelper(HloPassInterface* pass, + HloModuleGroup* module_group) { + return pass->RunOnModuleGroup(module_group); + } + const string name_; std::vector> passes_; std::vector> invariant_checkers_; diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc new file mode 100644 index 00000000000..e16b4d4c0a4 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc @@ -0,0 +1,259 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloPassPipelineTest : public HloTestBase { + protected: + StatusOr ParseModuleGroup( + absl::Span hlo_strings) { + HloModuleGroup group(TestName()); + for (const string& hlo_string : hlo_strings) { + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + ParseHloString(hlo_string)); + group.push_back(std::move(module)); + } + return std::move(group); + } +}; + +// A module pass which renames instructions named 'foo' to 'bar'. +class FooToBarModulePass : public HloModulePass { + absl::string_view name() const override { return "foo2bar"; } + + StatusOr Run(HloModule* module) override { + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "foo") { + instruction->SetAndSanitizeName("bar"); + changed = true; + } + } + } + return changed; + } +}; + +// A module group pass which renames instructions named 'baz' to 'qux'. +class BazToQuxModuleGroupPass : public HloModuleGroupPass { + absl::string_view name() const override { return "baz2qux"; } + + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override { + bool changed = false; + for (HloModule* module : module_group->modules()) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "baz") { + instruction->SetAndSanitizeName("qux"); + changed = true; + } + } + } + } + return changed; + } +}; + +// An invariant checker pass which returns an error if there exists an +// instruction named 'bar'. +class BarBlowerUpper : public HloModulePass { + absl::string_view name() const override { return "bar-blower-upper"; } + + StatusOr Run(HloModule* module) override { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "bar") { + return InternalError("Module has instruction named bar"); + } + } + } + return false; + } +}; + +TEST_F(HloPassPipelineTest, ModulePassChanged) { + // Test an HLO module pass which changes a module. + const string module_str = R"( +HloModule ModulePassChanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->name(), "foo"); + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_EQ(root->name(), "bar"); +} + +TEST_F(HloPassPipelineTest, ModulePassUnchanged) { + // Test an HLO module pass which does not change a module. + const string module_str = R"( +HloModule ModulePassUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT blahblah = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(HloPassPipelineTest, MixedPipeline) { + // Test a pipeline with both a module pass and a module group pass. + const string module_0_str = R"( +HloModule MixedPipeline.1 + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT baz = f32[] multiply(a, b) +} +)"; + const string module_1_str = R"( +HloModule MixedPipeline.0 + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group, + ParseModuleGroup({module_0_str, module_1_str})); + + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + pipeline.AddPass(); + + HloInstruction* root0 = + module_group.module(0).entry_computation()->root_instruction(); + HloInstruction* root1 = + module_group.module(1).entry_computation()->root_instruction(); + EXPECT_EQ(root0->name(), "baz"); + EXPECT_EQ(root1->name(), "foo"); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + pipeline.RunOnModuleGroup(&module_group)); + EXPECT_TRUE(changed); + + EXPECT_EQ(root0->name(), "qux"); + EXPECT_EQ(root1->name(), "bar"); +} + +TEST_F(HloPassPipelineTest, InvariantChecker) { + const string module_str = R"( +HloModule InvariantChecker + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + { + // Run a pipeline with just the invariant checker. It should not fail + // because there is no 'bar' instruction in the module. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker(); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_FALSE(changed); + } + + { + // Run a pipeline which renames 'foo' to 'bar' then an invariant checker + // which fails if there is an instruction named 'bar'. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker(); + pipeline.AddPass(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Module has instruction named bar")); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Failed after foo2bar")); + } + + { + // Run the invariant-checker only pipeline again. It should fail this time. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Module has instruction named bar")); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Failed after pipeline-start")); + } +} + +TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) { + // Running a module group pass on a module should produce an error. + const string module_str = R"( +HloModule ModuleGroupPassOnModule + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Module group pass cannot be run on a module")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index bd6dd79b679..a4386719362 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1198,6 +1198,12 @@ StatusOr HloRematerialization::Run(HloModule* module) { << HumanReadableNumBytes(memory_limit_bytes_); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); + // Initialize pass object state. + computation_peak_memory_.clear(); + rematerialized_computations_.clear(); + instructions_rematerialized_ = 0; + net_instructions_added_ = 0; + TF_RET_CHECK(module->has_schedule()); TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index e2aaf18b3e4..7330d73c09e 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -33,7 +33,7 @@ namespace xla { // CSE will undo the effects of this optimization and should not be run after // this pass. In general, this pass should be run very late, immediately before // code generation. -class HloRematerialization : public HloPassInterface { +class HloRematerialization : public HloModulePass { public: using ShapeSizeFunction = std::function; diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h index d1cf644f827..fa34bddde1a 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h @@ -22,7 +22,7 @@ namespace xla { // Unify subcomputations of a `HloModule`: if any computations are equal, choose // one arbitrarily to use and delete the others. -class HloSubcomputationUnification : public HloPassInterface { +class HloSubcomputationUnification : public HloModulePass { public: absl::string_view name() const override { return "subcomputation-unification"; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 42e3027bf14..0cde4a31af7 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -151,7 +151,7 @@ class ShapeVerifier : public DfsHloVisitor { // HLO pass that verifies invariants of HLO instructions for each computation in // the module. -class HloVerifier : public HloPassInterface { +class HloVerifier : public HloModulePass { public: using ShapeVerifierFactory = std::function()>; diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h index 85bb4a8b245..9c48b7db613 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h @@ -25,7 +25,7 @@ namespace xla { // Pass which replaces all implicit broadcasts with their equivalent sequence of // explicit broadcast and reshape instructions. -class ImplicitBroadcastRemover : public HloPassInterface { +class ImplicitBroadcastRemover : public HloModulePass { public: ImplicitBroadcastRemover() {} ~ImplicitBroadcastRemover() override {} diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index df9cbab915c..3e238f97a03 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -366,7 +366,7 @@ class IndexedArrayAnalysis { // A pass that prints all non-trivial results returned by IndexedArrayAnalysis. // This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to // unconditionally add to the regular HLO pass pipeline. -class IndexedArrayAnalysisPrinterPass : public HloPassInterface { +class IndexedArrayAnalysisPrinterPass : public HloModulePass { public: absl::string_view name() const override; StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h index efa8ed3abcc..e20af08fb73 100644 --- a/tensorflow/compiler/xla/service/inliner.h +++ b/tensorflow/compiler/xla/service/inliner.h @@ -24,7 +24,7 @@ namespace xla { // A pass which performs inlining. Which can result, for example, in functions // that were previously being mapped by Map instead directly applied to the // forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)). -class Inliner : public HloPassInterface { +class Inliner : public HloModulePass { public: ~Inliner() override = default; absl::string_view name() const override { return "inline"; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index c1fde8ecfc0..7e1196fb7fb 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -56,7 +56,7 @@ class FusionQueue { // with the intent that the loops which compute their values will be fused in // code generation. Derived classes define ShouldFuse method to select which // instructions to fuse. -class InstructionFusion : public HloPassInterface { +class InstructionFusion : public HloModulePass { public: explicit InstructionFusion( std::function is_expensive, diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index cf545031d3c..e29c199c42a 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -281,7 +281,7 @@ class ChannelLayoutConstraints { // HLO pass which assigns layouts to all instructions in the HLO module while // satisfying all necessary invariants and minimizing cost. -class LayoutAssignment : public HloPassInterface { +class LayoutAssignment : public HloModulePass { public: // entry_computation_layout is modified to populate a layout for the result in // the case that no particular layout is requested. diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index d2c52651c4f..0344626b26b 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -44,7 +44,7 @@ namespace xla { // Note that the reachability map is updated based on the original computation. // This works because the reachability is monotonically increasing with // instruction fusion. -class MultiOutputFusion : public HloPassInterface { +class MultiOutputFusion : public HloModulePass { public: MultiOutputFusion(int64 fuel) : fuel_(fuel) {} diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index 256b231e3af..4bb22428f3d 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -29,7 +29,7 @@ namespace xla { // HLO pass which inserts reduce-precision instructions into the HLO graph, for // purposes of experimenting with the effects of reduced-precision storage of // intermediate values. -class ReducePrecisionInsertion : public HloPassInterface { +class ReducePrecisionInsertion : public HloModulePass { using InstructionFilterFunction = std::function; public: diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h index 1e86a0823a5..a3db439e340 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.h +++ b/tensorflow/compiler/xla/service/reshape_mover.h @@ -24,7 +24,7 @@ namespace xla { // This now only moves them outputward across elementwise ops all whose operands // are equivalent Reshapes or Transposes, but in future could potentially move // them inputward also. -class ReshapeMover : public HloPassInterface { +class ReshapeMover : public HloModulePass { public: absl::string_view name() const override { return "reshape-mover"; } diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h index 14f062c89cf..559a85dccfe 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.h +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -20,7 +20,7 @@ limitations under the License. namespace xla { -class ScatterExpander : public HloPassInterface { +class ScatterExpander : public HloModulePass { public: absl::string_view name() const override { return "scatter_expander"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h index 3e5aa2db60e..f95f982eb89 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.h +++ b/tensorflow/compiler/xla/service/transpose_folding.h @@ -23,7 +23,7 @@ namespace xla { // HLO pass that folds transpose operators into Dot operators, where the Dot // operator is implemented by a GEMM kernel that can transpose its inputs. -class TransposeFolding : public HloPassInterface { +class TransposeFolding : public HloModulePass { public: using OperandIndices = std::vector; diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index 8c91d6e69de..e126a530234 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -25,7 +25,7 @@ namespace xla { // A pass which simplifies patterns of Tuple and GetTupleElement instructions in // the module. -class TupleSimplifier : public HloPassInterface { +class TupleSimplifier : public HloModulePass { public: TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} explicit TupleSimplifier(bool exclude_entry_computation); diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h index 2dba7d7f757..577bad6c706 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -50,7 +50,7 @@ namespace xla { // conditions as well. // // TODO(b/79121449): We should also sink broadcasts of constants. -class WhileLoopConstantSinking : public HloPassInterface { +class WhileLoopConstantSinking : public HloModulePass { public: ~WhileLoopConstantSinking() override = default; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 2cdf20ce803..3031899f71e 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -25,7 +25,7 @@ namespace xla { // HLO pass that rewrites while loops to hoist loop invariant instructions in // the while body into the computation that contains the while instruction. -class WhileLoopInvariantCodeMotion : public HloPassInterface { +class WhileLoopInvariantCodeMotion : public HloModulePass { public: // If `hoist_constants` is true then constants are always hoisted out of while // loop bodies. Otherwise they are only hoisted out if they enable other diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 78024f14dc8..0bc5a0107bb 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -30,7 +30,7 @@ namespace xla { // - Elements of a while loop's tuple that the loop doesn't use are removed // from the tuple. // -class WhileLoopSimplifier : public HloPassInterface { +class WhileLoopSimplifier : public HloModulePass { public: ~WhileLoopSimplifier() override {} absl::string_view name() const override { return "simplify-while-loops"; } diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index a7f0e207eb5..87294120d51 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -21,7 +21,7 @@ limitations under the License. // HLO pass that replaces zero sized Hlos with a zero sized constant literal. namespace xla { -class ZeroSizedHloElimination : public HloPassInterface { +class ZeroSizedHloElimination : public HloModulePass { public: StatusOr Run(HloModule* module) override; absl::string_view name() const override {