diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 62b3fae018a..fbdd9921a40 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -2029,8 +2029,11 @@ def TestFactory(xla_backend, cloud_tpu=False): return tests -def InstantiateTests(globals_dict, backend, test_prefix="", **kw): - for klass in TestFactory(backend, **kw): +def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): + # Avoid creating a new backend per test (this causes GPU OOM, and is probably + # inefficient). + backend_fn = functools.lru_cache(maxsize=None)(backend_fn) + for klass in TestFactory(backend_fn, **kw): test = type(test_prefix + klass.__name__, (klass,), {}) # Clean up the qualified names of the tests to not include the test factory. test.__qualname__ = test.__name__