diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD
index 5bbfba773a3..6a47be332d0 100644
--- a/tensorflow/compiler/mlir/python/BUILD
+++ b/tensorflow/compiler/mlir/python/BUILD
@@ -10,6 +10,7 @@ cc_library(
     deps = [
         "//tensorflow/c:tf_status",
         "//tensorflow/c:tf_status_helper",
+        "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
         "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
@@ -35,6 +36,7 @@ cc_library(
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc
index f1f6c43d3b3..8bec288cda5 100644
--- a/tensorflow/compiler/mlir/python/mlir.cc
+++ b/tensorflow/compiler/mlir/python/mlir.cc
@@ -16,11 +16,13 @@ limitations under the License.
 #include <string>
 
 #include "llvm/Support/raw_ostream.h"
+#include "mlir/InitAllPasses.h"  // from @llvm-project
 #include "mlir/Parser.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
 #include "tensorflow/c/tf_status.h"
 #include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
@@ -41,7 +43,6 @@ std::string ImportGraphDef(const std::string &proto,
   GraphDebugInfo debug_info;
   GraphImportConfig specs;
   mlir::MLIRContext context;
-  context.loadAllGloballyRegisteredDialects();
   auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
   if (!module.ok()) {
     Set_TF_Status_from_Status(status, module.status());
@@ -86,7 +87,6 @@ std::string ExperimentalConvertSavedModelToMlir(
   std::vector<string> exported_names =
       absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
   mlir::MLIRContext context;
-  context.loadAllGloballyRegisteredDialects();
   auto module_or = ConvertSavedModelToMlir(
       &bundle, &context, absl::Span<std::string>(exported_names));
   if (!module_or.status().ok()) {
@@ -117,7 +117,6 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
   // Convert the SavedModelBundle to an MLIR module.
 
   mlir::MLIRContext context;
-  context.loadAllGloballyRegisteredDialects();
   auto module_or =
       ConvertSavedModelV1ToMlir(bundle, {}, &context, upgrade_legacy);
   if (!module_or.status().ok()) {
@@ -153,6 +152,7 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
                                         bool show_debug_info,
                                         TF_Status *status) {
   mlir::MLIRContext context;
+  mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
   mlir::OwningModuleRef module;
   {
     mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
@@ -167,6 +167,7 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
   mlir::PassManager pm(&context);
   std::string error;
   llvm::raw_string_ostream error_stream(error);
+  mlir::registerAllPasses();
   if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
     TF_SetStatus(status, TF_INVALID_ARGUMENT,
                  ("Invalid pass_pipeline: " + error_stream.str()).c_str());
diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
index 4152b576e71..6cd49cf368d 100644
--- a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
+++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
@@ -22,23 +22,25 @@ limitations under the License.
 #include "mlir/Parser.h"  // from @llvm-project
 #include "pybind11/pybind11.h"
 #include "pybind11/stl.h"
+#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/python/lib/core/pybind11_lib.h"
 #include "tensorflow/python/lib/core/pybind11_status.h"
 
 PYBIND11_MODULE(mlir_wrapper, m) {
-  m.def("registerDialects", []() {
-    mlir::registerDialect<mlir::TF::TensorFlowDialect>();
-    mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
-    mlir::registerDialect<mlir::StandardOpsDialect>();
+  m.def("preloadTensorFlowDialects", [](mlir::MLIRContext &context) {
+    mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
+    context.getDialectRegistry().loadAll(&context);
   });
+
   m.def("verify", [](std::string input) {
     llvm::SourceMgr SM = llvm::SourceMgr();
     SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
                           llvm::SMLoc());
     mlir::MLIRContext ctx;
-    ctx.loadAllGloballyRegisteredDialects();
+    mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry());
+    ctx.getDialectRegistry().loadAll(&ctx);
     auto module = mlir::parseSourceFile(SM, &ctx);
     if (!module) {
       return false;
diff --git a/tensorflow/python/tf_program/pywrap_tfd.py b/tensorflow/python/tf_program/pywrap_tfd.py
index 0d9a236f5d3..a7a30b71f4e 100644
--- a/tensorflow/python/tf_program/pywrap_tfd.py
+++ b/tensorflow/python/tf_program/pywrap_tfd.py
@@ -137,8 +137,8 @@ class TFProgram(object):
   """Python wrap for a Tensorflow Program (essentially an mlir Module)."""
 
   def __init__(self):
-    mlir.registerDialects()
     self.ctx = mlir.MLIRContext()
+    mlir.preloadTensorFlowDialects(self.ctx)
     self.builder = mlir.Builder(self.ctx)
     self.module = mlir.ModuleOp.create(mlir.UnknownLoc.get(self.ctx))
     self.curr_func = None