Bump OSS LLVM to 82576d6fecfec71725eb900111c000d772002449
PiperOrigin-RevId: 305444589 Change-Id: I407c717e431b9b2514a5eee5ad0f4b51717e2e9a
This commit is contained in:
parent
7a4e9467cb
commit
7908b844de
@ -55,7 +55,8 @@ namespace quant {
|
|||||||
using QuantParamsEntry = QuantizationInfo::QuantParams;
|
using QuantParamsEntry = QuantizationInfo::QuantParams;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
|
class ImportQuantStatsPass
|
||||||
|
: public PassWrapper<ImportQuantStatsPass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
explicit ImportQuantStatsPass(OperationToName op_to_name)
|
explicit ImportQuantStatsPass(OperationToName op_to_name)
|
||||||
: op_to_name_(op_to_name) {}
|
: op_to_name_(op_to_name) {}
|
||||||
@ -193,7 +194,7 @@ void ImportQuantStatsPass::runOnFunction() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Creates an instance of the default quant parameters pass.
|
// Creates an instance of the default quant parameters pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
|
||||||
OperationToName op_to_name, const std::string &stats_str) {
|
OperationToName op_to_name, const std::string &stats_str) {
|
||||||
auto pass = absl::make_unique<ImportQuantStatsPass>(op_to_name);
|
auto pass = absl::make_unique<ImportQuantStatsPass>(op_to_name);
|
||||||
if (pass->ParseQuantStats(stats_str)) return nullptr;
|
if (pass->ParseQuantStats(stats_str)) return nullptr;
|
||||||
@ -203,7 +204,7 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
|||||||
// Creates an instance pass to import quantization stats to the operations in
|
// Creates an instance pass to import quantization stats to the operations in
|
||||||
// the function. A custom method to get the name from the op is used because
|
// the function. A custom method to get the name from the op is used because
|
||||||
// different dialect ops might have different ways to assign the name.
|
// different dialect ops might have different ways to assign the name.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
|
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
|
||||||
auto get_name_func = [](Operation *op) {
|
auto get_name_func = [](Operation *op) {
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
|
|||||||
@ -27,13 +27,13 @@ using OperationToName = std::function<llvm::StringRef(Operation* op)>;
|
|||||||
// Creates an instance pass to import quantization stats to the operations in
|
// Creates an instance pass to import quantization stats to the operations in
|
||||||
// the function. A custom method to get the name from the op is used because
|
// the function. A custom method to get the name from the op is used because
|
||||||
// different dialect ops might have different ways to assign the name.
|
// different dialect ops might have different ways to assign the name.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
|
||||||
OperationToName op_to_name, const std::string& stats_str);
|
OperationToName op_to_name, const std::string& stats_str);
|
||||||
|
|
||||||
// Creates an instance pass to import quantization stats to the operations in
|
// Creates an instance pass to import quantization stats to the operations in
|
||||||
// the function. A custom method to get the name from the op is used because
|
// the function. A custom method to get the name from the op is used because
|
||||||
// different dialect ops might have different ways to assign the name.
|
// different dialect ops might have different ways to assign the name.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
CreateImportQuantStatsPassForTFControlDialect(const std::string& stats_str);
|
CreateImportQuantStatsPassForTFControlDialect(const std::string& stats_str);
|
||||||
|
|
||||||
} // namespace quant
|
} // namespace quant
|
||||||
|
|||||||
@ -25,7 +25,7 @@ namespace mlir {
|
|||||||
namespace TF {
|
namespace TF {
|
||||||
|
|
||||||
// Legalize the tf ops to the quant ops, so the quantization passes can work.
|
// Legalize the tf ops to the quant ops, so the quantization passes can work.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFToQuantPass();
|
||||||
|
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|||||||
@ -27,7 +27,7 @@ namespace TF {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Legalize TF quantization emulation ops to that in Quant ops dialect.
|
// Legalize TF quantization emulation ops to that in Quant ops dialect.
|
||||||
struct LegalizeTFToQuant : public FunctionPass<LegalizeTFToQuant> {
|
struct LegalizeTFToQuant : public PassWrapper<LegalizeTFToQuant, FunctionPass> {
|
||||||
explicit LegalizeTFToQuant() = default;
|
explicit LegalizeTFToQuant() = default;
|
||||||
LegalizeTFToQuant(const LegalizeTFToQuant &) {}
|
LegalizeTFToQuant(const LegalizeTFToQuant &) {}
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ void LegalizeTFToQuant::runOnFunction() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow dialect to QuantOps dialect pass.
|
// Creates an instance of the TensorFlow dialect to QuantOps dialect pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFToQuantPass() {
|
||||||
return std::make_unique<LegalizeTFToQuant>();
|
return std::make_unique<LegalizeTFToQuant>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -315,7 +315,8 @@ llvm::SmallVector<Value, 0> FuseOps(PatternRewriter* rewriter,
|
|||||||
return new_values;
|
return new_values;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CpuKernelFusionPass : public FunctionPass<CpuKernelFusionPass> {
|
struct CpuKernelFusionPass
|
||||||
|
: public PassWrapper<CpuKernelFusionPass, FunctionPass> {
|
||||||
explicit CpuKernelFusionPass() = default;
|
explicit CpuKernelFusionPass() = default;
|
||||||
CpuKernelFusionPass(const CpuKernelFusionPass&) {}
|
CpuKernelFusionPass(const CpuKernelFusionPass&) {}
|
||||||
|
|
||||||
@ -335,7 +336,7 @@ void CpuKernelFusionPass::runOnFunction() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the xla_hlo cpu kernel fusion pass.
|
// Creates an instance of the xla_hlo cpu kernel fusion pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateCpuKernelFusionPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateCpuKernelFusionPass() {
|
||||||
return std::make_unique<CpuKernelFusionPass>();
|
return std::make_unique<CpuKernelFusionPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -141,7 +141,8 @@ class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Materialize the quantization results by hlo primitive ops.
|
// Materialize the quantization results by hlo primitive ops.
|
||||||
struct MaterializeToXlaPass : public FunctionPass<MaterializeToXlaPass> {
|
struct MaterializeToXlaPass
|
||||||
|
: public PassWrapper<MaterializeToXlaPass, FunctionPass> {
|
||||||
explicit MaterializeToXlaPass() = default;
|
explicit MaterializeToXlaPass() = default;
|
||||||
MaterializeToXlaPass(const MaterializeToXlaPass &) {}
|
MaterializeToXlaPass(const MaterializeToXlaPass &) {}
|
||||||
|
|
||||||
@ -162,7 +163,7 @@ void MaterializeToXlaPass::runOnFunction() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the xla_hlo dialect quantization propagation pass.
|
// Creates an instance of the xla_hlo dialect quantization propagation pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateMaterializeToXlaPass() {
|
||||||
return std::make_unique<MaterializeToXlaPass>();
|
return std::make_unique<MaterializeToXlaPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -26,10 +26,10 @@ namespace xla_hlo {
|
|||||||
|
|
||||||
// Propagate the quantization information to all the tensors according to the
|
// Propagate the quantization information to all the tensors according to the
|
||||||
// op quant spec.
|
// op quant spec.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreatePropagateQuantPass();
|
||||||
|
|
||||||
// Rewrite the graph and quantize the constant.
|
// Rewrite the graph and quantize the constant.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateMaterializeToXlaPass();
|
||||||
|
|
||||||
} // namespace xla_hlo
|
} // namespace xla_hlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|||||||
@ -50,7 +50,8 @@ namespace {
|
|||||||
// - The quantization spec for the ops
|
// - The quantization spec for the ops
|
||||||
// The propagation results should assign quantization types to all the tensors
|
// The propagation results should assign quantization types to all the tensors
|
||||||
// and the two restrictions are respected.
|
// and the two restrictions are respected.
|
||||||
struct PropagateQuantPass : public FunctionPass<PropagateQuantPass> {
|
struct PropagateQuantPass
|
||||||
|
: public PassWrapper<PropagateQuantPass, FunctionPass> {
|
||||||
explicit PropagateQuantPass() = default;
|
explicit PropagateQuantPass() = default;
|
||||||
PropagateQuantPass(const PropagateQuantPass &) {}
|
PropagateQuantPass(const PropagateQuantPass &) {}
|
||||||
|
|
||||||
@ -96,7 +97,7 @@ void PropagateQuantPass::runOnFunction() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the xla_hlo dialect quantization propagation pass.
|
// Creates an instance of the xla_hlo dialect quantization propagation pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreatePropagateQuantPass() {
|
||||||
return std::make_unique<PropagateQuantPass>();
|
return std::make_unique<PropagateQuantPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
/// Create a pass to convert from the TFExecutor to the TF control dialect.
|
/// Create a pass to convert from the TFExecutor to the TF control dialect.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
CreateTFExecutorToControlDialectConversion();
|
CreateTFExecutorToControlDialectConversion();
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|||||||
@ -40,7 +40,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
/// Create a pass to convert from the TFExecutor to the TF control dialect.
|
/// Create a pass to convert from the TFExecutor to the TF control dialect.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
CreateTFExecutorToControlDialectConversion();
|
CreateTFExecutorToControlDialectConversion();
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|||||||
@ -44,7 +44,8 @@ namespace TFL {
|
|||||||
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
|
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DefaultQuantParamsPass : public FunctionPass<DefaultQuantParamsPass> {
|
class DefaultQuantParamsPass
|
||||||
|
: public PassWrapper<DefaultQuantParamsPass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
explicit DefaultQuantParamsPass(double default_min, double default_max)
|
explicit DefaultQuantParamsPass(double default_min, double default_max)
|
||||||
: default_min_(default_min), default_max_(default_max) {}
|
: default_min_(default_min), default_max_(default_max) {}
|
||||||
@ -220,7 +221,7 @@ quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Creates an instance of the default quant parameters pass.
|
// Creates an instance of the default quant parameters pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
|
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
|
||||||
double default_min, double default_max) {
|
double default_min, double default_max) {
|
||||||
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max);
|
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,7 +29,7 @@ namespace TFL {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct DenseToSparse : public FunctionPass<DenseToSparse> {
|
struct DenseToSparse : public PassWrapper<DenseToSparse, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ void DenseToSparse::runOnFunction() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect DenseToSparse pass.
|
// Creates an instance of the TensorFlow Lite dialect DenseToSparse pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDenseToSparsePass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateDenseToSparsePass() {
|
||||||
return absl::make_unique<DenseToSparse>();
|
return absl::make_unique<DenseToSparse>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,8 @@ namespace mlir {
|
|||||||
namespace TFL {
|
namespace TFL {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct IdentifyDilatedConvPass : public FunctionPass<IdentifyDilatedConvPass> {
|
struct IdentifyDilatedConvPass
|
||||||
|
: public PassWrapper<IdentifyDilatedConvPass, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -679,7 +679,8 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ExtractOphintPass : public OperationPass<ExtractOphintPass, ModuleOp> {
|
struct ExtractOphintPass
|
||||||
|
: public PassWrapper<ExtractOphintPass, OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
void Verify();
|
void Verify();
|
||||||
|
|
||||||
@ -752,7 +753,7 @@ void ExtractOphintPass::Verify() {
|
|||||||
|
|
||||||
/// Creates an instance of the TensorFlow Lite dialect ExtractOphintPass
|
/// Creates an instance of the TensorFlow Lite dialect ExtractOphintPass
|
||||||
/// pass.
|
/// pass.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateExtractOphintPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateExtractOphintPass() {
|
||||||
return std::make_unique<ExtractOphintPass>();
|
return std::make_unique<ExtractOphintPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -69,7 +69,7 @@ constexpr char kUnidirectionalSequenceLstm[] = "UnidirectionalSequenceLstm";
|
|||||||
// |
|
// |
|
||||||
// OutputOp1
|
// OutputOp1
|
||||||
struct LegalizeOphintFuncOpPass
|
struct LegalizeOphintFuncOpPass
|
||||||
: public OperationPass<LegalizeOphintFuncOpPass, ModuleOp> {
|
: public PassWrapper<LegalizeOphintFuncOpPass, OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -284,7 +284,7 @@ void LegalizeOphintFuncOpPass::runOnOperation() {
|
|||||||
|
|
||||||
/// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass
|
/// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass
|
||||||
/// pass.
|
/// pass.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeOphintFuncOpPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeOphintFuncOpPass() {
|
||||||
return std::make_unique<LegalizeOphintFuncOpPass>();
|
return std::make_unique<LegalizeOphintFuncOpPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -70,7 +70,7 @@ constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn";
|
|||||||
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
|
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
|
||||||
|
|
||||||
// Legalize operations in functions.
|
// Legalize operations in functions.
|
||||||
struct LegalizeTF : public FunctionPass<LegalizeTF> {
|
struct LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -763,7 +763,7 @@ void LegalizeTF::runOnFunction() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
|
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass() {
|
||||||
return std::make_unique<LegalizeTF>();
|
return std::make_unique<LegalizeTF>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,8 @@ namespace {
|
|||||||
|
|
||||||
// Legalize TF While to TFL While with calls to the original functions from the
|
// Legalize TF While to TFL While with calls to the original functions from the
|
||||||
// cond and body regions.
|
// cond and body regions.
|
||||||
struct LegalizeWhile : public OperationPass<LegalizeWhile, ModuleOp> {
|
struct LegalizeWhile
|
||||||
|
: public PassWrapper<LegalizeWhile, OperationPass<ModuleOp>> {
|
||||||
void RunOnFunction(FuncOp func);
|
void RunOnFunction(FuncOp func);
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
@ -76,7 +77,7 @@ void LegalizeWhile::RunOnFunction(FuncOp func) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow While to TFLite While pass.
|
// Creates an instance of the TensorFlow While to TFLite While pass.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeTFWhilePass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFWhilePass() {
|
||||||
return std::make_unique<LegalizeWhile>();
|
return std::make_unique<LegalizeWhile>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -42,7 +42,8 @@ namespace {
|
|||||||
// AnyQuantizedType, thus bitwidth, narrow_range, etc are included. The op also
|
// AnyQuantizedType, thus bitwidth, narrow_range, etc are included. The op also
|
||||||
// defines the op quantization traits, which are used to propagate the
|
// defines the op quantization traits, which are used to propagate the
|
||||||
// quantization parameters by the following passes.
|
// quantization parameters by the following passes.
|
||||||
struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
|
struct LoadQuantizationRecipe
|
||||||
|
: public PassWrapper<LoadQuantizationRecipe, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -215,7 +216,7 @@ void LoadQuantizationRecipe::runOnFunction() {
|
|||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect LoadQuantizationRecipe
|
// Creates an instance of the TensorFlow Lite dialect LoadQuantizationRecipe
|
||||||
// pass.
|
// pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLoadQuantizationRecipePass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateLoadQuantizationRecipePass() {
|
||||||
return absl::make_unique<LoadQuantizationRecipe>();
|
return absl::make_unique<LoadQuantizationRecipe>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -82,7 +82,7 @@ class TensorListPatternRewriter : public PatternRewriter {
|
|||||||
|
|
||||||
/// Lower TensorList ops in functions for subsequent legalization.
|
/// Lower TensorList ops in functions for subsequent legalization.
|
||||||
struct LowerStaticTensorListPass
|
struct LowerStaticTensorListPass
|
||||||
: public OperationPass<LowerStaticTensorListPass, ModuleOp> {
|
: public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
|
|
||||||
// Apply type and op changes within a function.
|
// Apply type and op changes within a function.
|
||||||
@ -906,7 +906,8 @@ void LowerStaticTensorListPass::runOnOperation() {
|
|||||||
|
|
||||||
/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||||
/// pass.
|
/// pass.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> TFL::CreateLowerStaticTensorListPass() {
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
TFL::CreateLowerStaticTensorListPass() {
|
||||||
return std::make_unique<LowerStaticTensorListPass>();
|
return std::make_unique<LowerStaticTensorListPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -74,7 +74,7 @@ bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
|
|||||||
using ::llvm::cast;
|
using ::llvm::cast;
|
||||||
|
|
||||||
// Optimize TFLite operations in functions.
|
// Optimize TFLite operations in functions.
|
||||||
struct Optimize : public FunctionPass<Optimize> {
|
struct Optimize : public PassWrapper<Optimize, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -725,7 +725,7 @@ void Optimize::runOnFunction() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
|
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateOptimizePass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass() {
|
||||||
return std::make_unique<Optimize>();
|
return std::make_unique<Optimize>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -36,7 +36,7 @@ using FuncSet = llvm::SmallSet<FuncOp, 4>;
|
|||||||
|
|
||||||
// Module pass to optimize TensorFlow functional ops.
|
// Module pass to optimize TensorFlow functional ops.
|
||||||
struct OptimizeFunctionalOpsPass
|
struct OptimizeFunctionalOpsPass
|
||||||
: public OperationPass<OptimizeFunctionalOpsPass, ModuleOp> {
|
: public PassWrapper<OptimizeFunctionalOpsPass, OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -198,7 +198,7 @@ void OptimizeFunctionalOpsPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
|
||||||
return std::make_unique<OptimizeFunctionalOpsPass>();
|
return std::make_unique<OptimizeFunctionalOpsPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -24,75 +24,75 @@ namespace mlir {
|
|||||||
class FuncOp;
|
class FuncOp;
|
||||||
class ModuleOp;
|
class ModuleOp;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class OpPassBase;
|
class OperationPass;
|
||||||
|
|
||||||
namespace TFL {
|
namespace TFL {
|
||||||
class QuantizationSpecs;
|
class QuantizationSpecs;
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
|
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass();
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
|
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateOptimizePass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
|
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareTFPass(
|
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
|
||||||
bool unfold_batch_matmul);
|
bool unfold_batch_matmul);
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||||
// pass.
|
// pass.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateLowerStaticTensorListPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass();
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
|
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateQuantizePass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass();
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
|
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareQuantizePass(
|
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(
|
||||||
const QuantizationSpecs& quant_specs);
|
const QuantizationSpecs& quant_specs);
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
|
// Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePostQuantizePass(
|
std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
|
||||||
bool emit_quant_adaptor_ops);
|
bool emit_quant_adaptor_ops);
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
|
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
|
||||||
// pass.
|
// pass.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTrimFunctionsPass(
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
|
||||||
llvm::ArrayRef<std::string> trim_funcs_whitelist);
|
llvm::ArrayRef<std::string> trim_funcs_whitelist);
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions
|
// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions
|
||||||
// pass.
|
// pass.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreatePrepareCompositeFunctionsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareCompositeFunctionsPass();
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect ExtractOphint pass.
|
// Creates an instance of the TensorFlow Lite dialect ExtractOphint pass.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateExtractOphintPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateExtractOphintPass();
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass
|
// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass
|
||||||
// pass. The composite op is created from the ophint extraction pass.
|
// pass. The composite op is created from the ophint extraction pass.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeOphintFuncOpPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeOphintFuncOpPass();
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect SplitMergedOperandsPass.
|
// Creates an instance of the TensorFlow Lite dialect SplitMergedOperandsPass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateSplitMergedOperandsPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateSplitMergedOperandsPass();
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect OptimizeFunctionalOpsPass.
|
// Creates an instance of the TensorFlow Lite dialect OptimizeFunctionalOpsPass.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeFunctionalOpsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass();
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect pass to add default
|
// Creates an instance of the TensorFlow Lite dialect pass to add default
|
||||||
// quantization parameters.
|
// quantization parameters.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
|
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
|
||||||
double default_min, double default_max);
|
double default_min, double default_max);
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect pass to convert dense
|
// Creates an instance of the TensorFlow Lite dialect pass to convert dense
|
||||||
// tensor to sparse format.
|
// tensor to sparse format.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDenseToSparsePass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateDenseToSparsePass();
|
||||||
|
|
||||||
// Creates function pass to legalize TF While to TFL While.
|
// Creates function pass to legalize TF While to TFL While.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeTFWhilePass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFWhilePass();
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
|
||||||
|
|
||||||
// Verifies runtime supports types used.
|
// Verifies runtime supports types used.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeTypeVerifyPass();
|
||||||
|
|
||||||
} // namespace TFL
|
} // namespace TFL
|
||||||
|
|
||||||
|
|||||||
@ -30,7 +30,7 @@ namespace TFL {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Applies all the clean up steps after quantization.
|
// Applies all the clean up steps after quantization.
|
||||||
class PostQuantizePass : public FunctionPass<PostQuantizePass> {
|
class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
// Constructor used by the PassRegistration. This will remove the adaptor ops.
|
// Constructor used by the PassRegistration. This will remove the adaptor ops.
|
||||||
explicit PostQuantizePass() : emit_quant_adaptor_ops_(false) {}
|
explicit PostQuantizePass() : emit_quant_adaptor_ops_(false) {}
|
||||||
@ -135,7 +135,7 @@ void PostQuantizePass::runOnFunction() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
|
// Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePostQuantizePass(
|
std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
|
||||||
bool emit_quant_adaptor_ops) {
|
bool emit_quant_adaptor_ops) {
|
||||||
return std::make_unique<PostQuantizePass>(emit_quant_adaptor_ops);
|
return std::make_unique<PostQuantizePass>(emit_quant_adaptor_ops);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -94,7 +94,8 @@ class ConvertEmbeddedLookupFunc {
|
|||||||
// body with the corresponding fused TFLite op. The replacement need not always
|
// body with the corresponding fused TFLite op. The replacement need not always
|
||||||
// be a fused op, though that is the primary use case.
|
// be a fused op, though that is the primary use case.
|
||||||
class PrepareCompositeFunctionsPass
|
class PrepareCompositeFunctionsPass
|
||||||
: public OperationPass<PrepareCompositeFunctionsPass, ModuleOp> {
|
: public PassWrapper<PrepareCompositeFunctionsPass,
|
||||||
|
OperationPass<ModuleOp>> {
|
||||||
public:
|
public:
|
||||||
explicit PrepareCompositeFunctionsPass() {}
|
explicit PrepareCompositeFunctionsPass() {}
|
||||||
|
|
||||||
@ -211,7 +212,7 @@ void PrepareCompositeFunctionsPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
|
||||||
return std::make_unique<PrepareCompositeFunctionsPass>();
|
return std::make_unique<PrepareCompositeFunctionsPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -66,7 +66,8 @@ namespace {
|
|||||||
// across ops. This step is necessary for post-training quantization and also
|
// across ops. This step is necessary for post-training quantization and also
|
||||||
// making the quantization rule for some operations in the quantization-aware
|
// making the quantization rule for some operations in the quantization-aware
|
||||||
// training quantization simpler.
|
// training quantization simpler.
|
||||||
class PrepareQuantizePass : public FunctionPass<PrepareQuantizePass> {
|
class PrepareQuantizePass
|
||||||
|
: public PassWrapper<PrepareQuantizePass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
// Constructor used by the PassRegistration and enforce uint8 quantization.
|
// Constructor used by the PassRegistration and enforce uint8 quantization.
|
||||||
explicit PrepareQuantizePass() {
|
explicit PrepareQuantizePass() {
|
||||||
@ -281,7 +282,7 @@ void PrepareQuantizePass::runOnFunction() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
|
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareQuantizePass(
|
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(
|
||||||
const QuantizationSpecs& quant_specs) {
|
const QuantizationSpecs& quant_specs) {
|
||||||
return std::make_unique<PrepareQuantizePass>(quant_specs);
|
return std::make_unique<PrepareQuantizePass>(quant_specs);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -71,7 +71,7 @@ namespace TFL {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Prepare TF operations in functions for subsequent legalization.
|
// Prepare TF operations in functions for subsequent legalization.
|
||||||
class PrepareTFPass : public FunctionPass<PrepareTFPass> {
|
class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
explicit PrepareTFPass() : unfold_batch_matmul_(true) {}
|
explicit PrepareTFPass() : unfold_batch_matmul_(true) {}
|
||||||
explicit PrepareTFPass(bool unfold_batch_matmul)
|
explicit PrepareTFPass(bool unfold_batch_matmul)
|
||||||
@ -652,7 +652,7 @@ void PrepareTFPass::runOnFunction() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
|
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareTFPass(
|
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
|
||||||
bool unfold_batch_matmul) {
|
bool unfold_batch_matmul) {
|
||||||
return std::make_unique<PrepareTFPass>(unfold_batch_matmul);
|
return std::make_unique<PrepareTFPass>(unfold_batch_matmul);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -75,7 +75,7 @@ struct TFLFullQuantization
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Applies quantization on the model in TFL dialect.
|
// Applies quantization on the model in TFL dialect.
|
||||||
struct QuantizePass : public FunctionPass<QuantizePass> {
|
struct QuantizePass : public PassWrapper<QuantizePass, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -93,7 +93,7 @@ void QuantizePass::runOnFunction() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
|
// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateQuantizePass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass() {
|
||||||
return std::make_unique<QuantizePass>();
|
return std::make_unique<QuantizePass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,8 @@ namespace {
|
|||||||
|
|
||||||
// This pass verifies that the operands and results types are supported by
|
// This pass verifies that the operands and results types are supported by
|
||||||
// TFLite runtime.
|
// TFLite runtime.
|
||||||
class RuntimeTypeVerifyPass : public mlir::FunctionPass<RuntimeTypeVerifyPass> {
|
class RuntimeTypeVerifyPass
|
||||||
|
: public mlir::PassWrapper<RuntimeTypeVerifyPass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
explicit RuntimeTypeVerifyPass() {}
|
explicit RuntimeTypeVerifyPass() {}
|
||||||
|
|
||||||
@ -43,7 +44,7 @@ void RuntimeTypeVerifyPass::runOnFunction() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Verifies runtime supports types used.
|
// Verifies runtime supports types used.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeTypeVerifyPass() {
|
||||||
return std::make_unique<RuntimeTypeVerifyPass>();
|
return std::make_unique<RuntimeTypeVerifyPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -66,7 +66,8 @@ namespace mlir {
|
|||||||
namespace TFL {
|
namespace TFL {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct SplitMergedOperandsPass : public FunctionPass<SplitMergedOperandsPass> {
|
struct SplitMergedOperandsPass
|
||||||
|
: public PassWrapper<SplitMergedOperandsPass, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -119,7 +120,7 @@ void SplitMergedOperandsPass::runOnFunction() {
|
|||||||
|
|
||||||
/// Creates an instance of the TensorFlow Lite dialect SplitMergedOperands
|
/// Creates an instance of the TensorFlow Lite dialect SplitMergedOperands
|
||||||
/// pass.
|
/// pass.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateSplitMergedOperandsPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateSplitMergedOperandsPass() {
|
||||||
return std::make_unique<SplitMergedOperandsPass>();
|
return std::make_unique<SplitMergedOperandsPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -45,7 +45,7 @@ namespace {
|
|||||||
// The pass to trim functions before we legalize to TFL
|
// The pass to trim functions before we legalize to TFL
|
||||||
// dialect using the specified whitelist.
|
// dialect using the specified whitelist.
|
||||||
class TrimFunctionsPass
|
class TrimFunctionsPass
|
||||||
: public mlir::OperationPass<TrimFunctionsPass, ModuleOp> {
|
: public mlir::PassWrapper<TrimFunctionsPass, OperationPass<ModuleOp>> {
|
||||||
public:
|
public:
|
||||||
explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {}
|
explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {}
|
||||||
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_whitelist)
|
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_whitelist)
|
||||||
@ -120,7 +120,7 @@ void TrimFunctionsPass::Verify() {
|
|||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
|
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
|
||||||
/// pass.
|
/// pass.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTrimFunctionsPass(
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
|
||||||
llvm::ArrayRef<std::string> trim_funcs_whitelist) {
|
llvm::ArrayRef<std::string> trim_funcs_whitelist) {
|
||||||
return std::make_unique<TrimFunctionsPass>(trim_funcs_whitelist);
|
return std::make_unique<TrimFunctionsPass>(trim_funcs_whitelist);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -38,7 +38,7 @@ namespace {
|
|||||||
// This pass outlines the cond/body region of the TFL WhileOp into functions and
|
// This pass outlines the cond/body region of the TFL WhileOp into functions and
|
||||||
// replaces the regions with calls to these outlined functions.
|
// replaces the regions with calls to these outlined functions.
|
||||||
class WhileOutlinePass
|
class WhileOutlinePass
|
||||||
: public mlir::OperationPass<WhileOutlinePass, ModuleOp> {
|
: public mlir::PassWrapper<WhileOutlinePass, OperationPass<ModuleOp>> {
|
||||||
public:
|
public:
|
||||||
explicit WhileOutlinePass() {}
|
explicit WhileOutlinePass() {}
|
||||||
|
|
||||||
@ -241,7 +241,7 @@ void WhileOutlinePass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {
|
||||||
return std::make_unique<WhileOutlinePass>();
|
return std::make_unique<WhileOutlinePass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -39,7 +39,8 @@ constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
|
|||||||
// Analyzes the inputs to LaunchFuncOps in the module, and annotates their
|
// Analyzes the inputs to LaunchFuncOps in the module, and annotates their
|
||||||
// invoked functions whether each input has the same data across replicas.
|
// invoked functions whether each input has the same data across replicas.
|
||||||
struct AnnotateParameterReplication
|
struct AnnotateParameterReplication
|
||||||
: public OperationPass<AnnotateParameterReplication, ModuleOp> {
|
: public PassWrapper<AnnotateParameterReplication,
|
||||||
|
OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -90,7 +91,8 @@ void AnnotateParameterReplication::runOnOperation() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateAnnotateParameterReplicationPass() {
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
CreateAnnotateParameterReplicationPass() {
|
||||||
return std::make_unique<AnnotateParameterReplication>();
|
return std::make_unique<AnnotateParameterReplication>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -43,7 +43,8 @@ namespace TF {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Replace TF BatchMatMul by TF Einsum
|
// Replace TF BatchMatMul by TF Einsum
|
||||||
struct BatchMatMulToEinsumPass : public FunctionPass<BatchMatMulToEinsumPass> {
|
struct BatchMatMulToEinsumPass
|
||||||
|
: public PassWrapper<BatchMatMulToEinsumPass, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -117,7 +118,7 @@ static PassRegistration<BatchMatMulToEinsumPass> pass(
|
|||||||
"tf-batch-matmul-to-tf-einsum",
|
"tf-batch-matmul-to-tf-einsum",
|
||||||
"Replace TF BatchMatMul op by TF Einsum op.");
|
"Replace TF BatchMatMul op by TF Einsum op.");
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateBatchMatMulToEinsumPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateBatchMatMulToEinsumPass() {
|
||||||
return std::make_unique<BatchMatMulToEinsumPass>();
|
return std::make_unique<BatchMatMulToEinsumPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,8 @@ namespace TFDevice {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct ClusterFormationPass : public FunctionPass<ClusterFormationPass> {
|
struct ClusterFormationPass
|
||||||
|
: public PassWrapper<ClusterFormationPass, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -229,7 +230,7 @@ void ClusterFormationPass::runOnFunction() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateClusterFormationPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateClusterFormationPass() {
|
||||||
return std::make_unique<ClusterFormationPass>();
|
return std::make_unique<ClusterFormationPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -39,7 +39,7 @@ constexpr char kDeviceAttr[] = "device";
|
|||||||
constexpr char kFuncAttr[] = "func";
|
constexpr char kFuncAttr[] = "func";
|
||||||
|
|
||||||
struct ClusterOutliningPass
|
struct ClusterOutliningPass
|
||||||
: public OperationPass<ClusterOutliningPass, ModuleOp> {
|
: public PassWrapper<ClusterOutliningPass, OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -132,7 +132,7 @@ void ClusterOutliningPass::runOnOperation() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateClusterOutliningPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateClusterOutliningPass() {
|
||||||
return std::make_unique<ClusterOutliningPass>();
|
return std::make_unique<ClusterOutliningPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -52,7 +52,7 @@ bool DecodeOpaqueValueInConstantOp(Operation *op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// A pass to decode opaque constant values into readable ones.
|
// A pass to decode opaque constant values into readable ones.
|
||||||
struct DecodeConstant : public FunctionPass<DecodeConstant> {
|
struct DecodeConstant : public PassWrapper<DecodeConstant, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
auto walk_result = getFunction().walk([](Operation *op) {
|
auto walk_result = getFunction().walk([](Operation *op) {
|
||||||
return DecodeOpaqueValueInConstantOp(op) ? WalkResult::advance()
|
return DecodeOpaqueValueInConstantOp(op) ? WalkResult::advance()
|
||||||
@ -64,7 +64,7 @@ struct DecodeConstant : public FunctionPass<DecodeConstant> {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDecodeConstantPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateDecodeConstantPass() {
|
||||||
return std::make_unique<DecodeConstant>();
|
return std::make_unique<DecodeConstant>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,7 @@ namespace TF {
|
|||||||
// Creates a pass to decode and reset opaque values in constant ops into
|
// Creates a pass to decode and reset opaque values in constant ops into
|
||||||
// readable values.
|
// readable values.
|
||||||
// Note that this pass assumes RaiseTFControlFlow pass has already been run.
|
// Note that this pass assumes RaiseTFControlFlow pass has already been run.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDecodeConstantPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateDecodeConstantPass();
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|||||||
@ -38,7 +38,8 @@ namespace {
|
|||||||
// NOTE: This pass does not support `use_locking=true` for a lot of resource
|
// NOTE: This pass does not support `use_locking=true` for a lot of resource
|
||||||
// operations. So decomposition may not be correct outside of backends like XLA,
|
// operations. So decomposition may not be correct outside of backends like XLA,
|
||||||
// which automatically locks all resource variables.
|
// which automatically locks all resource variables.
|
||||||
struct DecomposeResourceOps : public FunctionPass<DecomposeResourceOps> {
|
struct DecomposeResourceOps
|
||||||
|
: public PassWrapper<DecomposeResourceOps, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
// Add lowering patterns to the list.
|
// Add lowering patterns to the list.
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
@ -50,7 +51,7 @@ struct DecomposeResourceOps : public FunctionPass<DecomposeResourceOps> {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDecomposeResourceOpsPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateDecomposeResourceOpsPass() {
|
||||||
return std::make_unique<DecomposeResourceOps>();
|
return std::make_unique<DecomposeResourceOps>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -354,7 +354,8 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Transform Einsum to other TF Ops for the supported variants.
|
// Transform Einsum to other TF Ops for the supported variants.
|
||||||
struct TransformEinsumPass : public FunctionPass<TransformEinsumPass> {
|
struct TransformEinsumPass
|
||||||
|
: public PassWrapper<TransformEinsumPass, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -57,7 +57,7 @@ struct IslandResult {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct ExecutorIslandCoarsening
|
struct ExecutorIslandCoarsening
|
||||||
: public FunctionPass<ExecutorIslandCoarsening> {
|
: public PassWrapper<ExecutorIslandCoarsening, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -346,7 +346,7 @@ void ExecutorIslandCoarsening::runOnFunction() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorIslandCoarseningPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorIslandCoarseningPass() {
|
||||||
return std::make_unique<ExecutorIslandCoarsening>();
|
return std::make_unique<ExecutorIslandCoarsening>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -43,7 +43,8 @@ constexpr llvm::StringRef kNestedModule = "_tpu_v1_compat_outlined";
|
|||||||
// Inlining the islands calling into the nested module that was outlined.
|
// Inlining the islands calling into the nested module that was outlined.
|
||||||
// This is the end of the TPU bridge in V1 compatibility mode.
|
// This is the end of the TPU bridge in V1 compatibility mode.
|
||||||
struct TPUBridgeExecutorIslandInlining
|
struct TPUBridgeExecutorIslandInlining
|
||||||
: public OperationPass<TPUBridgeExecutorIslandInlining, ModuleOp> {
|
: public PassWrapper<TPUBridgeExecutorIslandInlining,
|
||||||
|
OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -95,7 +96,7 @@ PassRegistration<TPUBridgeExecutorIslandInlining> tpu_pass(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
CreateTFExecutorTPUV1IslandInliningPass() {
|
CreateTFExecutorTPUV1IslandInliningPass() {
|
||||||
return std::make_unique<TPUBridgeExecutorIslandInlining>();
|
return std::make_unique<TPUBridgeExecutorIslandInlining>();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -59,7 +59,8 @@ constexpr llvm::StringRef kTpuStatusAttr = "_tpu_compilation_status";
|
|||||||
// TPU-annotated operations and intended to preserve backward compatibility with
|
// TPU-annotated operations and intended to preserve backward compatibility with
|
||||||
// TFv1.
|
// TFv1.
|
||||||
struct TpuV1BridgeExecutorIslandCoarsening
|
struct TpuV1BridgeExecutorIslandCoarsening
|
||||||
: public OperationPass<TpuV1BridgeExecutorIslandCoarsening, ModuleOp> {
|
: public PassWrapper<TpuV1BridgeExecutorIslandCoarsening,
|
||||||
|
OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -322,7 +323,7 @@ void TpuV1BridgeExecutorIslandCoarsening::runOnOperation() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
CreateTFExecutorTPUV1IslandCoarseningPass() {
|
CreateTFExecutorTPUV1IslandCoarseningPass() {
|
||||||
return std::make_unique<TpuV1BridgeExecutorIslandCoarsening>();
|
return std::make_unique<TpuV1BridgeExecutorIslandCoarsening>();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -44,7 +44,8 @@ constexpr llvm::StringRef kOutlinedFuncPrefix = "_tpu_v1_compat_outlined_func";
|
|||||||
// This is only intended for V1 compatibility mode where the bridge runs without
|
// This is only intended for V1 compatibility mode where the bridge runs without
|
||||||
// feed/fetches on session create/extend.
|
// feed/fetches on session create/extend.
|
||||||
struct TPUBridgeExecutorIslandOutlining
|
struct TPUBridgeExecutorIslandOutlining
|
||||||
: public OperationPass<TPUBridgeExecutorIslandOutlining, ModuleOp> {
|
: public PassWrapper<TPUBridgeExecutorIslandOutlining,
|
||||||
|
OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -160,7 +161,7 @@ PassRegistration<TPUBridgeExecutorIslandOutlining> tpu_pass(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
CreateTFExecutorTPUV1IslandOutliningPass() {
|
CreateTFExecutorTPUV1IslandOutliningPass() {
|
||||||
return std::make_unique<TPUBridgeExecutorIslandOutlining>();
|
return std::make_unique<TPUBridgeExecutorIslandOutlining>();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -58,7 +58,7 @@ limitations under the License.
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class SwitchFoldPass : public mlir::FunctionPass<SwitchFoldPass> {
|
class SwitchFoldPass : public mlir::PassWrapper<SwitchFoldPass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
@ -279,7 +279,7 @@ void SwitchFoldPass::runOnFunction() {
|
|||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
namespace tf_executor {
|
namespace tf_executor {
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateSwitchFoldPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateSwitchFoldPass() {
|
||||||
return std::make_unique<SwitchFoldPass>();
|
return std::make_unique<SwitchFoldPass>();
|
||||||
}
|
}
|
||||||
} // namespace tf_executor
|
} // namespace tf_executor
|
||||||
|
|||||||
@ -42,7 +42,7 @@ namespace {
|
|||||||
// support resources/variables . Further, this contract also ensures that this
|
// support resources/variables . Further, this contract also ensures that this
|
||||||
// pass lowers from saved model to pure TF. Hence it fails, if it cannot lower.
|
// pass lowers from saved model to pure TF. Hence it fails, if it cannot lower.
|
||||||
struct FreezeGlobalTensorsPass
|
struct FreezeGlobalTensorsPass
|
||||||
: public OperationPass<FreezeGlobalTensorsPass, ModuleOp> {
|
: public PassWrapper<FreezeGlobalTensorsPass, OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -113,7 +113,7 @@ static PassRegistration<FreezeGlobalTensorsPass> pass(
|
|||||||
"tf-saved-model-freeze-global-tensors",
|
"tf-saved-model-freeze-global-tensors",
|
||||||
"Freeze tf_saved_model.global_tensor's in func bodies.");
|
"Freeze tf_saved_model.global_tensor's in func bodies.");
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateFreezeGlobalTensorsPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass() {
|
||||||
return std::make_unique<FreezeGlobalTensorsPass>();
|
return std::make_unique<FreezeGlobalTensorsPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -34,7 +34,7 @@ namespace TF {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct FunctionalControlFlowToCFG
|
struct FunctionalControlFlowToCFG
|
||||||
: public FunctionPass<FunctionalControlFlowToCFG> {
|
: public PassWrapper<FunctionalControlFlowToCFG, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -312,7 +312,7 @@ void FunctionalControlFlowToCFG::runOnFunction() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFFunctionalControlFlowToCFG() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG() {
|
||||||
return std::make_unique<FunctionalControlFlowToCFG>();
|
return std::make_unique<FunctionalControlFlowToCFG>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -35,7 +35,7 @@ namespace {
|
|||||||
// GpuOpFusionPass is a pass performing fusion specific to GPU targets.
|
// GpuOpFusionPass is a pass performing fusion specific to GPU targets.
|
||||||
// This is an ad-hoc pass for now, but should be integrated with some notion
|
// This is an ad-hoc pass for now, but should be integrated with some notion
|
||||||
// of "target" in the MLIR pipeline in the future.
|
// of "target" in the MLIR pipeline in the future.
|
||||||
class GpuOpFusionPass : public FunctionPass<GpuOpFusionPass> {
|
class GpuOpFusionPass : public PassWrapper<GpuOpFusionPass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
void runOnFunction() final;
|
void runOnFunction() final;
|
||||||
};
|
};
|
||||||
@ -123,7 +123,7 @@ void GpuOpFusionPass::runOnFunction() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateGpuOpFusionPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateGpuOpFusionPass() {
|
||||||
return std::make_unique<GpuOpFusionPass>();
|
return std::make_unique<GpuOpFusionPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -84,7 +84,7 @@ void PruneGraph(GraphOp graph) {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// This transformation pass prunes a TF graph eliminating dead-nodes.
|
// This transformation pass prunes a TF graph eliminating dead-nodes.
|
||||||
struct GraphPruning : public FunctionPass<GraphPruning> {
|
struct GraphPruning : public PassWrapper<GraphPruning, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
getFunction().walk([](tf_executor::GraphOp graph) {
|
getFunction().walk([](tf_executor::GraphOp graph) {
|
||||||
// For TensorFlow V1.0 compatibility: when importing a graph without
|
// For TensorFlow V1.0 compatibility: when importing a graph without
|
||||||
@ -100,7 +100,7 @@ struct GraphPruning : public FunctionPass<GraphPruning> {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorGraphPruningPass() {
|
||||||
return std::make_unique<GraphPruning>();
|
return std::make_unique<GraphPruning>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -57,7 +57,7 @@ namespace {
|
|||||||
constexpr char kDeviceAttr[] = "device";
|
constexpr char kDeviceAttr[] = "device";
|
||||||
|
|
||||||
struct LaunchToDeviceAttributePass
|
struct LaunchToDeviceAttributePass
|
||||||
: public FunctionPass<LaunchToDeviceAttributePass> {
|
: public PassWrapper<LaunchToDeviceAttributePass, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ void LaunchToDeviceAttributePass::runOnFunction() {
|
|||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLaunchToDeviceAttributePass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateLaunchToDeviceAttributePass() {
|
||||||
return std::make_unique<LaunchToDeviceAttributePass>();
|
return std::make_unique<LaunchToDeviceAttributePass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -36,7 +36,8 @@ namespace {
|
|||||||
|
|
||||||
// LayoutAssignmentPass assigns optimal data layout (data format) for all
|
// LayoutAssignmentPass assigns optimal data layout (data format) for all
|
||||||
// layout sensitive operations.
|
// layout sensitive operations.
|
||||||
class LayoutAssignmentPass : public FunctionPass<LayoutAssignmentPass> {
|
class LayoutAssignmentPass
|
||||||
|
: public PassWrapper<LayoutAssignmentPass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
LayoutAssignmentPass() = default;
|
LayoutAssignmentPass() = default;
|
||||||
explicit LayoutAssignmentPass(const std::string& force_data_format) {
|
explicit LayoutAssignmentPass(const std::string& force_data_format) {
|
||||||
@ -57,7 +58,8 @@ class LayoutAssignmentPass : public FunctionPass<LayoutAssignmentPass> {
|
|||||||
// MoveTransposesPass moves all Transpose ops to the beginning or to the end of
|
// MoveTransposesPass moves all Transpose ops to the beginning or to the end of
|
||||||
// the basic block where they are defined. This will allow canonicalzer to
|
// the basic block where they are defined. This will allow canonicalzer to
|
||||||
// delete redundant transposes.
|
// delete redundant transposes.
|
||||||
class MoveTransposesPass : public FunctionPass<MoveTransposesPass> {
|
class MoveTransposesPass
|
||||||
|
: public PassWrapper<MoveTransposesPass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
enum class Direction { kBegin, kEnd };
|
enum class Direction { kBegin, kEnd };
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,7 @@ namespace mlir {
|
|||||||
namespace TF {
|
namespace TF {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class LegalizeHloToTf : public FunctionPass<LegalizeHloToTf> {
|
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
LegalizeHloToTf() = default;
|
LegalizeHloToTf() = default;
|
||||||
LegalizeHloToTf(const LegalizeHloToTf &) {}
|
LegalizeHloToTf(const LegalizeHloToTf &) {}
|
||||||
@ -76,7 +76,7 @@ static PassRegistration<LegalizeHloToTf> pass(
|
|||||||
|
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeHloToTfPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() {
|
||||||
return std::make_unique<LegalizeHloToTf>();
|
return std::make_unique<LegalizeHloToTf>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,7 @@ namespace {
|
|||||||
|
|
||||||
// Lowers some of the TensorFlow operations that can be represented using other
|
// Lowers some of the TensorFlow operations that can be represented using other
|
||||||
// TensorFlow operations.
|
// TensorFlow operations.
|
||||||
struct LowerTF : public FunctionPass<LowerTF> {
|
struct LowerTF : public PassWrapper<LowerTF, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
// Add lowering patterns to the list.
|
// Add lowering patterns to the list.
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
|
|||||||
@ -74,8 +74,9 @@ LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification(
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct MarkFunctionVisibilityUsingEntryFunctionSpecificationPass
|
struct MarkFunctionVisibilityUsingEntryFunctionSpecificationPass
|
||||||
: public OperationPass<
|
: public PassWrapper<
|
||||||
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass, ModuleOp> {
|
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass,
|
||||||
|
OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
if (failed(MarkFunctionVisibilityUsingEntryFunctionSpecification(
|
if (failed(MarkFunctionVisibilityUsingEntryFunctionSpecification(
|
||||||
getOperation()))) {
|
getOperation()))) {
|
||||||
@ -90,7 +91,7 @@ static PassRegistration<
|
|||||||
pass("tf-mark-func-visibility",
|
pass("tf-mark-func-visibility",
|
||||||
"Use tf.entry_function to mark function visibility.");
|
"Use tf.entry_function to mark function visibility.");
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass() {
|
CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass() {
|
||||||
return std::make_unique<
|
return std::make_unique<
|
||||||
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass>();
|
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass>();
|
||||||
@ -110,8 +111,8 @@ static LogicalResult MarkFunctionVisibilityUsingSavedModelLinkage(
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct MarkFunctionVisibilityUsingSavedModelLinkagePass
|
struct MarkFunctionVisibilityUsingSavedModelLinkagePass
|
||||||
: public OperationPass<MarkFunctionVisibilityUsingSavedModelLinkagePass,
|
: public PassWrapper<MarkFunctionVisibilityUsingSavedModelLinkagePass,
|
||||||
ModuleOp> {
|
OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
if (failed(MarkFunctionVisibilityUsingSavedModelLinkage(getOperation()))) {
|
if (failed(MarkFunctionVisibilityUsingSavedModelLinkage(getOperation()))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
@ -124,7 +125,7 @@ static PassRegistration<MarkFunctionVisibilityUsingSavedModelLinkagePass> pass(
|
|||||||
"tf-saved-model-mark-func-visibility",
|
"tf-saved-model-mark-func-visibility",
|
||||||
"Use tf_saved_model linkage information to mark function visibility.");
|
"Use tf_saved_model linkage information to mark function visibility.");
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass() {
|
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass() {
|
||||||
return std::make_unique<MarkFunctionVisibilityUsingSavedModelLinkagePass>();
|
return std::make_unique<MarkFunctionVisibilityUsingSavedModelLinkagePass>();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -35,7 +35,7 @@ namespace mlir {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class MaterializePassthroughOpPass
|
class MaterializePassthroughOpPass
|
||||||
: public FunctionPass<MaterializePassthroughOpPass> {
|
: public PassWrapper<MaterializePassthroughOpPass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
@ -96,7 +96,7 @@ void MaterializePassthroughOpPass::runOnFunction() {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace TF {
|
namespace TF {
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializePassthroughOpPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateMaterializePassthroughOpPass() {
|
||||||
return std::make_unique<MaterializePassthroughOpPass>();
|
return std::make_unique<MaterializePassthroughOpPass>();
|
||||||
}
|
}
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
|
|||||||
@ -33,7 +33,7 @@ namespace {
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_optimize.inc"
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_optimize.inc"
|
||||||
|
|
||||||
// Canonicalize operations in functions.
|
// Canonicalize operations in functions.
|
||||||
struct TFOptimizePass : public FunctionPass<TFOptimizePass> {
|
struct TFOptimizePass : public PassWrapper<TFOptimizePass, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
@ -71,7 +71,7 @@ void CreateTFStandardPipeline(OpPassManager &pm,
|
|||||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFOptimizePass() {
|
||||||
return std::make_unique<TFOptimizePass>();
|
return std::make_unique<TFOptimizePass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -41,7 +41,7 @@ namespace mlir {
|
|||||||
namespace tf_saved_model {
|
namespace tf_saved_model {
|
||||||
namespace {
|
namespace {
|
||||||
struct OptimizeGlobalTensorsPass
|
struct OptimizeGlobalTensorsPass
|
||||||
: public OperationPass<OptimizeGlobalTensorsPass, ModuleOp> {
|
: public PassWrapper<OptimizeGlobalTensorsPass, OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -296,7 +296,7 @@ static PassRegistration<OptimizeGlobalTensorsPass> pass(
|
|||||||
"tf-saved-model-optimize-global-tensors",
|
"tf-saved-model-optimize-global-tensors",
|
||||||
"Optimize tf_saved_model.global_tensor's.");
|
"Optimize tf_saved_model.global_tensor's.");
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeGlobalTensorsPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass() {
|
||||||
return std::make_unique<OptimizeGlobalTensorsPass>();
|
return std::make_unique<OptimizeGlobalTensorsPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -83,7 +83,7 @@ namespace TFDevice {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct ParallelExecuteToIslandsPass
|
struct ParallelExecuteToIslandsPass
|
||||||
: public FunctionPass<ParallelExecuteToIslandsPass> {
|
: public PassWrapper<ParallelExecuteToIslandsPass, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -251,7 +251,7 @@ void ParallelExecuteToIslandsPass::runOnFunction() {
|
|||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateParallelExecuteToIslandsPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateParallelExecuteToIslandsPass() {
|
||||||
return std::make_unique<ParallelExecuteToIslandsPass>();
|
return std::make_unique<ParallelExecuteToIslandsPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -24,36 +24,36 @@ namespace mlir {
|
|||||||
|
|
||||||
// Creates a pass that breaks up an island with multiple ops into multiple
|
// Creates a pass that breaks up an island with multiple ops into multiple
|
||||||
// islands, each with a single op.
|
// islands, each with a single op.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateBreakUpIslandsPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateBreakUpIslandsPass();
|
||||||
|
|
||||||
// Creates a pass that converts mlir functions consisting of mlir ops into a
|
// Creates a pass that converts mlir functions consisting of mlir ops into a
|
||||||
// tf_executor dialect as a single island.
|
// tf_executor dialect as a single island.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
CreateFunctionalToExecutorDialectConversionPass();
|
CreateFunctionalToExecutorDialectConversionPass();
|
||||||
|
|
||||||
namespace TF {
|
namespace TF {
|
||||||
// Transforms functional control flow operations in the standard TensorFlow
|
// Transforms functional control flow operations in the standard TensorFlow
|
||||||
// dialect to MLIR Control Flow Graph (CFG) form.
|
// dialect to MLIR Control Flow Graph (CFG) form.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFFunctionalControlFlowToCFG();
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG();
|
||||||
|
|
||||||
// Materialize the MlirPassthroughOp by replacing it with the MLIR module
|
// Materialize the MlirPassthroughOp by replacing it with the MLIR module
|
||||||
// attached as an attribute.
|
// attached as an attribute.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializePassthroughOpPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateMaterializePassthroughOpPass();
|
||||||
|
|
||||||
// Performs Shape Inference on the TensorFlow dialect using the global registry.
|
// Performs Shape Inference on the TensorFlow dialect using the global registry.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTFShapeInferencePass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTFShapeInferencePass();
|
||||||
|
|
||||||
// Optional pass which will unroll BatchMatMul and use only MatMul
|
// Optional pass which will unroll BatchMatMul and use only MatMul
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateUnrollBatchMatMulPassPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateUnrollBatchMatMulPassPass();
|
||||||
|
|
||||||
// Optional pass which will map TF BatchMatMul to TF Einsum
|
// Optional pass which will map TF BatchMatMul to TF Einsum
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateBatchMatMulToEinsumPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateBatchMatMulToEinsumPass();
|
||||||
|
|
||||||
// Optimizes Tensorflow graph.
|
// Optimizes Tensorflow graph.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFOptimizePass();
|
||||||
|
|
||||||
// Performs specific fusion for GPU targets.
|
// Performs specific fusion for GPU targets.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateGpuOpFusionPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateGpuOpFusionPass();
|
||||||
|
|
||||||
struct LayoutOptimizationPipelineOptions
|
struct LayoutOptimizationPipelineOptions
|
||||||
: public PassPipelineOptions<LayoutOptimizationPipelineOptions> {
|
: public PassPipelineOptions<LayoutOptimizationPipelineOptions> {
|
||||||
@ -82,14 +82,14 @@ void CreateTFStandardPipeline(OpPassManager& pm,
|
|||||||
const StandardPipelineOptions& options);
|
const StandardPipelineOptions& options);
|
||||||
|
|
||||||
// Propagates device attributes of resources from callers to callees.
|
// Propagates device attributes of resources from callers to callees.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceDeviceInferencePass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass();
|
||||||
|
|
||||||
// Creates a pass that promotes resource reads/writes in the main function to
|
// Creates a pass that promotes resource reads/writes in the main function to
|
||||||
// inputs and outputs of the main function, assuming that resource operations
|
// inputs and outputs of the main function, assuming that resource operations
|
||||||
// have already been decomposed and function calls have already been inlined.
|
// have already been decomposed and function calls have already been inlined.
|
||||||
// The pass also annotates the input arguments for resources with the indices
|
// The pass also annotates the input arguments for resources with the indices
|
||||||
// of their aliasing output arguments.
|
// of their aliasing output arguments.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreatePromoteResourcesToArgsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass();
|
||||||
|
|
||||||
// Marks function visibility using tf.entry_function specification. That is,
|
// Marks function visibility using tf.entry_function specification. That is,
|
||||||
// functions with tf.entry_function attributes are marked with public
|
// functions with tf.entry_function attributes are marked with public
|
||||||
@ -98,11 +98,11 @@ LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification(
|
|||||||
ModuleOp module);
|
ModuleOp module);
|
||||||
// Creates a pass that uses tf.entry_function specification to mark function
|
// Creates a pass that uses tf.entry_function specification to mark function
|
||||||
// visibility.
|
// visibility.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass();
|
CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass();
|
||||||
|
|
||||||
// Creates a simple device assignment pass on TF dialect for CoreRT use case.
|
// Creates a simple device assignment pass on TF dialect for CoreRT use case.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
|
std::unique_ptr<OperationPass<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
|
||||||
llvm::StringRef default_device);
|
llvm::StringRef default_device);
|
||||||
|
|
||||||
// Performs resource lifting on the function body to hoist resource variable
|
// Performs resource lifting on the function body to hoist resource variable
|
||||||
@ -112,25 +112,26 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function);
|
|||||||
// Converts stack ops into operations on local variables, which can later be
|
// Converts stack ops into operations on local variables, which can later be
|
||||||
// removed by resource lifting. Requires known maximum sizes of stacks and
|
// removed by resource lifting. Requires known maximum sizes of stacks and
|
||||||
// known element shapes of push ops.
|
// known element shapes of push ops.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateStackOpsDecompositionPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateStackOpsDecompositionPass();
|
||||||
|
|
||||||
// Converts tensor list operations into operations on buffers and sizes. Needs
|
// Converts tensor list operations into operations on buffers and sizes. Needs
|
||||||
// static shapes and known max element count.
|
// static shapes and known max element count.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTensorListOpsDecompositionPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTensorListOpsDecompositionPass();
|
||||||
|
|
||||||
// Converts tensor array ops into operations on local variables, which can later
|
// Converts tensor array ops into operations on local variables, which can later
|
||||||
// be removed by resource lifting. Requires known sizes and known element shapes
|
// be removed by resource lifting. Requires known sizes and known element shapes
|
||||||
// (either defined in TensorArrayV3 or implied in the first write).
|
// (either defined in TensorArrayV3 or implied in the first write).
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTensorArrayOpsDecompositionPass();
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
CreateTensorArrayOpsDecompositionPass();
|
||||||
|
|
||||||
// Create a pass that legalize HLO to TF dialect.
|
// Create a pass that legalize HLO to TF dialect.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeHloToTfPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
|
|
||||||
namespace TFControlFlow {
|
namespace TFControlFlow {
|
||||||
// Raises from the "TensorFlow Control Flow" dialect to the standard TensorFlow
|
// Raises from the "TensorFlow Control Flow" dialect to the standard TensorFlow
|
||||||
// dialect.
|
// dialect.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateRaiseTFControlFlowPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseTFControlFlowPass();
|
||||||
|
|
||||||
} // namespace TFControlFlow
|
} // namespace TFControlFlow
|
||||||
|
|
||||||
@ -138,29 +139,30 @@ namespace tf_executor {
|
|||||||
class GraphOp;
|
class GraphOp;
|
||||||
|
|
||||||
// Returns a pass that folds switch nodes with constant predicates.
|
// Returns a pass that folds switch nodes with constant predicates.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateSwitchFoldPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateSwitchFoldPass();
|
||||||
|
|
||||||
// Creates a pass to merge IslandOps from TFExecutor dialect.
|
// Creates a pass to merge IslandOps from TFExecutor dialect.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorIslandCoarseningPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorIslandCoarseningPass();
|
||||||
|
|
||||||
// Creates a pass to merge IslandOps for operation marked for execution on TPU.
|
// Creates a pass to merge IslandOps for operation marked for execution on TPU.
|
||||||
// This is a V1 backward compatibility.
|
// This is a V1 backward compatibility.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
CreateTFExecutorTPUV1IslandCoarseningPass();
|
CreateTFExecutorTPUV1IslandCoarseningPass();
|
||||||
|
|
||||||
// Creates a pass to outlining TPU clusters from single IslandOp into a nested
|
// Creates a pass to outlining TPU clusters from single IslandOp into a nested
|
||||||
// module suitable for being processed as-if it was a V2 module.
|
// module suitable for being processed as-if it was a V2 module.
|
||||||
// This is a V1 backward compatibility.
|
// This is a V1 backward compatibility.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
CreateTFExecutorTPUV1IslandOutliningPass();
|
CreateTFExecutorTPUV1IslandOutliningPass();
|
||||||
|
|
||||||
// Creates a pass to inline calls to the nested TPU module, this reverses the
|
// Creates a pass to inline calls to the nested TPU module, this reverses the
|
||||||
// effect of the `TFExecutorTPUV1IslandOutlining` pass above.
|
// effect of the `TFExecutorTPUV1IslandOutlining` pass above.
|
||||||
// This is a V1 backward compatibility.
|
// This is a V1 backward compatibility.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTFExecutorTPUV1IslandInliningPass();
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
CreateTFExecutorTPUV1IslandInliningPass();
|
||||||
|
|
||||||
// Creates a pass to prune tf_executor.graph from dead nodes.
|
// Creates a pass to prune tf_executor.graph from dead nodes.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorGraphPruningPass();
|
||||||
|
|
||||||
// Prunes unreachable operations of a tf_executor.graph operation.
|
// Prunes unreachable operations of a tf_executor.graph operation.
|
||||||
void PruneGraph(GraphOp graph);
|
void PruneGraph(GraphOp graph);
|
||||||
@ -168,29 +170,29 @@ void PruneGraph(GraphOp graph);
|
|||||||
// Sink `tf.Const` operations in the LaunchOp region using them. This is
|
// Sink `tf.Const` operations in the LaunchOp region using them. This is
|
||||||
// performed in order to limit the number of values implicitly captured in this
|
// performed in order to limit the number of values implicitly captured in this
|
||||||
// region before outlining.
|
// region before outlining.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorConstantSinkingPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorConstantSinkingPass();
|
||||||
|
|
||||||
} // namespace tf_executor
|
} // namespace tf_executor
|
||||||
|
|
||||||
namespace TFDevice {
|
namespace TFDevice {
|
||||||
// Creates a pass that forms clusters from instructions that are assigned to
|
// Creates a pass that forms clusters from instructions that are assigned to
|
||||||
// same device.
|
// same device.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateClusterFormationPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateClusterFormationPass();
|
||||||
|
|
||||||
// Creates a pass that outlines regions of tf_device.launch operations.
|
// Creates a pass that outlines regions of tf_device.launch operations.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateClusterOutliningPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateClusterOutliningPass();
|
||||||
|
|
||||||
// A pass that decomposes composite resource operations into primitive ones like
|
// A pass that decomposes composite resource operations into primitive ones like
|
||||||
// ReadVariableOp, AssignVariableOp and other computations to facilitate
|
// ReadVariableOp, AssignVariableOp and other computations to facilitate
|
||||||
// transformations like resource op lifting.
|
// transformations like resource op lifting.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDecomposeResourceOpsPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateDecomposeResourceOpsPass();
|
||||||
|
|
||||||
// Creates a pass that lifts operations on external resource variables from
|
// Creates a pass that lifts operations on external resource variables from
|
||||||
// device computation nested in `tf_device::LaunchOp` out so that resource
|
// device computation nested in `tf_device::LaunchOp` out so that resource
|
||||||
// variable load operations are all before device computation while resource
|
// variable load operations are all before device computation while resource
|
||||||
// variable store operations are all after device computation. After this pass,
|
// variable store operations are all after device computation. After this pass,
|
||||||
// device computation no longer interacts with external resource variables.
|
// device computation no longer interacts with external resource variables.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceOpLiftingPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateResourceOpLiftingPass();
|
||||||
|
|
||||||
// Lifts resource operations from tf_device.launch_func ops nested in `op`
|
// Lifts resource operations from tf_device.launch_func ops nested in `op`
|
||||||
// outside. Returns a failure if there are remaining resource-type values that
|
// outside. Returns a failure if there are remaining resource-type values that
|
||||||
@ -198,55 +200,56 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceOpLiftingPass();
|
|||||||
LogicalResult LiftResourceOps(Operation* op);
|
LogicalResult LiftResourceOps(Operation* op);
|
||||||
|
|
||||||
// Creates a pass that hoists invariant operations in a `tf_device.replicate`.
|
// Creates a pass that hoists invariant operations in a `tf_device.replicate`.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateReplicateInvariantOpHoistingPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateReplicateInvariantOpHoistingPass();
|
||||||
|
|
||||||
// Creates a pass that forms replica `tf_executor.island` from a single
|
// Creates a pass that forms replica `tf_executor.island` from a single
|
||||||
// `tf_device.replicate` island.
|
// `tf_device.replicate` island.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateReplicateToIslandPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateReplicateToIslandPass();
|
||||||
|
|
||||||
// Creates a pass that creates `tf_executor.island` from a single
|
// Creates a pass that creates `tf_executor.island` from a single
|
||||||
// `tf_device.parallel_execute` island.
|
// `tf_device.parallel_execute` island.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateParallelExecuteToIslandsPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateParallelExecuteToIslandsPass();
|
||||||
|
|
||||||
// Creates a pass that annotates whether a LaunchFuncOp's parameters have the
|
// Creates a pass that annotates whether a LaunchFuncOp's parameters have the
|
||||||
// same data across replicas.
|
// same data across replicas.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateAnnotateParameterReplicationPass();
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
CreateAnnotateParameterReplicationPass();
|
||||||
|
|
||||||
// Creates a pass that hoists a `tf_device.launch` body and assigns a `device`
|
// Creates a pass that hoists a `tf_device.launch` body and assigns a `device`
|
||||||
// attribute to each TensorFlow dialect op in the body based on the `device`
|
// attribute to each TensorFlow dialect op in the body based on the `device`
|
||||||
// attribute on the `tf_device.launch`.
|
// attribute on the `tf_device.launch`.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLaunchToDeviceAttributePass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateLaunchToDeviceAttributePass();
|
||||||
} // namespace TFDevice
|
} // namespace TFDevice
|
||||||
|
|
||||||
namespace TFTPU {
|
namespace TFTPU {
|
||||||
// Creates a pass that forms clusters from operations of the same
|
// Creates a pass that forms clusters from operations of the same
|
||||||
// `_tpu_replicate` attribute.
|
// `_tpu_replicate` attribute.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUClusterFormationPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateTPUClusterFormationPass();
|
||||||
|
|
||||||
// Creates a pass that allows TPU program inputs to have layouts determined at
|
// Creates a pass that allows TPU program inputs to have layouts determined at
|
||||||
// run time.
|
// run time.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUDynamicLayoutPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateTPUDynamicLayoutPass();
|
||||||
|
|
||||||
// Creates a pass that remaps and assigns padding map from a
|
// Creates a pass that remaps and assigns padding map from a
|
||||||
// `tf_device.launch_func` `padding_map` attribute to its encapsulated function.
|
// `tf_device.launch_func` `padding_map` attribute to its encapsulated function.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUDynamicPaddingMapperPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicPaddingMapperPass();
|
||||||
|
|
||||||
// Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime
|
// Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime
|
||||||
// ops.
|
// ops.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPURewritePass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass();
|
||||||
|
|
||||||
// Creates a pass that identifies XLASharding ops in launch op for TPU
|
// Creates a pass that identifies XLASharding ops in launch op for TPU
|
||||||
// computation.
|
// computation.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUShardingIdentificationPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUShardingIdentificationPass();
|
||||||
|
|
||||||
// Creates a pass that merges device variable reads/updates into the surrounded
|
// Creates a pass that merges device variable reads/updates into the surrounded
|
||||||
// TPUExecute node. This allows the execute node to perform in-place variable
|
// TPUExecute node. This allows the execute node to perform in-place variable
|
||||||
// updates.
|
// updates.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUMergeVariablesWithExecutePass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateTPUMergeVariablesWithExecutePass();
|
||||||
|
|
||||||
// Creates a pass that adds ops which perform formatting on variables at
|
// Creates a pass that adds ops which perform formatting on variables at
|
||||||
// run-time according to compilation result.
|
// run-time according to compilation result.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUVariableReformattingPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass();
|
||||||
|
|
||||||
// Populates the supplied passmanager with the passes required to run the
|
// Populates the supplied passmanager with the passes required to run the
|
||||||
void CreateTPUBridgePipeline(OpPassManager& pm);
|
void CreateTPUBridgePipeline(OpPassManager& pm);
|
||||||
@ -260,16 +263,16 @@ void CreateTPUBridgePipelineV1(OpPassManager& pm);
|
|||||||
namespace tf_saved_model {
|
namespace tf_saved_model {
|
||||||
|
|
||||||
// Creates a pass that optimizes tf_saved_model.global_tensor ops.
|
// Creates a pass that optimizes tf_saved_model.global_tensor ops.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeGlobalTensorsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass();
|
||||||
|
|
||||||
// Creates a pass that freezes tf_saved_model.global_tensor ops.
|
// Creates a pass that freezes tf_saved_model.global_tensor ops.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateFreezeGlobalTensorsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass();
|
||||||
|
|
||||||
// Creates a pass that uses tf_saved_model dialect linkage information
|
// Creates a pass that uses tf_saved_model dialect linkage information
|
||||||
// to mark function visibility. That is, exported functions are marked with
|
// to mark function visibility. That is, exported functions are marked with
|
||||||
// public visibility while the other functions are marked with private
|
// public visibility while the other functions are marked with private
|
||||||
// visibility.
|
// visibility.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass();
|
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass();
|
||||||
|
|
||||||
} // namespace tf_saved_model
|
} // namespace tf_saved_model
|
||||||
|
|||||||
@ -258,7 +258,7 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class PromoteResourcesToArgsPass
|
class PromoteResourcesToArgsPass
|
||||||
: public OperationPass<PromoteResourcesToArgsPass, ModuleOp> {
|
: public PassWrapper<PromoteResourcesToArgsPass, OperationPass<ModuleOp>> {
|
||||||
public:
|
public:
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
@ -285,7 +285,7 @@ void PromoteResourcesToArgsPass::runOnOperation() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreatePromoteResourcesToArgsPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass() {
|
||||||
return std::make_unique<PromoteResourcesToArgsPass>();
|
return std::make_unique<PromoteResourcesToArgsPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -32,7 +32,8 @@ namespace mlir {
|
|||||||
namespace TFControlFlow {
|
namespace TFControlFlow {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct RaiseTFControlFlow : public FunctionPass<RaiseTFControlFlow> {
|
struct RaiseTFControlFlow
|
||||||
|
: public PassWrapper<RaiseTFControlFlow, FunctionPass> {
|
||||||
void runOnFunction() {
|
void runOnFunction() {
|
||||||
// First start by recognizing loops and reconstructing a loop tree.
|
// First start by recognizing loops and reconstructing a loop tree.
|
||||||
buildLoopNests();
|
buildLoopNests();
|
||||||
@ -145,7 +146,7 @@ void RaiseTFControlFlow::rewriteOps() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateRaiseTFControlFlowPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseTFControlFlowPass() {
|
||||||
return std::make_unique<RaiseTFControlFlow>();
|
return std::make_unique<RaiseTFControlFlow>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,7 @@ namespace {
|
|||||||
constexpr char kDeviceAttr[] = "device";
|
constexpr char kDeviceAttr[] = "device";
|
||||||
|
|
||||||
struct ReplicateInvariantOpHoistingPass
|
struct ReplicateInvariantOpHoistingPass
|
||||||
: public FunctionPass<ReplicateInvariantOpHoistingPass> {
|
: public PassWrapper<ReplicateInvariantOpHoistingPass, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -178,7 +178,8 @@ void ReplicateInvariantOpHoistingPass::runOnFunction() {
|
|||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateReplicateInvariantOpHoistingPass() {
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
CreateReplicateInvariantOpHoistingPass() {
|
||||||
return std::make_unique<ReplicateInvariantOpHoistingPass>();
|
return std::make_unique<ReplicateInvariantOpHoistingPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -43,7 +43,8 @@ namespace TFDevice {
|
|||||||
namespace {
|
namespace {
|
||||||
constexpr char kDeviceAttr[] = "device";
|
constexpr char kDeviceAttr[] = "device";
|
||||||
|
|
||||||
struct ReplicateToIslandPass : public FunctionPass<ReplicateToIslandPass> {
|
struct ReplicateToIslandPass
|
||||||
|
: public PassWrapper<ReplicateToIslandPass, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -237,7 +238,7 @@ void ReplicateToIslandPass::runOnFunction() {
|
|||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateReplicateToIslandPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateReplicateToIslandPass() {
|
||||||
return std::make_unique<ReplicateToIslandPass>();
|
return std::make_unique<ReplicateToIslandPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -54,7 +54,7 @@ constexpr char kFuncDeviceAttr[] = "tf.device";
|
|||||||
// This pass changes the module by adding "tf.device" attribute to function
|
// This pass changes the module by adding "tf.device" attribute to function
|
||||||
// arguments and adding "device" attribute to TF ops.
|
// arguments and adding "device" attribute to TF ops.
|
||||||
struct ResourceDeviceInference
|
struct ResourceDeviceInference
|
||||||
: public OperationPass<ResourceDeviceInference, ModuleOp> {
|
: public PassWrapper<ResourceDeviceInference, OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -266,7 +266,7 @@ void ResourceDeviceInference::runOnOperation() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceDeviceInferencePass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass() {
|
||||||
return std::make_unique<ResourceDeviceInference>();
|
return std::make_unique<ResourceDeviceInference>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -132,7 +132,7 @@ namespace {
|
|||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
struct ResourceOpLiftingPass
|
struct ResourceOpLiftingPass
|
||||||
: public OperationPass<ResourceOpLiftingPass, ModuleOp> {
|
: public PassWrapper<ResourceOpLiftingPass, OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1071,7 +1071,8 @@ void ResourceOpLiftingPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct ResourceOpLiftingForMainFunctionPass
|
struct ResourceOpLiftingForMainFunctionPass
|
||||||
: public OperationPass<ResourceOpLiftingForMainFunctionPass, ModuleOp> {
|
: public PassWrapper<ResourceOpLiftingForMainFunctionPass,
|
||||||
|
OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1100,7 +1101,7 @@ static PassRegistration<ResourceOpLiftingPass> pass(
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace TFDevice {
|
namespace TFDevice {
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceOpLiftingPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateResourceOpLiftingPass() {
|
||||||
return std::make_unique<ResourceOpLiftingPass>();
|
return std::make_unique<ResourceOpLiftingPass>();
|
||||||
}
|
}
|
||||||
} // namespace TFDevice
|
} // namespace TFDevice
|
||||||
|
|||||||
@ -47,7 +47,8 @@ namespace {
|
|||||||
|
|
||||||
// This transformation pass propagate shapes on the TensorFlow graph.
|
// This transformation pass propagate shapes on the TensorFlow graph.
|
||||||
// It is a ModulePass in order to be able to change function types.
|
// It is a ModulePass in order to be able to change function types.
|
||||||
struct ShapeInference : public OperationPass<ShapeInference, ModuleOp> {
|
struct ShapeInference
|
||||||
|
: public PassWrapper<ShapeInference, OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
auto producer_or = tensorflow::GetTfGraphProducerVersion(module);
|
auto producer_or = tensorflow::GetTfGraphProducerVersion(module);
|
||||||
@ -70,7 +71,7 @@ PassRegistration<ShapeInference> pass(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTFShapeInferencePass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTFShapeInferencePass() {
|
||||||
return std::make_unique<ShapeInference>();
|
return std::make_unique<ShapeInference>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -39,7 +39,7 @@ namespace {
|
|||||||
using ::mlir::TF::ConstOp;
|
using ::mlir::TF::ConstOp;
|
||||||
|
|
||||||
class ExecutorConstantSinking
|
class ExecutorConstantSinking
|
||||||
: public mlir::FunctionPass<ExecutorConstantSinking> {
|
: public mlir::PassWrapper<ExecutorConstantSinking, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
getFunction().walk([](tf_device::LaunchOp launch) {
|
getFunction().walk([](tf_device::LaunchOp launch) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Visit " << *launch.getOperation() << "\n");
|
LLVM_DEBUG(llvm::dbgs() << "Visit " << *launch.getOperation() << "\n");
|
||||||
@ -89,7 +89,7 @@ static mlir::PassRegistration<ExecutorConstantSinking> pass(
|
|||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorConstantSinkingPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorConstantSinkingPass() {
|
||||||
return std::make_unique<ExecutorConstantSinking>();
|
return std::make_unique<ExecutorConstantSinking>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -85,7 +85,7 @@ namespace cutil = TF::collection_ops_util;
|
|||||||
//
|
//
|
||||||
// The pass also works across control flow and functional calls.
|
// The pass also works across control flow and functional calls.
|
||||||
struct StackOpsDecompositionPass
|
struct StackOpsDecompositionPass
|
||||||
: public OperationPass<StackOpsDecompositionPass, ModuleOp> {
|
: public PassWrapper<StackOpsDecompositionPass, OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -568,7 +568,7 @@ static PassRegistration<StackOpsDecompositionPass> pass(
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace TF {
|
namespace TF {
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateStackOpsDecompositionPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateStackOpsDecompositionPass() {
|
||||||
return std::make_unique<StackOpsDecompositionPass>();
|
return std::make_unique<StackOpsDecompositionPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -68,7 +68,8 @@ using std::string;
|
|||||||
// shape.
|
// shape.
|
||||||
//
|
//
|
||||||
struct TensorArrayOpsDecompositionPass
|
struct TensorArrayOpsDecompositionPass
|
||||||
: public OperationPass<TensorArrayOpsDecompositionPass, ModuleOp> {
|
: public PassWrapper<TensorArrayOpsDecompositionPass,
|
||||||
|
OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -893,7 +894,8 @@ static PassRegistration<TensorArrayOpsDecompositionPass> pass(
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace TF {
|
namespace TF {
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTensorArrayOpsDecompositionPass() {
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
CreateTensorArrayOpsDecompositionPass() {
|
||||||
return std::make_unique<TensorArrayOpsDecompositionPass>();
|
return std::make_unique<TensorArrayOpsDecompositionPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -62,7 +62,8 @@ namespace cutil = TF::collection_ops_util;
|
|||||||
//
|
//
|
||||||
// The pass also works across control flow and functional calls.
|
// The pass also works across control flow and functional calls.
|
||||||
struct TensorListOpsDecompositionPass
|
struct TensorListOpsDecompositionPass
|
||||||
: public OperationPass<TensorListOpsDecompositionPass, ModuleOp> {
|
: public PassWrapper<TensorListOpsDecompositionPass,
|
||||||
|
OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -728,7 +729,8 @@ static PassRegistration<TensorListOpsDecompositionPass> pass(
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace TF {
|
namespace TF {
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTensorListOpsDecompositionPass() {
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
CreateTensorListOpsDecompositionPass() {
|
||||||
return std::make_unique<TensorListOpsDecompositionPass>();
|
return std::make_unique<TensorListOpsDecompositionPass>();
|
||||||
}
|
}
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
|
|||||||
@ -39,7 +39,7 @@ namespace {
|
|||||||
// A pass that adds "Predecessors" and "Successors" remarks for each op based on
|
// A pass that adds "Predecessors" and "Successors" remarks for each op based on
|
||||||
// SideEffectAnalysis result. For testing purpose only.
|
// SideEffectAnalysis result. For testing purpose only.
|
||||||
struct TestSideEffectAnalysis
|
struct TestSideEffectAnalysis
|
||||||
: public mlir::FunctionPass<TestSideEffectAnalysis> {
|
: public mlir::PassWrapper<TestSideEffectAnalysis, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
int64_t next_id = 0;
|
int64_t next_id = 0;
|
||||||
llvm::SmallDenseMap<Operation*, int64_t, 8> ids;
|
llvm::SmallDenseMap<Operation*, int64_t, 8> ids;
|
||||||
|
|||||||
@ -24,7 +24,7 @@ namespace TF {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class SimpleTFDeviceAssignmentPass
|
class SimpleTFDeviceAssignmentPass
|
||||||
: public FunctionPass<SimpleTFDeviceAssignmentPass> {
|
: public PassWrapper<SimpleTFDeviceAssignmentPass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
SimpleTFDeviceAssignmentPass() = default;
|
SimpleTFDeviceAssignmentPass() = default;
|
||||||
SimpleTFDeviceAssignmentPass(const SimpleTFDeviceAssignmentPass&) {}
|
SimpleTFDeviceAssignmentPass(const SimpleTFDeviceAssignmentPass&) {}
|
||||||
@ -57,7 +57,7 @@ class SimpleTFDeviceAssignmentPass
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
|
std::unique_ptr<OperationPass<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
|
||||||
llvm::StringRef default_device) {
|
llvm::StringRef default_device) {
|
||||||
return std::make_unique<SimpleTFDeviceAssignmentPass>(default_device);
|
return std::make_unique<SimpleTFDeviceAssignmentPass>(default_device);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -40,7 +40,9 @@ namespace tensorflow {
|
|||||||
// Optimization Passes and convert back to MLIR.
|
// Optimization Passes and convert back to MLIR.
|
||||||
// Constraints: This pass expects that all operations in the MLIR module either
|
// Constraints: This pass expects that all operations in the MLIR module either
|
||||||
// belong to 'tf' or '_tf' dialect. The output is in '_tf' dialect.
|
// belong to 'tf' or '_tf' dialect. The output is in '_tf' dialect.
|
||||||
class GraphOptPass : public mlir::OperationPass<GraphOptPass, mlir::ModuleOp> {
|
class GraphOptPass
|
||||||
|
: public mlir::PassWrapper<GraphOptPass,
|
||||||
|
mlir::OperationPass<mlir::ModuleOp>> {
|
||||||
public:
|
public:
|
||||||
explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes)
|
explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes)
|
||||||
: passes_(std::move(passes)) {}
|
: passes_(std::move(passes)) {}
|
||||||
@ -166,13 +168,13 @@ class GraphOptByNamePass : public GraphOptPass {
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
|
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
||||||
tensorflow::CreateTensorFlowGraphOptimizationPass(
|
tensorflow::CreateTensorFlowGraphOptimizationPass(
|
||||||
std::vector<tensorflow::GraphOptimizationPass*> tf_passes) {
|
std::vector<tensorflow::GraphOptimizationPass*> tf_passes) {
|
||||||
return std::make_unique<GraphOptPass>(std::move(tf_passes));
|
return std::make_unique<GraphOptPass>(std::move(tf_passes));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
|
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
||||||
tensorflow::CreateTensorFlowGraphOptimizationPass(
|
tensorflow::CreateTensorFlowGraphOptimizationPass(
|
||||||
const std::vector<std::string>& pass_names) {
|
const std::vector<std::string>& pass_names) {
|
||||||
return std::make_unique<GraphOptByNamePass>(pass_names);
|
return std::make_unique<GraphOptByNamePass>(pass_names);
|
||||||
|
|||||||
@ -24,7 +24,7 @@ namespace tensorflow {
|
|||||||
// Create a module pass that will execute the given TF GraphOptimization passes
|
// Create a module pass that will execute the given TF GraphOptimization passes
|
||||||
// in sequence.
|
// in sequence.
|
||||||
// Pass requires that the module ran on is convertible to TF Graph.
|
// Pass requires that the module ran on is convertible to TF Graph.
|
||||||
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
|
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
||||||
CreateTensorFlowGraphOptimizationPass(
|
CreateTensorFlowGraphOptimizationPass(
|
||||||
std::vector<tensorflow::GraphOptimizationPass*> tf_passes);
|
std::vector<tensorflow::GraphOptimizationPass*> tf_passes);
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ CreateTensorFlowGraphOptimizationPass(
|
|||||||
// passes are queried, if a TF graph optimization pass is not found in registry
|
// passes are queried, if a TF graph optimization pass is not found in registry
|
||||||
// then the pass fails.
|
// then the pass fails.
|
||||||
// Pass requires that the module ran on is convertible to TF Graph.
|
// Pass requires that the module ran on is convertible to TF Graph.
|
||||||
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
|
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
||||||
CreateTensorFlowGraphOptimizationPass(
|
CreateTensorFlowGraphOptimizationPass(
|
||||||
const std::vector<std::string>& pass_names);
|
const std::vector<std::string>& pass_names);
|
||||||
|
|
||||||
|
|||||||
@ -71,7 +71,8 @@ using MetadataMap = llvm::SmallDenseMap<llvm::StringRef, NamedAttributeList, 8>;
|
|||||||
using ClusterMap = llvm::SmallDenseMap<llvm::StringRef,
|
using ClusterMap = llvm::SmallDenseMap<llvm::StringRef,
|
||||||
llvm::SmallSetVector<Operation*, 8>, 8>;
|
llvm::SmallSetVector<Operation*, 8>, 8>;
|
||||||
|
|
||||||
struct TPUClusterFormation : public FunctionPass<TPUClusterFormation> {
|
struct TPUClusterFormation
|
||||||
|
: public PassWrapper<TPUClusterFormation, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -502,7 +503,7 @@ void TPUClusterFormation::runOnFunction() {
|
|||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUClusterFormationPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateTPUClusterFormationPass() {
|
||||||
return std::make_unique<TPUClusterFormation>();
|
return std::make_unique<TPUClusterFormation>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -73,7 +73,8 @@ constexpr char kDeviceAttr[] = "device";
|
|||||||
// %copy_to_device. There will not be send/recv ops added by later passes,
|
// %copy_to_device. There will not be send/recv ops added by later passes,
|
||||||
// because tf.TPUCopyWithLayout accepts a host input and produces a device
|
// because tf.TPUCopyWithLayout accepts a host input and produces a device
|
||||||
// output.
|
// output.
|
||||||
struct TPUDynamicLayoutPass : public FunctionPass<TPUDynamicLayoutPass> {
|
struct TPUDynamicLayoutPass
|
||||||
|
: public PassWrapper<TPUDynamicLayoutPass, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -256,7 +257,7 @@ void TPUDynamicLayoutPass::runOnFunction() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUDynamicLayoutPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateTPUDynamicLayoutPass() {
|
||||||
return std::make_unique<TPUDynamicLayoutPass>();
|
return std::make_unique<TPUDynamicLayoutPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ constexpr char kPaddingMapAttr[] = "padding_map";
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct TPUDynamicPaddingMapper
|
struct TPUDynamicPaddingMapper
|
||||||
: public OperationPass<TPUDynamicPaddingMapper, ModuleOp> {
|
: public PassWrapper<TPUDynamicPaddingMapper, OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -200,7 +200,7 @@ void TPUDynamicPaddingMapper::runOnOperation() {
|
|||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUDynamicPaddingMapperPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicPaddingMapperPass() {
|
||||||
return std::make_unique<TPUDynamicPaddingMapper>();
|
return std::make_unique<TPUDynamicPaddingMapper>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -75,7 +75,7 @@ constexpr char kFuncDeviceAttr[] = "tf.device";
|
|||||||
// the TPUExecute op.
|
// the TPUExecute op.
|
||||||
|
|
||||||
struct TPUMergeVariablesWithExecutePass
|
struct TPUMergeVariablesWithExecutePass
|
||||||
: public FunctionPass<TPUMergeVariablesWithExecutePass> {
|
: public PassWrapper<TPUMergeVariablesWithExecutePass, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -531,7 +531,8 @@ void TPUMergeVariablesWithExecutePass::runOnFunction() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUMergeVariablesWithExecutePass() {
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
CreateTPUMergeVariablesWithExecutePass() {
|
||||||
return std::make_unique<TPUMergeVariablesWithExecutePass>();
|
return std::make_unique<TPUMergeVariablesWithExecutePass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -98,7 +98,8 @@ constexpr char kBadArrayAttrLengthMsg[] =
|
|||||||
// %4 = "tf.SomeOp"(%3)
|
// %4 = "tf.SomeOp"(%3)
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct TPURewritePass : public OperationPass<TPURewritePass, ModuleOp> {
|
struct TPURewritePass
|
||||||
|
: public PassWrapper<TPURewritePass, OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -770,7 +771,7 @@ void TPURewritePass::runOnOperation() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPURewritePass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass() {
|
||||||
return std::make_unique<TPURewritePass>();
|
return std::make_unique<TPURewritePass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -40,7 +40,8 @@ namespace {
|
|||||||
constexpr char kShardingAttr[] = "xla_hlo.sharding";
|
constexpr char kShardingAttr[] = "xla_hlo.sharding";
|
||||||
|
|
||||||
struct TPUShardingIdentificationPass
|
struct TPUShardingIdentificationPass
|
||||||
: public OperationPass<TPUShardingIdentificationPass, ModuleOp> {
|
: public PassWrapper<TPUShardingIdentificationPass,
|
||||||
|
OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -185,7 +186,7 @@ void TPUShardingIdentificationPass::runOnOperation() {
|
|||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUShardingIdentificationPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUShardingIdentificationPass() {
|
||||||
return std::make_unique<TPUShardingIdentificationPass>();
|
return std::make_unique<TPUShardingIdentificationPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -116,7 +116,8 @@ std::string GetRandomStateVariableName() {
|
|||||||
// tf.TPUReshardVariablesOp(%rvar, %default_format, %rstate)
|
// tf.TPUReshardVariablesOp(%rvar, %default_format, %rstate)
|
||||||
// }
|
// }
|
||||||
struct TPUVariableRuntimeReformattingPass
|
struct TPUVariableRuntimeReformattingPass
|
||||||
: public OperationPass<TPUVariableRuntimeReformattingPass, ModuleOp> {
|
: public PassWrapper<TPUVariableRuntimeReformattingPass,
|
||||||
|
OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -575,7 +576,7 @@ void TPUVariableRuntimeReformattingPass::runOnOperation() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUVariableReformattingPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass() {
|
||||||
return std::make_unique<TPUVariableRuntimeReformattingPass>();
|
return std::make_unique<TPUVariableRuntimeReformattingPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -44,7 +44,8 @@ namespace {
|
|||||||
// Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
|
// Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
|
||||||
// of the inputs, matmul them individually, then stack them all back together at
|
// of the inputs, matmul them individually, then stack them all back together at
|
||||||
// the end.
|
// the end.
|
||||||
struct UnrollBatchMatMulPass : public FunctionPass<UnrollBatchMatMulPass> {
|
struct UnrollBatchMatMulPass
|
||||||
|
: public PassWrapper<UnrollBatchMatMulPass, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -309,7 +310,7 @@ static PassRegistration<UnrollBatchMatMulPass> pass(
|
|||||||
"tf-unroll-batch-matmul",
|
"tf-unroll-batch-matmul",
|
||||||
"Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.");
|
"Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.");
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateUnrollBatchMatMulPassPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateUnrollBatchMatMulPassPass() {
|
||||||
return std::make_unique<UnrollBatchMatMulPass>();
|
return std::make_unique<UnrollBatchMatMulPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -42,7 +42,7 @@ namespace mlir {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct BreakUpIslands : FunctionPass<BreakUpIslands> {
|
struct BreakUpIslands : PassWrapper<BreakUpIslands, FunctionPass> {
|
||||||
void runOnFunction() final;
|
void runOnFunction() final;
|
||||||
|
|
||||||
void BreakUpIsland(tf_executor::IslandOp island_op,
|
void BreakUpIsland(tf_executor::IslandOp island_op,
|
||||||
@ -325,7 +325,7 @@ void BreakUpIslands::BreakUpIsland(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateBreakUpIslandsPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateBreakUpIslandsPass() {
|
||||||
return std::make_unique<BreakUpIslands>();
|
return std::make_unique<BreakUpIslands>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -45,7 +45,7 @@ namespace {
|
|||||||
// otherwise _tf operations are wrapped in an island and the _ prefix is
|
// otherwise _tf operations are wrapped in an island and the _ prefix is
|
||||||
// removed. Control dependencies are moved to be handled by the island itself.
|
// removed. Control dependencies are moved to be handled by the island itself.
|
||||||
struct ControlToExecutorDialectConversion
|
struct ControlToExecutorDialectConversion
|
||||||
: public FunctionPass<ControlToExecutorDialectConversion> {
|
: public PassWrapper<ControlToExecutorDialectConversion, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -237,7 +237,7 @@ void ControlToExecutorDialectConversion::runOnFunction() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
OpPassBase<FuncOp> *CreateTFControlToExecutorDialectConversion() {
|
OperationPass<FuncOp> *CreateTFControlToExecutorDialectConversion() {
|
||||||
return new ControlToExecutorDialectConversion();
|
return new ControlToExecutorDialectConversion();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -39,7 +39,7 @@ namespace mlir {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct ExecutorToControlDialectConversion
|
struct ExecutorToControlDialectConversion
|
||||||
: public FunctionPass<ExecutorToControlDialectConversion> {
|
: public PassWrapper<ExecutorToControlDialectConversion, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
@ -230,7 +230,7 @@ void ExecutorToControlDialectConversion::runOnFunction() {
|
|||||||
graph.erase();
|
graph.erase();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
CreateTFExecutorToControlDialectConversion() {
|
CreateTFExecutorToControlDialectConversion() {
|
||||||
return std::make_unique<ExecutorToControlDialectConversion>();
|
return std::make_unique<ExecutorToControlDialectConversion>();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -40,7 +40,7 @@ namespace {
|
|||||||
// return %graph_results#...
|
// return %graph_results#...
|
||||||
// }
|
// }
|
||||||
struct FunctionalToExecutorDialectConversion
|
struct FunctionalToExecutorDialectConversion
|
||||||
: public FunctionPass<FunctionalToExecutorDialectConversion> {
|
: public PassWrapper<FunctionalToExecutorDialectConversion, FunctionPass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
@ -95,7 +95,7 @@ void FunctionalToExecutorDialectConversion::runOnFunction() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
CreateFunctionalToExecutorDialectConversionPass() {
|
CreateFunctionalToExecutorDialectConversionPass() {
|
||||||
return std::make_unique<FunctionalToExecutorDialectConversion>();
|
return std::make_unique<FunctionalToExecutorDialectConversion>();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -343,7 +343,8 @@ class BufferAssignmentAnalysis {
|
|||||||
/// the right positions. It uses the algorithm described at the top of the file.
|
/// the right positions. It uses the algorithm described at the top of the file.
|
||||||
// TODO(dfki): create a templated version that allows to match dialect-specific
|
// TODO(dfki): create a templated version that allows to match dialect-specific
|
||||||
// alloc/dealloc nodes and to insert dialect-specific dealloc node.
|
// alloc/dealloc nodes and to insert dialect-specific dealloc node.
|
||||||
struct BufferAssignmentPass : mlir::FunctionPass<BufferAssignmentPass> {
|
struct BufferAssignmentPass
|
||||||
|
: mlir::PassWrapper<BufferAssignmentPass, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
// Get required analysis information first.
|
// Get required analysis information first.
|
||||||
auto& analysis = getAnalysis<BufferAssignmentAnalysis>();
|
auto& analysis = getAnalysis<BufferAssignmentAnalysis>();
|
||||||
@ -471,7 +472,7 @@ void FunctionAndBlockSignatureConverter::addDynamicallyLegalFuncOp(
|
|||||||
// Buffer assignment pass registrations
|
// Buffer assignment pass registrations
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> createBufferAssignmentPass() {
|
std::unique_ptr<OperationPass<FuncOp>> createBufferAssignmentPass() {
|
||||||
return absl::make_unique<BufferAssignmentPass>();
|
return absl::make_unique<BufferAssignmentPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -482,14 +483,15 @@ static PassRegistration<BufferAssignmentPass> buffer_assignment_pass(
|
|||||||
|
|
||||||
/// A simple pass to print debug/test information for the buffer assignment
|
/// A simple pass to print debug/test information for the buffer assignment
|
||||||
/// analysis.
|
/// analysis.
|
||||||
struct BufferAssignmentTestPass : mlir::FunctionPass<BufferAssignmentTestPass> {
|
struct BufferAssignmentTestPass
|
||||||
|
: mlir::PassWrapper<BufferAssignmentTestPass, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
llvm::outs() << "Testing : " << getFunction().getName() << "\n";
|
llvm::outs() << "Testing : " << getFunction().getName() << "\n";
|
||||||
getAnalysis<BufferAssignmentAnalysis>().print(llvm::outs());
|
getAnalysis<BufferAssignmentAnalysis>().print(llvm::outs());
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> createBufferAssignmentTestPass() {
|
std::unique_ptr<OperationPass<FuncOp>> createBufferAssignmentTestPass() {
|
||||||
return absl::make_unique<BufferAssignmentTestPass>();
|
return absl::make_unique<BufferAssignmentTestPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -324,7 +324,8 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
|
|||||||
// "xla_lhlo.terminator"() : () -> ()
|
// "xla_lhlo.terminator"() : () -> ()
|
||||||
// }
|
// }
|
||||||
|
|
||||||
struct HloLegalizeToLhlo : public OperationPass<HloLegalizeToLhlo, ModuleOp> {
|
struct HloLegalizeToLhlo
|
||||||
|
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto& context = getContext();
|
auto& context = getContext();
|
||||||
@ -473,7 +474,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
|||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
|
||||||
return absl::make_unique<HloLegalizeToLhlo>();
|
return absl::make_unique<HloLegalizeToLhlo>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,8 @@ using mlir::PassRegistration;
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace xla_hlo {
|
namespace xla_hlo {
|
||||||
namespace {
|
namespace {
|
||||||
struct LegalizeControlFlow : public mlir::FunctionPass<LegalizeControlFlow> {
|
struct LegalizeControlFlow
|
||||||
|
: public mlir::PassWrapper<LegalizeControlFlow, FunctionPass> {
|
||||||
// Perform the lowering to MLIR control flow.
|
// Perform the lowering to MLIR control flow.
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
@ -227,7 +228,7 @@ void LegalizeControlFlow::runOnFunction() {
|
|||||||
} // namespace xla_hlo
|
} // namespace xla_hlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
std::unique_ptr<mlir::OpPassBase<mlir::FuncOp>>
|
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
|
||||||
mlir::xla_hlo::createLegalizeControlFlowPass() {
|
mlir::xla_hlo::createLegalizeControlFlowPass() {
|
||||||
return std::make_unique<LegalizeControlFlow>();
|
return std::make_unique<LegalizeControlFlow>();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -55,7 +55,7 @@ namespace mlir {
|
|||||||
namespace xla_hlo {
|
namespace xla_hlo {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class LegalizeTF : public FunctionPass<LegalizeTF> {
|
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
LegalizeTF() = default;
|
LegalizeTF() = default;
|
||||||
LegalizeTF(const LegalizeTF &) {}
|
LegalizeTF(const LegalizeTF &) {}
|
||||||
@ -3829,7 +3829,7 @@ static PassRegistration<LegalizeTF> pass(
|
|||||||
|
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeTFPass(
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
|
||||||
bool allow_partial_conversion) {
|
bool allow_partial_conversion) {
|
||||||
return std::make_unique<LegalizeTF>(allow_partial_conversion);
|
return std::make_unique<LegalizeTF>(allow_partial_conversion);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -52,13 +52,13 @@ namespace mlir {
|
|||||||
namespace xla_hlo {
|
namespace xla_hlo {
|
||||||
namespace {
|
namespace {
|
||||||
class LegalizeTFControlFlow
|
class LegalizeTFControlFlow
|
||||||
: public OperationPass<LegalizeTFControlFlow, ModuleOp> {
|
: public PassWrapper<LegalizeTFControlFlow, OperationPass<ModuleOp>> {
|
||||||
public:
|
public:
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
|
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
||||||
createLegalizeTFControlFlowPass() {
|
createLegalizeTFControlFlowPass() {
|
||||||
return std::make_unique<LegalizeTFControlFlow>();
|
return std::make_unique<LegalizeTFControlFlow>();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -333,7 +333,7 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
class LegalizeTF : public FunctionPass<LegalizeTF> {
|
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
LegalizeTF() = default;
|
LegalizeTF() = default;
|
||||||
|
|
||||||
|
|||||||
@ -177,13 +177,14 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
|||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct LegalizeToStandard : public FunctionPass<LegalizeToStandard> {
|
struct LegalizeToStandard
|
||||||
|
: public PassWrapper<LegalizeToStandard, FunctionPass> {
|
||||||
/// Perform the lowering to Standard dialect.
|
/// Perform the lowering to Standard dialect.
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
std::unique_ptr<mlir::OpPassBase<mlir::FuncOp>> createLegalizeToStdPass() {
|
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
|
||||||
return std::make_unique<LegalizeToStandard>();
|
return std::make_unique<LegalizeToStandard>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -30,7 +30,7 @@ namespace {
|
|||||||
// arguments. All uses of each buffer are replaced with the corresponding block
|
// arguments. All uses of each buffer are replaced with the corresponding block
|
||||||
// argument and the buffer is freed. Note that this pass only works in regions
|
// argument and the buffer is freed. Note that this pass only works in regions
|
||||||
// with a single block.
|
// with a single block.
|
||||||
struct LhloCopyRemoval : mlir::OperationPass<LhloCopyRemoval> {
|
struct LhloCopyRemoval : mlir::PassWrapper<LhloCopyRemoval, OperationPass<>> {
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
llvm::SmallVector<mlir::Operation*, 2> eraseList;
|
llvm::SmallVector<mlir::Operation*, 2> eraseList;
|
||||||
auto operation = getOperation();
|
auto operation = getOperation();
|
||||||
|
|||||||
@ -30,7 +30,7 @@ namespace {
|
|||||||
|
|
||||||
using linalg::LinalgOp;
|
using linalg::LinalgOp;
|
||||||
|
|
||||||
class LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
|
class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
LhloFuseLinalg() = default;
|
LhloFuseLinalg() = default;
|
||||||
LhloFuseLinalg(const LhloFuseLinalg&) {}
|
LhloFuseLinalg(const LhloFuseLinalg&) {}
|
||||||
@ -123,7 +123,7 @@ class LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> createLhloFuseLinalg(
|
std::unique_ptr<OperationPass<FuncOp>> createLhloFuseLinalg(
|
||||||
bool use_parallel_loops, ArrayRef<unsigned> tile_sizes) {
|
bool use_parallel_loops, ArrayRef<unsigned> tile_sizes) {
|
||||||
return absl::make_unique<LhloFuseLinalg>(use_parallel_loops, tile_sizes);
|
return absl::make_unique<LhloFuseLinalg>(use_parallel_loops, tile_sizes);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -81,7 +81,8 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
|
|||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LhloLegalizeToAffine : public FunctionPass<LhloLegalizeToAffine> {
|
struct LhloLegalizeToAffine
|
||||||
|
: public PassWrapper<LhloLegalizeToAffine, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
@ -92,7 +93,7 @@ struct LhloLegalizeToAffine : public FunctionPass<LhloLegalizeToAffine> {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeToAffinePass() {
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass() {
|
||||||
return absl::make_unique<LhloLegalizeToAffine>();
|
return absl::make_unique<LhloLegalizeToAffine>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -168,7 +168,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LhloLegalizeToGpu : public FunctionPass<LhloLegalizeToGpu> {
|
struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
@ -186,7 +186,7 @@ struct LhloLegalizeToGpu : public FunctionPass<LhloLegalizeToGpu> {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeToGpuPass() {
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass() {
|
||||||
return absl::make_unique<LhloLegalizeToGpu>();
|
return absl::make_unique<LhloLegalizeToGpu>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -452,7 +452,7 @@ class ReduceWindowOpConverter
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct LhloLegalizeToParallelLoops
|
struct LhloLegalizeToParallelLoops
|
||||||
: public FunctionPass<LhloLegalizeToParallelLoops> {
|
: public PassWrapper<LhloLegalizeToParallelLoops, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
|
|
||||||
@ -478,7 +478,7 @@ struct LhloLegalizeToParallelLoops
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeLhloToParallelLoopsPass() {
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass() {
|
||||||
return absl::make_unique<LhloLegalizeToParallelLoops>();
|
return absl::make_unique<LhloLegalizeToParallelLoops>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -38,11 +38,12 @@ limitations under the License.
|
|||||||
using mlir::FunctionPass;
|
using mlir::FunctionPass;
|
||||||
using mlir::OwningRewritePatternList;
|
using mlir::OwningRewritePatternList;
|
||||||
using mlir::PassRegistration;
|
using mlir::PassRegistration;
|
||||||
|
using mlir::PassWrapper;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class LowerComplex : public FunctionPass<LowerComplex> {
|
class LowerComplex : public PassWrapper<LowerComplex, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
explicit LowerComplex() : FunctionPass<LowerComplex>() {}
|
explicit LowerComplex() : PassWrapper<LowerComplex, FunctionPass>() {}
|
||||||
|
|
||||||
/// Performs the lowering to XLA dialect.
|
/// Performs the lowering to XLA dialect.
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
|
|||||||
@ -39,6 +39,7 @@ using mlir::MLIRContext;
|
|||||||
using mlir::OpRewritePattern;
|
using mlir::OpRewritePattern;
|
||||||
using mlir::OwningRewritePatternList;
|
using mlir::OwningRewritePatternList;
|
||||||
using mlir::PassRegistration;
|
using mlir::PassRegistration;
|
||||||
|
using mlir::PassWrapper;
|
||||||
using mlir::PatternRewriter;
|
using mlir::PatternRewriter;
|
||||||
using mlir::RankedTensorType;
|
using mlir::RankedTensorType;
|
||||||
using mlir::success;
|
using mlir::success;
|
||||||
@ -170,7 +171,8 @@ struct GeneralDotConvert
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LegalizeGeneralDot : public FunctionPass<LegalizeGeneralDot> {
|
struct LegalizeGeneralDot
|
||||||
|
: public PassWrapper<LegalizeGeneralDot, FunctionPass> {
|
||||||
/// Lower all general dots that can be represented as a non-batched matmul.
|
/// Lower all general dots that can be represented as a non-batched matmul.
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
|
|||||||
@ -28,7 +28,7 @@ namespace xla_hlo {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct TestMaterializeBroadcastsPass
|
struct TestMaterializeBroadcastsPass
|
||||||
: public FunctionPass<TestMaterializeBroadcastsPass> {
|
: public PassWrapper<TestMaterializeBroadcastsPass, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
ConversionTarget conversionTarget(getContext());
|
ConversionTarget conversionTarget(getContext());
|
||||||
OwningRewritePatternList conversionPatterns;
|
OwningRewritePatternList conversionPatterns;
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user