diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index de97c6ff210..037016abcb5 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -85,7 +85,14 @@ class XLATestCase(test.TestCase): super(XLATestCase, self).__init__(method_name) if 'XLA' in FLAGS.test_device: context.context().enable_xla_devices() - context.context().enable_mlir_bridge = test_util.is_mlir_bridge_enabled() + + # Check if the mlir bridge has been explicitly enabled or disabled. If + # is_mlir_bridge_enabled() returns None, the user did not explictly enable + # or disable the bridge so do not update enable_mlir_bridge. + if test_util.is_mlir_bridge_enabled(): + context.context().enable_mlir_bridge = True + elif test_util.is_mlir_bridge_enabled() is not None: + context.context().enable_mlir_bridge = False self.device = FLAGS.test_device self.has_custom_call = (self.device == 'XLA_CPU')