Fix XLATestCase enable mlir bridge logic

XLA has its own test case class separate from TensorFlowTestCase.
The XLATestCase was not differentiating between is_mlir_bridge_enabled()
returning False vs None. False indicates that the mlir bridge was
forcibly disabled. None indicates the user didn't specify a preference.

Update the XLATestCase class to only disable the mlir bridge when
is_mlir_bridge_enabled() is False.

PiperOrigin-RevId: 356538298
Change-Id: I13ef9fcaf725cd656780999c020e832a95c9556b
This commit is contained in:
Marissa Ikonomidis 2021-02-09 10:43:39 -08:00 committed by TensorFlower Gardener
parent e1b9195cae
commit 417c2d448c

View File

@ -85,7 +85,14 @@ class XLATestCase(test.TestCase):
super(XLATestCase, self).__init__(method_name)
if 'XLA' in FLAGS.test_device:
context.context().enable_xla_devices()
context.context().enable_mlir_bridge = test_util.is_mlir_bridge_enabled()
# Check if the mlir bridge has been explicitly enabled or disabled. If
# is_mlir_bridge_enabled() returns None, the user did not explictly enable
# or disable the bridge so do not update enable_mlir_bridge.
if test_util.is_mlir_bridge_enabled():
context.context().enable_mlir_bridge = True
elif test_util.is_mlir_bridge_enabled() is not None:
context.context().enable_mlir_bridge = False
self.device = FLAGS.test_device
self.has_custom_call = (self.device == 'XLA_CPU')