Enable //third_party/tensorflow/python/keras:multi_gpu_utils_test_xla_gpu

//third_party/tensorflow/python/keras:multi_gpu_utils_test_xla_gpu seems to pass
now.  Not sure exactly what changed, could be that this was fixed by
cr/284627940.

PiperOrigin-RevId: 287712795
Change-Id: Ib3bc033b3a6ab1de61bd49bbf79e9f3c64cc51b9
This commit is contained in:
Sanjoy Das 2019-12-31 22:15:29 -08:00 committed by TensorFlower Gardener
parent 45e1e4598d
commit 31af0e31e1
2 changed files with 2 additions and 5 deletions

View File

@ -1418,7 +1418,7 @@ cuda_py_test(
"guitar",
"multi_gpu",
],
xla_enable_strict_auto_jit = False, # b/142744009
xla_enable_strict_auto_jit = True,
deps = [
":keras",
"//tensorflow/python:client_testlib",

View File

@ -44,10 +44,7 @@ class TestMultiGPUModel(test.TestCase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
super(TestMultiGPUModel, self).__init__(methodName)
gpu_devices = config.list_physical_devices('GPU')
xla_gpu_devices = config.list_physical_devices('XLA_GPU')
# NOTE: XLA devices don't support the set_logical_device_configuration
# codepaths.
if len(gpu_devices) == 1 and not xla_gpu_devices:
if len(gpu_devices) == 1:
# A GPU is available, simulate 2 instead.
config.set_logical_device_configuration(gpu_devices[0], [
context.LogicalDeviceConfiguration(500),