diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index e54cfddd51a..731394a89da 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -43,6 +43,7 @@ cc_library( "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "//tensorflow/core:lib", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:MlirOptLib", diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 297b1459fc5..b31da15c35f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -30,6 +30,7 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Threading.h" #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project @@ -74,6 +75,10 @@ constexpr char kTfLiteInputIndices[] = "_tflite_input_indices"; // Legalize operations in functions. class LegalizeTF : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: LegalizeTF() = default; LegalizeTF(const LegalizeTF&) {} diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc index 6202507ae91..8b54ca42dab 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc @@ -33,6 +33,10 @@ namespace { // cond and body regions. struct LegalizeWhile : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void RunOnFunction(FuncOp func); void runOnOperation() override { diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index d0b0e2d0063..172ce59ddd4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -110,6 +110,10 @@ class ConvertEmbeddedLookupFunc { class PrepareCompositeFunctionsPass : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: explicit PrepareCompositeFunctionsPass() {} diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 593d85d409c..115a5780e08 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1318,6 +1318,7 @@ cc_library( deps = [ ":convert_graphdef", ":mlir_roundtrip_flags", + ":tensorflow", "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index 6883d0358ec..2eaa511dbfe 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -115,6 +115,11 @@ static LogicalResult Verify(SessionInitializerOp session_initializer) { TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context) : Dialect(/*name=*/"tf_saved_model", context, TypeID::get()) { + // The TensorFlow Dialect is needed in the verifier and other routines + // associated to this dialect. It makes little sense anyway to use the + // SavedModel dialect without the TensorFlow Dialect. + context->loadDialect(); + addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc index 2b8ab85be38..e85058a1964 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc @@ -39,6 +39,10 @@ namespace { struct ClusterFormationPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override; }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index ad241ef9488..e64206d13d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -615,6 +615,10 @@ class ConvertReduceOpToTfMin : public OpConversionPattern { }; class LegalizeHloToTf : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + public: LegalizeHloToTf() = default; LegalizeHloToTf(const LegalizeHloToTf &) {} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc index 527af0934ea..352604955c0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc @@ -39,6 +39,10 @@ namespace { struct ParallelizeEmbeddingParamsOpsPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override; }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 1e4caaf5dd6..52ac87ecf71 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/Identifier.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" @@ -43,6 +44,10 @@ namespace tensorflow { class GraphOptPass : public mlir::PassWrapper> { + void getDependentDialects(mlir::DialectRegistry& registry) const override { + mlir::RegisterAllTensorFlowDialects(registry); + } + public: explicit GraphOptPass(std::vector passes) : passes_(std::move(passes)) {} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index f5bdd08d980..13be3ad75b6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -78,6 +78,10 @@ using ClusterMap = llvm::SmallDenseMap { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction( FuncOp func, const TF::ResourceAliasAnalysis::Info& resource_alias_analysis); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc index 0a69987deb0..b65f07c39ac 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc @@ -43,6 +43,10 @@ namespace { class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass< BreakUpIslands, TF::SideEffectAnalysis> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: void runOnFunction(FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index c539ce9b468..b78f3112bdb 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -144,8 +144,9 @@ bool IsResourceOutputShapesAttribute(const AttrValue& attr_value, void LoadImporterDialects(mlir::MLIRContext& context) { // Load dialects involved in the conversion - mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry()); - context.getDialectRegistry().loadAll(&context); + mlir::DialectRegistry registry; + mlir::RegisterAllTensorFlowDialects(registry); + registry.loadAll(&context); } // This class is used to generate new MLIR function name strings that are both diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 2a5caf35dd5..e48b14a6bc3 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/core/platform/init_main.h" int main(int argc, char **argv) { @@ -34,6 +35,7 @@ int main(int argc, char **argv) { mlir::mhlo::registerAllMhloDialects(registry); registry.insert(); registry.insert(); + registry.insert(); return failed( mlir::MlirOptMain(argc, argv, "TensorFlow pass driver\n", registry)); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc index ef07c801bc4..10a9d0c5515 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc @@ -67,6 +67,10 @@ class UnrankedTensorStoreTestOnlyPattern }; struct BufferizePass : public BufferizePassBase { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: void runOnOperation() override { OwningRewritePatternList patterns; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc index a0cfcae65d1..a26198e5871 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc @@ -36,6 +36,10 @@ static constexpr StringRef kTFEntry = "tf_entry"; // * std.dealloc becomes tf_framework.dealloc_raw. class EmbedTFFrameworkPass : public EmbedTFFrameworkPassBase { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: void runOnOperation() override { ModuleOp m = getOperation(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc index 28d3647bb63..d27f14d304d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc @@ -38,6 +38,10 @@ namespace { struct ShapeToDescriptorsPass : public ShapeToDescriptorsPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + public: void runOnOperation() override { MLIRContext &ctx = getContext(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc index 42e89433dff..6e6d71f8365 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc @@ -33,6 +33,10 @@ namespace { class TestTFFrameworkToLLVMPass : public TestTFFrameworkLegalizeToLLVMPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + public: void runOnOperation() override { ModuleOp m = getOperation(); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 262b526f11f..0332dbee589 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -72,7 +72,8 @@ constexpr char kShardingAttr[] = "mhlo.sharding"; class LegalizeTF : public PassWrapper { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } public: diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc index 1f884b1bdea..6320ad2032b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc @@ -60,6 +60,10 @@ const char kXlaHostTransferOriginalTypeAttr[] = // ops other than certain control flow ops (`mhlo.if`, `mhlo.while`). class LegalizeTFCommunication : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: void runOnOperation() override; };