Add interface for HLO passes which run on HloModuleGroup.
Derive HloModulePass and HloModuleGroupPass from HloPassInterface which run module-scoped and module-group-scoped respectively. Replace all existing uses of HloPassInterface with HloModulePass because all existing passes are module-scoped. Also rewrite HloPassPipeline to support both module-scoped and module-group-scoped passes. PiperOrigin-RevId: 213629604
This commit is contained in:
parent
e1db78697b
commit
f8655c08cf
@ -2560,6 +2560,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_module_group",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -2591,6 +2592,26 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hlo_pass_pipeline_test",
|
||||
srcs = ["hlo_pass_pipeline_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_parser",
|
||||
":hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_cse",
|
||||
srcs = ["hlo_cse.cc"],
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
// A pass which performs algebraic simplifications.
|
||||
class AlgebraicSimplifier : public HloPassInterface {
|
||||
class AlgebraicSimplifier : public HloModulePass {
|
||||
public:
|
||||
// Given shapes 'from_shape' and 'to_shape', determines if it is valid to
|
||||
// bitcast from 'from_shape' to 'to_shape' after considering platform
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
// Normally these would live in the algebraic simplifier, but we want to run
|
||||
// this to fixpoint (this pass reaches fixed point in one execution) before we
|
||||
// run the DotDecomposer.
|
||||
class BatchDotSimplification : public HloPassInterface {
|
||||
class BatchDotSimplification : public HloModulePass {
|
||||
public:
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
absl::string_view name() const override;
|
||||
|
@ -26,7 +26,7 @@ namespace xla {
|
||||
// A pass which rewrites batch norm operations into more operations. Breaking a
|
||||
// big operation into smaller operations helps leverage our generic fusion
|
||||
// logic.
|
||||
class BatchNormExpander : public HloPassInterface {
|
||||
class BatchNormExpander : public HloModulePass {
|
||||
public:
|
||||
// When use_fusion is set, a multi-output fusion node is created.
|
||||
BatchNormExpander(bool rewrite_training_op = false,
|
||||
|
@ -31,7 +31,7 @@ namespace xla {
|
||||
// optimization pipeline followed by a DCE pass. If other passes are needed
|
||||
// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the
|
||||
// changed made by this pass.
|
||||
class BFloat16ConversionFolding : public HloPassInterface {
|
||||
class BFloat16ConversionFolding : public HloModulePass {
|
||||
public:
|
||||
explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support)
|
||||
: bfloat16_support_(bfloat16_support) {}
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
// A pass which adds F32 <-> BF16 conversions for HLO instructions that do not
|
||||
// support BF16 input/output or mixed precision, according to the passed-in
|
||||
// backend-specific BF16 support rules.
|
||||
class BFloat16Normalization : public HloPassInterface {
|
||||
class BFloat16Normalization : public HloModulePass {
|
||||
public:
|
||||
explicit BFloat16Normalization(const BFloat16Support* bfloat16_support)
|
||||
: bfloat16_support_(bfloat16_support) {}
|
||||
@ -48,7 +48,7 @@ class BFloat16Normalization : public HloPassInterface {
|
||||
// use mixed precision; it removes mixed precision even if the backend supports
|
||||
// it. This pass is used to make the HLO module valid for other HLO passes which
|
||||
// do not support mixed precision.
|
||||
class BFloat16MixedPrecisionRemoval : public HloPassInterface {
|
||||
class BFloat16MixedPrecisionRemoval : public HloModulePass {
|
||||
public:
|
||||
BFloat16MixedPrecisionRemoval() {}
|
||||
|
||||
|
@ -58,7 +58,7 @@ namespace xla {
|
||||
// BFloat16ConversionFolding. If other passes are needed after this pass, run
|
||||
// BFloat16MixedPrecisionRemoval first to undo some of the changes made by this
|
||||
// pass.
|
||||
class BFloat16Propagation : public HloPassInterface {
|
||||
class BFloat16Propagation : public HloModulePass {
|
||||
public:
|
||||
explicit BFloat16Propagation(const BFloat16Support* bfloat16_support);
|
||||
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
|
||||
// For every kCall operation in the main computation, we inline the body of the
|
||||
// called function, and proceed recursively.
|
||||
class CallInliner : public HloPassInterface {
|
||||
class CallInliner : public HloModulePass {
|
||||
public:
|
||||
using InlinedInstructionMap =
|
||||
std::unordered_map<HloInstruction*, HloInstruction*>;
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
|
||||
// HLO pass that removes kConditional with a constant predicate, replacing them
|
||||
// with their true or false computation as appropriate.
|
||||
class ConditionalSimplifier : public HloPassInterface {
|
||||
class ConditionalSimplifier : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "simplify-conditional"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
|
||||
// A pass which rewrites convolutions with feature_group_count > 1 into
|
||||
// convolutions with feature_group_count = 1.
|
||||
class ConvolutionFeatureGroupConverter : public HloPassInterface {
|
||||
class ConvolutionFeatureGroupConverter : public HloModulePass {
|
||||
public:
|
||||
ConvolutionFeatureGroupConverter() {}
|
||||
|
||||
|
@ -43,7 +43,7 @@ namespace xla {
|
||||
// (3) The buffer set of the root instruction of the entry computation must be
|
||||
// unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and
|
||||
// InstructionAliasSet::IsDistinct return true.
|
||||
class CopyInsertion : public HloPassInterface {
|
||||
class CopyInsertion : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "copy-insertion"; }
|
||||
|
||||
|
@ -31,7 +31,7 @@ namespace cpu {
|
||||
// called canonical convolutions). This pass expands non-canonical convolutions
|
||||
// into reshapes and canonical convolutions, so that these non-canonical
|
||||
// convolutions can run faster.
|
||||
class ConvCanonicalization : public HloPassInterface {
|
||||
class ConvCanonicalization : public HloModulePass {
|
||||
public:
|
||||
explicit ConvCanonicalization(
|
||||
const TargetMachineFeatures* target_machine_features)
|
||||
|
@ -30,7 +30,7 @@ namespace xla {
|
||||
//
|
||||
// TODO(b/62548313): Remove this when buffer assignment is smarter
|
||||
// (module-scoped).
|
||||
class CpuCopyInsertion : public HloPassInterface {
|
||||
class CpuCopyInsertion : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "copy-insertion"; }
|
||||
|
||||
|
@ -23,7 +23,7 @@ namespace xla {
|
||||
// This pass should run early in the HLO pipeline and checks for HLO constructs
|
||||
// which are not supported by the CPU backend and cannot be removed via HLO
|
||||
// transformations (eg, sparse layouts).
|
||||
class CpuHloSupportChecker : public HloPassInterface {
|
||||
class CpuHloSupportChecker : public HloModulePass {
|
||||
public:
|
||||
CpuHloSupportChecker() = default;
|
||||
~CpuHloSupportChecker() override = default;
|
||||
|
@ -60,7 +60,7 @@ class ParallelTaskAssignment {
|
||||
// own embedded computation, which is compiled as a parallel compute function,
|
||||
// and which is invoked from a kCall instruction that is lowered in codegen to
|
||||
// a runtime parallel fork/join call.
|
||||
class ParallelTaskAssigner : public HloPassInterface {
|
||||
class ParallelTaskAssigner : public HloModulePass {
|
||||
public:
|
||||
// 'max_parallelism': the maximum parallel task count per instruction.
|
||||
// 'shape_size': shape size function used by HloCostAnalysis during parallel
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
|
||||
// A pass which replaces all fusion instructions with the equivalent un-fused
|
||||
// instructions.
|
||||
class Defuser : public HloPassInterface {
|
||||
class Defuser : public HloModulePass {
|
||||
public:
|
||||
Defuser() {}
|
||||
~Defuser() override {}
|
||||
|
@ -24,7 +24,7 @@ namespace xla {
|
||||
namespace {
|
||||
|
||||
// Pass which strips control dependencies from all instructions in the module.
|
||||
class ControlDepRemover : public HloPassInterface {
|
||||
class ControlDepRemover : public HloModulePass {
|
||||
public:
|
||||
ControlDepRemover() = default;
|
||||
absl::string_view name() const override { return "control-dep-remover"; }
|
||||
|
@ -30,7 +30,7 @@ namespace xla {
|
||||
//
|
||||
// Current despecialization passes are Defuser, ImplicitBroadcastRemover,
|
||||
// and BFloat16MixedPrecisionRemoval.
|
||||
class Despecializer : public HloPassInterface {
|
||||
class Despecializer : public HloModulePass {
|
||||
public:
|
||||
Despecializer();
|
||||
absl::string_view name() const override { return "despecializer"; }
|
||||
|
@ -23,7 +23,7 @@ namespace xla {
|
||||
|
||||
// DotDecomposer is a pass which decomposes batch Dot operations into a
|
||||
// sequence of smaller (R2) Dot operations.
|
||||
class DotDecomposer : public HloPassInterface {
|
||||
class DotDecomposer : public HloModulePass {
|
||||
public:
|
||||
// Decomposes batch Dot operations when 'decompose_batch_dot' is true.
|
||||
DotDecomposer(bool decompose_batch_dot = true)
|
||||
|
@ -26,7 +26,7 @@ namespace xla {
|
||||
// Flattening associates each call site with a unique computation (for
|
||||
// sequential calling contexts) This simplifies buffer assignment and
|
||||
// points-to analysis (see b/36865746 for details).
|
||||
class FlattenCallGraph : public HloPassInterface {
|
||||
class FlattenCallGraph : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "flatten-call-graph"; }
|
||||
|
||||
|
@ -23,7 +23,7 @@ namespace xla {
|
||||
// This pass rewrites gather operations into (roughly) while loops of dynamic
|
||||
// slices. This lets backends that don't support gather directly to
|
||||
// nevertheless have a minimum level of support.
|
||||
class GatherExpander : public HloPassInterface {
|
||||
class GatherExpander : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "gather_expander"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
@ -52,7 +52,7 @@ namespace gpu {
|
||||
// The GPU backend does not implement a lowering for the batchnorm HLOs -- it
|
||||
// expects them to be lowered to cudnn calls via this pass or to HLO soup via
|
||||
// BatchNormRewriter.
|
||||
class CudnnBatchNormRewriter : public HloPassInterface {
|
||||
class CudnnBatchNormRewriter : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
@ -30,7 +30,7 @@ namespace gpu {
|
||||
|
||||
// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
|
||||
// each and adding explicit scratch space to the CustomCalls.
|
||||
class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
|
||||
class CudnnConvolutionAlgorithmPicker : public HloModulePass {
|
||||
public:
|
||||
// If the `allocator` parameter is not null, we will use it to allocate temp
|
||||
// memory while timing the various convolution algorithms. If it's null,
|
||||
|
@ -24,7 +24,7 @@ namespace gpu {
|
||||
|
||||
// Rewrites plain convolutions, backwards-filter convolutions, and
|
||||
// backwards-input convolutions into CustomCall HLOs that call into cuDNN.
|
||||
class CudnnConvolutionRewriter : public HloPassInterface {
|
||||
class CudnnConvolutionRewriter : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override {
|
||||
return "cudnn-convolution-rewriter";
|
||||
|
@ -32,7 +32,7 @@ namespace gpu {
|
||||
// 2) The result of merging the fusion instruction into its users would not
|
||||
// increase bytes transferred.
|
||||
//
|
||||
class FusionMerger : public HloPassInterface {
|
||||
class FusionMerger : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "fusion merger"; }
|
||||
|
||||
|
@ -34,15 +34,6 @@ namespace xla {
|
||||
|
||||
namespace gpu {
|
||||
|
||||
StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy(
|
||||
HloInstruction* hlo) {
|
||||
HloInstruction*& copy = hlo_to_copy_map_[hlo];
|
||||
if (copy == nullptr) {
|
||||
TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo));
|
||||
}
|
||||
return copy;
|
||||
}
|
||||
|
||||
StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
|
||||
CopyInsertion generic_copy_insertion;
|
||||
|
||||
|
@ -25,20 +25,11 @@ namespace gpu {
|
||||
// Besides the modifications made by the generic xla::CopyInsertion, this
|
||||
// GPU-specific copy insertion also materializes operands of library calls by
|
||||
// inserting kCopy instructions.
|
||||
class GpuCopyInsertion : public HloPassInterface {
|
||||
class GpuCopyInsertion : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "copy-insertion"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
protected:
|
||||
// Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making
|
||||
// duplicate copies.
|
||||
StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
|
||||
|
||||
// A map containing all copies inserted to materialize operands of library
|
||||
// calls. The key is the copied instruction and the value is the copy.
|
||||
tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> hlo_to_copy_map_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -23,7 +23,7 @@ namespace xla {
|
||||
// his pass should run early in the HLO pipeline and checks for HLO constructs
|
||||
// which are not supported by the GPU backend and cannot be removed via HLO
|
||||
// transformations (eg, sparse layouts).
|
||||
class GpuHloSupportChecker : public HloPassInterface {
|
||||
class GpuHloSupportChecker : public HloModulePass {
|
||||
public:
|
||||
GpuHloSupportChecker() = default;
|
||||
~GpuHloSupportChecker() override = default;
|
||||
|
@ -30,7 +30,7 @@ namespace gpu {
|
||||
// targeting before running this pass.
|
||||
//
|
||||
// TODO(jlebar): Also pad dots.
|
||||
class PadForTensorCores : public HloPassInterface {
|
||||
class PadForTensorCores : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "pad for tensor cores"; }
|
||||
|
||||
|
@ -24,7 +24,7 @@ namespace gpu {
|
||||
// An HLO pass that canonicalizes convolution instructions for GPU codegen. It
|
||||
// inserts Pad instructions before Convolution instructions with uncanonicalized
|
||||
// padding, so that they can be lowered to cuDNN convolution.
|
||||
class PadInsertion : public HloPassInterface {
|
||||
class PadInsertion : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "pad insertion"; }
|
||||
|
||||
|
@ -23,7 +23,7 @@ namespace xla {
|
||||
|
||||
// A pass which performs constant folding in order to avoid unnecessary
|
||||
// computation on constants.
|
||||
class HloConstantFolding : public HloPassInterface {
|
||||
class HloConstantFolding : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "constant_folding"; }
|
||||
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
// and identical instructions with the same operands are commoned. The pass
|
||||
// iterates over the instructions in topological order which enables the pass to
|
||||
// find arbitrarily large common expressions.
|
||||
class HloCSE : public HloPassInterface {
|
||||
class HloCSE : public HloModulePass {
|
||||
public:
|
||||
// If is_layout_sensitive is true, then the simplifier preserves layout during
|
||||
// transformation. Otherwise, layout is ignored.
|
||||
|
@ -33,7 +33,7 @@ namespace xla {
|
||||
//
|
||||
// This pass does not remove dead parameter instructions, as parameter
|
||||
// instructions cannot be deleted.
|
||||
class HloDCE : public HloPassInterface {
|
||||
class HloDCE : public HloModulePass {
|
||||
public:
|
||||
~HloDCE() override {}
|
||||
absl::string_view name() const override { return "dce"; }
|
||||
|
@ -30,7 +30,7 @@ namespace xla {
|
||||
// used to break an HLO graph edge connecting two instructions with different
|
||||
// sharding. If a set of connected instructions have all the same sharding, no
|
||||
// kDomain instruction will be placed.
|
||||
class HloDomainIsolator : public HloPassInterface {
|
||||
class HloDomainIsolator : public HloModulePass {
|
||||
public:
|
||||
// Creates a new kDomain instruction for the edge between the use instruction
|
||||
// (the first HloInstruction argument), and the operand instruction (the
|
||||
|
@ -26,7 +26,7 @@ namespace xla {
|
||||
// Removes all the kDomain instructions of a given kind from the input module,
|
||||
// and calls the normalizer to propagate the properties on the possibly new born
|
||||
// instructions.
|
||||
class HloDomainRemover : public HloPassInterface {
|
||||
class HloDomainRemover : public HloModulePass {
|
||||
public:
|
||||
// Creates a new HloDomainRemover object tasked at removing all the kDomain
|
||||
// instructions of a given kind.
|
||||
|
@ -29,7 +29,7 @@ namespace xla {
|
||||
|
||||
// Verifies that the domain instructions are consistent, and the each domain is
|
||||
// surrounded by the same metadata.
|
||||
class HloDomainVerifier : public HloPassInterface {
|
||||
class HloDomainVerifier : public HloModulePass {
|
||||
public:
|
||||
HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {}
|
||||
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
// inserting Convert ops. This allows a backend to support an element type while
|
||||
// only actually implementing the Convert op for that element type. This is
|
||||
// generally not the fastest approach, but it works.
|
||||
class HloElementTypeConverter : public HloPassInterface {
|
||||
class HloElementTypeConverter : public HloModulePass {
|
||||
public:
|
||||
// eliminate_type is the type to eliminate as the input or output of ops,
|
||||
// using Convert ops to replace it with replace_with_type.
|
||||
|
@ -90,7 +90,7 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
|
||||
// A pass which schedules the HLO instructions in a module. The HloModule's
|
||||
// schedule field is set to the resulting HloSchedule using
|
||||
// HloModule::set_schedule.
|
||||
class HloMemoryScheduler : public HloPassInterface {
|
||||
class HloMemoryScheduler : public HloModulePass {
|
||||
public:
|
||||
// size_function is the function returning the number of bytes required for a
|
||||
// LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
|
||||
@ -109,7 +109,7 @@ class HloMemoryScheduler : public HloPassInterface {
|
||||
|
||||
// A trivial pass which clears the schedule currently set on the
|
||||
// HloModule. After this pass runs HloModudle::has_schedule will return false.
|
||||
class HloDescheduler : public HloPassInterface {
|
||||
class HloDescheduler : public HloModulePass {
|
||||
public:
|
||||
HloDescheduler() = default;
|
||||
~HloDescheduler() override = default;
|
||||
|
@ -28,7 +28,7 @@ namespace xla {
|
||||
// Sweeps through live instructions which cross computation boundaries (kWhile),
|
||||
// and removes code at dead shape indices.
|
||||
//
|
||||
class HloModuleDCE : public HloPassInterface {
|
||||
class HloModuleDCE : public HloModulePass {
|
||||
public:
|
||||
~HloModuleDCE() override {}
|
||||
absl::string_view name() const override { return "hlo-module-dce"; }
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -25,15 +26,45 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
// Base class for HLO passes. These are used with the HloPassPipeline to
|
||||
// organize a sequence of passes.
|
||||
// organize a sequence of passes. An HLO pass should not extend this class
|
||||
// directly; it should extend HloModulePass or HloModuleGroupPass.
|
||||
class HloPassInterface {
|
||||
public:
|
||||
virtual ~HloPassInterface() = default;
|
||||
virtual absl::string_view name() const = 0;
|
||||
|
||||
// Run the pass on the given HLO module. Return whether it modified the
|
||||
// Run the pass on the given HLO module. Returns whether it modified the
|
||||
// module.
|
||||
virtual StatusOr<bool> Run(HloModule* module) = 0;
|
||||
|
||||
// Run the pass on the given HLO module group. Returns whether it modified the
|
||||
// module group. Ideally, the module group variant would be named "Run" as
|
||||
// well, but C++ does not handle overloaded virtual methods well.
|
||||
virtual StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) = 0;
|
||||
};
|
||||
|
||||
// Base class for passes which are module-scoped.
|
||||
class HloModulePass : public HloPassInterface {
|
||||
public:
|
||||
// Runs the pass on a module group by iterating through each module in the
|
||||
// group.
|
||||
StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
|
||||
bool changed = false;
|
||||
for (HloModule* module : module_group->modules()) {
|
||||
TF_ASSIGN_OR_RETURN(bool module_changed, Run(module));
|
||||
changed |= module_changed;
|
||||
}
|
||||
return changed;
|
||||
};
|
||||
};
|
||||
|
||||
// Base class for passes which are module-group scoped. These passes cannot run
|
||||
// on an HLO module.
|
||||
class HloModuleGroupPass : public HloPassInterface {
|
||||
public:
|
||||
StatusOr<bool> Run(HloModule* module) override {
|
||||
return InternalError("Module group pass cannot be run on a module");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
@ -29,108 +28,128 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using absl::StrAppend;
|
||||
using absl::StrCat;
|
||||
template <typename HloT>
|
||||
Status HloPassPipeline::RunInvariantCheckers(
|
||||
HloT* hlo, absl::string_view after_pass_name) {
|
||||
for (auto& invariant_checker : invariant_checkers_) {
|
||||
VLOG(1) << " Invariant checker " << invariant_checker->name();
|
||||
StatusOr<bool> changed_status = RunHelper(invariant_checker.get(), hlo);
|
||||
VLOG(1) << " Invariant checker done " << invariant_checker->name();
|
||||
if (!changed_status.ok()) {
|
||||
VLOG(2) << "Failed invariant check:";
|
||||
XLA_VLOG_LINES(2, hlo->ToString());
|
||||
return Status(changed_status.status().code(),
|
||||
absl::StrCat(changed_status.status().error_message(),
|
||||
"\n\nFailed after ", after_pass_name));
|
||||
}
|
||||
TF_RET_CHECK(!changed_status.ValueOrDie())
|
||||
<< "invariant checkers must not change the graph";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DumpModuleGraph(const HloModule& module, const string& message) {
|
||||
template <typename HloT>
|
||||
StatusOr<bool> HloPassPipeline::RunPassesInternal(
|
||||
HloT* hlo, absl::Span<HloPassInterface* const> passes) {
|
||||
string last_pass_name = "pipeline-start";
|
||||
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
|
||||
bool changed = false;
|
||||
for (HloPassInterface* pass : passes) {
|
||||
VLOG(1) << " HLO pass " << pass->name();
|
||||
MaybeDumpHlo(*hlo,
|
||||
/*after_pass_name=*/last_pass_name,
|
||||
/*before_pass_name=*/pass->name());
|
||||
TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
|
||||
changed |= pass_changed;
|
||||
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name()));
|
||||
last_pass_name = string(pass->name());
|
||||
}
|
||||
MaybeDumpHlo(*hlo,
|
||||
/*after_pass_name=*/last_pass_name,
|
||||
/*before_pass_name=*/"pipeline-end");
|
||||
return changed;
|
||||
}
|
||||
|
||||
std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
|
||||
const DebugOptions& debug_options) {
|
||||
auto repeated_field = debug_options.xla_disable_hlo_passes();
|
||||
tensorflow::gtl::FlatSet<string> disabled_pass_names(repeated_field.begin(),
|
||||
repeated_field.end());
|
||||
if (!disabled_pass_names.empty()) {
|
||||
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
|
||||
<< absl::StrJoin(disabled_pass_names, ", ");
|
||||
}
|
||||
|
||||
std::vector<HloPassInterface*> enabled_passes;
|
||||
for (auto& pass : passes_) {
|
||||
if (disabled_pass_names.count(string(pass->name())) == 0) {
|
||||
enabled_passes.push_back(pass.get());
|
||||
}
|
||||
}
|
||||
return enabled_passes;
|
||||
}
|
||||
|
||||
void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
|
||||
absl::string_view after_pass_name,
|
||||
absl::string_view before_pass_name) {
|
||||
const string& proto_dump_path =
|
||||
module.config().debug_options().xla_dump_per_pass_hlo_proto_to();
|
||||
if (!proto_dump_path.empty()) {
|
||||
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||
static auto* const module_id_to_pass_number =
|
||||
new tensorflow::gtl::FlatMap<int64, int64>();
|
||||
|
||||
tensorflow::mutex_lock lock(mu);
|
||||
const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
|
||||
|
||||
const string filename = SanitizeFileName(
|
||||
absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
|
||||
pass_number, name(), after_pass_name));
|
||||
|
||||
TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(
|
||||
MakeHloProto(module), proto_dump_path, filename));
|
||||
}
|
||||
|
||||
const string message =
|
||||
StrCat("after ", after_pass_name, ", before ", before_pass_name);
|
||||
hlo_graph_dumper::MaybeDumpHloModule(module, message);
|
||||
VLOG(3) << "HLO " << message << ":";
|
||||
XLA_VLOG_LINES(3, module.ToString());
|
||||
}
|
||||
|
||||
void DumpModuleProto(const HloModule& module, const string& dump_to,
|
||||
const string& pipeline_name, const string& pass_name) {
|
||||
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||
static auto* const module_id_to_pass_number =
|
||||
new tensorflow::gtl::FlatMap<int64, int64>();
|
||||
|
||||
tensorflow::mutex_lock lock(mu);
|
||||
const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
|
||||
|
||||
const string mod_name = SanitizeFileName(
|
||||
absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
|
||||
pass_number, pipeline_name, pass_name));
|
||||
|
||||
TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module),
|
||||
dump_to, mod_name));
|
||||
void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group,
|
||||
absl::string_view after_pass_name,
|
||||
absl::string_view before_pass_name) {
|
||||
for (const HloModule* module : module_group.modules()) {
|
||||
MaybeDumpHlo(*module, after_pass_name, before_pass_name);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
run_called_ = true;
|
||||
|
||||
VLOG(1) << "Running HLO pass pipeline " << name();
|
||||
VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": "
|
||||
<< name();
|
||||
|
||||
auto repeated_field =
|
||||
module->config().debug_options().xla_disable_hlo_passes();
|
||||
tensorflow::gtl::FlatSet<string> disabled_passes(repeated_field.begin(),
|
||||
repeated_field.end());
|
||||
if (!disabled_passes.empty()) {
|
||||
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
|
||||
<< absl::StrJoin(disabled_passes, ", ");
|
||||
return RunPassesInternal(module,
|
||||
GetEnabledPasses(module->config().debug_options()));
|
||||
}
|
||||
|
||||
StatusOr<bool> HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) {
|
||||
run_called_ = true;
|
||||
|
||||
VLOG(1) << "Running HLO pass pipeline on module group "
|
||||
<< module_group->name() << ": " << name();
|
||||
|
||||
if (module_group->modules().empty()) {
|
||||
VLOG(1) << "Module group is empty. Nothing to do.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto run_invariant_checkers = [this,
|
||||
module](const string& message) -> Status {
|
||||
for (auto& invariant_checker : invariant_checkers_) {
|
||||
VLOG(1) << " Invariant checker " << invariant_checker->name();
|
||||
StatusOr<bool> changed_status = invariant_checker->Run(module);
|
||||
VLOG(1) << " Invariant checker done " << invariant_checker->name();
|
||||
if (!changed_status.ok()) {
|
||||
VLOG(2) << "Module failed invariant check:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
return Status(changed_status.status().code(),
|
||||
StrCat(changed_status.status().error_message(),
|
||||
"\n\nFailed ", message));
|
||||
}
|
||||
TF_RET_CHECK(!changed_status.ValueOrDie())
|
||||
<< "invariant checkers must not change the graph";
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
string prefix = StrCat(name(), ": pipeline start");
|
||||
bool changed = false;
|
||||
string message;
|
||||
TF_RETURN_IF_ERROR(
|
||||
run_invariant_checkers(StrCat("before running pipeline: ", name())));
|
||||
const string xla_dump_per_pass_hlo_proto_to =
|
||||
module->config().debug_options().xla_dump_per_pass_hlo_proto_to();
|
||||
if (!xla_dump_per_pass_hlo_proto_to.empty()) {
|
||||
DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
|
||||
"pipeline_start");
|
||||
}
|
||||
|
||||
for (auto& pass : passes_) {
|
||||
if (disabled_passes.count(string(pass->name())) > 0) {
|
||||
VLOG(1) << " Skipping HLO pass " << pass->name()
|
||||
<< ", disabled by --xla_disable_hlo_passes";
|
||||
continue;
|
||||
}
|
||||
|
||||
VLOG(1) << " HLO pass " << pass->name();
|
||||
|
||||
// Emit label containing: "after foo-pass, before bar-pass".
|
||||
message.clear();
|
||||
StrAppend(&message, prefix, ", before ", pass->name());
|
||||
DumpModuleGraph(*module, message);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
|
||||
TF_RETURN_IF_ERROR(
|
||||
run_invariant_checkers(StrCat("after running pass: ", pass->name())));
|
||||
if (!xla_dump_per_pass_hlo_proto_to.empty()) {
|
||||
DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
|
||||
string(pass->name()));
|
||||
}
|
||||
|
||||
changed |= changed_this_pass;
|
||||
prefix.clear();
|
||||
StrAppend(&prefix, name(), ": after ", pass->name());
|
||||
}
|
||||
DumpModuleGraph(*module, prefix + ", pipeline end");
|
||||
return changed;
|
||||
return RunPassesInternal(
|
||||
module_group,
|
||||
GetEnabledPasses(module_group->module(0).config().debug_options()));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -61,10 +62,45 @@ class HloPassPipeline : public HloPassInterface {
|
||||
return *pass;
|
||||
}
|
||||
|
||||
// Run all passes on the given HLO module.
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override;
|
||||
|
||||
private:
|
||||
// Returns the set of passes which are enabled. DebugOptions can selectively
|
||||
// disable passes via --xla_disable_hlo_passes flag.
|
||||
std::vector<HloPassInterface*> GetEnabledPasses(
|
||||
const DebugOptions& debug_options);
|
||||
|
||||
// Maybe dumps the given module or module group depending on flag values
|
||||
// contained in DebugOptions of module config.
|
||||
void MaybeDumpHlo(const HloModuleGroup& module_group,
|
||||
absl::string_view after_pass_name,
|
||||
absl::string_view before_pass_name);
|
||||
void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name,
|
||||
absl::string_view before_pass_name);
|
||||
|
||||
// Runs the invariant checker on the given HLO. HloT can be either HloModule
|
||||
// or HloModuleGroup.
|
||||
template <typename HloT>
|
||||
Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name);
|
||||
|
||||
// Helper which runs the given pass on the given HLO. HloT can be either
|
||||
// HloModule or HloModuleGroup.
|
||||
template <typename HloT>
|
||||
StatusOr<bool> RunPassesInternal(HloT* hlo,
|
||||
absl::Span<HloPassInterface* const> passes);
|
||||
|
||||
// Helpers which run the given passes on the given HLO construct. These
|
||||
// helpers enable templating of the core of the pipeline logic by providing
|
||||
// HloModule and HloModuleGroup specific methods with the same name.
|
||||
static StatusOr<bool> RunHelper(HloPassInterface* pass, HloModule* module) {
|
||||
return pass->Run(module);
|
||||
}
|
||||
static StatusOr<bool> RunHelper(HloPassInterface* pass,
|
||||
HloModuleGroup* module_group) {
|
||||
return pass->RunOnModuleGroup(module_group);
|
||||
}
|
||||
|
||||
const string name_;
|
||||
std::vector<std::unique_ptr<HloPassInterface>> passes_;
|
||||
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
|
||||
|
259
tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
Normal file
259
tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
Normal file
@ -0,0 +1,259 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class HloPassPipelineTest : public HloTestBase {
|
||||
protected:
|
||||
StatusOr<HloModuleGroup> ParseModuleGroup(
|
||||
absl::Span<const string> hlo_strings) {
|
||||
HloModuleGroup group(TestName());
|
||||
for (const string& hlo_string : hlo_strings) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(hlo_string));
|
||||
group.push_back(std::move(module));
|
||||
}
|
||||
return std::move(group);
|
||||
}
|
||||
};
|
||||
|
||||
// A module pass which renames instructions named 'foo' to 'bar'.
|
||||
class FooToBarModulePass : public HloModulePass {
|
||||
absl::string_view name() const override { return "foo2bar"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override {
|
||||
bool changed = false;
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->name() == "foo") {
|
||||
instruction->SetAndSanitizeName("bar");
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
};
|
||||
|
||||
// A module group pass which renames instructions named 'baz' to 'qux'.
|
||||
class BazToQuxModuleGroupPass : public HloModuleGroupPass {
|
||||
absl::string_view name() const override { return "baz2qux"; }
|
||||
|
||||
StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
|
||||
bool changed = false;
|
||||
for (HloModule* module : module_group->modules()) {
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->name() == "baz") {
|
||||
instruction->SetAndSanitizeName("qux");
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
};
|
||||
|
||||
// An invariant checker pass which returns an error if there exists an
|
||||
// instruction named 'bar'.
|
||||
class BarBlowerUpper : public HloModulePass {
|
||||
absl::string_view name() const override { return "bar-blower-upper"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override {
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->name() == "bar") {
|
||||
return InternalError("Module has instruction named bar");
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(HloPassPipelineTest, ModulePassChanged) {
|
||||
// Test an HLO module pass which changes a module.
|
||||
const string module_str = R"(
|
||||
HloModule ModulePassChanged
|
||||
|
||||
ENTRY main {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT foo = f32[] multiply(a, b)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(module_str));
|
||||
HloPassPipeline pipeline(TestName());
|
||||
pipeline.AddPass<FooToBarModulePass>();
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_EQ(root->name(), "foo");
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_EQ(root->name(), "bar");
|
||||
}
|
||||
|
||||
TEST_F(HloPassPipelineTest, ModulePassUnchanged) {
|
||||
// Test an HLO module pass which does not change a module.
|
||||
const string module_str = R"(
|
||||
HloModule ModulePassUnchanged
|
||||
|
||||
ENTRY main {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT blahblah = f32[] multiply(a, b)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(module_str));
|
||||
HloPassPipeline pipeline(TestName());
|
||||
pipeline.AddPass<FooToBarModulePass>();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
TEST_F(HloPassPipelineTest, MixedPipeline) {
|
||||
// Test a pipeline with both a module pass and a module group pass.
|
||||
const string module_0_str = R"(
|
||||
HloModule MixedPipeline.1
|
||||
|
||||
ENTRY main {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT baz = f32[] multiply(a, b)
|
||||
}
|
||||
)";
|
||||
const string module_1_str = R"(
|
||||
HloModule MixedPipeline.0
|
||||
|
||||
ENTRY main {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT foo = f32[] multiply(a, b)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group,
|
||||
ParseModuleGroup({module_0_str, module_1_str}));
|
||||
|
||||
HloPassPipeline pipeline(TestName());
|
||||
pipeline.AddPass<BazToQuxModuleGroupPass>();
|
||||
pipeline.AddPass<FooToBarModulePass>();
|
||||
|
||||
HloInstruction* root0 =
|
||||
module_group.module(0).entry_computation()->root_instruction();
|
||||
HloInstruction* root1 =
|
||||
module_group.module(1).entry_computation()->root_instruction();
|
||||
EXPECT_EQ(root0->name(), "baz");
|
||||
EXPECT_EQ(root1->name(), "foo");
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
pipeline.RunOnModuleGroup(&module_group));
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
EXPECT_EQ(root0->name(), "qux");
|
||||
EXPECT_EQ(root1->name(), "bar");
|
||||
}
|
||||
|
||||
TEST_F(HloPassPipelineTest, InvariantChecker) {
|
||||
const string module_str = R"(
|
||||
HloModule InvariantChecker
|
||||
|
||||
ENTRY main {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT foo = f32[] multiply(a, b)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(module_str));
|
||||
{
|
||||
// Run a pipeline with just the invariant checker. It should not fail
|
||||
// because there is no 'bar' instruction in the module.
|
||||
HloPassPipeline pipeline(TestName());
|
||||
pipeline.AddInvariantChecker<BarBlowerUpper>();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
{
|
||||
// Run a pipeline which renames 'foo' to 'bar' then an invariant checker
|
||||
// which fails if there is an instruction named 'bar'.
|
||||
HloPassPipeline pipeline(TestName());
|
||||
pipeline.AddInvariantChecker<BarBlowerUpper>();
|
||||
pipeline.AddPass<FooToBarModulePass>();
|
||||
|
||||
Status status = pipeline.Run(module.get()).status();
|
||||
ASSERT_IS_NOT_OK(status);
|
||||
EXPECT_THAT(status.error_message(),
|
||||
::testing::HasSubstr("Module has instruction named bar"));
|
||||
EXPECT_THAT(status.error_message(),
|
||||
::testing::HasSubstr("Failed after foo2bar"));
|
||||
}
|
||||
|
||||
{
|
||||
// Run the invariant-checker only pipeline again. It should fail this time.
|
||||
HloPassPipeline pipeline(TestName());
|
||||
pipeline.AddInvariantChecker<BarBlowerUpper>();
|
||||
|
||||
Status status = pipeline.Run(module.get()).status();
|
||||
ASSERT_IS_NOT_OK(status);
|
||||
EXPECT_THAT(status.error_message(),
|
||||
::testing::HasSubstr("Module has instruction named bar"));
|
||||
EXPECT_THAT(status.error_message(),
|
||||
::testing::HasSubstr("Failed after pipeline-start"));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) {
|
||||
// Running a module group pass on a module should produce an error.
|
||||
const string module_str = R"(
|
||||
HloModule ModuleGroupPassOnModule
|
||||
|
||||
ENTRY main {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT foo = f32[] multiply(a, b)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(module_str));
|
||||
HloPassPipeline pipeline(TestName());
|
||||
pipeline.AddPass<BazToQuxModuleGroupPass>();
|
||||
|
||||
Status status = pipeline.Run(module.get()).status();
|
||||
ASSERT_IS_NOT_OK(status);
|
||||
EXPECT_THAT(
|
||||
status.error_message(),
|
||||
::testing::HasSubstr("Module group pass cannot be run on a module"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -1198,6 +1198,12 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
|
||||
<< HumanReadableNumBytes(memory_limit_bytes_);
|
||||
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
|
||||
|
||||
// Initialize pass object state.
|
||||
computation_peak_memory_.clear();
|
||||
rematerialized_computations_.clear();
|
||||
instructions_rematerialized_ = 0;
|
||||
net_instructions_added_ = 0;
|
||||
|
||||
TF_RET_CHECK(module->has_schedule());
|
||||
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
|
||||
|
||||
|
@ -33,7 +33,7 @@ namespace xla {
|
||||
// CSE will undo the effects of this optimization and should not be run after
|
||||
// this pass. In general, this pass should be run very late, immediately before
|
||||
// code generation.
|
||||
class HloRematerialization : public HloPassInterface {
|
||||
class HloRematerialization : public HloModulePass {
|
||||
public:
|
||||
using ShapeSizeFunction = std::function<int64(const Shape&)>;
|
||||
|
||||
|
@ -22,7 +22,7 @@ namespace xla {
|
||||
|
||||
// Unify subcomputations of a `HloModule`: if any computations are equal, choose
|
||||
// one arbitrarily to use and delete the others.
|
||||
class HloSubcomputationUnification : public HloPassInterface {
|
||||
class HloSubcomputationUnification : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override {
|
||||
return "subcomputation-unification";
|
||||
|
@ -151,7 +151,7 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
|
||||
// HLO pass that verifies invariants of HLO instructions for each computation in
|
||||
// the module.
|
||||
class HloVerifier : public HloPassInterface {
|
||||
class HloVerifier : public HloModulePass {
|
||||
public:
|
||||
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
|
||||
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
|
||||
// Pass which replaces all implicit broadcasts with their equivalent sequence of
|
||||
// explicit broadcast and reshape instructions.
|
||||
class ImplicitBroadcastRemover : public HloPassInterface {
|
||||
class ImplicitBroadcastRemover : public HloModulePass {
|
||||
public:
|
||||
ImplicitBroadcastRemover() {}
|
||||
~ImplicitBroadcastRemover() override {}
|
||||
|
@ -366,7 +366,7 @@ class IndexedArrayAnalysis {
|
||||
// A pass that prints all non-trivial results returned by IndexedArrayAnalysis.
|
||||
// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to
|
||||
// unconditionally add to the regular HLO pass pipeline.
|
||||
class IndexedArrayAnalysisPrinterPass : public HloPassInterface {
|
||||
class IndexedArrayAnalysisPrinterPass : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override;
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
@ -24,7 +24,7 @@ namespace xla {
|
||||
// A pass which performs inlining. Which can result, for example, in functions
|
||||
// that were previously being mapped by Map instead directly applied to the
|
||||
// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)).
|
||||
class Inliner : public HloPassInterface {
|
||||
class Inliner : public HloModulePass {
|
||||
public:
|
||||
~Inliner() override = default;
|
||||
absl::string_view name() const override { return "inline"; }
|
||||
|
@ -56,7 +56,7 @@ class FusionQueue {
|
||||
// with the intent that the loops which compute their values will be fused in
|
||||
// code generation. Derived classes define ShouldFuse method to select which
|
||||
// instructions to fuse.
|
||||
class InstructionFusion : public HloPassInterface {
|
||||
class InstructionFusion : public HloModulePass {
|
||||
public:
|
||||
explicit InstructionFusion(
|
||||
std::function<bool(const HloInstruction& instruction)> is_expensive,
|
||||
|
@ -281,7 +281,7 @@ class ChannelLayoutConstraints {
|
||||
|
||||
// HLO pass which assigns layouts to all instructions in the HLO module while
|
||||
// satisfying all necessary invariants and minimizing cost.
|
||||
class LayoutAssignment : public HloPassInterface {
|
||||
class LayoutAssignment : public HloModulePass {
|
||||
public:
|
||||
// entry_computation_layout is modified to populate a layout for the result in
|
||||
// the case that no particular layout is requested.
|
||||
|
@ -44,7 +44,7 @@ namespace xla {
|
||||
// Note that the reachability map is updated based on the original computation.
|
||||
// This works because the reachability is monotonically increasing with
|
||||
// instruction fusion.
|
||||
class MultiOutputFusion : public HloPassInterface {
|
||||
class MultiOutputFusion : public HloModulePass {
|
||||
public:
|
||||
MultiOutputFusion(int64 fuel) : fuel_(fuel) {}
|
||||
|
||||
|
@ -29,7 +29,7 @@ namespace xla {
|
||||
// HLO pass which inserts reduce-precision instructions into the HLO graph, for
|
||||
// purposes of experimenting with the effects of reduced-precision storage of
|
||||
// intermediate values.
|
||||
class ReducePrecisionInsertion : public HloPassInterface {
|
||||
class ReducePrecisionInsertion : public HloModulePass {
|
||||
using InstructionFilterFunction = std::function<bool(const HloInstruction*)>;
|
||||
|
||||
public:
|
||||
|
@ -24,7 +24,7 @@ namespace xla {
|
||||
// This now only moves them outputward across elementwise ops all whose operands
|
||||
// are equivalent Reshapes or Transposes, but in future could potentially move
|
||||
// them inputward also.
|
||||
class ReshapeMover : public HloPassInterface {
|
||||
class ReshapeMover : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "reshape-mover"; }
|
||||
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
class ScatterExpander : public HloPassInterface {
|
||||
class ScatterExpander : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "scatter_expander"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
@ -23,7 +23,7 @@ namespace xla {
|
||||
|
||||
// HLO pass that folds transpose operators into Dot operators, where the Dot
|
||||
// operator is implemented by a GEMM kernel that can transpose its inputs.
|
||||
class TransposeFolding : public HloPassInterface {
|
||||
class TransposeFolding : public HloModulePass {
|
||||
public:
|
||||
using OperandIndices = std::vector<int64>;
|
||||
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
|
||||
// A pass which simplifies patterns of Tuple and GetTupleElement instructions in
|
||||
// the module.
|
||||
class TupleSimplifier : public HloPassInterface {
|
||||
class TupleSimplifier : public HloModulePass {
|
||||
public:
|
||||
TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {}
|
||||
explicit TupleSimplifier(bool exclude_entry_computation);
|
||||
|
@ -50,7 +50,7 @@ namespace xla {
|
||||
// conditions as well.
|
||||
//
|
||||
// TODO(b/79121449): We should also sink broadcasts of constants.
|
||||
class WhileLoopConstantSinking : public HloPassInterface {
|
||||
class WhileLoopConstantSinking : public HloModulePass {
|
||||
public:
|
||||
~WhileLoopConstantSinking() override = default;
|
||||
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
// HLO pass that rewrites while loops to hoist loop invariant instructions in
|
||||
// the while body into the computation that contains the while instruction.
|
||||
|
||||
class WhileLoopInvariantCodeMotion : public HloPassInterface {
|
||||
class WhileLoopInvariantCodeMotion : public HloModulePass {
|
||||
public:
|
||||
// If `hoist_constants` is true then constants are always hoisted out of while
|
||||
// loop bodies. Otherwise they are only hoisted out if they enable other
|
||||
|
@ -30,7 +30,7 @@ namespace xla {
|
||||
// - Elements of a while loop's tuple that the loop doesn't use are removed
|
||||
// from the tuple.
|
||||
//
|
||||
class WhileLoopSimplifier : public HloPassInterface {
|
||||
class WhileLoopSimplifier : public HloModulePass {
|
||||
public:
|
||||
~WhileLoopSimplifier() override {}
|
||||
absl::string_view name() const override { return "simplify-while-loops"; }
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
|
||||
// HLO pass that replaces zero sized Hlos with a zero sized constant literal.
|
||||
namespace xla {
|
||||
class ZeroSizedHloElimination : public HloPassInterface {
|
||||
class ZeroSizedHloElimination : public HloModulePass {
|
||||
public:
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
absl::string_view name() const override {
|
||||
|
Loading…
x
Reference in New Issue
Block a user