[XLA] Make HloPass
an interface, NFC
This will allow inheritance from both `HloPassInterface` and `DfsHloVisitor`, so various passes which include a visitor can have handler methods overridden per backend. Change: 145477041
This commit is contained in:
parent
cdcc0f58a2
commit
f9006d72eb
@ -576,6 +576,7 @@ cc_test(
|
||||
":algebraic_simplifier",
|
||||
":cpu_plugin",
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
@ -1006,7 +1007,10 @@ cc_test(
|
||||
|
||||
cc_library(
|
||||
name = "hlo_pass",
|
||||
hdrs = ["hlo_pass.h"],
|
||||
hdrs = [
|
||||
"hlo_pass_fix.h",
|
||||
"hlo_pass_interface.h",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
|
@ -19,12 +19,12 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A pass which performs AlgebraicSimplications.
|
||||
class AlgebraicSimplifier : public HloPass {
|
||||
class AlgebraicSimplifier : public HloPassInterface {
|
||||
public:
|
||||
// Given two shapes, determines if it is valid to bitcast between them after
|
||||
// considering platform dependent effects on layout like alignment
|
||||
@ -39,10 +39,10 @@ class AlgebraicSimplifier : public HloPass {
|
||||
// bitcasts.
|
||||
AlgebraicSimplifier(bool is_layout_sensitive,
|
||||
ValidBitcastCallback valid_bitcast_callback)
|
||||
: HloPass("algsimp"),
|
||||
is_layout_sensitive_(is_layout_sensitive),
|
||||
: is_layout_sensitive_(is_layout_sensitive),
|
||||
valid_bitcast_callback_(std::move(valid_bitcast_callback)) {}
|
||||
~AlgebraicSimplifier() override {}
|
||||
tensorflow::StringPiece name() const override { return "algsimp"; }
|
||||
|
||||
// Run algebraic simplification on the given computation. Returns whether the
|
||||
// computation was changed.
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -30,10 +30,10 @@ namespace xla {
|
||||
// constant or parameter instructions will be copied.
|
||||
// Copy insertion is necessary because constant and parameter arrays have
|
||||
// different lifetimes than computation results.
|
||||
class CopyInsertion : public HloPass {
|
||||
class CopyInsertion : public HloPassInterface {
|
||||
public:
|
||||
CopyInsertion() : HloPass("copy-insertion") {}
|
||||
~CopyInsertion() override {}
|
||||
tensorflow::StringPiece name() const override { return "copy-insertion"; }
|
||||
|
||||
// Run the pass on the given module. Returns whether the module was changed
|
||||
// (copies were inserted).
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
@ -30,10 +30,12 @@ namespace cpu {
|
||||
// called canonical convolutions). This pass expands non-canonical convolutions
|
||||
// into reshapes and canonical convolutions, so that these non-canonical
|
||||
// convolutions can run faster.
|
||||
class ConvCanonicalization : public HloPass {
|
||||
class ConvCanonicalization : public HloPassInterface {
|
||||
public:
|
||||
ConvCanonicalization() : HloPass("convolution-canonicalization") {}
|
||||
~ConvCanonicalization() override {}
|
||||
tensorflow::StringPiece name() const override {
|
||||
return "convolution-canonicalization";
|
||||
}
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
|
@ -62,7 +62,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/inliner.h"
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
@ -30,10 +30,12 @@ namespace cpu {
|
||||
// computations. However, it could make sense to coarsen the parallelization to
|
||||
// improve cache locality. Also, we will need to do something to intelligently
|
||||
// handle While constructs.
|
||||
class ParallelizationPreparation : public HloPass {
|
||||
class ParallelizationPreparation : public HloPassInterface {
|
||||
public:
|
||||
explicit ParallelizationPreparation() : HloPass("cpu-parallel-prepare") {}
|
||||
~ParallelizationPreparation() override {}
|
||||
tensorflow::StringPiece name() const override {
|
||||
return "cpu-parallel-prepare";
|
||||
}
|
||||
|
||||
// Run instruction fusion on the given computation. Returns whether the
|
||||
// computation was changed.
|
||||
|
@ -17,14 +17,17 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
class ConvolutionFolding : public HloPass {
|
||||
class ConvolutionFolding : public HloPassInterface {
|
||||
public:
|
||||
ConvolutionFolding() : HloPass("convolution-folding") {}
|
||||
tensorflow::StringPiece name() const override {
|
||||
return "convolution-folding";
|
||||
}
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
@ -32,9 +32,9 @@ namespace gpu {
|
||||
// 2) The result of merging the fusion instruction into its users would not
|
||||
// increase bytes transferred.
|
||||
//
|
||||
class FusionMerger : public HloPass {
|
||||
class FusionMerger : public HloPassInterface {
|
||||
public:
|
||||
FusionMerger() : HloPass("fusion merger") {}
|
||||
tensorflow::StringPiece name() const override { return "fusion merger"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
|
@ -48,7 +48,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_cse.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
@ -24,9 +24,9 @@ namespace gpu {
|
||||
// An HLO pass that canonicalizes convolution instructions for GPU codegen. It
|
||||
// inserts Pad instructions before Convolution instructions with uncanonicalized
|
||||
// padding, so that they can be lowered to cuDNN convolution.
|
||||
class PadInsertion : public HloPass {
|
||||
class PadInsertion : public HloPassInterface {
|
||||
public:
|
||||
PadInsertion() : HloPass("pad insertion") {}
|
||||
tensorflow::StringPiece name() const override { return "pad insertion"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CSE_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -25,13 +25,14 @@ namespace xla {
|
||||
// and identical instructions with the same operands are commoned. The pass
|
||||
// iterates over the instructions in topological order which enables the pass to
|
||||
// find arbitrarily large common expressions.
|
||||
class HloCSE : public HloPass {
|
||||
class HloCSE : public HloPassInterface {
|
||||
public:
|
||||
// If is_layout_sensitive is true, then the simplifier preserves layout during
|
||||
// transformation. Otherwise, layout is ignored.
|
||||
explicit HloCSE(bool is_layout_sensitive)
|
||||
: HloPass("cse"), is_layout_sensitive_(is_layout_sensitive) {}
|
||||
: is_layout_sensitive_(is_layout_sensitive) {}
|
||||
~HloCSE() override {}
|
||||
tensorflow::StringPiece name() const override { return "cse"; }
|
||||
|
||||
// Run CSE on the given module. Returns whether the module was changed (common
|
||||
// subexpressions were found and eliminated).
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
@ -28,10 +28,10 @@ namespace xla {
|
||||
// module. An instruction is dead if it is not reachable from the root. This
|
||||
// pass does not remove dead parameter instructions as parameter instructions
|
||||
// cannot be deleted, nor does the pass remove dead computations.
|
||||
class HloDCE : public HloPass {
|
||||
class HloDCE : public HloPassInterface {
|
||||
public:
|
||||
HloDCE() : HloPass("dce") {}
|
||||
~HloDCE() override {}
|
||||
tensorflow::StringPiece name() const override { return "dce"; }
|
||||
|
||||
// Run the pass on the given module. Returns whether the module was changed
|
||||
// (instructions were removed).
|
||||
|
@ -13,10 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_H_
|
||||
|
||||
#include <string>
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -26,25 +24,6 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Base class for HLO passes. These are used with the HloPassPipeline to
|
||||
// organize a sequence of passes.
|
||||
class HloPass {
|
||||
public:
|
||||
explicit HloPass(const string& name) : name_(name) {}
|
||||
virtual ~HloPass() {}
|
||||
|
||||
const string& name() const { return name_; }
|
||||
|
||||
// Run the pass on the given HLO module. Return whether it modified the
|
||||
// module.
|
||||
virtual StatusOr<bool> Run(HloModule* module) = 0;
|
||||
|
||||
private:
|
||||
const string name_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(HloPass);
|
||||
};
|
||||
|
||||
// Do an HLO pass to a fix point.
|
||||
template <typename Pass>
|
||||
class HloPassFix : public Pass {
|
||||
@ -65,4 +44,4 @@ class HloPassFix : public Pass {
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_
|
41
tensorflow/compiler/xla/service/hlo_pass_interface.h
Normal file
41
tensorflow/compiler/xla/service/hlo_pass_interface.h
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Base class for HLO passes. These are used with the HloPassPipeline to
|
||||
// organize a sequence of passes.
|
||||
class HloPassInterface {
|
||||
public:
|
||||
virtual ~HloPassInterface() = default;
|
||||
virtual tensorflow::StringPiece name() const = 0;
|
||||
|
||||
// Run the pass on the given HLO module. Return whether it modified the
|
||||
// module.
|
||||
virtual StatusOr<bool> Run(HloModule* module) = 0;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
|
@ -35,11 +35,12 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ',');
|
||||
tensorflow::gtl::FlatSet<string> disabled_passes(tmp.begin(), tmp.end());
|
||||
|
||||
string prefix = name() + ": pipeline start";
|
||||
string prefix = name().ToString() + ": pipeline start";
|
||||
bool changed = false;
|
||||
string message;
|
||||
for (auto& pass : passes_) {
|
||||
if (!disabled_passes.empty() && disabled_passes.count(pass->name()) > 0) {
|
||||
if (!disabled_passes.empty() &&
|
||||
disabled_passes.count(pass->name().ToString()) > 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -32,11 +32,12 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
// Pipeline of HLO passes.
|
||||
class HloPassPipeline : public HloPass {
|
||||
class HloPassPipeline : public HloPassInterface {
|
||||
public:
|
||||
explicit HloPassPipeline(const string& name,
|
||||
const Compiler::HloDumper& dumper)
|
||||
: HloPass(name), dumper_(dumper) {}
|
||||
: name_(name), dumper_(dumper) {}
|
||||
tensorflow::StringPiece name() const override { return name_; }
|
||||
|
||||
// Add a pass to the pipeline. It should be called with the arguments for the
|
||||
// pass constructor:
|
||||
@ -55,8 +56,9 @@ class HloPassPipeline : public HloPass {
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
const string name_;
|
||||
Compiler::HloDumper dumper_;
|
||||
std::vector<std::unique_ptr<HloPass>> passes_;
|
||||
std::vector<std::unique_ptr<HloPassInterface>> passes_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline);
|
||||
};
|
||||
|
@ -16,15 +16,17 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SUBCOMPUTATION_UNIFICATION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SUBCOMPUTATION_UNIFICATION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Unify subcomputations of a `HloModule`: if any computations are equal, choose
|
||||
// one arbitrarily to use and delete the others.
|
||||
class HloSubcomputationUnification : public HloPass {
|
||||
class HloSubcomputationUnification : public HloPassInterface {
|
||||
public:
|
||||
HloSubcomputationUnification() : HloPass("subcomputation unification") {}
|
||||
tensorflow::StringPiece name() const override {
|
||||
return "subcomputation-unification";
|
||||
}
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
|
@ -17,17 +17,17 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A pass which performs inlining. Which can result, for example, in functions
|
||||
// that were previously being mapped by Map instead directly applied to the
|
||||
// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)).
|
||||
class Inliner : public HloPass {
|
||||
class Inliner : public HloPassInterface {
|
||||
public:
|
||||
Inliner() : HloPass("inline") {}
|
||||
~Inliner() override = default;
|
||||
tensorflow::StringPiece name() const override { return "inline"; }
|
||||
|
||||
// Run inlining on the given computation. Returns whether the computation was
|
||||
// changed.
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace xla {
|
||||
@ -38,11 +38,12 @@ bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer);
|
||||
// with the intent that the loops which compute their values will be fused in
|
||||
// code generation. Derived classes define ShouldFuse method to select which
|
||||
// instructions to fuse.
|
||||
class InstructionFusion : public HloPass {
|
||||
class InstructionFusion : public HloPassInterface {
|
||||
public:
|
||||
explicit InstructionFusion(bool may_duplicate = true)
|
||||
: HloPass("fusion"), may_duplicate_(may_duplicate) {}
|
||||
: may_duplicate_(may_duplicate) {}
|
||||
~InstructionFusion() override {}
|
||||
tensorflow::StringPiece name() const override { return "fusion"; }
|
||||
|
||||
// Run instruction fusion on the given computation. Returns whether the
|
||||
// computation was changed (instructions were fused).
|
||||
|
@ -638,8 +638,7 @@ Status CheckLayouts(
|
||||
} // namespace
|
||||
|
||||
LayoutAssignment::LayoutAssignment(ComputationLayout* entry_computation_layout)
|
||||
: HloPass("layout-assignment"),
|
||||
entry_computation_layout_(entry_computation_layout) {
|
||||
: entry_computation_layout_(entry_computation_layout) {
|
||||
VLOG(1) << "entry computation layout given to layout assignment: "
|
||||
<< entry_computation_layout_->ToString();
|
||||
// Layouts of all parameter instructions must be set.
|
||||
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/compiler/xla/shape_layout.h"
|
||||
@ -203,12 +203,13 @@ class LayoutConstraints {
|
||||
|
||||
// HLO pass which assigns layouts to all instructions in the HLO module while
|
||||
// satisfying all necessary invariants and minimizing cost.
|
||||
class LayoutAssignment : public HloPass {
|
||||
class LayoutAssignment : public HloPassInterface {
|
||||
public:
|
||||
// entry_computation_layout is modified to populate a layout for the result in
|
||||
// the case that no particular layout is requested.
|
||||
explicit LayoutAssignment(ComputationLayout* entry_computation_layout);
|
||||
~LayoutAssignment() override {}
|
||||
tensorflow::StringPiece name() const override { return "layout-assignment"; }
|
||||
|
||||
// Assign layouts to the given module. Returns whether the module was changed
|
||||
// (any layouts were changed).
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_RESHAPE_MOVER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_RESHAPE_MOVER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -24,9 +24,9 @@ namespace xla {
|
||||
// This now only moves them outputward across elementwise ops all whose operands
|
||||
// are equivalent Reshapes or Transposes, but in future could potentially move
|
||||
// them inputward also.
|
||||
class ReshapeMover : public HloPass {
|
||||
class ReshapeMover : public HloPassInterface {
|
||||
public:
|
||||
ReshapeMover() : HloPass("reshape motion") {}
|
||||
tensorflow::StringPiece name() const override { return "reshape-motion"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
|
@ -82,8 +82,7 @@ bool FoldTransposeIntoDot(HloInstruction* dot, HloComputation* computation) {
|
||||
} // namespace
|
||||
|
||||
TransposeFolding::TransposeFolding(IsTransposableGemmFn is_transposable_gemm)
|
||||
: HloPass("transpose-folding"),
|
||||
is_transposable_gemm_(std::move(is_transposable_gemm)) {}
|
||||
: is_transposable_gemm_(std::move(is_transposable_gemm)) {}
|
||||
|
||||
StatusOr<bool> TransposeFolding::Run(HloModule* module) {
|
||||
// Modifying the graph while traversing is dangerous, so we find all folding
|
||||
|
@ -17,18 +17,19 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// HLO pass that folds transpose operators into Dot operators, where the Dot
|
||||
// operator is implemented by a GEMM kernel that can transpose its inputs.
|
||||
class TransposeFolding : public HloPass {
|
||||
class TransposeFolding : public HloPassInterface {
|
||||
public:
|
||||
// IsTransposableGemmFn should return true iff the instruction argument is
|
||||
// implemented as a GEMM kernel that supports transposing its arguments.
|
||||
typedef std::function<bool(const HloInstruction&)> IsTransposableGemmFn;
|
||||
explicit TransposeFolding(IsTransposableGemmFn is_transposable_gemm);
|
||||
tensorflow::StringPiece name() const override { return "transpose-folding"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user