diff --git a/tensorflow/core/platform/default/distribute.bzl b/tensorflow/core/platform/default/distribute.bzl index 46a5d826a79..b16d5e8cff7 100644 --- a/tensorflow/core/platform/default/distribute.bzl +++ b/tensorflow/core/platform/default/distribute.bzl @@ -22,6 +22,7 @@ def distribute_py_test( full_precision = False, disable_v2 = False, disable_v3 = False, + disable_mlir_bridge = True, **kwargs): """Generates py_test targets for CPU and GPU. @@ -40,6 +41,7 @@ def distribute_py_test( full_precision: unused. disable_v2: whether tests for TPU version 2 should be generated. disable_v3: whether tests for TPU version 3 should be generated. + disable_mlir_bridge: whether to also run this with the mlir bridge enabled. **kwargs: extra keyword arguments to the non-tpu test. """ @@ -77,6 +79,7 @@ def distribute_py_test( tags = tpu_tags, disable_v2 = disable_v2, disable_v3 = disable_v3, + disable_mlir_bridge = disable_mlir_bridge, ) register_extension_info( diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index a46bb7c9bda..8ddbcf34f3b 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1933,6 +1933,9 @@ class TensorFlowTestCase(googletest.TestCase): # disable it here. pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True) + if is_mlir_bridge_enabled(): + context.context().enable_mlir_bridge = True + self._threads = [] self._tempdir = None self._cached_session = None diff --git a/tensorflow/python/tpu/tpu.bzl b/tensorflow/python/tpu/tpu.bzl index 5453702d64d..3c26d9b49bf 100644 --- a/tensorflow/python/tpu/tpu.bzl +++ b/tensorflow/python/tpu/tpu.bzl @@ -25,6 +25,7 @@ def tpu_py_test( disable_v2 = False, disable_v3 = False, disable_experimental = False, + disable_mlir_bridge = True, args = [], **kwargs): """Generates identical unit test variants for various Cloud TPU versions. @@ -37,6 +38,7 @@ def tpu_py_test( disable_v2: If true, don't generate TPU v2 tests. disable_v3: If true, don't generate TPU v3 tests. disable_experimental: Unused. + disable_mlir_bridge: Unused. args: Arguments to apply to tests. **kwargs: Additional named arguments to apply to tests. """