Fix XLATestCase enable mlir bridge logic
XLA has its own test case class separate from TensorFlowTestCase. The XLATestCase was not differentiating between is_mlir_bridge_enabled() returning False vs None. False indicates that the mlir bridge was forcibly disabled. None indicates the user didn't specify a preference. Update the XLATestCase class to only disable the mlir bridge when is_mlir_bridge_enabled() is False. PiperOrigin-RevId: 356538298 Change-Id: I13ef9fcaf725cd656780999c020e832a95c9556b
This commit is contained in:
parent
e1b9195cae
commit
417c2d448c
@ -85,7 +85,14 @@ class XLATestCase(test.TestCase):
|
||||
super(XLATestCase, self).__init__(method_name)
|
||||
if 'XLA' in FLAGS.test_device:
|
||||
context.context().enable_xla_devices()
|
||||
context.context().enable_mlir_bridge = test_util.is_mlir_bridge_enabled()
|
||||
|
||||
# Check if the mlir bridge has been explicitly enabled or disabled. If
|
||||
# is_mlir_bridge_enabled() returns None, the user did not explictly enable
|
||||
# or disable the bridge so do not update enable_mlir_bridge.
|
||||
if test_util.is_mlir_bridge_enabled():
|
||||
context.context().enable_mlir_bridge = True
|
||||
elif test_util.is_mlir_bridge_enabled() is not None:
|
||||
context.context().enable_mlir_bridge = False
|
||||
|
||||
self.device = FLAGS.test_device
|
||||
self.has_custom_call = (self.device == 'XLA_CPU')
|
||||
|
Loading…
x
Reference in New Issue
Block a user