diff --git a/tensorflow/python/distribute/combinations.py b/tensorflow/python/distribute/combinations.py index ffa03ee5329..9a479a3769b 100644 --- a/tensorflow/python/distribute/combinations.py +++ b/tensorflow/python/distribute/combinations.py @@ -22,6 +22,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re import sys import types import unittest @@ -94,10 +95,10 @@ class NamedGPUCombination(combinations_lib.TestCombination): Attributes: GPU_TEST: The environment is considered to have GPU hardware available if - the name of the program contains "test_gpu". + the name of the program contains "test_gpu" or "test_xla_gpu". """ - GPU_TEST = "test_gpu" in sys.argv[0] + GPU_TEST = re.search(r"(test_gpu|test_xla_gpu)$", sys.argv[0]) def should_execute_combination(self, kwargs): distributions = [