Add an option to enable MLIR bridge for tpu_py_test rule

If enable_mlir_bridge is True, a new test will be generated that runs with the MLIR bridge enabled.
This option is off by default.

PiperOrigin-RevId: 317173675
Change-Id: I332e1ae24cf82fceea20fd0aff2cec7c9b236a24
This commit is contained in:
Marissa Ikonomidis 2020-06-18 13:52:21 -07:00 committed by TensorFlower Gardener
parent fc51511308
commit c41f4652b4
3 changed files with 8 additions and 0 deletions

View File

@ -22,6 +22,7 @@ def distribute_py_test(
full_precision = False,
disable_v2 = False,
disable_v3 = False,
disable_mlir_bridge = True,
**kwargs):
"""Generates py_test targets for CPU and GPU.
@ -40,6 +41,7 @@ def distribute_py_test(
full_precision: unused.
disable_v2: whether tests for TPU version 2 should be generated.
disable_v3: whether tests for TPU version 3 should be generated.
disable_mlir_bridge: whether to also run this with the mlir bridge enabled.
**kwargs: extra keyword arguments to the non-tpu test.
"""
@ -77,6 +79,7 @@ def distribute_py_test(
tags = tpu_tags,
disable_v2 = disable_v2,
disable_v3 = disable_v3,
disable_mlir_bridge = disable_mlir_bridge,
)
register_extension_info(

View File

@ -1933,6 +1933,9 @@ class TensorFlowTestCase(googletest.TestCase):
# disable it here.
pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True)
if is_mlir_bridge_enabled():
context.context().enable_mlir_bridge = True
self._threads = []
self._tempdir = None
self._cached_session = None

View File

@ -25,6 +25,7 @@ def tpu_py_test(
disable_v2 = False,
disable_v3 = False,
disable_experimental = False,
disable_mlir_bridge = True,
args = [],
**kwargs):
"""Generates identical unit test variants for various Cloud TPU versions.
@ -37,6 +38,7 @@ def tpu_py_test(
disable_v2: If true, don't generate TPU v2 tests.
disable_v3: If true, don't generate TPU v3 tests.
disable_experimental: Unused.
disable_mlir_bridge: Unused.
args: Arguments to apply to tests.
**kwargs: Additional named arguments to apply to tests.
"""