diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index 5bbfba773a3..6a47be332d0 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -10,6 +10,7 @@ cc_library( deps = [ "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", @@ -35,6 +36,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index f1f6c43d3b3..8bec288cda5 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -16,11 +16,13 @@ limitations under the License. #include <string> #include "llvm/Support/raw_ostream.h" +#include "mlir/InitAllPasses.h" // from @llvm-project #include "mlir/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" @@ -41,7 +43,6 @@ std::string ImportGraphDef(const std::string &proto, GraphDebugInfo debug_info; GraphImportConfig specs; mlir::MLIRContext context; - context.loadAllGloballyRegisteredDialects(); auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context); if (!module.ok()) { Set_TF_Status_from_Status(status, module.status()); @@ -86,7 +87,6 @@ std::string ExperimentalConvertSavedModelToMlir( std::vector<string> exported_names = absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); mlir::MLIRContext context; - context.loadAllGloballyRegisteredDialects(); auto module_or = ConvertSavedModelToMlir( &bundle, &context, absl::Span<std::string>(exported_names)); if (!module_or.status().ok()) { @@ -117,7 +117,6 @@ std::string ExperimentalConvertSavedModelV1ToMlir( // Convert the SavedModelBundle to an MLIR module. mlir::MLIRContext context; - context.loadAllGloballyRegisteredDialects(); auto module_or = ConvertSavedModelV1ToMlir(bundle, {}, &context, upgrade_legacy); if (!module_or.status().ok()) { @@ -153,6 +152,7 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt, bool show_debug_info, TF_Status *status) { mlir::MLIRContext context; + mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry()); mlir::OwningModuleRef module; { mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); @@ -167,6 +167,7 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt, mlir::PassManager pm(&context); std::string error; llvm::raw_string_ostream error_stream(error); + mlir::registerAllPasses(); if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) { TF_SetStatus(status, TF_INVALID_ARGUMENT, ("Invalid pass_pipeline: " + error_stream.str()).c_str()); diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc index 4152b576e71..6cd49cf368d 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc @@ -22,23 +22,25 @@ limitations under the License. #include "mlir/Parser.h" // from @llvm-project #include "pybind11/pybind11.h" #include "pybind11/stl.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/python/lib/core/pybind11_lib.h" #include "tensorflow/python/lib/core/pybind11_status.h" PYBIND11_MODULE(mlir_wrapper, m) { - m.def("registerDialects", []() { - mlir::registerDialect<mlir::TF::TensorFlowDialect>(); - mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>(); - mlir::registerDialect<mlir::StandardOpsDialect>(); + m.def("preloadTensorFlowDialects", [](mlir::MLIRContext &context) { + mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry()); + context.getDialectRegistry().loadAll(&context); }); + m.def("verify", [](std::string input) { llvm::SourceMgr SM = llvm::SourceMgr(); SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input), llvm::SMLoc()); mlir::MLIRContext ctx; - ctx.loadAllGloballyRegisteredDialects(); + mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry()); + ctx.getDialectRegistry().loadAll(&ctx); auto module = mlir::parseSourceFile(SM, &ctx); if (!module) { return false; diff --git a/tensorflow/python/tf_program/pywrap_tfd.py b/tensorflow/python/tf_program/pywrap_tfd.py index 0d9a236f5d3..a7a30b71f4e 100644 --- a/tensorflow/python/tf_program/pywrap_tfd.py +++ b/tensorflow/python/tf_program/pywrap_tfd.py @@ -137,8 +137,8 @@ class TFProgram(object): """Python wrap for a Tensorflow Program (essentially an mlir Module).""" def __init__(self): - mlir.registerDialects() self.ctx = mlir.MLIRContext() + mlir.preloadTensorFlowDialects(self.ctx) self.builder = mlir.Builder(self.ctx) self.module = mlir.ModuleOp.create(mlir.UnknownLoc.get(self.ctx)) self.curr_func = None