Add explicit dependent dialects registration in TF MLIR
PiperOrigin-RevId: 328488901 Change-Id: Icb334bd4bf937a7634f6cd708f382026cae27051
This commit is contained in:
parent
e83b36531b
commit
0d336074b7
@ -43,6 +43,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
|
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
|
||||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||||
"//tensorflow/compiler/mlir/tensorflow",
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
|
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||||
"@llvm-project//mlir:MlirOptLib",
|
"@llvm-project//mlir:MlirOptLib",
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/StringSwitch.h"
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
#include "llvm/Support/Threading.h"
|
#include "llvm/Support/Threading.h"
|
||||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
|
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||||
@ -74,6 +75,10 @@ constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
|
|||||||
|
|
||||||
// Legalize operations in functions.
|
// Legalize operations in functions.
|
||||||
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<quant::QuantizationDialect, TFL::TensorFlowLiteDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LegalizeTF() = default;
|
LegalizeTF() = default;
|
||||||
LegalizeTF(const LegalizeTF&) {}
|
LegalizeTF(const LegalizeTF&) {}
|
||||||
|
@ -33,6 +33,10 @@ namespace {
|
|||||||
// cond and body regions.
|
// cond and body regions.
|
||||||
struct LegalizeWhile
|
struct LegalizeWhile
|
||||||
: public PassWrapper<LegalizeWhile, OperationPass<ModuleOp>> {
|
: public PassWrapper<LegalizeWhile, OperationPass<ModuleOp>> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<TFL::TensorFlowLiteDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
void RunOnFunction(FuncOp func);
|
void RunOnFunction(FuncOp func);
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
|
@ -110,6 +110,10 @@ class ConvertEmbeddedLookupFunc {
|
|||||||
class PrepareCompositeFunctionsPass
|
class PrepareCompositeFunctionsPass
|
||||||
: public PassWrapper<PrepareCompositeFunctionsPass,
|
: public PassWrapper<PrepareCompositeFunctionsPass,
|
||||||
OperationPass<ModuleOp>> {
|
OperationPass<ModuleOp>> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<TFL::TensorFlowLiteDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit PrepareCompositeFunctionsPass() {}
|
explicit PrepareCompositeFunctionsPass() {}
|
||||||
|
|
||||||
|
@ -1318,6 +1318,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":convert_graphdef",
|
":convert_graphdef",
|
||||||
":mlir_roundtrip_flags",
|
":mlir_roundtrip_flags",
|
||||||
|
":tensorflow",
|
||||||
"//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
|
"//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
@ -115,6 +115,11 @@ static LogicalResult Verify(SessionInitializerOp session_initializer) {
|
|||||||
TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context)
|
TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context)
|
||||||
: Dialect(/*name=*/"tf_saved_model", context,
|
: Dialect(/*name=*/"tf_saved_model", context,
|
||||||
TypeID::get<TensorFlowSavedModelDialect>()) {
|
TypeID::get<TensorFlowSavedModelDialect>()) {
|
||||||
|
// The TensorFlow Dialect is needed in the verifier and other routines
|
||||||
|
// associated to this dialect. It makes little sense anyway to use the
|
||||||
|
// SavedModel dialect without the TensorFlow Dialect.
|
||||||
|
context->loadDialect<TF::TensorFlowDialect>();
|
||||||
|
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
|
||||||
|
@ -39,6 +39,10 @@ namespace {
|
|||||||
|
|
||||||
struct ClusterFormationPass
|
struct ClusterFormationPass
|
||||||
: public PassWrapper<ClusterFormationPass, FunctionPass> {
|
: public PassWrapper<ClusterFormationPass, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<tf_device::TensorFlowDeviceDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -615,6 +615,10 @@ class ConvertReduceOpToTfMin : public OpConversionPattern<mhlo::ReduceOp> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
|
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<TF::TensorFlowDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LegalizeHloToTf() = default;
|
LegalizeHloToTf() = default;
|
||||||
LegalizeHloToTf(const LegalizeHloToTf &) {}
|
LegalizeHloToTf(const LegalizeHloToTf &) {}
|
||||||
|
@ -39,6 +39,10 @@ namespace {
|
|||||||
|
|
||||||
struct ParallelizeEmbeddingParamsOpsPass
|
struct ParallelizeEmbeddingParamsOpsPass
|
||||||
: public PassWrapper<ParallelizeEmbeddingParamsOpsPass, FunctionPass> {
|
: public PassWrapper<ParallelizeEmbeddingParamsOpsPass, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<tf_device::TensorFlowDeviceDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||||
#include "mlir/IR/Location.h" // from @llvm-project
|
#include "mlir/IR/Location.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||||
@ -43,6 +44,10 @@ namespace tensorflow {
|
|||||||
class GraphOptPass
|
class GraphOptPass
|
||||||
: public mlir::PassWrapper<GraphOptPass,
|
: public mlir::PassWrapper<GraphOptPass,
|
||||||
mlir::OperationPass<mlir::ModuleOp>> {
|
mlir::OperationPass<mlir::ModuleOp>> {
|
||||||
|
void getDependentDialects(mlir::DialectRegistry& registry) const override {
|
||||||
|
mlir::RegisterAllTensorFlowDialects(registry);
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes)
|
explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes)
|
||||||
: passes_(std::move(passes)) {}
|
: passes_(std::move(passes)) {}
|
||||||
|
@ -78,6 +78,10 @@ using ClusterMap = llvm::SmallDenseMap<llvm::StringRef,
|
|||||||
struct TPUClusterFormation
|
struct TPUClusterFormation
|
||||||
: public TF::PerFunctionAggregateAnalysisConsumerPass<
|
: public TF::PerFunctionAggregateAnalysisConsumerPass<
|
||||||
TPUClusterFormation, TF::ResourceAliasAnalysis> {
|
TPUClusterFormation, TF::ResourceAliasAnalysis> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<tf_device::TensorFlowDeviceDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
void runOnFunction(
|
void runOnFunction(
|
||||||
FuncOp func,
|
FuncOp func,
|
||||||
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis);
|
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis);
|
||||||
|
@ -43,6 +43,10 @@ namespace {
|
|||||||
|
|
||||||
class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass<
|
class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass<
|
||||||
BreakUpIslands, TF::SideEffectAnalysis> {
|
BreakUpIslands, TF::SideEffectAnalysis> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<tf_executor::TensorFlowExecutorDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void runOnFunction(FuncOp func,
|
void runOnFunction(FuncOp func,
|
||||||
const TF::SideEffectAnalysis::Info& side_effect_analysis);
|
const TF::SideEffectAnalysis::Info& side_effect_analysis);
|
||||||
|
@ -144,8 +144,9 @@ bool IsResourceOutputShapesAttribute(const AttrValue& attr_value,
|
|||||||
|
|
||||||
void LoadImporterDialects(mlir::MLIRContext& context) {
|
void LoadImporterDialects(mlir::MLIRContext& context) {
|
||||||
// Load dialects involved in the conversion
|
// Load dialects involved in the conversion
|
||||||
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
|
mlir::DialectRegistry registry;
|
||||||
context.getDialectRegistry().loadAll(&context);
|
mlir::RegisterAllTensorFlowDialects(registry);
|
||||||
|
registry.loadAll(&context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// This class is used to generate new MLIR function name strings that are both
|
// This class is used to generate new MLIR function name strings that are both
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
|
||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
@ -34,6 +35,7 @@ int main(int argc, char **argv) {
|
|||||||
mlir::mhlo::registerAllMhloDialects(registry);
|
mlir::mhlo::registerAllMhloDialects(registry);
|
||||||
registry.insert<mlir::shape::ShapeDialect>();
|
registry.insert<mlir::shape::ShapeDialect>();
|
||||||
registry.insert<mlir::TFL::TensorFlowLiteDialect>();
|
registry.insert<mlir::TFL::TensorFlowLiteDialect>();
|
||||||
|
registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect>();
|
||||||
return failed(
|
return failed(
|
||||||
mlir::MlirOptMain(argc, argv, "TensorFlow pass driver\n", registry));
|
mlir::MlirOptMain(argc, argv, "TensorFlow pass driver\n", registry));
|
||||||
}
|
}
|
||||||
|
@ -67,6 +67,10 @@ class UnrankedTensorStoreTestOnlyPattern
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct BufferizePass : public BufferizePassBase<BufferizePass> {
|
struct BufferizePass : public BufferizePassBase<BufferizePass> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<lmhlo::LmhloDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
|
@ -36,6 +36,10 @@ static constexpr StringRef kTFEntry = "tf_entry";
|
|||||||
// * std.dealloc becomes tf_framework.dealloc_raw.
|
// * std.dealloc becomes tf_framework.dealloc_raw.
|
||||||
class EmbedTFFrameworkPass
|
class EmbedTFFrameworkPass
|
||||||
: public EmbedTFFrameworkPassBase<EmbedTFFrameworkPass> {
|
: public EmbedTFFrameworkPassBase<EmbedTFFrameworkPass> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
ModuleOp m = getOperation();
|
ModuleOp m = getOperation();
|
||||||
|
@ -38,6 +38,10 @@ namespace {
|
|||||||
|
|
||||||
struct ShapeToDescriptorsPass
|
struct ShapeToDescriptorsPass
|
||||||
: public ShapeToDescriptorsPassBase<ShapeToDescriptorsPass> {
|
: public ShapeToDescriptorsPassBase<ShapeToDescriptorsPass> {
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<scf::SCFDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext &ctx = getContext();
|
MLIRContext &ctx = getContext();
|
||||||
|
@ -33,6 +33,10 @@ namespace {
|
|||||||
|
|
||||||
class TestTFFrameworkToLLVMPass
|
class TestTFFrameworkToLLVMPass
|
||||||
: public TestTFFrameworkLegalizeToLLVMPassBase<TestTFFrameworkToLLVMPass> {
|
: public TestTFFrameworkLegalizeToLLVMPassBase<TestTFFrameworkToLLVMPass> {
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<LLVM::LLVMDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
ModuleOp m = getOperation();
|
ModuleOp m = getOperation();
|
||||||
|
@ -72,7 +72,8 @@ constexpr char kShardingAttr[] = "mhlo.sharding";
|
|||||||
|
|
||||||
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect>();
|
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
|
||||||
|
shape::ShapeDialect, StandardOpsDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -60,6 +60,10 @@ const char kXlaHostTransferOriginalTypeAttr[] =
|
|||||||
// ops other than certain control flow ops (`mhlo.if`, `mhlo.while`).
|
// ops other than certain control flow ops (`mhlo.if`, `mhlo.while`).
|
||||||
class LegalizeTFCommunication
|
class LegalizeTFCommunication
|
||||||
: public PassWrapper<LegalizeTFCommunication, OperationPass<ModuleOp>> {
|
: public PassWrapper<LegalizeTFCommunication, OperationPass<ModuleOp>> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<mhlo::MhloDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user