Add an option to enable MLIR bridge for tpu_py_test rule
If enable_mlir_bridge is True, a new test will be generated that runs with the MLIR bridge enabled. This option is off by default. PiperOrigin-RevId: 317173675 Change-Id: I332e1ae24cf82fceea20fd0aff2cec7c9b236a24
This commit is contained in:
parent
fc51511308
commit
c41f4652b4
@ -22,6 +22,7 @@ def distribute_py_test(
|
||||
full_precision = False,
|
||||
disable_v2 = False,
|
||||
disable_v3 = False,
|
||||
disable_mlir_bridge = True,
|
||||
**kwargs):
|
||||
"""Generates py_test targets for CPU and GPU.
|
||||
|
||||
@ -40,6 +41,7 @@ def distribute_py_test(
|
||||
full_precision: unused.
|
||||
disable_v2: whether tests for TPU version 2 should be generated.
|
||||
disable_v3: whether tests for TPU version 3 should be generated.
|
||||
disable_mlir_bridge: whether to also run this with the mlir bridge enabled.
|
||||
**kwargs: extra keyword arguments to the non-tpu test.
|
||||
"""
|
||||
|
||||
@ -77,6 +79,7 @@ def distribute_py_test(
|
||||
tags = tpu_tags,
|
||||
disable_v2 = disable_v2,
|
||||
disable_v3 = disable_v3,
|
||||
disable_mlir_bridge = disable_mlir_bridge,
|
||||
)
|
||||
|
||||
register_extension_info(
|
||||
|
@ -1933,6 +1933,9 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
# disable it here.
|
||||
pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True)
|
||||
|
||||
if is_mlir_bridge_enabled():
|
||||
context.context().enable_mlir_bridge = True
|
||||
|
||||
self._threads = []
|
||||
self._tempdir = None
|
||||
self._cached_session = None
|
||||
|
@ -25,6 +25,7 @@ def tpu_py_test(
|
||||
disable_v2 = False,
|
||||
disable_v3 = False,
|
||||
disable_experimental = False,
|
||||
disable_mlir_bridge = True,
|
||||
args = [],
|
||||
**kwargs):
|
||||
"""Generates identical unit test variants for various Cloud TPU versions.
|
||||
@ -37,6 +38,7 @@ def tpu_py_test(
|
||||
disable_v2: If true, don't generate TPU v2 tests.
|
||||
disable_v3: If true, don't generate TPU v3 tests.
|
||||
disable_experimental: Unused.
|
||||
disable_mlir_bridge: Unused.
|
||||
args: Arguments to apply to tests.
|
||||
**kwargs: Additional named arguments to apply to tests.
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user