Add interface for HLO passes which run on HloModuleGroup.

Derive HloModulePass and HloModuleGroupPass from HloPassInterface which run module-scoped and module-group-scoped respectively. Replace all existing uses of HloPassInterface with HloModulePass because all existing passes are module-scoped. Also rewrite HloPassPipeline to support both module-scoped and module-group-scoped passes.

PiperOrigin-RevId: 213629604
This commit is contained in:
Mark Heffernan 2018-09-19 08:12:29 -07:00 committed by TensorFlower Gardener
parent e1db78697b
commit f8655c08cf
62 changed files with 520 additions and 166 deletions

View File

@ -2560,6 +2560,7 @@ cc_library(
],
deps = [
":hlo",
":hlo_module_group",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
@ -2591,6 +2592,26 @@ cc_library(
],
)
tf_cc_test(
name = "hlo_pass_pipeline_test",
srcs = ["hlo_pass_pipeline_test.cc"],
deps = [
":hlo",
":hlo_parser",
":hlo_pass_pipeline",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
cc_library(
name = "hlo_cse",
srcs = ["hlo_cse.cc"],

View File

@ -24,7 +24,7 @@ limitations under the License.
namespace xla {
// A pass which performs algebraic simplifications.
class AlgebraicSimplifier : public HloPassInterface {
class AlgebraicSimplifier : public HloModulePass {
public:
// Given shapes 'from_shape' and 'to_shape', determines if it is valid to
// bitcast from 'from_shape' to 'to_shape' after considering platform

View File

@ -25,7 +25,7 @@ namespace xla {
// Normally these would live in the algebraic simplifier, but we want to run
// this to fixpoint (this pass reaches fixed point in one execution) before we
// run the DotDecomposer.
class BatchDotSimplification : public HloPassInterface {
class BatchDotSimplification : public HloModulePass {
public:
StatusOr<bool> Run(HloModule* module) override;
absl::string_view name() const override;

View File

@ -26,7 +26,7 @@ namespace xla {
// A pass which rewrites batch norm operations into more operations. Breaking a
// big operation into smaller operations helps leverage our generic fusion
// logic.
class BatchNormExpander : public HloPassInterface {
class BatchNormExpander : public HloModulePass {
public:
// When use_fusion is set, a multi-output fusion node is created.
BatchNormExpander(bool rewrite_training_op = false,

View File

@ -31,7 +31,7 @@ namespace xla {
// optimization pipeline followed by a DCE pass. If other passes are needed
// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the
// changed made by this pass.
class BFloat16ConversionFolding : public HloPassInterface {
class BFloat16ConversionFolding : public HloModulePass {
public:
explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support)
: bfloat16_support_(bfloat16_support) {}

View File

@ -25,7 +25,7 @@ namespace xla {
// A pass which adds F32 <-> BF16 conversions for HLO instructions that do not
// support BF16 input/output or mixed precision, according to the passed-in
// backend-specific BF16 support rules.
class BFloat16Normalization : public HloPassInterface {
class BFloat16Normalization : public HloModulePass {
public:
explicit BFloat16Normalization(const BFloat16Support* bfloat16_support)
: bfloat16_support_(bfloat16_support) {}
@ -48,7 +48,7 @@ class BFloat16Normalization : public HloPassInterface {
// use mixed precision; it removes mixed precision even if the backend supports
// it. This pass is used to make the HLO module valid for other HLO passes which
// do not support mixed precision.
class BFloat16MixedPrecisionRemoval : public HloPassInterface {
class BFloat16MixedPrecisionRemoval : public HloModulePass {
public:
BFloat16MixedPrecisionRemoval() {}

View File

@ -58,7 +58,7 @@ namespace xla {
// BFloat16ConversionFolding. If other passes are needed after this pass, run
// BFloat16MixedPrecisionRemoval first to undo some of the changes made by this
// pass.
class BFloat16Propagation : public HloPassInterface {
class BFloat16Propagation : public HloModulePass {
public:
explicit BFloat16Propagation(const BFloat16Support* bfloat16_support);

View File

@ -25,7 +25,7 @@ namespace xla {
// For every kCall operation in the main computation, we inline the body of the
// called function, and proceed recursively.
class CallInliner : public HloPassInterface {
class CallInliner : public HloModulePass {
public:
using InlinedInstructionMap =
std::unordered_map<HloInstruction*, HloInstruction*>;

View File

@ -25,7 +25,7 @@ namespace xla {
// HLO pass that removes kConditional with a constant predicate, replacing them
// with their true or false computation as appropriate.
class ConditionalSimplifier : public HloPassInterface {
class ConditionalSimplifier : public HloModulePass {
public:
absl::string_view name() const override { return "simplify-conditional"; }
StatusOr<bool> Run(HloModule* module) override;

View File

@ -25,7 +25,7 @@ namespace xla {
// A pass which rewrites convolutions with feature_group_count > 1 into
// convolutions with feature_group_count = 1.
class ConvolutionFeatureGroupConverter : public HloPassInterface {
class ConvolutionFeatureGroupConverter : public HloModulePass {
public:
ConvolutionFeatureGroupConverter() {}

View File

@ -43,7 +43,7 @@ namespace xla {
// (3) The buffer set of the root instruction of the entry computation must be
// unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and
// InstructionAliasSet::IsDistinct return true.
class CopyInsertion : public HloPassInterface {
class CopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }

View File

@ -31,7 +31,7 @@ namespace cpu {
// called canonical convolutions). This pass expands non-canonical convolutions
// into reshapes and canonical convolutions, so that these non-canonical
// convolutions can run faster.
class ConvCanonicalization : public HloPassInterface {
class ConvCanonicalization : public HloModulePass {
public:
explicit ConvCanonicalization(
const TargetMachineFeatures* target_machine_features)

View File

@ -30,7 +30,7 @@ namespace xla {
//
// TODO(b/62548313): Remove this when buffer assignment is smarter
// (module-scoped).
class CpuCopyInsertion : public HloPassInterface {
class CpuCopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }

View File

@ -23,7 +23,7 @@ namespace xla {
// This pass should run early in the HLO pipeline and checks for HLO constructs
// which are not supported by the CPU backend and cannot be removed via HLO
// transformations (eg, sparse layouts).
class CpuHloSupportChecker : public HloPassInterface {
class CpuHloSupportChecker : public HloModulePass {
public:
CpuHloSupportChecker() = default;
~CpuHloSupportChecker() override = default;

View File

@ -60,7 +60,7 @@ class ParallelTaskAssignment {
// own embedded computation, which is compiled as a parallel compute function,
// and which is invoked from a kCall instruction that is lowered in codegen to
// a runtime parallel fork/join call.
class ParallelTaskAssigner : public HloPassInterface {
class ParallelTaskAssigner : public HloModulePass {
public:
// 'max_parallelism': the maximum parallel task count per instruction.
// 'shape_size': shape size function used by HloCostAnalysis during parallel

View File

@ -25,7 +25,7 @@ namespace xla {
// A pass which replaces all fusion instructions with the equivalent un-fused
// instructions.
class Defuser : public HloPassInterface {
class Defuser : public HloModulePass {
public:
Defuser() {}
~Defuser() override {}

View File

@ -24,7 +24,7 @@ namespace xla {
namespace {
// Pass which strips control dependencies from all instructions in the module.
class ControlDepRemover : public HloPassInterface {
class ControlDepRemover : public HloModulePass {
public:
ControlDepRemover() = default;
absl::string_view name() const override { return "control-dep-remover"; }

View File

@ -30,7 +30,7 @@ namespace xla {
//
// Current despecialization passes are Defuser, ImplicitBroadcastRemover,
// and BFloat16MixedPrecisionRemoval.
class Despecializer : public HloPassInterface {
class Despecializer : public HloModulePass {
public:
Despecializer();
absl::string_view name() const override { return "despecializer"; }

View File

@ -23,7 +23,7 @@ namespace xla {
// DotDecomposer is a pass which decomposes batch Dot operations into a
// sequence of smaller (R2) Dot operations.
class DotDecomposer : public HloPassInterface {
class DotDecomposer : public HloModulePass {
public:
// Decomposes batch Dot operations when 'decompose_batch_dot' is true.
DotDecomposer(bool decompose_batch_dot = true)

View File

@ -26,7 +26,7 @@ namespace xla {
// Flattening associates each call site with a unique computation (for
// sequential calling contexts) This simplifies buffer assignment and
// points-to analysis (see b/36865746 for details).
class FlattenCallGraph : public HloPassInterface {
class FlattenCallGraph : public HloModulePass {
public:
absl::string_view name() const override { return "flatten-call-graph"; }

View File

@ -23,7 +23,7 @@ namespace xla {
// This pass rewrites gather operations into (roughly) while loops of dynamic
// slices. This lets backends that don't support gather directly to
// nevertheless have a minimum level of support.
class GatherExpander : public HloPassInterface {
class GatherExpander : public HloModulePass {
public:
absl::string_view name() const override { return "gather_expander"; }
StatusOr<bool> Run(HloModule* module) override;

View File

@ -52,7 +52,7 @@ namespace gpu {
// The GPU backend does not implement a lowering for the batchnorm HLOs -- it
// expects them to be lowered to cudnn calls via this pass or to HLO soup via
// BatchNormRewriter.
class CudnnBatchNormRewriter : public HloPassInterface {
class CudnnBatchNormRewriter : public HloModulePass {
public:
absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; }
StatusOr<bool> Run(HloModule* module) override;

View File

@ -30,7 +30,7 @@ namespace gpu {
// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
// each and adding explicit scratch space to the CustomCalls.
class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
class CudnnConvolutionAlgorithmPicker : public HloModulePass {
public:
// If the `allocator` parameter is not null, we will use it to allocate temp
// memory while timing the various convolution algorithms. If it's null,

View File

@ -24,7 +24,7 @@ namespace gpu {
// Rewrites plain convolutions, backwards-filter convolutions, and
// backwards-input convolutions into CustomCall HLOs that call into cuDNN.
class CudnnConvolutionRewriter : public HloPassInterface {
class CudnnConvolutionRewriter : public HloModulePass {
public:
absl::string_view name() const override {
return "cudnn-convolution-rewriter";

View File

@ -32,7 +32,7 @@ namespace gpu {
// 2) The result of merging the fusion instruction into its users would not
// increase bytes transferred.
//
class FusionMerger : public HloPassInterface {
class FusionMerger : public HloModulePass {
public:
absl::string_view name() const override { return "fusion merger"; }

View File

@ -34,15 +34,6 @@ namespace xla {
namespace gpu {
StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy(
HloInstruction* hlo) {
HloInstruction*& copy = hlo_to_copy_map_[hlo];
if (copy == nullptr) {
TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo));
}
return copy;
}
StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
CopyInsertion generic_copy_insertion;

View File

@ -25,20 +25,11 @@ namespace gpu {
// Besides the modifications made by the generic xla::CopyInsertion, this
// GPU-specific copy insertion also materializes operands of library calls by
// inserting kCopy instructions.
class GpuCopyInsertion : public HloPassInterface {
class GpuCopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }
StatusOr<bool> Run(HloModule* module) override;
protected:
// Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making
// duplicate copies.
StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
// A map containing all copies inserted to materialize operands of library
// calls. The key is the copied instruction and the value is the copy.
tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> hlo_to_copy_map_;
};
} // namespace gpu

View File

@ -23,7 +23,7 @@ namespace xla {
// his pass should run early in the HLO pipeline and checks for HLO constructs
// which are not supported by the GPU backend and cannot be removed via HLO
// transformations (eg, sparse layouts).
class GpuHloSupportChecker : public HloPassInterface {
class GpuHloSupportChecker : public HloModulePass {
public:
GpuHloSupportChecker() = default;
~GpuHloSupportChecker() override = default;

View File

@ -30,7 +30,7 @@ namespace gpu {
// targeting before running this pass.
//
// TODO(jlebar): Also pad dots.
class PadForTensorCores : public HloPassInterface {
class PadForTensorCores : public HloModulePass {
public:
absl::string_view name() const override { return "pad for tensor cores"; }

View File

@ -24,7 +24,7 @@ namespace gpu {
// An HLO pass that canonicalizes convolution instructions for GPU codegen. It
// inserts Pad instructions before Convolution instructions with uncanonicalized
// padding, so that they can be lowered to cuDNN convolution.
class PadInsertion : public HloPassInterface {
class PadInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "pad insertion"; }

View File

@ -23,7 +23,7 @@ namespace xla {
// A pass which performs constant folding in order to avoid unnecessary
// computation on constants.
class HloConstantFolding : public HloPassInterface {
class HloConstantFolding : public HloModulePass {
public:
absl::string_view name() const override { return "constant_folding"; }

View File

@ -25,7 +25,7 @@ namespace xla {
// and identical instructions with the same operands are commoned. The pass
// iterates over the instructions in topological order which enables the pass to
// find arbitrarily large common expressions.
class HloCSE : public HloPassInterface {
class HloCSE : public HloModulePass {
public:
// If is_layout_sensitive is true, then the simplifier preserves layout during
// transformation. Otherwise, layout is ignored.

View File

@ -33,7 +33,7 @@ namespace xla {
//
// This pass does not remove dead parameter instructions, as parameter
// instructions cannot be deleted.
class HloDCE : public HloPassInterface {
class HloDCE : public HloModulePass {
public:
~HloDCE() override {}
absl::string_view name() const override { return "dce"; }

View File

@ -30,7 +30,7 @@ namespace xla {
// used to break an HLO graph edge connecting two instructions with different
// sharding. If a set of connected instructions have all the same sharding, no
// kDomain instruction will be placed.
class HloDomainIsolator : public HloPassInterface {
class HloDomainIsolator : public HloModulePass {
public:
// Creates a new kDomain instruction for the edge between the use instruction
// (the first HloInstruction argument), and the operand instruction (the

View File

@ -26,7 +26,7 @@ namespace xla {
// Removes all the kDomain instructions of a given kind from the input module,
// and calls the normalizer to propagate the properties on the possibly new born
// instructions.
class HloDomainRemover : public HloPassInterface {
class HloDomainRemover : public HloModulePass {
public:
// Creates a new HloDomainRemover object tasked at removing all the kDomain
// instructions of a given kind.

View File

@ -29,7 +29,7 @@ namespace xla {
// Verifies that the domain instructions are consistent, and the each domain is
// surrounded by the same metadata.
class HloDomainVerifier : public HloPassInterface {
class HloDomainVerifier : public HloModulePass {
public:
HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {}

View File

@ -25,7 +25,7 @@ namespace xla {
// inserting Convert ops. This allows a backend to support an element type while
// only actually implementing the Convert op for that element type. This is
// generally not the fastest approach, but it works.
class HloElementTypeConverter : public HloPassInterface {
class HloElementTypeConverter : public HloModulePass {
public:
// eliminate_type is the type to eliminate as the input or output of ops,
// using Convert ops to replace it with replace_with_type.

View File

@ -90,7 +90,7 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
// A pass which schedules the HLO instructions in a module. The HloModule's
// schedule field is set to the resulting HloSchedule using
// HloModule::set_schedule.
class HloMemoryScheduler : public HloPassInterface {
class HloMemoryScheduler : public HloModulePass {
public:
// size_function is the function returning the number of bytes required for a
// LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
@ -109,7 +109,7 @@ class HloMemoryScheduler : public HloPassInterface {
// A trivial pass which clears the schedule currently set on the
// HloModule. After this pass runs HloModudle::has_schedule will return false.
class HloDescheduler : public HloPassInterface {
class HloDescheduler : public HloModulePass {
public:
HloDescheduler() = default;
~HloDescheduler() override = default;

View File

@ -28,7 +28,7 @@ namespace xla {
// Sweeps through live instructions which cross computation boundaries (kWhile),
// and removes code at dead shape indices.
//
class HloModuleDCE : public HloPassInterface {
class HloModuleDCE : public HloModulePass {
public:
~HloModuleDCE() override {}
absl::string_view name() const override { return "hlo-module-dce"; }

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@ -25,15 +26,45 @@ limitations under the License.
namespace xla {
// Base class for HLO passes. These are used with the HloPassPipeline to
// organize a sequence of passes.
// organize a sequence of passes. An HLO pass should not extend this class
// directly; it should extend HloModulePass or HloModuleGroupPass.
class HloPassInterface {
public:
virtual ~HloPassInterface() = default;
virtual absl::string_view name() const = 0;
// Run the pass on the given HLO module. Return whether it modified the
// Run the pass on the given HLO module. Returns whether it modified the
// module.
virtual StatusOr<bool> Run(HloModule* module) = 0;
// Run the pass on the given HLO module group. Returns whether it modified the
// module group. Ideally, the module group variant would be named "Run" as
// well, but C++ does not handle overloaded virtual methods well.
virtual StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) = 0;
};
// Base class for passes which are module-scoped.
class HloModulePass : public HloPassInterface {
public:
// Runs the pass on a module group by iterating through each module in the
// group.
StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
bool changed = false;
for (HloModule* module : module_group->modules()) {
TF_ASSIGN_OR_RETURN(bool module_changed, Run(module));
changed |= module_changed;
}
return changed;
};
};
// Base class for passes which are module-group scoped. These passes cannot run
// on an HLO module.
class HloModuleGroupPass : public HloPassInterface {
public:
StatusOr<bool> Run(HloModule* module) override {
return InternalError("Module group pass cannot be run on a module");
}
};
} // namespace xla

View File

@ -17,7 +17,6 @@ limitations under the License.
#include <functional>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
@ -29,108 +28,128 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace {
using absl::StrAppend;
using absl::StrCat;
template <typename HloT>
Status HloPassPipeline::RunInvariantCheckers(
HloT* hlo, absl::string_view after_pass_name) {
for (auto& invariant_checker : invariant_checkers_) {
VLOG(1) << " Invariant checker " << invariant_checker->name();
StatusOr<bool> changed_status = RunHelper(invariant_checker.get(), hlo);
VLOG(1) << " Invariant checker done " << invariant_checker->name();
if (!changed_status.ok()) {
VLOG(2) << "Failed invariant check:";
XLA_VLOG_LINES(2, hlo->ToString());
return Status(changed_status.status().code(),
absl::StrCat(changed_status.status().error_message(),
"\n\nFailed after ", after_pass_name));
}
TF_RET_CHECK(!changed_status.ValueOrDie())
<< "invariant checkers must not change the graph";
}
return Status::OK();
}
void DumpModuleGraph(const HloModule& module, const string& message) {
template <typename HloT>
StatusOr<bool> HloPassPipeline::RunPassesInternal(
HloT* hlo, absl::Span<HloPassInterface* const> passes) {
string last_pass_name = "pipeline-start";
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
bool changed = false;
for (HloPassInterface* pass : passes) {
VLOG(1) << " HLO pass " << pass->name();
MaybeDumpHlo(*hlo,
/*after_pass_name=*/last_pass_name,
/*before_pass_name=*/pass->name());
TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
changed |= pass_changed;
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name()));
last_pass_name = string(pass->name());
}
MaybeDumpHlo(*hlo,
/*after_pass_name=*/last_pass_name,
/*before_pass_name=*/"pipeline-end");
return changed;
}
std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
const DebugOptions& debug_options) {
auto repeated_field = debug_options.xla_disable_hlo_passes();
tensorflow::gtl::FlatSet<string> disabled_pass_names(repeated_field.begin(),
repeated_field.end());
if (!disabled_pass_names.empty()) {
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
<< absl::StrJoin(disabled_pass_names, ", ");
}
std::vector<HloPassInterface*> enabled_passes;
for (auto& pass : passes_) {
if (disabled_pass_names.count(string(pass->name())) == 0) {
enabled_passes.push_back(pass.get());
}
}
return enabled_passes;
}
void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
absl::string_view after_pass_name,
absl::string_view before_pass_name) {
const string& proto_dump_path =
module.config().debug_options().xla_dump_per_pass_hlo_proto_to();
if (!proto_dump_path.empty()) {
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
static auto* const module_id_to_pass_number =
new tensorflow::gtl::FlatMap<int64, int64>();
tensorflow::mutex_lock lock(mu);
const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
const string filename = SanitizeFileName(
absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
pass_number, name(), after_pass_name));
TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(
MakeHloProto(module), proto_dump_path, filename));
}
const string message =
StrCat("after ", after_pass_name, ", before ", before_pass_name);
hlo_graph_dumper::MaybeDumpHloModule(module, message);
VLOG(3) << "HLO " << message << ":";
XLA_VLOG_LINES(3, module.ToString());
}
void DumpModuleProto(const HloModule& module, const string& dump_to,
const string& pipeline_name, const string& pass_name) {
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
static auto* const module_id_to_pass_number =
new tensorflow::gtl::FlatMap<int64, int64>();
tensorflow::mutex_lock lock(mu);
const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
const string mod_name = SanitizeFileName(
absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
pass_number, pipeline_name, pass_name));
TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module),
dump_to, mod_name));
void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group,
absl::string_view after_pass_name,
absl::string_view before_pass_name) {
for (const HloModule* module : module_group.modules()) {
MaybeDumpHlo(*module, after_pass_name, before_pass_name);
}
}
} // namespace
StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
run_called_ = true;
VLOG(1) << "Running HLO pass pipeline " << name();
VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": "
<< name();
auto repeated_field =
module->config().debug_options().xla_disable_hlo_passes();
tensorflow::gtl::FlatSet<string> disabled_passes(repeated_field.begin(),
repeated_field.end());
if (!disabled_passes.empty()) {
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
<< absl::StrJoin(disabled_passes, ", ");
return RunPassesInternal(module,
GetEnabledPasses(module->config().debug_options()));
}
StatusOr<bool> HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) {
run_called_ = true;
VLOG(1) << "Running HLO pass pipeline on module group "
<< module_group->name() << ": " << name();
if (module_group->modules().empty()) {
VLOG(1) << "Module group is empty. Nothing to do.";
return false;
}
auto run_invariant_checkers = [this,
module](const string& message) -> Status {
for (auto& invariant_checker : invariant_checkers_) {
VLOG(1) << " Invariant checker " << invariant_checker->name();
StatusOr<bool> changed_status = invariant_checker->Run(module);
VLOG(1) << " Invariant checker done " << invariant_checker->name();
if (!changed_status.ok()) {
VLOG(2) << "Module failed invariant check:";
XLA_VLOG_LINES(2, module->ToString());
return Status(changed_status.status().code(),
StrCat(changed_status.status().error_message(),
"\n\nFailed ", message));
}
TF_RET_CHECK(!changed_status.ValueOrDie())
<< "invariant checkers must not change the graph";
}
return Status::OK();
};
string prefix = StrCat(name(), ": pipeline start");
bool changed = false;
string message;
TF_RETURN_IF_ERROR(
run_invariant_checkers(StrCat("before running pipeline: ", name())));
const string xla_dump_per_pass_hlo_proto_to =
module->config().debug_options().xla_dump_per_pass_hlo_proto_to();
if (!xla_dump_per_pass_hlo_proto_to.empty()) {
DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
"pipeline_start");
}
for (auto& pass : passes_) {
if (disabled_passes.count(string(pass->name())) > 0) {
VLOG(1) << " Skipping HLO pass " << pass->name()
<< ", disabled by --xla_disable_hlo_passes";
continue;
}
VLOG(1) << " HLO pass " << pass->name();
// Emit label containing: "after foo-pass, before bar-pass".
message.clear();
StrAppend(&message, prefix, ", before ", pass->name());
DumpModuleGraph(*module, message);
TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
TF_RETURN_IF_ERROR(
run_invariant_checkers(StrCat("after running pass: ", pass->name())));
if (!xla_dump_per_pass_hlo_proto_to.empty()) {
DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
string(pass->name()));
}
changed |= changed_this_pass;
prefix.clear();
StrAppend(&prefix, name(), ": after ", pass->name());
}
DumpModuleGraph(*module, prefix + ", pipeline end");
return changed;
return RunPassesInternal(
module_group,
GetEnabledPasses(module_group->module(0).config().debug_options()));
}
} // namespace xla

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -61,10 +62,45 @@ class HloPassPipeline : public HloPassInterface {
return *pass;
}
// Run all passes on the given HLO module.
StatusOr<bool> Run(HloModule* module) override;
StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override;
private:
// Returns the set of passes which are enabled. DebugOptions can selectively
// disable passes via --xla_disable_hlo_passes flag.
std::vector<HloPassInterface*> GetEnabledPasses(
const DebugOptions& debug_options);
// Maybe dumps the given module or module group depending on flag values
// contained in DebugOptions of module config.
void MaybeDumpHlo(const HloModuleGroup& module_group,
absl::string_view after_pass_name,
absl::string_view before_pass_name);
void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name,
absl::string_view before_pass_name);
// Runs the invariant checker on the given HLO. HloT can be either HloModule
// or HloModuleGroup.
template <typename HloT>
Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name);
// Helper which runs the given pass on the given HLO. HloT can be either
// HloModule or HloModuleGroup.
template <typename HloT>
StatusOr<bool> RunPassesInternal(HloT* hlo,
absl::Span<HloPassInterface* const> passes);
// Helpers which run the given passes on the given HLO construct. These
// helpers enable templating of the core of the pipeline logic by providing
// HloModule and HloModuleGroup specific methods with the same name.
static StatusOr<bool> RunHelper(HloPassInterface* pass, HloModule* module) {
return pass->Run(module);
}
static StatusOr<bool> RunHelper(HloPassInterface* pass,
HloModuleGroup* module_group) {
return pass->RunOnModuleGroup(module_group);
}
const string name_;
std::vector<std::unique_ptr<HloPassInterface>> passes_;
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;

View File

@ -0,0 +1,259 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
class HloPassPipelineTest : public HloTestBase {
protected:
StatusOr<HloModuleGroup> ParseModuleGroup(
absl::Span<const string> hlo_strings) {
HloModuleGroup group(TestName());
for (const string& hlo_string : hlo_strings) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
ParseHloString(hlo_string));
group.push_back(std::move(module));
}
return std::move(group);
}
};
// A module pass which renames instructions named 'foo' to 'bar'.
class FooToBarModulePass : public HloModulePass {
absl::string_view name() const override { return "foo2bar"; }
StatusOr<bool> Run(HloModule* module) override {
bool changed = false;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->name() == "foo") {
instruction->SetAndSanitizeName("bar");
changed = true;
}
}
}
return changed;
}
};
// A module group pass which renames instructions named 'baz' to 'qux'.
class BazToQuxModuleGroupPass : public HloModuleGroupPass {
absl::string_view name() const override { return "baz2qux"; }
StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
bool changed = false;
for (HloModule* module : module_group->modules()) {
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->name() == "baz") {
instruction->SetAndSanitizeName("qux");
changed = true;
}
}
}
}
return changed;
}
};
// An invariant checker pass which returns an error if there exists an
// instruction named 'bar'.
class BarBlowerUpper : public HloModulePass {
absl::string_view name() const override { return "bar-blower-upper"; }
StatusOr<bool> Run(HloModule* module) override {
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->name() == "bar") {
return InternalError("Module has instruction named bar");
}
}
}
return false;
}
};
TEST_F(HloPassPipelineTest, ModulePassChanged) {
// Test an HLO module pass which changes a module.
const string module_str = R"(
HloModule ModulePassChanged
ENTRY main {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT foo = f32[] multiply(a, b)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(module_str));
HloPassPipeline pipeline(TestName());
pipeline.AddPass<FooToBarModulePass>();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_EQ(root->name(), "foo");
TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_EQ(root->name(), "bar");
}
TEST_F(HloPassPipelineTest, ModulePassUnchanged) {
// Test an HLO module pass which does not change a module.
const string module_str = R"(
HloModule ModulePassUnchanged
ENTRY main {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT blahblah = f32[] multiply(a, b)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(module_str));
HloPassPipeline pipeline(TestName());
pipeline.AddPass<FooToBarModulePass>();
TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
EXPECT_FALSE(changed);
}
TEST_F(HloPassPipelineTest, MixedPipeline) {
// Test a pipeline with both a module pass and a module group pass.
const string module_0_str = R"(
HloModule MixedPipeline.1
ENTRY main {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT baz = f32[] multiply(a, b)
}
)";
const string module_1_str = R"(
HloModule MixedPipeline.0
ENTRY main {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT foo = f32[] multiply(a, b)
}
)";
TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group,
ParseModuleGroup({module_0_str, module_1_str}));
HloPassPipeline pipeline(TestName());
pipeline.AddPass<BazToQuxModuleGroupPass>();
pipeline.AddPass<FooToBarModulePass>();
HloInstruction* root0 =
module_group.module(0).entry_computation()->root_instruction();
HloInstruction* root1 =
module_group.module(1).entry_computation()->root_instruction();
EXPECT_EQ(root0->name(), "baz");
EXPECT_EQ(root1->name(), "foo");
TF_ASSERT_OK_AND_ASSIGN(bool changed,
pipeline.RunOnModuleGroup(&module_group));
EXPECT_TRUE(changed);
EXPECT_EQ(root0->name(), "qux");
EXPECT_EQ(root1->name(), "bar");
}
TEST_F(HloPassPipelineTest, InvariantChecker) {
const string module_str = R"(
HloModule InvariantChecker
ENTRY main {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT foo = f32[] multiply(a, b)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(module_str));
{
// Run a pipeline with just the invariant checker. It should not fail
// because there is no 'bar' instruction in the module.
HloPassPipeline pipeline(TestName());
pipeline.AddInvariantChecker<BarBlowerUpper>();
TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
EXPECT_FALSE(changed);
}
{
// Run a pipeline which renames 'foo' to 'bar' then an invariant checker
// which fails if there is an instruction named 'bar'.
HloPassPipeline pipeline(TestName());
pipeline.AddInvariantChecker<BarBlowerUpper>();
pipeline.AddPass<FooToBarModulePass>();
Status status = pipeline.Run(module.get()).status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(status.error_message(),
::testing::HasSubstr("Module has instruction named bar"));
EXPECT_THAT(status.error_message(),
::testing::HasSubstr("Failed after foo2bar"));
}
{
// Run the invariant-checker only pipeline again. It should fail this time.
HloPassPipeline pipeline(TestName());
pipeline.AddInvariantChecker<BarBlowerUpper>();
Status status = pipeline.Run(module.get()).status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(status.error_message(),
::testing::HasSubstr("Module has instruction named bar"));
EXPECT_THAT(status.error_message(),
::testing::HasSubstr("Failed after pipeline-start"));
}
}
TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) {
// Running a module group pass on a module should produce an error.
const string module_str = R"(
HloModule ModuleGroupPassOnModule
ENTRY main {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT foo = f32[] multiply(a, b)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(module_str));
HloPassPipeline pipeline(TestName());
pipeline.AddPass<BazToQuxModuleGroupPass>();
Status status = pipeline.Run(module.get()).status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(
status.error_message(),
::testing::HasSubstr("Module group pass cannot be run on a module"));
}
} // namespace
} // namespace xla

View File

@ -1198,6 +1198,12 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
<< HumanReadableNumBytes(memory_limit_bytes_);
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
// Initialize pass object state.
computation_peak_memory_.clear();
rematerialized_computations_.clear();
instructions_rematerialized_ = 0;
net_instructions_added_ = 0;
TF_RET_CHECK(module->has_schedule());
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));

View File

@ -33,7 +33,7 @@ namespace xla {
// CSE will undo the effects of this optimization and should not be run after
// this pass. In general, this pass should be run very late, immediately before
// code generation.
class HloRematerialization : public HloPassInterface {
class HloRematerialization : public HloModulePass {
public:
using ShapeSizeFunction = std::function<int64(const Shape&)>;

View File

@ -22,7 +22,7 @@ namespace xla {
// Unify subcomputations of a `HloModule`: if any computations are equal, choose
// one arbitrarily to use and delete the others.
class HloSubcomputationUnification : public HloPassInterface {
class HloSubcomputationUnification : public HloModulePass {
public:
absl::string_view name() const override {
return "subcomputation-unification";

View File

@ -151,7 +151,7 @@ class ShapeVerifier : public DfsHloVisitor {
// HLO pass that verifies invariants of HLO instructions for each computation in
// the module.
class HloVerifier : public HloPassInterface {
class HloVerifier : public HloModulePass {
public:
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;

View File

@ -25,7 +25,7 @@ namespace xla {
// Pass which replaces all implicit broadcasts with their equivalent sequence of
// explicit broadcast and reshape instructions.
class ImplicitBroadcastRemover : public HloPassInterface {
class ImplicitBroadcastRemover : public HloModulePass {
public:
ImplicitBroadcastRemover() {}
~ImplicitBroadcastRemover() override {}

View File

@ -366,7 +366,7 @@ class IndexedArrayAnalysis {
// A pass that prints all non-trivial results returned by IndexedArrayAnalysis.
// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to
// unconditionally add to the regular HLO pass pipeline.
class IndexedArrayAnalysisPrinterPass : public HloPassInterface {
class IndexedArrayAnalysisPrinterPass : public HloModulePass {
public:
absl::string_view name() const override;
StatusOr<bool> Run(HloModule* module) override;

View File

@ -24,7 +24,7 @@ namespace xla {
// A pass which performs inlining. Which can result, for example, in functions
// that were previously being mapped by Map instead directly applied to the
// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)).
class Inliner : public HloPassInterface {
class Inliner : public HloModulePass {
public:
~Inliner() override = default;
absl::string_view name() const override { return "inline"; }

View File

@ -56,7 +56,7 @@ class FusionQueue {
// with the intent that the loops which compute their values will be fused in
// code generation. Derived classes define ShouldFuse method to select which
// instructions to fuse.
class InstructionFusion : public HloPassInterface {
class InstructionFusion : public HloModulePass {
public:
explicit InstructionFusion(
std::function<bool(const HloInstruction& instruction)> is_expensive,

View File

@ -281,7 +281,7 @@ class ChannelLayoutConstraints {
// HLO pass which assigns layouts to all instructions in the HLO module while
// satisfying all necessary invariants and minimizing cost.
class LayoutAssignment : public HloPassInterface {
class LayoutAssignment : public HloModulePass {
public:
// entry_computation_layout is modified to populate a layout for the result in
// the case that no particular layout is requested.

View File

@ -44,7 +44,7 @@ namespace xla {
// Note that the reachability map is updated based on the original computation.
// This works because the reachability is monotonically increasing with
// instruction fusion.
class MultiOutputFusion : public HloPassInterface {
class MultiOutputFusion : public HloModulePass {
public:
MultiOutputFusion(int64 fuel) : fuel_(fuel) {}

View File

@ -29,7 +29,7 @@ namespace xla {
// HLO pass which inserts reduce-precision instructions into the HLO graph, for
// purposes of experimenting with the effects of reduced-precision storage of
// intermediate values.
class ReducePrecisionInsertion : public HloPassInterface {
class ReducePrecisionInsertion : public HloModulePass {
using InstructionFilterFunction = std::function<bool(const HloInstruction*)>;
public:

View File

@ -24,7 +24,7 @@ namespace xla {
// This now only moves them outputward across elementwise ops all whose operands
// are equivalent Reshapes or Transposes, but in future could potentially move
// them inputward also.
class ReshapeMover : public HloPassInterface {
class ReshapeMover : public HloModulePass {
public:
absl::string_view name() const override { return "reshape-mover"; }

View File

@ -20,7 +20,7 @@ limitations under the License.
namespace xla {
class ScatterExpander : public HloPassInterface {
class ScatterExpander : public HloModulePass {
public:
absl::string_view name() const override { return "scatter_expander"; }
StatusOr<bool> Run(HloModule* module) override;

View File

@ -23,7 +23,7 @@ namespace xla {
// HLO pass that folds transpose operators into Dot operators, where the Dot
// operator is implemented by a GEMM kernel that can transpose its inputs.
class TransposeFolding : public HloPassInterface {
class TransposeFolding : public HloModulePass {
public:
using OperandIndices = std::vector<int64>;

View File

@ -25,7 +25,7 @@ namespace xla {
// A pass which simplifies patterns of Tuple and GetTupleElement instructions in
// the module.
class TupleSimplifier : public HloPassInterface {
class TupleSimplifier : public HloModulePass {
public:
TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {}
explicit TupleSimplifier(bool exclude_entry_computation);

View File

@ -50,7 +50,7 @@ namespace xla {
// conditions as well.
//
// TODO(b/79121449): We should also sink broadcasts of constants.
class WhileLoopConstantSinking : public HloPassInterface {
class WhileLoopConstantSinking : public HloModulePass {
public:
~WhileLoopConstantSinking() override = default;

View File

@ -25,7 +25,7 @@ namespace xla {
// HLO pass that rewrites while loops to hoist loop invariant instructions in
// the while body into the computation that contains the while instruction.
class WhileLoopInvariantCodeMotion : public HloPassInterface {
class WhileLoopInvariantCodeMotion : public HloModulePass {
public:
// If `hoist_constants` is true then constants are always hoisted out of while
// loop bodies. Otherwise they are only hoisted out if they enable other

View File

@ -30,7 +30,7 @@ namespace xla {
// - Elements of a while loop's tuple that the loop doesn't use are removed
// from the tuple.
//
class WhileLoopSimplifier : public HloPassInterface {
class WhileLoopSimplifier : public HloModulePass {
public:
~WhileLoopSimplifier() override {}
absl::string_view name() const override { return "simplify-while-loops"; }

View File

@ -21,7 +21,7 @@ limitations under the License.
// HLO pass that replaces zero sized Hlos with a zero sized constant literal.
namespace xla {
class ZeroSizedHloElimination : public HloPassInterface {
class ZeroSizedHloElimination : public HloModulePass {
public:
StatusOr<bool> Run(HloModule* module) override;
absl::string_view name() const override {