From 417c2d448c11709696cf69a1d241b38955120f5f Mon Sep 17 00:00:00 2001 From: Marissa Ikonomidis Date: Tue, 9 Feb 2021 10:43:39 -0800 Subject: [PATCH] Fix XLATestCase enable mlir bridge logic XLA has its own test case class separate from TensorFlowTestCase. The XLATestCase was not differentiating between is_mlir_bridge_enabled() returning False vs None. False indicates that the mlir bridge was forcibly disabled. None indicates the user didn't specify a preference. Update the XLATestCase class to only disable the mlir bridge when is_mlir_bridge_enabled() is False. PiperOrigin-RevId: 356538298 Change-Id: I13ef9fcaf725cd656780999c020e832a95c9556b --- tensorflow/compiler/tests/xla_test.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index de97c6ff210..037016abcb5 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -85,7 +85,14 @@ class XLATestCase(test.TestCase): super(XLATestCase, self).__init__(method_name) if 'XLA' in FLAGS.test_device: context.context().enable_xla_devices() - context.context().enable_mlir_bridge = test_util.is_mlir_bridge_enabled() + + # Check if the mlir bridge has been explicitly enabled or disabled. If + # is_mlir_bridge_enabled() returns None, the user did not explictly enable + # or disable the bridge so do not update enable_mlir_bridge. + if test_util.is_mlir_bridge_enabled(): + context.context().enable_mlir_bridge = True + elif test_util.is_mlir_bridge_enabled() is not None: + context.context().enable_mlir_bridge = False self.device = FLAGS.test_device self.has_custom_call = (self.device == 'XLA_CPU')