[XLA] Make HloPass an interface, NFC

This will allow inheritance from both `HloPassInterface` and `DfsHloVisitor`, so various passes which include a visitor can have handler methods overridden per backend.
Change: 145477041
This commit is contained in:
A. Unique TensorFlower 2017-01-24 15:35:34 -08:00 committed by TensorFlower Gardener
parent cdcc0f58a2
commit f9006d72eb
25 changed files with 120 additions and 81 deletions

View File

@ -576,6 +576,7 @@ cc_test(
":algebraic_simplifier",
":cpu_plugin",
":hlo",
":hlo_pass",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test_helpers",
@ -1006,7 +1007,10 @@ cc_test(
cc_library(
name = "hlo_pass",
hdrs = ["hlo_pass.h"],
hdrs = [
"hlo_pass_fix.h",
"hlo_pass_interface.h",
],
deps = [
":hlo",
"//tensorflow/compiler/xla:status_macros",

View File

@ -19,12 +19,12 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// A pass which performs AlgebraicSimplications.
class AlgebraicSimplifier : public HloPass {
class AlgebraicSimplifier : public HloPassInterface {
public:
// Given two shapes, determines if it is valid to bitcast between them after
// considering platform dependent effects on layout like alignment
@ -39,10 +39,10 @@ class AlgebraicSimplifier : public HloPass {
// bitcasts.
AlgebraicSimplifier(bool is_layout_sensitive,
ValidBitcastCallback valid_bitcast_callback)
: HloPass("algsimp"),
is_layout_sensitive_(is_layout_sensitive),
: is_layout_sensitive_(is_layout_sensitive),
valid_bitcast_callback_(std::move(valid_bitcast_callback)) {}
~AlgebraicSimplifier() override {}
tensorflow::StringPiece name() const override { return "algsimp"; }
// Run algebraic simplification on the given computation. Returns whether the
// computation was changed.

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
@ -30,10 +30,10 @@ namespace xla {
// constant or parameter instructions will be copied.
// Copy insertion is necessary because constant and parameter arrays have
// different lifetimes than computation results.
class CopyInsertion : public HloPass {
class CopyInsertion : public HloPassInterface {
public:
CopyInsertion() : HloPass("copy-insertion") {}
~CopyInsertion() override {}
tensorflow::StringPiece name() const override { return "copy-insertion"; }
// Run the pass on the given module. Returns whether the module was changed
// (copies were inserted).

View File

@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
namespace cpu {
@ -30,10 +30,12 @@ namespace cpu {
// called canonical convolutions). This pass expands non-canonical convolutions
// into reshapes and canonical convolutions, so that these non-canonical
// convolutions can run faster.
class ConvCanonicalization : public HloPass {
class ConvCanonicalization : public HloPassInterface {
public:
ConvCanonicalization() : HloPass("convolution-canonicalization") {}
~ConvCanonicalization() override {}
tensorflow::StringPiece name() const override {
return "convolution-canonicalization";
}
StatusOr<bool> Run(HloModule* module) override;
};

View File

@ -62,7 +62,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/inliner.h"

View File

@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
namespace cpu {
@ -30,10 +30,12 @@ namespace cpu {
// computations. However, it could make sense to coarsen the parallelization to
// improve cache locality. Also, we will need to do something to intelligently
// handle While constructs.
class ParallelizationPreparation : public HloPass {
class ParallelizationPreparation : public HloPassInterface {
public:
explicit ParallelizationPreparation() : HloPass("cpu-parallel-prepare") {}
~ParallelizationPreparation() override {}
tensorflow::StringPiece name() const override {
return "cpu-parallel-prepare";
}
// Run instruction fusion on the given computation. Returns whether the
// computation was changed.

View File

@ -17,14 +17,17 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
namespace gpu {
class ConvolutionFolding : public HloPass {
class ConvolutionFolding : public HloPassInterface {
public:
ConvolutionFolding() : HloPass("convolution-folding") {}
tensorflow::StringPiece name() const override {
return "convolution-folding";
}
StatusOr<bool> Run(HloModule* module) override;
};

View File

@ -17,7 +17,7 @@ limitations under the License.
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
namespace gpu {
@ -32,9 +32,9 @@ namespace gpu {
// 2) The result of merging the fusion instruction into its users would not
// increase bytes transferred.
//
class FusionMerger : public HloPass {
class FusionMerger : public HloPassInterface {
public:
FusionMerger() : HloPass("fusion merger") {}
tensorflow::StringPiece name() const override { return "fusion merger"; }
StatusOr<bool> Run(HloModule* module) override;

View File

@ -48,7 +48,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_cse.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
namespace gpu {
@ -24,9 +24,9 @@ namespace gpu {
// An HLO pass that canonicalizes convolution instructions for GPU codegen. It
// inserts Pad instructions before Convolution instructions with uncanonicalized
// padding, so that they can be lowered to cuDNN convolution.
class PadInsertion : public HloPass {
class PadInsertion : public HloPassInterface {
public:
PadInsertion() : HloPass("pad insertion") {}
tensorflow::StringPiece name() const override { return "pad insertion"; }
StatusOr<bool> Run(HloModule* module) override;

View File

@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CSE_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
@ -25,13 +25,14 @@ namespace xla {
// and identical instructions with the same operands are commoned. The pass
// iterates over the instructions in topological order which enables the pass to
// find arbitrarily large common expressions.
class HloCSE : public HloPass {
class HloCSE : public HloPassInterface {
public:
// If is_layout_sensitive is true, then the simplifier preserves layout during
// transformation. Otherwise, layout is ignored.
explicit HloCSE(bool is_layout_sensitive)
: HloPass("cse"), is_layout_sensitive_(is_layout_sensitive) {}
: is_layout_sensitive_(is_layout_sensitive) {}
~HloCSE() override {}
tensorflow::StringPiece name() const override { return "cse"; }
// Run CSE on the given module. Returns whether the module was changed (common
// subexpressions were found and eliminated).

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
@ -28,10 +28,10 @@ namespace xla {
// module. An instruction is dead if it is not reachable from the root. This
// pass does not remove dead parameter instructions as parameter instructions
// cannot be deleted, nor does the pass remove dead computations.
class HloDCE : public HloPass {
class HloDCE : public HloPassInterface {
public:
HloDCE() : HloPass("dce") {}
~HloDCE() override {}
tensorflow::StringPiece name() const override { return "dce"; }
// Run the pass on the given module. Returns whether the module was changed
// (instructions were removed).

View File

@ -13,10 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_H_
#include <string>
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -26,25 +24,6 @@ limitations under the License.
namespace xla {
// Base class for HLO passes. These are used with the HloPassPipeline to
// organize a sequence of passes.
class HloPass {
public:
explicit HloPass(const string& name) : name_(name) {}
virtual ~HloPass() {}
const string& name() const { return name_; }
// Run the pass on the given HLO module. Return whether it modified the
// module.
virtual StatusOr<bool> Run(HloModule* module) = 0;
private:
const string name_;
TF_DISALLOW_COPY_AND_ASSIGN(HloPass);
};
// Do an HLO pass to a fix point.
template <typename Pass>
class HloPassFix : public Pass {
@ -65,4 +44,4 @@ class HloPassFix : public Pass {
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_H_
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_

View File

@ -0,0 +1,41 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
// Base class for HLO passes. These are used with the HloPassPipeline to
// organize a sequence of passes.
class HloPassInterface {
public:
virtual ~HloPassInterface() = default;
virtual tensorflow::StringPiece name() const = 0;
// Run the pass on the given HLO module. Return whether it modified the
// module.
virtual StatusOr<bool> Run(HloModule* module) = 0;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_

View File

@ -35,11 +35,12 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ',');
tensorflow::gtl::FlatSet<string> disabled_passes(tmp.begin(), tmp.end());
string prefix = name() + ": pipeline start";
string prefix = name().ToString() + ": pipeline start";
bool changed = false;
string message;
for (auto& pass : passes_) {
if (!disabled_passes.empty() && disabled_passes.count(pass->name()) > 0) {
if (!disabled_passes.empty() &&
disabled_passes.count(pass->name().ToString()) > 0) {
continue;
}

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/macros.h"
@ -32,11 +32,12 @@ limitations under the License.
namespace xla {
// Pipeline of HLO passes.
class HloPassPipeline : public HloPass {
class HloPassPipeline : public HloPassInterface {
public:
explicit HloPassPipeline(const string& name,
const Compiler::HloDumper& dumper)
: HloPass(name), dumper_(dumper) {}
: name_(name), dumper_(dumper) {}
tensorflow::StringPiece name() const override { return name_; }
// Add a pass to the pipeline. It should be called with the arguments for the
// pass constructor:
@ -55,8 +56,9 @@ class HloPassPipeline : public HloPass {
StatusOr<bool> Run(HloModule* module) override;
private:
const string name_;
Compiler::HloDumper dumper_;
std::vector<std::unique_ptr<HloPass>> passes_;
std::vector<std::unique_ptr<HloPassInterface>> passes_;
TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline);
};

View File

@ -16,15 +16,17 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SUBCOMPUTATION_UNIFICATION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SUBCOMPUTATION_UNIFICATION_H_
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// Unify subcomputations of a `HloModule`: if any computations are equal, choose
// one arbitrarily to use and delete the others.
class HloSubcomputationUnification : public HloPass {
class HloSubcomputationUnification : public HloPassInterface {
public:
HloSubcomputationUnification() : HloPass("subcomputation unification") {}
tensorflow::StringPiece name() const override {
return "subcomputation-unification";
}
StatusOr<bool> Run(HloModule* module) override;
};

View File

@ -17,17 +17,17 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// A pass which performs inlining. Which can result, for example, in functions
// that were previously being mapped by Map instead directly applied to the
// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)).
class Inliner : public HloPass {
class Inliner : public HloPassInterface {
public:
Inliner() : HloPass("inline") {}
~Inliner() override = default;
tensorflow::StringPiece name() const override { return "inline"; }
// Run inlining on the given computation. Returns whether the computation was
// changed.

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@ -38,11 +38,12 @@ bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer);
// with the intent that the loops which compute their values will be fused in
// code generation. Derived classes define ShouldFuse method to select which
// instructions to fuse.
class InstructionFusion : public HloPass {
class InstructionFusion : public HloPassInterface {
public:
explicit InstructionFusion(bool may_duplicate = true)
: HloPass("fusion"), may_duplicate_(may_duplicate) {}
: may_duplicate_(may_duplicate) {}
~InstructionFusion() override {}
tensorflow::StringPiece name() const override { return "fusion"; }
// Run instruction fusion on the given computation. Returns whether the
// computation was changed (instructions were fused).

View File

@ -638,8 +638,7 @@ Status CheckLayouts(
} // namespace
LayoutAssignment::LayoutAssignment(ComputationLayout* entry_computation_layout)
: HloPass("layout-assignment"),
entry_computation_layout_(entry_computation_layout) {
: entry_computation_layout_(entry_computation_layout) {
VLOG(1) << "entry computation layout given to layout assignment: "
<< entry_computation_layout_->ToString();
// Layouts of all parameter instructions must be set.

View File

@ -29,7 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/shape_layout.h"
@ -203,12 +203,13 @@ class LayoutConstraints {
// HLO pass which assigns layouts to all instructions in the HLO module while
// satisfying all necessary invariants and minimizing cost.
class LayoutAssignment : public HloPass {
class LayoutAssignment : public HloPassInterface {
public:
// entry_computation_layout is modified to populate a layout for the result in
// the case that no particular layout is requested.
explicit LayoutAssignment(ComputationLayout* entry_computation_layout);
~LayoutAssignment() override {}
tensorflow::StringPiece name() const override { return "layout-assignment"; }
// Assign layouts to the given module. Returns whether the module was changed
// (any layouts were changed).

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_RESHAPE_MOVER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_RESHAPE_MOVER_H_
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
@ -24,9 +24,9 @@ namespace xla {
// This now only moves them outputward across elementwise ops all whose operands
// are equivalent Reshapes or Transposes, but in future could potentially move
// them inputward also.
class ReshapeMover : public HloPass {
class ReshapeMover : public HloPassInterface {
public:
ReshapeMover() : HloPass("reshape motion") {}
tensorflow::StringPiece name() const override { return "reshape-motion"; }
StatusOr<bool> Run(HloModule* module) override;
};

View File

@ -82,8 +82,7 @@ bool FoldTransposeIntoDot(HloInstruction* dot, HloComputation* computation) {
} // namespace
TransposeFolding::TransposeFolding(IsTransposableGemmFn is_transposable_gemm)
: HloPass("transpose-folding"),
is_transposable_gemm_(std::move(is_transposable_gemm)) {}
: is_transposable_gemm_(std::move(is_transposable_gemm)) {}
StatusOr<bool> TransposeFolding::Run(HloModule* module) {
// Modifying the graph while traversing is dangerous, so we find all folding

View File

@ -17,18 +17,19 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// HLO pass that folds transpose operators into Dot operators, where the Dot
// operator is implemented by a GEMM kernel that can transpose its inputs.
class TransposeFolding : public HloPass {
class TransposeFolding : public HloPassInterface {
public:
// IsTransposableGemmFn should return true iff the instruction argument is
// implemented as a GEMM kernel that supports transposing its arguments.
typedef std::function<bool(const HloInstruction&)> IsTransposableGemmFn;
explicit TransposeFolding(IsTransposableGemmFn is_transposable_gemm);
tensorflow::StringPiece name() const override { return "transpose-folding"; }
StatusOr<bool> Run(HloModule* module) override;