Bump OSS LLVM to 82576d6fecfec71725eb900111c000d772002449

PiperOrigin-RevId: 305444589
Change-Id: I407c717e431b9b2514a5eee5ad0f4b51717e2e9a
This commit is contained in:
Benjamin Kramer 2020-04-08 03:30:48 -07:00 committed by TensorFlower Gardener
parent 7a4e9467cb
commit 7908b844de
108 changed files with 384 additions and 305 deletions

View File

@ -55,7 +55,8 @@ namespace quant {
using QuantParamsEntry = QuantizationInfo::QuantParams;
namespace {
class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
class ImportQuantStatsPass
: public PassWrapper<ImportQuantStatsPass, FunctionPass> {
public:
explicit ImportQuantStatsPass(OperationToName op_to_name)
: op_to_name_(op_to_name) {}
@ -193,7 +194,7 @@ void ImportQuantStatsPass::runOnFunction() {
}
// Creates an instance of the default quant parameters pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
OperationToName op_to_name, const std::string &stats_str) {
auto pass = absl::make_unique<ImportQuantStatsPass>(op_to_name);
if (pass->ParseQuantStats(stats_str)) return nullptr;
@ -203,7 +204,7 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
// Creates an instance pass to import quantization stats to the operations in
// the function. A custom method to get the name from the op is used because
// different dialect ops might have different ways to assign the name.
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OperationPass<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
auto get_name_func = [](Operation *op) {
Location loc = op->getLoc();

View File

@ -27,13 +27,13 @@ using OperationToName = std::function<llvm::StringRef(Operation* op)>;
// Creates an instance pass to import quantization stats to the operations in
// the function. A custom method to get the name from the op is used because
// different dialect ops might have different ways to assign the name.
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
OperationToName op_to_name, const std::string& stats_str);
// Creates an instance pass to import quantization stats to the operations in
// the function. A custom method to get the name from the op is used because
// different dialect ops might have different ways to assign the name.
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OperationPass<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string& stats_str);
} // namespace quant

View File

@ -25,7 +25,7 @@ namespace mlir {
namespace TF {
// Legalize the tf ops to the quant ops, so the quantization passes can work.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass();
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFToQuantPass();
} // namespace TF
} // namespace mlir

View File

@ -27,7 +27,7 @@ namespace TF {
namespace {
// Legalize TF quantization emulation ops to that in Quant ops dialect.
struct LegalizeTFToQuant : public FunctionPass<LegalizeTFToQuant> {
struct LegalizeTFToQuant : public PassWrapper<LegalizeTFToQuant, FunctionPass> {
explicit LegalizeTFToQuant() = default;
LegalizeTFToQuant(const LegalizeTFToQuant &) {}
@ -151,7 +151,7 @@ void LegalizeTFToQuant::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow dialect to QuantOps dialect pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFToQuantPass() {
return std::make_unique<LegalizeTFToQuant>();
}

View File

@ -315,7 +315,8 @@ llvm::SmallVector<Value, 0> FuseOps(PatternRewriter* rewriter,
return new_values;
}
struct CpuKernelFusionPass : public FunctionPass<CpuKernelFusionPass> {
struct CpuKernelFusionPass
: public PassWrapper<CpuKernelFusionPass, FunctionPass> {
explicit CpuKernelFusionPass() = default;
CpuKernelFusionPass(const CpuKernelFusionPass&) {}
@ -335,7 +336,7 @@ void CpuKernelFusionPass::runOnFunction() {
} // namespace
// Creates an instance of the xla_hlo cpu kernel fusion pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateCpuKernelFusionPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateCpuKernelFusionPass() {
return std::make_unique<CpuKernelFusionPass>();
}

View File

@ -141,7 +141,8 @@ class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
};
// Materialize the quantization results by hlo primitive ops.
struct MaterializeToXlaPass : public FunctionPass<MaterializeToXlaPass> {
struct MaterializeToXlaPass
: public PassWrapper<MaterializeToXlaPass, FunctionPass> {
explicit MaterializeToXlaPass() = default;
MaterializeToXlaPass(const MaterializeToXlaPass &) {}
@ -162,7 +163,7 @@ void MaterializeToXlaPass::runOnFunction() {
} // namespace
// Creates an instance of the xla_hlo dialect quantization propagation pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateMaterializeToXlaPass() {
return std::make_unique<MaterializeToXlaPass>();
}

View File

@ -26,10 +26,10 @@ namespace xla_hlo {
// Propagate the quantization information to all the tensors according to the
// op quant spec.
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass();
std::unique_ptr<OperationPass<FuncOp>> CreatePropagateQuantPass();
// Rewrite the graph and quantize the constant.
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass();
std::unique_ptr<OperationPass<FuncOp>> CreateMaterializeToXlaPass();
} // namespace xla_hlo
} // namespace mlir

View File

@ -50,7 +50,8 @@ namespace {
// - The quantization spec for the ops
// The propagation results should assign quantization types to all the tensors
// and the two restrictions are respected.
struct PropagateQuantPass : public FunctionPass<PropagateQuantPass> {
struct PropagateQuantPass
: public PassWrapper<PropagateQuantPass, FunctionPass> {
explicit PropagateQuantPass() = default;
PropagateQuantPass(const PropagateQuantPass &) {}
@ -96,7 +97,7 @@ void PropagateQuantPass::runOnFunction() {
} // namespace
// Creates an instance of the xla_hlo dialect quantization propagation pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass() {
std::unique_ptr<OperationPass<FuncOp>> CreatePropagateQuantPass() {
return std::make_unique<PropagateQuantPass>();
}

View File

@ -29,7 +29,7 @@ limitations under the License.
namespace mlir {
/// Create a pass to convert from the TFExecutor to the TF control dialect.
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OperationPass<FuncOp>>
CreateTFExecutorToControlDialectConversion();
} // namespace mlir

View File

@ -40,7 +40,7 @@ limitations under the License.
namespace mlir {
/// Create a pass to convert from the TFExecutor to the TF control dialect.
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OperationPass<FuncOp>>
CreateTFExecutorToControlDialectConversion();
} // namespace mlir

View File

@ -44,7 +44,8 @@ namespace TFL {
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
namespace {
class DefaultQuantParamsPass : public FunctionPass<DefaultQuantParamsPass> {
class DefaultQuantParamsPass
: public PassWrapper<DefaultQuantParamsPass, FunctionPass> {
public:
explicit DefaultQuantParamsPass(double default_min, double default_max)
: default_min_(default_min), default_max_(default_max) {}
@ -220,7 +221,7 @@ quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
}
// Creates an instance of the default quant parameters pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max) {
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max);
}

View File

@ -29,7 +29,7 @@ namespace TFL {
namespace {
struct DenseToSparse : public FunctionPass<DenseToSparse> {
struct DenseToSparse : public PassWrapper<DenseToSparse, FunctionPass> {
void runOnFunction() override;
};
@ -63,7 +63,7 @@ void DenseToSparse::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow Lite dialect DenseToSparse pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateDenseToSparsePass() {
std::unique_ptr<OperationPass<FuncOp>> CreateDenseToSparsePass() {
return absl::make_unique<DenseToSparse>();
}

View File

@ -18,7 +18,8 @@ namespace mlir {
namespace TFL {
namespace {
struct IdentifyDilatedConvPass : public FunctionPass<IdentifyDilatedConvPass> {
struct IdentifyDilatedConvPass
: public PassWrapper<IdentifyDilatedConvPass, FunctionPass> {
void runOnFunction() override;
};

View File

@ -679,7 +679,8 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
return success();
}
struct ExtractOphintPass : public OperationPass<ExtractOphintPass, ModuleOp> {
struct ExtractOphintPass
: public PassWrapper<ExtractOphintPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
void Verify();
@ -752,7 +753,7 @@ void ExtractOphintPass::Verify() {
/// Creates an instance of the TensorFlow Lite dialect ExtractOphintPass
/// pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateExtractOphintPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateExtractOphintPass() {
return std::make_unique<ExtractOphintPass>();
}

View File

@ -69,7 +69,7 @@ constexpr char kUnidirectionalSequenceLstm[] = "UnidirectionalSequenceLstm";
// |
// OutputOp1
struct LegalizeOphintFuncOpPass
: public OperationPass<LegalizeOphintFuncOpPass, ModuleOp> {
: public PassWrapper<LegalizeOphintFuncOpPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -284,7 +284,7 @@ void LegalizeOphintFuncOpPass::runOnOperation() {
/// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass
/// pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeOphintFuncOpPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeOphintFuncOpPass() {
return std::make_unique<LegalizeOphintFuncOpPass>();
}

View File

@ -70,7 +70,7 @@ constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn";
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
// Legalize operations in functions.
struct LegalizeTF : public FunctionPass<LegalizeTF> {
struct LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
void runOnFunction() override;
};
@ -763,7 +763,7 @@ void LegalizeTF::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass() {
return std::make_unique<LegalizeTF>();
}

View File

@ -31,7 +31,8 @@ namespace {
// Legalize TF While to TFL While with calls to the original functions from the
// cond and body regions.
struct LegalizeWhile : public OperationPass<LegalizeWhile, ModuleOp> {
struct LegalizeWhile
: public PassWrapper<LegalizeWhile, OperationPass<ModuleOp>> {
void RunOnFunction(FuncOp func);
void runOnOperation() override {
@ -76,7 +77,7 @@ void LegalizeWhile::RunOnFunction(FuncOp func) {
}
// Creates an instance of the TensorFlow While to TFLite While pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeTFWhilePass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFWhilePass() {
return std::make_unique<LegalizeWhile>();
}

View File

@ -42,7 +42,8 @@ namespace {
// AnyQuantizedType, thus bitwidth, narrow_range, etc are included. The op also
// defines the op quantization traits, which are used to propagate the
// quantization parameters by the following passes.
struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
struct LoadQuantizationRecipe
: public PassWrapper<LoadQuantizationRecipe, FunctionPass> {
void runOnFunction() override;
private:
@ -215,7 +216,7 @@ void LoadQuantizationRecipe::runOnFunction() {
// Creates an instance of the TensorFlow Lite dialect LoadQuantizationRecipe
// pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateLoadQuantizationRecipePass() {
std::unique_ptr<OperationPass<FuncOp>> CreateLoadQuantizationRecipePass() {
return absl::make_unique<LoadQuantizationRecipe>();
}

View File

@ -82,7 +82,7 @@ class TensorListPatternRewriter : public PatternRewriter {
/// Lower TensorList ops in functions for subsequent legalization.
struct LowerStaticTensorListPass
: public OperationPass<LowerStaticTensorListPass, ModuleOp> {
: public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
// Apply type and op changes within a function.
@ -906,7 +906,8 @@ void LowerStaticTensorListPass::runOnOperation() {
/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
/// pass.
std::unique_ptr<OpPassBase<ModuleOp>> TFL::CreateLowerStaticTensorListPass() {
std::unique_ptr<OperationPass<ModuleOp>>
TFL::CreateLowerStaticTensorListPass() {
return std::make_unique<LowerStaticTensorListPass>();
}

View File

@ -74,7 +74,7 @@ bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
using ::llvm::cast;
// Optimize TFLite operations in functions.
struct Optimize : public FunctionPass<Optimize> {
struct Optimize : public PassWrapper<Optimize, FunctionPass> {
void runOnFunction() override;
};
@ -725,7 +725,7 @@ void Optimize::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateOptimizePass() {
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass() {
return std::make_unique<Optimize>();
}

View File

@ -36,7 +36,7 @@ using FuncSet = llvm::SmallSet<FuncOp, 4>;
// Module pass to optimize TensorFlow functional ops.
struct OptimizeFunctionalOpsPass
: public OperationPass<OptimizeFunctionalOpsPass, ModuleOp> {
: public PassWrapper<OptimizeFunctionalOpsPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -198,7 +198,7 @@ void OptimizeFunctionalOpsPass::runOnOperation() {
}
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
return std::make_unique<OptimizeFunctionalOpsPass>();
}

View File

@ -24,75 +24,75 @@ namespace mlir {
class FuncOp;
class ModuleOp;
template <typename T>
class OpPassBase;
class OperationPass;
namespace TFL {
class QuantizationSpecs;
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFPass();
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass();
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateOptimizePass();
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareTFPass(
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
bool unfold_batch_matmul);
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
// pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateLowerStaticTensorListPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass();
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateQuantizePass();
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass();
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareQuantizePass(
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(
const QuantizationSpecs& quant_specs);
// Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePostQuantizePass(
std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
bool emit_quant_adaptor_ops);
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
// pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTrimFunctionsPass(
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
llvm::ArrayRef<std::string> trim_funcs_whitelist);
// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions
// pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreatePrepareCompositeFunctionsPass();
std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareCompositeFunctionsPass();
// Creates an instance of the TensorFlow Lite dialect ExtractOphint pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateExtractOphintPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateExtractOphintPass();
// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass
// pass. The composite op is created from the ophint extraction pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeOphintFuncOpPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeOphintFuncOpPass();
// Creates an instance of the TensorFlow Lite dialect SplitMergedOperandsPass.
std::unique_ptr<OpPassBase<FuncOp>> CreateSplitMergedOperandsPass();
std::unique_ptr<OperationPass<FuncOp>> CreateSplitMergedOperandsPass();
// Creates an instance of the TensorFlow Lite dialect OptimizeFunctionalOpsPass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeFunctionalOpsPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass();
// Creates an instance of the TensorFlow Lite dialect pass to add default
// quantization parameters.
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max);
// Creates an instance of the TensorFlow Lite dialect pass to convert dense
// tensor to sparse format.
std::unique_ptr<OpPassBase<FuncOp>> CreateDenseToSparsePass();
std::unique_ptr<OperationPass<FuncOp>> CreateDenseToSparsePass();
// Creates function pass to legalize TF While to TFL While.
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeTFWhilePass();
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFWhilePass();
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass();
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
// Verifies runtime supports types used.
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass();
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeTypeVerifyPass();
} // namespace TFL

View File

@ -30,7 +30,7 @@ namespace TFL {
namespace {
// Applies all the clean up steps after quantization.
class PostQuantizePass : public FunctionPass<PostQuantizePass> {
class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
public:
// Constructor used by the PassRegistration. This will remove the adaptor ops.
explicit PostQuantizePass() : emit_quant_adaptor_ops_(false) {}
@ -135,7 +135,7 @@ void PostQuantizePass::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePostQuantizePass(
std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
bool emit_quant_adaptor_ops) {
return std::make_unique<PostQuantizePass>(emit_quant_adaptor_ops);
}

View File

@ -94,7 +94,8 @@ class ConvertEmbeddedLookupFunc {
// body with the corresponding fused TFLite op. The replacement need not always
// be a fused op, though that is the primary use case.
class PrepareCompositeFunctionsPass
: public OperationPass<PrepareCompositeFunctionsPass, ModuleOp> {
: public PassWrapper<PrepareCompositeFunctionsPass,
OperationPass<ModuleOp>> {
public:
explicit PrepareCompositeFunctionsPass() {}
@ -211,7 +212,7 @@ void PrepareCompositeFunctionsPass::runOnOperation() {
}
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
return std::make_unique<PrepareCompositeFunctionsPass>();
}

View File

@ -66,7 +66,8 @@ namespace {
// across ops. This step is necessary for post-training quantization and also
// making the quantization rule for some operations in the quantization-aware
// training quantization simpler.
class PrepareQuantizePass : public FunctionPass<PrepareQuantizePass> {
class PrepareQuantizePass
: public PassWrapper<PrepareQuantizePass, FunctionPass> {
public:
// Constructor used by the PassRegistration and enforce uint8 quantization.
explicit PrepareQuantizePass() {
@ -281,7 +282,7 @@ void PrepareQuantizePass::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareQuantizePass(
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(
const QuantizationSpecs& quant_specs) {
return std::make_unique<PrepareQuantizePass>(quant_specs);
}

View File

@ -71,7 +71,7 @@ namespace TFL {
namespace {
// Prepare TF operations in functions for subsequent legalization.
class PrepareTFPass : public FunctionPass<PrepareTFPass> {
class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
public:
explicit PrepareTFPass() : unfold_batch_matmul_(true) {}
explicit PrepareTFPass(bool unfold_batch_matmul)
@ -652,7 +652,7 @@ void PrepareTFPass::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareTFPass(
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
bool unfold_batch_matmul) {
return std::make_unique<PrepareTFPass>(unfold_batch_matmul);
}

View File

@ -75,7 +75,7 @@ struct TFLFullQuantization
};
// Applies quantization on the model in TFL dialect.
struct QuantizePass : public FunctionPass<QuantizePass> {
struct QuantizePass : public PassWrapper<QuantizePass, FunctionPass> {
void runOnFunction() override;
};
@ -93,7 +93,7 @@ void QuantizePass::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateQuantizePass() {
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass() {
return std::make_unique<QuantizePass>();
}

View File

@ -24,7 +24,8 @@ namespace {
// This pass verifies that the operands and results types are supported by
// TFLite runtime.
class RuntimeTypeVerifyPass : public mlir::FunctionPass<RuntimeTypeVerifyPass> {
class RuntimeTypeVerifyPass
: public mlir::PassWrapper<RuntimeTypeVerifyPass, FunctionPass> {
public:
explicit RuntimeTypeVerifyPass() {}
@ -43,7 +44,7 @@ void RuntimeTypeVerifyPass::runOnFunction() {
} // namespace
// Verifies runtime supports types used.
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeTypeVerifyPass() {
return std::make_unique<RuntimeTypeVerifyPass>();
}

View File

@ -66,7 +66,8 @@ namespace mlir {
namespace TFL {
namespace {
struct SplitMergedOperandsPass : public FunctionPass<SplitMergedOperandsPass> {
struct SplitMergedOperandsPass
: public PassWrapper<SplitMergedOperandsPass, FunctionPass> {
void runOnFunction() override;
};
@ -119,7 +120,7 @@ void SplitMergedOperandsPass::runOnFunction() {
/// Creates an instance of the TensorFlow Lite dialect SplitMergedOperands
/// pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateSplitMergedOperandsPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateSplitMergedOperandsPass() {
return std::make_unique<SplitMergedOperandsPass>();
}

View File

@ -45,7 +45,7 @@ namespace {
// The pass to trim functions before we legalize to TFL
// dialect using the specified whitelist.
class TrimFunctionsPass
: public mlir::OperationPass<TrimFunctionsPass, ModuleOp> {
: public mlir::PassWrapper<TrimFunctionsPass, OperationPass<ModuleOp>> {
public:
explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {}
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_whitelist)
@ -120,7 +120,7 @@ void TrimFunctionsPass::Verify() {
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
/// pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTrimFunctionsPass(
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
llvm::ArrayRef<std::string> trim_funcs_whitelist) {
return std::make_unique<TrimFunctionsPass>(trim_funcs_whitelist);
}

View File

@ -38,7 +38,7 @@ namespace {
// This pass outlines the cond/body region of the TFL WhileOp into functions and
// replaces the regions with calls to these outlined functions.
class WhileOutlinePass
: public mlir::OperationPass<WhileOutlinePass, ModuleOp> {
: public mlir::PassWrapper<WhileOutlinePass, OperationPass<ModuleOp>> {
public:
explicit WhileOutlinePass() {}
@ -241,7 +241,7 @@ void WhileOutlinePass::runOnOperation() {
}
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {
return std::make_unique<WhileOutlinePass>();
}

View File

@ -39,7 +39,8 @@ constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
// Analyzes the inputs to LaunchFuncOps in the module, and annotates their
// invoked functions whether each input has the same data across replicas.
struct AnnotateParameterReplication
: public OperationPass<AnnotateParameterReplication, ModuleOp> {
: public PassWrapper<AnnotateParameterReplication,
OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -90,7 +91,8 @@ void AnnotateParameterReplication::runOnOperation() {
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreateAnnotateParameterReplicationPass() {
std::unique_ptr<OperationPass<ModuleOp>>
CreateAnnotateParameterReplicationPass() {
return std::make_unique<AnnotateParameterReplication>();
}

View File

@ -43,7 +43,8 @@ namespace TF {
namespace {
// Replace TF BatchMatMul by TF Einsum
struct BatchMatMulToEinsumPass : public FunctionPass<BatchMatMulToEinsumPass> {
struct BatchMatMulToEinsumPass
: public PassWrapper<BatchMatMulToEinsumPass, FunctionPass> {
void runOnFunction() override;
};
@ -117,7 +118,7 @@ static PassRegistration<BatchMatMulToEinsumPass> pass(
"tf-batch-matmul-to-tf-einsum",
"Replace TF BatchMatMul op by TF Einsum op.");
std::unique_ptr<OpPassBase<FuncOp>> CreateBatchMatMulToEinsumPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateBatchMatMulToEinsumPass() {
return std::make_unique<BatchMatMulToEinsumPass>();
}

View File

@ -37,7 +37,8 @@ namespace TFDevice {
namespace {
struct ClusterFormationPass : public FunctionPass<ClusterFormationPass> {
struct ClusterFormationPass
: public PassWrapper<ClusterFormationPass, FunctionPass> {
void runOnFunction() override;
};
@ -229,7 +230,7 @@ void ClusterFormationPass::runOnFunction() {
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateClusterFormationPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateClusterFormationPass() {
return std::make_unique<ClusterFormationPass>();
}

View File

@ -39,7 +39,7 @@ constexpr char kDeviceAttr[] = "device";
constexpr char kFuncAttr[] = "func";
struct ClusterOutliningPass
: public OperationPass<ClusterOutliningPass, ModuleOp> {
: public PassWrapper<ClusterOutliningPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -132,7 +132,7 @@ void ClusterOutliningPass::runOnOperation() {
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreateClusterOutliningPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateClusterOutliningPass() {
return std::make_unique<ClusterOutliningPass>();
}

View File

@ -52,7 +52,7 @@ bool DecodeOpaqueValueInConstantOp(Operation *op) {
}
// A pass to decode opaque constant values into readable ones.
struct DecodeConstant : public FunctionPass<DecodeConstant> {
struct DecodeConstant : public PassWrapper<DecodeConstant, FunctionPass> {
void runOnFunction() override {
auto walk_result = getFunction().walk([](Operation *op) {
return DecodeOpaqueValueInConstantOp(op) ? WalkResult::advance()
@ -64,7 +64,7 @@ struct DecodeConstant : public FunctionPass<DecodeConstant> {
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateDecodeConstantPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateDecodeConstantPass() {
return std::make_unique<DecodeConstant>();
}

View File

@ -23,7 +23,7 @@ namespace TF {
// Creates a pass to decode and reset opaque values in constant ops into
// readable values.
// Note that this pass assumes RaiseTFControlFlow pass has already been run.
std::unique_ptr<OpPassBase<FuncOp>> CreateDecodeConstantPass();
std::unique_ptr<OperationPass<FuncOp>> CreateDecodeConstantPass();
} // namespace TF
} // namespace mlir

View File

@ -38,7 +38,8 @@ namespace {
// NOTE: This pass does not support `use_locking=true` for a lot of resource
// operations. So decomposition may not be correct outside of backends like XLA,
// which automatically locks all resource variables.
struct DecomposeResourceOps : public FunctionPass<DecomposeResourceOps> {
struct DecomposeResourceOps
: public PassWrapper<DecomposeResourceOps, FunctionPass> {
void runOnFunction() override {
// Add lowering patterns to the list.
OwningRewritePatternList patterns;
@ -50,7 +51,7 @@ struct DecomposeResourceOps : public FunctionPass<DecomposeResourceOps> {
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateDecomposeResourceOpsPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateDecomposeResourceOpsPass() {
return std::make_unique<DecomposeResourceOps>();
}

View File

@ -354,7 +354,8 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite(
}
// Transform Einsum to other TF Ops for the supported variants.
struct TransformEinsumPass : public FunctionPass<TransformEinsumPass> {
struct TransformEinsumPass
: public PassWrapper<TransformEinsumPass, FunctionPass> {
void runOnFunction() override;
};

View File

@ -57,7 +57,7 @@ struct IslandResult {
};
struct ExecutorIslandCoarsening
: public FunctionPass<ExecutorIslandCoarsening> {
: public PassWrapper<ExecutorIslandCoarsening, FunctionPass> {
void runOnFunction() override;
};
@ -346,7 +346,7 @@ void ExecutorIslandCoarsening::runOnFunction() {
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorIslandCoarseningPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorIslandCoarseningPass() {
return std::make_unique<ExecutorIslandCoarsening>();
}

View File

@ -43,7 +43,8 @@ constexpr llvm::StringRef kNestedModule = "_tpu_v1_compat_outlined";
// Inlining the islands calling into the nested module that was outlined.
// This is the end of the TPU bridge in V1 compatibility mode.
struct TPUBridgeExecutorIslandInlining
: public OperationPass<TPUBridgeExecutorIslandInlining, ModuleOp> {
: public PassWrapper<TPUBridgeExecutorIslandInlining,
OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -95,7 +96,7 @@ PassRegistration<TPUBridgeExecutorIslandInlining> tpu_pass(
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>>
std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorTPUV1IslandInliningPass() {
return std::make_unique<TPUBridgeExecutorIslandInlining>();
}

View File

@ -59,7 +59,8 @@ constexpr llvm::StringRef kTpuStatusAttr = "_tpu_compilation_status";
// TPU-annotated operations and intended to preserve backward compatibility with
// TFv1.
struct TpuV1BridgeExecutorIslandCoarsening
: public OperationPass<TpuV1BridgeExecutorIslandCoarsening, ModuleOp> {
: public PassWrapper<TpuV1BridgeExecutorIslandCoarsening,
OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -322,7 +323,7 @@ void TpuV1BridgeExecutorIslandCoarsening::runOnOperation() {
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>>
std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorTPUV1IslandCoarseningPass() {
return std::make_unique<TpuV1BridgeExecutorIslandCoarsening>();
}

View File

@ -44,7 +44,8 @@ constexpr llvm::StringRef kOutlinedFuncPrefix = "_tpu_v1_compat_outlined_func";
// This is only intended for V1 compatibility mode where the bridge runs without
// feed/fetches on session create/extend.
struct TPUBridgeExecutorIslandOutlining
: public OperationPass<TPUBridgeExecutorIslandOutlining, ModuleOp> {
: public PassWrapper<TPUBridgeExecutorIslandOutlining,
OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -160,7 +161,7 @@ PassRegistration<TPUBridgeExecutorIslandOutlining> tpu_pass(
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>>
std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorTPUV1IslandOutliningPass() {
return std::make_unique<TPUBridgeExecutorIslandOutlining>();
}

View File

@ -58,7 +58,7 @@ limitations under the License.
namespace mlir {
namespace {
class SwitchFoldPass : public mlir::FunctionPass<SwitchFoldPass> {
class SwitchFoldPass : public mlir::PassWrapper<SwitchFoldPass, FunctionPass> {
public:
void runOnFunction() override;
};
@ -279,7 +279,7 @@ void SwitchFoldPass::runOnFunction() {
} // namespace mlir
namespace tf_executor {
std::unique_ptr<OpPassBase<FuncOp>> CreateSwitchFoldPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateSwitchFoldPass() {
return std::make_unique<SwitchFoldPass>();
}
} // namespace tf_executor

View File

@ -42,7 +42,7 @@ namespace {
// support resources/variables . Further, this contract also ensures that this
// pass lowers from saved model to pure TF. Hence it fails, if it cannot lower.
struct FreezeGlobalTensorsPass
: public OperationPass<FreezeGlobalTensorsPass, ModuleOp> {
: public PassWrapper<FreezeGlobalTensorsPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -113,7 +113,7 @@ static PassRegistration<FreezeGlobalTensorsPass> pass(
"tf-saved-model-freeze-global-tensors",
"Freeze tf_saved_model.global_tensor's in func bodies.");
std::unique_ptr<OpPassBase<ModuleOp>> CreateFreezeGlobalTensorsPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass() {
return std::make_unique<FreezeGlobalTensorsPass>();
}

View File

@ -34,7 +34,7 @@ namespace TF {
namespace {
struct FunctionalControlFlowToCFG
: public FunctionPass<FunctionalControlFlowToCFG> {
: public PassWrapper<FunctionalControlFlowToCFG, FunctionPass> {
void runOnFunction() override;
};
@ -312,7 +312,7 @@ void FunctionalControlFlowToCFG::runOnFunction() {
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateTFFunctionalControlFlowToCFG() {
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG() {
return std::make_unique<FunctionalControlFlowToCFG>();
}

View File

@ -35,7 +35,7 @@ namespace {
// GpuOpFusionPass is a pass performing fusion specific to GPU targets.
// This is an ad-hoc pass for now, but should be integrated with some notion
// of "target" in the MLIR pipeline in the future.
class GpuOpFusionPass : public FunctionPass<GpuOpFusionPass> {
class GpuOpFusionPass : public PassWrapper<GpuOpFusionPass, FunctionPass> {
public:
void runOnFunction() final;
};
@ -123,7 +123,7 @@ void GpuOpFusionPass::runOnFunction() {
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateGpuOpFusionPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateGpuOpFusionPass() {
return std::make_unique<GpuOpFusionPass>();
}

View File

@ -84,7 +84,7 @@ void PruneGraph(GraphOp graph) {
namespace {
// This transformation pass prunes a TF graph eliminating dead-nodes.
struct GraphPruning : public FunctionPass<GraphPruning> {
struct GraphPruning : public PassWrapper<GraphPruning, FunctionPass> {
void runOnFunction() override {
getFunction().walk([](tf_executor::GraphOp graph) {
// For TensorFlow V1.0 compatibility: when importing a graph without
@ -100,7 +100,7 @@ struct GraphPruning : public FunctionPass<GraphPruning> {
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorGraphPruningPass() {
return std::make_unique<GraphPruning>();
}

View File

@ -57,7 +57,7 @@ namespace {
constexpr char kDeviceAttr[] = "device";
struct LaunchToDeviceAttributePass
: public FunctionPass<LaunchToDeviceAttributePass> {
: public PassWrapper<LaunchToDeviceAttributePass, FunctionPass> {
void runOnFunction() override;
};
@ -122,7 +122,7 @@ void LaunchToDeviceAttributePass::runOnFunction() {
} // anonymous namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateLaunchToDeviceAttributePass() {
std::unique_ptr<OperationPass<FuncOp>> CreateLaunchToDeviceAttributePass() {
return std::make_unique<LaunchToDeviceAttributePass>();
}

View File

@ -36,7 +36,8 @@ namespace {
// LayoutAssignmentPass assigns optimal data layout (data format) for all
// layout sensitive operations.
class LayoutAssignmentPass : public FunctionPass<LayoutAssignmentPass> {
class LayoutAssignmentPass
: public PassWrapper<LayoutAssignmentPass, FunctionPass> {
public:
LayoutAssignmentPass() = default;
explicit LayoutAssignmentPass(const std::string& force_data_format) {
@ -57,7 +58,8 @@ class LayoutAssignmentPass : public FunctionPass<LayoutAssignmentPass> {
// MoveTransposesPass moves all Transpose ops to the beginning or to the end of
// the basic block where they are defined. This will allow canonicalzer to
// delete redundant transposes.
class MoveTransposesPass : public FunctionPass<MoveTransposesPass> {
class MoveTransposesPass
: public PassWrapper<MoveTransposesPass, FunctionPass> {
public:
enum class Direction { kBegin, kEnd };

View File

@ -31,7 +31,7 @@ namespace mlir {
namespace TF {
namespace {
class LegalizeHloToTf : public FunctionPass<LegalizeHloToTf> {
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
public:
LegalizeHloToTf() = default;
LegalizeHloToTf(const LegalizeHloToTf &) {}
@ -76,7 +76,7 @@ static PassRegistration<LegalizeHloToTf> pass(
} // end namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeHloToTfPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() {
return std::make_unique<LegalizeHloToTf>();
}

View File

@ -23,7 +23,7 @@ namespace {
// Lowers some of the TensorFlow operations that can be represented using other
// TensorFlow operations.
struct LowerTF : public FunctionPass<LowerTF> {
struct LowerTF : public PassWrapper<LowerTF, FunctionPass> {
void runOnFunction() override {
// Add lowering patterns to the list.
OwningRewritePatternList patterns;

View File

@ -74,8 +74,9 @@ LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification(
namespace {
struct MarkFunctionVisibilityUsingEntryFunctionSpecificationPass
: public OperationPass<
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass, ModuleOp> {
: public PassWrapper<
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass,
OperationPass<ModuleOp>> {
void runOnOperation() override {
if (failed(MarkFunctionVisibilityUsingEntryFunctionSpecification(
getOperation()))) {
@ -90,7 +91,7 @@ static PassRegistration<
pass("tf-mark-func-visibility",
"Use tf.entry_function to mark function visibility.");
std::unique_ptr<OpPassBase<ModuleOp>>
std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass() {
return std::make_unique<
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass>();
@ -110,8 +111,8 @@ static LogicalResult MarkFunctionVisibilityUsingSavedModelLinkage(
namespace {
struct MarkFunctionVisibilityUsingSavedModelLinkagePass
: public OperationPass<MarkFunctionVisibilityUsingSavedModelLinkagePass,
ModuleOp> {
: public PassWrapper<MarkFunctionVisibilityUsingSavedModelLinkagePass,
OperationPass<ModuleOp>> {
void runOnOperation() override {
if (failed(MarkFunctionVisibilityUsingSavedModelLinkage(getOperation()))) {
signalPassFailure();
@ -124,7 +125,7 @@ static PassRegistration<MarkFunctionVisibilityUsingSavedModelLinkagePass> pass(
"tf-saved-model-mark-func-visibility",
"Use tf_saved_model linkage information to mark function visibility.");
std::unique_ptr<OpPassBase<ModuleOp>>
std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass() {
return std::make_unique<MarkFunctionVisibilityUsingSavedModelLinkagePass>();
}

View File

@ -35,7 +35,7 @@ namespace mlir {
namespace {
class MaterializePassthroughOpPass
: public FunctionPass<MaterializePassthroughOpPass> {
: public PassWrapper<MaterializePassthroughOpPass, FunctionPass> {
public:
void runOnFunction() override;
};
@ -96,7 +96,7 @@ void MaterializePassthroughOpPass::runOnFunction() {
} // namespace
namespace TF {
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializePassthroughOpPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateMaterializePassthroughOpPass() {
return std::make_unique<MaterializePassthroughOpPass>();
}
} // namespace TF

View File

@ -33,7 +33,7 @@ namespace {
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_optimize.inc"
// Canonicalize operations in functions.
struct TFOptimizePass : public FunctionPass<TFOptimizePass> {
struct TFOptimizePass : public PassWrapper<TFOptimizePass, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
auto func = getFunction();
@ -71,7 +71,7 @@ void CreateTFStandardPipeline(OpPassManager &pm,
pm.addNestedPass<FuncOp>(createCSEPass());
}
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass() {
std::unique_ptr<OperationPass<FuncOp>> CreateTFOptimizePass() {
return std::make_unique<TFOptimizePass>();
}

View File

@ -41,7 +41,7 @@ namespace mlir {
namespace tf_saved_model {
namespace {
struct OptimizeGlobalTensorsPass
: public OperationPass<OptimizeGlobalTensorsPass, ModuleOp> {
: public PassWrapper<OptimizeGlobalTensorsPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -296,7 +296,7 @@ static PassRegistration<OptimizeGlobalTensorsPass> pass(
"tf-saved-model-optimize-global-tensors",
"Optimize tf_saved_model.global_tensor's.");
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeGlobalTensorsPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass() {
return std::make_unique<OptimizeGlobalTensorsPass>();
}

View File

@ -83,7 +83,7 @@ namespace TFDevice {
namespace {
struct ParallelExecuteToIslandsPass
: public FunctionPass<ParallelExecuteToIslandsPass> {
: public PassWrapper<ParallelExecuteToIslandsPass, FunctionPass> {
void runOnFunction() override;
};
@ -251,7 +251,7 @@ void ParallelExecuteToIslandsPass::runOnFunction() {
}
} // anonymous namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateParallelExecuteToIslandsPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateParallelExecuteToIslandsPass() {
return std::make_unique<ParallelExecuteToIslandsPass>();
}

View File

@ -24,36 +24,36 @@ namespace mlir {
// Creates a pass that breaks up an island with multiple ops into multiple
// islands, each with a single op.
std::unique_ptr<OpPassBase<FuncOp>> CreateBreakUpIslandsPass();
std::unique_ptr<OperationPass<FuncOp>> CreateBreakUpIslandsPass();
// Creates a pass that converts mlir functions consisting of mlir ops into a
// tf_executor dialect as a single island.
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OperationPass<FuncOp>>
CreateFunctionalToExecutorDialectConversionPass();
namespace TF {
// Transforms functional control flow operations in the standard TensorFlow
// dialect to MLIR Control Flow Graph (CFG) form.
std::unique_ptr<OpPassBase<FuncOp>> CreateTFFunctionalControlFlowToCFG();
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG();
// Materialize the MlirPassthroughOp by replacing it with the MLIR module
// attached as an attribute.
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializePassthroughOpPass();
std::unique_ptr<OperationPass<FuncOp>> CreateMaterializePassthroughOpPass();
// Performs Shape Inference on the TensorFlow dialect using the global registry.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTFShapeInferencePass();
std::unique_ptr<OperationPass<ModuleOp>> CreateTFShapeInferencePass();
// Optional pass which will unroll BatchMatMul and use only MatMul
std::unique_ptr<OpPassBase<FuncOp>> CreateUnrollBatchMatMulPassPass();
std::unique_ptr<OperationPass<FuncOp>> CreateUnrollBatchMatMulPassPass();
// Optional pass which will map TF BatchMatMul to TF Einsum
std::unique_ptr<OpPassBase<FuncOp>> CreateBatchMatMulToEinsumPass();
std::unique_ptr<OperationPass<FuncOp>> CreateBatchMatMulToEinsumPass();
// Optimizes Tensorflow graph.
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass();
std::unique_ptr<OperationPass<FuncOp>> CreateTFOptimizePass();
// Performs specific fusion for GPU targets.
std::unique_ptr<OpPassBase<FuncOp>> CreateGpuOpFusionPass();
std::unique_ptr<OperationPass<FuncOp>> CreateGpuOpFusionPass();
struct LayoutOptimizationPipelineOptions
: public PassPipelineOptions<LayoutOptimizationPipelineOptions> {
@ -82,14 +82,14 @@ void CreateTFStandardPipeline(OpPassManager& pm,
const StandardPipelineOptions& options);
// Propagates device attributes of resources from callers to callees.
std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceDeviceInferencePass();
std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass();
// Creates a pass that promotes resource reads/writes in the main function to
// inputs and outputs of the main function, assuming that resource operations
// have already been decomposed and function calls have already been inlined.
// The pass also annotates the input arguments for resources with the indices
// of their aliasing output arguments.
std::unique_ptr<OpPassBase<ModuleOp>> CreatePromoteResourcesToArgsPass();
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass();
// Marks function visibility using tf.entry_function specification. That is,
// functions with tf.entry_function attributes are marked with public
@ -98,11 +98,11 @@ LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification(
ModuleOp module);
// Creates a pass that uses tf.entry_function specification to mark function
// visibility.
std::unique_ptr<OpPassBase<ModuleOp>>
std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass();
// Creates a simple device assignment pass on TF dialect for CoreRT use case.
std::unique_ptr<OpPassBase<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
std::unique_ptr<OperationPass<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
llvm::StringRef default_device);
// Performs resource lifting on the function body to hoist resource variable
@ -112,25 +112,26 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function);
// Converts stack ops into operations on local variables, which can later be
// removed by resource lifting. Requires known maximum sizes of stacks and
// known element shapes of push ops.
std::unique_ptr<OpPassBase<ModuleOp>> CreateStackOpsDecompositionPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateStackOpsDecompositionPass();
// Converts tensor list operations into operations on buffers and sizes. Needs
// static shapes and known max element count.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTensorListOpsDecompositionPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateTensorListOpsDecompositionPass();
// Converts tensor array ops into operations on local variables, which can later
// be removed by resource lifting. Requires known sizes and known element shapes
// (either defined in TensorArrayV3 or implied in the first write).
std::unique_ptr<OpPassBase<ModuleOp>> CreateTensorArrayOpsDecompositionPass();
std::unique_ptr<OperationPass<ModuleOp>>
CreateTensorArrayOpsDecompositionPass();
// Create a pass that legalize HLO to TF dialect.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeHloToTfPass();
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
} // namespace TF
namespace TFControlFlow {
// Raises from the "TensorFlow Control Flow" dialect to the standard TensorFlow
// dialect.
std::unique_ptr<OpPassBase<FuncOp>> CreateRaiseTFControlFlowPass();
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseTFControlFlowPass();
} // namespace TFControlFlow
@ -138,29 +139,30 @@ namespace tf_executor {
class GraphOp;
// Returns a pass that folds switch nodes with constant predicates.
std::unique_ptr<OpPassBase<FuncOp>> CreateSwitchFoldPass();
std::unique_ptr<OperationPass<FuncOp>> CreateSwitchFoldPass();
// Creates a pass to merge IslandOps from TFExecutor dialect.
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorIslandCoarseningPass();
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorIslandCoarseningPass();
// Creates a pass to merge IslandOps for operation marked for execution on TPU.
// This is a V1 backward compatibility.
std::unique_ptr<OpPassBase<ModuleOp>>
std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorTPUV1IslandCoarseningPass();
// Creates a pass to outlining TPU clusters from single IslandOp into a nested
// module suitable for being processed as-if it was a V2 module.
// This is a V1 backward compatibility.
std::unique_ptr<OpPassBase<ModuleOp>>
std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorTPUV1IslandOutliningPass();
// Creates a pass to inline calls to the nested TPU module, this reverses the
// effect of the `TFExecutorTPUV1IslandOutlining` pass above.
// This is a V1 backward compatibility.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTFExecutorTPUV1IslandInliningPass();
std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorTPUV1IslandInliningPass();
// Creates a pass to prune tf_executor.graph from dead nodes.
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass();
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorGraphPruningPass();
// Prunes unreachable operations of a tf_executor.graph operation.
void PruneGraph(GraphOp graph);
@ -168,29 +170,29 @@ void PruneGraph(GraphOp graph);
// Sink `tf.Const` operations in the LaunchOp region using them. This is
// performed in order to limit the number of values implicitly captured in this
// region before outlining.
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorConstantSinkingPass();
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorConstantSinkingPass();
} // namespace tf_executor
namespace TFDevice {
// Creates a pass that forms clusters from instructions that are assigned to
// same device.
std::unique_ptr<OpPassBase<FuncOp>> CreateClusterFormationPass();
std::unique_ptr<OperationPass<FuncOp>> CreateClusterFormationPass();
// Creates a pass that outlines regions of tf_device.launch operations.
std::unique_ptr<OpPassBase<ModuleOp>> CreateClusterOutliningPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateClusterOutliningPass();
// A pass that decomposes composite resource operations into primitive ones like
// ReadVariableOp, AssignVariableOp and other computations to facilitate
// transformations like resource op lifting.
std::unique_ptr<OpPassBase<FuncOp>> CreateDecomposeResourceOpsPass();
std::unique_ptr<OperationPass<FuncOp>> CreateDecomposeResourceOpsPass();
// Creates a pass that lifts operations on external resource variables from
// device computation nested in `tf_device::LaunchOp` out so that resource
// variable load operations are all before device computation while resource
// variable store operations are all after device computation. After this pass,
// device computation no longer interacts with external resource variables.
std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceOpLiftingPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateResourceOpLiftingPass();
// Lifts resource operations from tf_device.launch_func ops nested in `op`
// outside. Returns a failure if there are remaining resource-type values that
@ -198,55 +200,56 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceOpLiftingPass();
LogicalResult LiftResourceOps(Operation* op);
// Creates a pass that hoists invariant operations in a `tf_device.replicate`.
std::unique_ptr<OpPassBase<FuncOp>> CreateReplicateInvariantOpHoistingPass();
std::unique_ptr<OperationPass<FuncOp>> CreateReplicateInvariantOpHoistingPass();
// Creates a pass that forms replica `tf_executor.island` from a single
// `tf_device.replicate` island.
std::unique_ptr<OpPassBase<FuncOp>> CreateReplicateToIslandPass();
std::unique_ptr<OperationPass<FuncOp>> CreateReplicateToIslandPass();
// Creates a pass that creates `tf_executor.island` from a single
// `tf_device.parallel_execute` island.
std::unique_ptr<OpPassBase<FuncOp>> CreateParallelExecuteToIslandsPass();
std::unique_ptr<OperationPass<FuncOp>> CreateParallelExecuteToIslandsPass();
// Creates a pass that annotates whether a LaunchFuncOp's parameters have the
// same data across replicas.
std::unique_ptr<OpPassBase<ModuleOp>> CreateAnnotateParameterReplicationPass();
std::unique_ptr<OperationPass<ModuleOp>>
CreateAnnotateParameterReplicationPass();
// Creates a pass that hoists a `tf_device.launch` body and assigns a `device`
// attribute to each TensorFlow dialect op in the body based on the `device`
// attribute on the `tf_device.launch`.
std::unique_ptr<OpPassBase<FuncOp>> CreateLaunchToDeviceAttributePass();
std::unique_ptr<OperationPass<FuncOp>> CreateLaunchToDeviceAttributePass();
} // namespace TFDevice
namespace TFTPU {
// Creates a pass that forms clusters from operations of the same
// `_tpu_replicate` attribute.
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUClusterFormationPass();
std::unique_ptr<OperationPass<FuncOp>> CreateTPUClusterFormationPass();
// Creates a pass that allows TPU program inputs to have layouts determined at
// run time.
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUDynamicLayoutPass();
std::unique_ptr<OperationPass<FuncOp>> CreateTPUDynamicLayoutPass();
// Creates a pass that remaps and assigns padding map from a
// `tf_device.launch_func` `padding_map` attribute to its encapsulated function.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUDynamicPaddingMapperPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicPaddingMapperPass();
// Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime
// ops.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPURewritePass();
std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass();
// Creates a pass that identifies XLASharding ops in launch op for TPU
// computation.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUShardingIdentificationPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUShardingIdentificationPass();
// Creates a pass that merges device variable reads/updates into the surrounded
// TPUExecute node. This allows the execute node to perform in-place variable
// updates.
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUMergeVariablesWithExecutePass();
std::unique_ptr<OperationPass<FuncOp>> CreateTPUMergeVariablesWithExecutePass();
// Creates a pass that adds ops which perform formatting on variables at
// run-time according to compilation result.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUVariableReformattingPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass();
// Populates the supplied passmanager with the passes required to run the
void CreateTPUBridgePipeline(OpPassManager& pm);
@ -260,16 +263,16 @@ void CreateTPUBridgePipelineV1(OpPassManager& pm);
namespace tf_saved_model {
// Creates a pass that optimizes tf_saved_model.global_tensor ops.
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeGlobalTensorsPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass();
// Creates a pass that freezes tf_saved_model.global_tensor ops.
std::unique_ptr<OpPassBase<ModuleOp>> CreateFreezeGlobalTensorsPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass();
// Creates a pass that uses tf_saved_model dialect linkage information
// to mark function visibility. That is, exported functions are marked with
// public visibility while the other functions are marked with private
// visibility.
std::unique_ptr<OpPassBase<ModuleOp>>
std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass();
} // namespace tf_saved_model

View File

@ -258,7 +258,7 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) {
}
class PromoteResourcesToArgsPass
: public OperationPass<PromoteResourcesToArgsPass, ModuleOp> {
: public PassWrapper<PromoteResourcesToArgsPass, OperationPass<ModuleOp>> {
public:
void runOnOperation() override;
};
@ -285,7 +285,7 @@ void PromoteResourcesToArgsPass::runOnOperation() {
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreatePromoteResourcesToArgsPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass() {
return std::make_unique<PromoteResourcesToArgsPass>();
}

View File

@ -32,7 +32,8 @@ namespace mlir {
namespace TFControlFlow {
namespace {
struct RaiseTFControlFlow : public FunctionPass<RaiseTFControlFlow> {
struct RaiseTFControlFlow
: public PassWrapper<RaiseTFControlFlow, FunctionPass> {
void runOnFunction() {
// First start by recognizing loops and reconstructing a loop tree.
buildLoopNests();
@ -145,7 +146,7 @@ void RaiseTFControlFlow::rewriteOps() {
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateRaiseTFControlFlowPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseTFControlFlowPass() {
return std::make_unique<RaiseTFControlFlow>();
}

View File

@ -37,7 +37,7 @@ namespace {
constexpr char kDeviceAttr[] = "device";
struct ReplicateInvariantOpHoistingPass
: public FunctionPass<ReplicateInvariantOpHoistingPass> {
: public PassWrapper<ReplicateInvariantOpHoistingPass, FunctionPass> {
void runOnFunction() override;
};
@ -178,7 +178,8 @@ void ReplicateInvariantOpHoistingPass::runOnFunction() {
}
} // anonymous namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateReplicateInvariantOpHoistingPass() {
std::unique_ptr<OperationPass<FuncOp>>
CreateReplicateInvariantOpHoistingPass() {
return std::make_unique<ReplicateInvariantOpHoistingPass>();
}

View File

@ -43,7 +43,8 @@ namespace TFDevice {
namespace {
constexpr char kDeviceAttr[] = "device";
struct ReplicateToIslandPass : public FunctionPass<ReplicateToIslandPass> {
struct ReplicateToIslandPass
: public PassWrapper<ReplicateToIslandPass, FunctionPass> {
void runOnFunction() override;
};
@ -237,7 +238,7 @@ void ReplicateToIslandPass::runOnFunction() {
}
} // anonymous namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateReplicateToIslandPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateReplicateToIslandPass() {
return std::make_unique<ReplicateToIslandPass>();
}

View File

@ -54,7 +54,7 @@ constexpr char kFuncDeviceAttr[] = "tf.device";
// This pass changes the module by adding "tf.device" attribute to function
// arguments and adding "device" attribute to TF ops.
struct ResourceDeviceInference
: public OperationPass<ResourceDeviceInference, ModuleOp> {
: public PassWrapper<ResourceDeviceInference, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -266,7 +266,7 @@ void ResourceDeviceInference::runOnOperation() {
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceDeviceInferencePass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass() {
return std::make_unique<ResourceDeviceInference>();
}

View File

@ -132,7 +132,7 @@ namespace {
// }
//
struct ResourceOpLiftingPass
: public OperationPass<ResourceOpLiftingPass, ModuleOp> {
: public PassWrapper<ResourceOpLiftingPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -1071,7 +1071,8 @@ void ResourceOpLiftingPass::runOnOperation() {
}
struct ResourceOpLiftingForMainFunctionPass
: public OperationPass<ResourceOpLiftingForMainFunctionPass, ModuleOp> {
: public PassWrapper<ResourceOpLiftingForMainFunctionPass,
OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -1100,7 +1101,7 @@ static PassRegistration<ResourceOpLiftingPass> pass(
} // namespace
namespace TFDevice {
std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceOpLiftingPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateResourceOpLiftingPass() {
return std::make_unique<ResourceOpLiftingPass>();
}
} // namespace TFDevice

View File

@ -47,7 +47,8 @@ namespace {
// This transformation pass propagate shapes on the TensorFlow graph.
// It is a ModulePass in order to be able to change function types.
struct ShapeInference : public OperationPass<ShapeInference, ModuleOp> {
struct ShapeInference
: public PassWrapper<ShapeInference, OperationPass<ModuleOp>> {
void runOnOperation() override {
auto module = getOperation();
auto producer_or = tensorflow::GetTfGraphProducerVersion(module);
@ -70,7 +71,7 @@ PassRegistration<ShapeInference> pass(
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreateTFShapeInferencePass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateTFShapeInferencePass() {
return std::make_unique<ShapeInference>();
}

View File

@ -39,7 +39,7 @@ namespace {
using ::mlir::TF::ConstOp;
class ExecutorConstantSinking
: public mlir::FunctionPass<ExecutorConstantSinking> {
: public mlir::PassWrapper<ExecutorConstantSinking, FunctionPass> {
void runOnFunction() override {
getFunction().walk([](tf_device::LaunchOp launch) {
LLVM_DEBUG(llvm::dbgs() << "Visit " << *launch.getOperation() << "\n");
@ -89,7 +89,7 @@ static mlir::PassRegistration<ExecutorConstantSinking> pass(
} // anonymous namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorConstantSinkingPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorConstantSinkingPass() {
return std::make_unique<ExecutorConstantSinking>();
}

View File

@ -85,7 +85,7 @@ namespace cutil = TF::collection_ops_util;
//
// The pass also works across control flow and functional calls.
struct StackOpsDecompositionPass
: public OperationPass<StackOpsDecompositionPass, ModuleOp> {
: public PassWrapper<StackOpsDecompositionPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -568,7 +568,7 @@ static PassRegistration<StackOpsDecompositionPass> pass(
} // namespace
namespace TF {
std::unique_ptr<OpPassBase<ModuleOp>> CreateStackOpsDecompositionPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateStackOpsDecompositionPass() {
return std::make_unique<StackOpsDecompositionPass>();
}

View File

@ -68,7 +68,8 @@ using std::string;
// shape.
//
struct TensorArrayOpsDecompositionPass
: public OperationPass<TensorArrayOpsDecompositionPass, ModuleOp> {
: public PassWrapper<TensorArrayOpsDecompositionPass,
OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -893,7 +894,8 @@ static PassRegistration<TensorArrayOpsDecompositionPass> pass(
} // namespace
namespace TF {
std::unique_ptr<OpPassBase<ModuleOp>> CreateTensorArrayOpsDecompositionPass() {
std::unique_ptr<OperationPass<ModuleOp>>
CreateTensorArrayOpsDecompositionPass() {
return std::make_unique<TensorArrayOpsDecompositionPass>();
}

View File

@ -62,7 +62,8 @@ namespace cutil = TF::collection_ops_util;
//
// The pass also works across control flow and functional calls.
struct TensorListOpsDecompositionPass
: public OperationPass<TensorListOpsDecompositionPass, ModuleOp> {
: public PassWrapper<TensorListOpsDecompositionPass,
OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -728,7 +729,8 @@ static PassRegistration<TensorListOpsDecompositionPass> pass(
} // namespace
namespace TF {
std::unique_ptr<OpPassBase<ModuleOp>> CreateTensorListOpsDecompositionPass() {
std::unique_ptr<OperationPass<ModuleOp>>
CreateTensorListOpsDecompositionPass() {
return std::make_unique<TensorListOpsDecompositionPass>();
}
} // namespace TF

View File

@ -39,7 +39,7 @@ namespace {
// A pass that adds "Predecessors" and "Successors" remarks for each op based on
// SideEffectAnalysis result. For testing purpose only.
struct TestSideEffectAnalysis
: public mlir::FunctionPass<TestSideEffectAnalysis> {
: public mlir::PassWrapper<TestSideEffectAnalysis, FunctionPass> {
void runOnFunction() override {
int64_t next_id = 0;
llvm::SmallDenseMap<Operation*, int64_t, 8> ids;

View File

@ -24,7 +24,7 @@ namespace TF {
namespace {
class SimpleTFDeviceAssignmentPass
: public FunctionPass<SimpleTFDeviceAssignmentPass> {
: public PassWrapper<SimpleTFDeviceAssignmentPass, FunctionPass> {
public:
SimpleTFDeviceAssignmentPass() = default;
SimpleTFDeviceAssignmentPass(const SimpleTFDeviceAssignmentPass&) {}
@ -57,7 +57,7 @@ class SimpleTFDeviceAssignmentPass
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
std::unique_ptr<OperationPass<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
llvm::StringRef default_device) {
return std::make_unique<SimpleTFDeviceAssignmentPass>(default_device);
}

View File

@ -40,7 +40,9 @@ namespace tensorflow {
// Optimization Passes and convert back to MLIR.
// Constraints: This pass expects that all operations in the MLIR module either
// belong to 'tf' or '_tf' dialect. The output is in '_tf' dialect.
class GraphOptPass : public mlir::OperationPass<GraphOptPass, mlir::ModuleOp> {
class GraphOptPass
: public mlir::PassWrapper<GraphOptPass,
mlir::OperationPass<mlir::ModuleOp>> {
public:
explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes)
: passes_(std::move(passes)) {}
@ -166,13 +168,13 @@ class GraphOptByNamePass : public GraphOptPass {
} // namespace tensorflow
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
tensorflow::CreateTensorFlowGraphOptimizationPass(
std::vector<tensorflow::GraphOptimizationPass*> tf_passes) {
return std::make_unique<GraphOptPass>(std::move(tf_passes));
}
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
tensorflow::CreateTensorFlowGraphOptimizationPass(
const std::vector<std::string>& pass_names) {
return std::make_unique<GraphOptByNamePass>(pass_names);

View File

@ -24,7 +24,7 @@ namespace tensorflow {
// Create a module pass that will execute the given TF GraphOptimization passes
// in sequence.
// Pass requires that the module ran on is convertible to TF Graph.
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateTensorFlowGraphOptimizationPass(
std::vector<tensorflow::GraphOptimizationPass*> tf_passes);
@ -32,7 +32,7 @@ CreateTensorFlowGraphOptimizationPass(
// passes are queried, if a TF graph optimization pass is not found in registry
// then the pass fails.
// Pass requires that the module ran on is convertible to TF Graph.
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateTensorFlowGraphOptimizationPass(
const std::vector<std::string>& pass_names);

View File

@ -71,7 +71,8 @@ using MetadataMap = llvm::SmallDenseMap<llvm::StringRef, NamedAttributeList, 8>;
using ClusterMap = llvm::SmallDenseMap<llvm::StringRef,
llvm::SmallSetVector<Operation*, 8>, 8>;
struct TPUClusterFormation : public FunctionPass<TPUClusterFormation> {
struct TPUClusterFormation
: public PassWrapper<TPUClusterFormation, FunctionPass> {
void runOnFunction() override;
};
@ -502,7 +503,7 @@ void TPUClusterFormation::runOnFunction() {
}
} // anonymous namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUClusterFormationPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateTPUClusterFormationPass() {
return std::make_unique<TPUClusterFormation>();
}

View File

@ -73,7 +73,8 @@ constexpr char kDeviceAttr[] = "device";
// %copy_to_device. There will not be send/recv ops added by later passes,
// because tf.TPUCopyWithLayout accepts a host input and produces a device
// output.
struct TPUDynamicLayoutPass : public FunctionPass<TPUDynamicLayoutPass> {
struct TPUDynamicLayoutPass
: public PassWrapper<TPUDynamicLayoutPass, FunctionPass> {
void runOnFunction() override;
};
@ -256,7 +257,7 @@ void TPUDynamicLayoutPass::runOnFunction() {
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUDynamicLayoutPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateTPUDynamicLayoutPass() {
return std::make_unique<TPUDynamicLayoutPass>();
}

View File

@ -49,7 +49,7 @@ constexpr char kPaddingMapAttr[] = "padding_map";
namespace {
struct TPUDynamicPaddingMapper
: public OperationPass<TPUDynamicPaddingMapper, ModuleOp> {
: public PassWrapper<TPUDynamicPaddingMapper, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -200,7 +200,7 @@ void TPUDynamicPaddingMapper::runOnOperation() {
}
} // anonymous namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUDynamicPaddingMapperPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicPaddingMapperPass() {
return std::make_unique<TPUDynamicPaddingMapper>();
}

View File

@ -75,7 +75,7 @@ constexpr char kFuncDeviceAttr[] = "tf.device";
// the TPUExecute op.
struct TPUMergeVariablesWithExecutePass
: public FunctionPass<TPUMergeVariablesWithExecutePass> {
: public PassWrapper<TPUMergeVariablesWithExecutePass, FunctionPass> {
void runOnFunction() override;
};
@ -531,7 +531,8 @@ void TPUMergeVariablesWithExecutePass::runOnFunction() {
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUMergeVariablesWithExecutePass() {
std::unique_ptr<OperationPass<FuncOp>>
CreateTPUMergeVariablesWithExecutePass() {
return std::make_unique<TPUMergeVariablesWithExecutePass>();
}

View File

@ -98,7 +98,8 @@ constexpr char kBadArrayAttrLengthMsg[] =
// %4 = "tf.SomeOp"(%3)
namespace {
struct TPURewritePass : public OperationPass<TPURewritePass, ModuleOp> {
struct TPURewritePass
: public PassWrapper<TPURewritePass, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -770,7 +771,7 @@ void TPURewritePass::runOnOperation() {
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPURewritePass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass() {
return std::make_unique<TPURewritePass>();
}

View File

@ -40,7 +40,8 @@ namespace {
constexpr char kShardingAttr[] = "xla_hlo.sharding";
struct TPUShardingIdentificationPass
: public OperationPass<TPUShardingIdentificationPass, ModuleOp> {
: public PassWrapper<TPUShardingIdentificationPass,
OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -185,7 +186,7 @@ void TPUShardingIdentificationPass::runOnOperation() {
} // anonymous namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUShardingIdentificationPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUShardingIdentificationPass() {
return std::make_unique<TPUShardingIdentificationPass>();
}

View File

@ -116,7 +116,8 @@ std::string GetRandomStateVariableName() {
// tf.TPUReshardVariablesOp(%rvar, %default_format, %rstate)
// }
struct TPUVariableRuntimeReformattingPass
: public OperationPass<TPUVariableRuntimeReformattingPass, ModuleOp> {
: public PassWrapper<TPUVariableRuntimeReformattingPass,
OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -575,7 +576,7 @@ void TPUVariableRuntimeReformattingPass::runOnOperation() {
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUVariableReformattingPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass() {
return std::make_unique<TPUVariableRuntimeReformattingPass>();
}

View File

@ -44,7 +44,8 @@ namespace {
// Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
// of the inputs, matmul them individually, then stack them all back together at
// the end.
struct UnrollBatchMatMulPass : public FunctionPass<UnrollBatchMatMulPass> {
struct UnrollBatchMatMulPass
: public PassWrapper<UnrollBatchMatMulPass, FunctionPass> {
void runOnFunction() override;
};
@ -309,7 +310,7 @@ static PassRegistration<UnrollBatchMatMulPass> pass(
"tf-unroll-batch-matmul",
"Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.");
std::unique_ptr<OpPassBase<FuncOp>> CreateUnrollBatchMatMulPassPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateUnrollBatchMatMulPassPass() {
return std::make_unique<UnrollBatchMatMulPass>();
}

View File

@ -42,7 +42,7 @@ namespace mlir {
namespace {
struct BreakUpIslands : FunctionPass<BreakUpIslands> {
struct BreakUpIslands : PassWrapper<BreakUpIslands, FunctionPass> {
void runOnFunction() final;
void BreakUpIsland(tf_executor::IslandOp island_op,
@ -325,7 +325,7 @@ void BreakUpIslands::BreakUpIsland(
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateBreakUpIslandsPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateBreakUpIslandsPass() {
return std::make_unique<BreakUpIslands>();
}

View File

@ -45,7 +45,7 @@ namespace {
// otherwise _tf operations are wrapped in an island and the _ prefix is
// removed. Control dependencies are moved to be handled by the island itself.
struct ControlToExecutorDialectConversion
: public FunctionPass<ControlToExecutorDialectConversion> {
: public PassWrapper<ControlToExecutorDialectConversion, FunctionPass> {
void runOnFunction() override;
private:
@ -237,7 +237,7 @@ void ControlToExecutorDialectConversion::runOnFunction() {
}
}
OpPassBase<FuncOp> *CreateTFControlToExecutorDialectConversion() {
OperationPass<FuncOp> *CreateTFControlToExecutorDialectConversion() {
return new ControlToExecutorDialectConversion();
}

View File

@ -39,7 +39,7 @@ namespace mlir {
namespace {
struct ExecutorToControlDialectConversion
: public FunctionPass<ExecutorToControlDialectConversion> {
: public PassWrapper<ExecutorToControlDialectConversion, FunctionPass> {
void runOnFunction() override;
};
} // end anonymous namespace
@ -230,7 +230,7 @@ void ExecutorToControlDialectConversion::runOnFunction() {
graph.erase();
}
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OperationPass<FuncOp>>
CreateTFExecutorToControlDialectConversion() {
return std::make_unique<ExecutorToControlDialectConversion>();
}

View File

@ -40,7 +40,7 @@ namespace {
// return %graph_results#...
// }
struct FunctionalToExecutorDialectConversion
: public FunctionPass<FunctionalToExecutorDialectConversion> {
: public PassWrapper<FunctionalToExecutorDialectConversion, FunctionPass> {
void runOnFunction() override;
};
} // end anonymous namespace
@ -95,7 +95,7 @@ void FunctionalToExecutorDialectConversion::runOnFunction() {
}
}
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OperationPass<FuncOp>>
CreateFunctionalToExecutorDialectConversionPass() {
return std::make_unique<FunctionalToExecutorDialectConversion>();
}

View File

@ -343,7 +343,8 @@ class BufferAssignmentAnalysis {
/// the right positions. It uses the algorithm described at the top of the file.
// TODO(dfki): create a templated version that allows to match dialect-specific
// alloc/dealloc nodes and to insert dialect-specific dealloc node.
struct BufferAssignmentPass : mlir::FunctionPass<BufferAssignmentPass> {
struct BufferAssignmentPass
: mlir::PassWrapper<BufferAssignmentPass, FunctionPass> {
void runOnFunction() override {
// Get required analysis information first.
auto& analysis = getAnalysis<BufferAssignmentAnalysis>();
@ -471,7 +472,7 @@ void FunctionAndBlockSignatureConverter::addDynamicallyLegalFuncOp(
// Buffer assignment pass registrations
//===----------------------------------------------------------------------===//
std::unique_ptr<OpPassBase<FuncOp>> createBufferAssignmentPass() {
std::unique_ptr<OperationPass<FuncOp>> createBufferAssignmentPass() {
return absl::make_unique<BufferAssignmentPass>();
}
@ -482,14 +483,15 @@ static PassRegistration<BufferAssignmentPass> buffer_assignment_pass(
/// A simple pass to print debug/test information for the buffer assignment
/// analysis.
struct BufferAssignmentTestPass : mlir::FunctionPass<BufferAssignmentTestPass> {
struct BufferAssignmentTestPass
: mlir::PassWrapper<BufferAssignmentTestPass, FunctionPass> {
void runOnFunction() override {
llvm::outs() << "Testing : " << getFunction().getName() << "\n";
getAnalysis<BufferAssignmentAnalysis>().print(llvm::outs());
};
};
std::unique_ptr<OpPassBase<FuncOp>> createBufferAssignmentTestPass() {
std::unique_ptr<OperationPass<FuncOp>> createBufferAssignmentTestPass() {
return absl::make_unique<BufferAssignmentTestPass>();
}

View File

@ -324,7 +324,8 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
// "xla_lhlo.terminator"() : () -> ()
// }
struct HloLegalizeToLhlo : public OperationPass<HloLegalizeToLhlo, ModuleOp> {
struct HloLegalizeToLhlo
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
void runOnOperation() override {
OwningRewritePatternList patterns;
auto& context = getContext();
@ -473,7 +474,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
// clang-format on
}
std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass() {
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
return absl::make_unique<HloLegalizeToLhlo>();
}

View File

@ -37,7 +37,8 @@ using mlir::PassRegistration;
namespace mlir {
namespace xla_hlo {
namespace {
struct LegalizeControlFlow : public mlir::FunctionPass<LegalizeControlFlow> {
struct LegalizeControlFlow
: public mlir::PassWrapper<LegalizeControlFlow, FunctionPass> {
// Perform the lowering to MLIR control flow.
void runOnFunction() override;
};
@ -227,7 +228,7 @@ void LegalizeControlFlow::runOnFunction() {
} // namespace xla_hlo
} // namespace mlir
std::unique_ptr<mlir::OpPassBase<mlir::FuncOp>>
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
mlir::xla_hlo::createLegalizeControlFlowPass() {
return std::make_unique<LegalizeControlFlow>();
}

View File

@ -55,7 +55,7 @@ namespace mlir {
namespace xla_hlo {
namespace {
class LegalizeTF : public FunctionPass<LegalizeTF> {
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
public:
LegalizeTF() = default;
LegalizeTF(const LegalizeTF &) {}
@ -3829,7 +3829,7 @@ static PassRegistration<LegalizeTF> pass(
} // end namespace
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeTFPass(
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
bool allow_partial_conversion) {
return std::make_unique<LegalizeTF>(allow_partial_conversion);
}

View File

@ -52,13 +52,13 @@ namespace mlir {
namespace xla_hlo {
namespace {
class LegalizeTFControlFlow
: public OperationPass<LegalizeTFControlFlow, ModuleOp> {
: public PassWrapper<LegalizeTFControlFlow, OperationPass<ModuleOp>> {
public:
void runOnOperation() override;
};
} // namespace
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createLegalizeTFControlFlowPass() {
return std::make_unique<LegalizeTFControlFlow>();
}

View File

@ -333,7 +333,7 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
return success();
}
class LegalizeTF : public FunctionPass<LegalizeTF> {
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
public:
LegalizeTF() = default;

View File

@ -177,13 +177,14 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
} // end anonymous namespace
namespace {
struct LegalizeToStandard : public FunctionPass<LegalizeToStandard> {
struct LegalizeToStandard
: public PassWrapper<LegalizeToStandard, FunctionPass> {
/// Perform the lowering to Standard dialect.
void runOnFunction() override;
};
} // end anonymous namespace
std::unique_ptr<mlir::OpPassBase<mlir::FuncOp>> createLegalizeToStdPass() {
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
return std::make_unique<LegalizeToStandard>();
}

View File

@ -30,7 +30,7 @@ namespace {
// arguments. All uses of each buffer are replaced with the corresponding block
// argument and the buffer is freed. Note that this pass only works in regions
// with a single block.
struct LhloCopyRemoval : mlir::OperationPass<LhloCopyRemoval> {
struct LhloCopyRemoval : mlir::PassWrapper<LhloCopyRemoval, OperationPass<>> {
void runOnOperation() override {
llvm::SmallVector<mlir::Operation*, 2> eraseList;
auto operation = getOperation();

View File

@ -30,7 +30,7 @@ namespace {
using linalg::LinalgOp;
class LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
public:
LhloFuseLinalg() = default;
LhloFuseLinalg(const LhloFuseLinalg&) {}
@ -123,7 +123,7 @@ class LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> createLhloFuseLinalg(
std::unique_ptr<OperationPass<FuncOp>> createLhloFuseLinalg(
bool use_parallel_loops, ArrayRef<unsigned> tile_sizes) {
return absl::make_unique<LhloFuseLinalg>(use_parallel_loops, tile_sizes);
}

View File

@ -81,7 +81,8 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
// clang-format on
}
struct LhloLegalizeToAffine : public FunctionPass<LhloLegalizeToAffine> {
struct LhloLegalizeToAffine
: public PassWrapper<LhloLegalizeToAffine, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
auto func = getFunction();
@ -92,7 +93,7 @@ struct LhloLegalizeToAffine : public FunctionPass<LhloLegalizeToAffine> {
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeToAffinePass() {
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass() {
return absl::make_unique<LhloLegalizeToAffine>();
}

View File

@ -168,7 +168,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
};
};
struct LhloLegalizeToGpu : public FunctionPass<LhloLegalizeToGpu> {
struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
@ -186,7 +186,7 @@ struct LhloLegalizeToGpu : public FunctionPass<LhloLegalizeToGpu> {
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeToGpuPass() {
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass() {
return absl::make_unique<LhloLegalizeToGpu>();
}

View File

@ -452,7 +452,7 @@ class ReduceWindowOpConverter
};
struct LhloLegalizeToParallelLoops
: public FunctionPass<LhloLegalizeToParallelLoops> {
: public PassWrapper<LhloLegalizeToParallelLoops, FunctionPass> {
void runOnFunction() override {
auto func = getFunction();
@ -478,7 +478,7 @@ struct LhloLegalizeToParallelLoops
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeLhloToParallelLoopsPass() {
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass() {
return absl::make_unique<LhloLegalizeToParallelLoops>();
}

View File

@ -38,11 +38,12 @@ limitations under the License.
using mlir::FunctionPass;
using mlir::OwningRewritePatternList;
using mlir::PassRegistration;
using mlir::PassWrapper;
namespace {
class LowerComplex : public FunctionPass<LowerComplex> {
class LowerComplex : public PassWrapper<LowerComplex, FunctionPass> {
public:
explicit LowerComplex() : FunctionPass<LowerComplex>() {}
explicit LowerComplex() : PassWrapper<LowerComplex, FunctionPass>() {}
/// Performs the lowering to XLA dialect.
void runOnFunction() override;

View File

@ -39,6 +39,7 @@ using mlir::MLIRContext;
using mlir::OpRewritePattern;
using mlir::OwningRewritePatternList;
using mlir::PassRegistration;
using mlir::PassWrapper;
using mlir::PatternRewriter;
using mlir::RankedTensorType;
using mlir::success;
@ -170,7 +171,8 @@ struct GeneralDotConvert
}
};
struct LegalizeGeneralDot : public FunctionPass<LegalizeGeneralDot> {
struct LegalizeGeneralDot
: public PassWrapper<LegalizeGeneralDot, FunctionPass> {
/// Lower all general dots that can be represented as a non-batched matmul.
void runOnFunction() override {
OwningRewritePatternList patterns;

View File

@ -28,7 +28,7 @@ namespace xla_hlo {
namespace {
struct TestMaterializeBroadcastsPass
: public FunctionPass<TestMaterializeBroadcastsPass> {
: public PassWrapper<TestMaterializeBroadcastsPass, FunctionPass> {
void runOnFunction() override {
ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns;

Some files were not shown because too many files have changed in this diff Show More