diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 0134d6a583e..e52573da4af 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -1418,7 +1418,7 @@ cuda_py_test( "guitar", "multi_gpu", ], - xla_enable_strict_auto_jit = False, # b/142744009 + xla_enable_strict_auto_jit = True, deps = [ ":keras", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py index 7e9ec9358b3..465ace7f264 100644 --- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py +++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py @@ -44,10 +44,7 @@ class TestMultiGPUModel(test.TestCase): def __init__(self, methodName='runTest'): # pylint: disable=invalid-name super(TestMultiGPUModel, self).__init__(methodName) gpu_devices = config.list_physical_devices('GPU') - xla_gpu_devices = config.list_physical_devices('XLA_GPU') - # NOTE: XLA devices don't support the set_logical_device_configuration - # codepaths. - if len(gpu_devices) == 1 and not xla_gpu_devices: + if len(gpu_devices) == 1: # A GPU is available, simulate 2 instead. config.set_logical_device_configuration(gpu_devices[0], [ context.LogicalDeviceConfiguration(500),