diff --git a/tensorflow/core/platform/default/distribute.bzl b/tensorflow/core/platform/default/distribute.bzl
index 46a5d826a79..b16d5e8cff7 100644
--- a/tensorflow/core/platform/default/distribute.bzl
+++ b/tensorflow/core/platform/default/distribute.bzl
@@ -22,6 +22,7 @@ def distribute_py_test(
         full_precision = False,
         disable_v2 = False,
         disable_v3 = False,
+        disable_mlir_bridge = True,
         **kwargs):
     """Generates py_test targets for CPU and GPU.
 
@@ -40,6 +41,7 @@ def distribute_py_test(
         full_precision: unused.
         disable_v2: whether tests for TPU version 2 should be generated.
         disable_v3: whether tests for TPU version 3 should be generated.
+        disable_mlir_bridge: whether to also run this with the mlir bridge enabled.
         **kwargs: extra keyword arguments to the non-tpu test.
     """
 
@@ -77,6 +79,7 @@ def distribute_py_test(
             tags = tpu_tags,
             disable_v2 = disable_v2,
             disable_v3 = disable_v3,
+            disable_mlir_bridge = disable_mlir_bridge,
         )
 
 register_extension_info(
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index a46bb7c9bda..8ddbcf34f3b 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -1933,6 +1933,9 @@ class TensorFlowTestCase(googletest.TestCase):
       # disable it here.
       pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True)
 
+    if is_mlir_bridge_enabled():
+      context.context().enable_mlir_bridge = True
+
     self._threads = []
     self._tempdir = None
     self._cached_session = None
diff --git a/tensorflow/python/tpu/tpu.bzl b/tensorflow/python/tpu/tpu.bzl
index 5453702d64d..3c26d9b49bf 100644
--- a/tensorflow/python/tpu/tpu.bzl
+++ b/tensorflow/python/tpu/tpu.bzl
@@ -25,6 +25,7 @@ def tpu_py_test(
         disable_v2 = False,
         disable_v3 = False,
         disable_experimental = False,
+        disable_mlir_bridge = True,
         args = [],
         **kwargs):
     """Generates identical unit test variants for various Cloud TPU versions.
@@ -37,6 +38,7 @@ def tpu_py_test(
         disable_v2: If true, don't generate TPU v2 tests.
         disable_v3: If true, don't generate TPU v3 tests.
         disable_experimental: Unused.
+        disable_mlir_bridge: Unused.
         args: Arguments to apply to tests.
         **kwargs: Additional named arguments to apply to tests.
     """