From 866e01f318188f15c00d77c2efb219a2c50eb96b Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 14 May 2020 09:55:16 -0700 Subject: [PATCH] [XLA:Python] Cache the backend in xla_client_test. This is in preparation for removing backend caching logic from xla_client. PiperOrigin-RevId: 311551914 Change-Id: Ia791dc911bd7d9890dec111b8da69a9c619f061c --- tensorflow/compiler/xla/python/xla_client_test.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 62b3fae018a..fbdd9921a40 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -2029,8 +2029,11 @@ def TestFactory(xla_backend, cloud_tpu=False): return tests -def InstantiateTests(globals_dict, backend, test_prefix="", **kw): - for klass in TestFactory(backend, **kw): +def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): + # Avoid creating a new backend per test (this causes GPU OOM, and is probably + # inefficient). + backend_fn = functools.lru_cache(maxsize=None)(backend_fn) + for klass in TestFactory(backend_fn, **kw): test = type(test_prefix + klass.__name__, (klass,), {}) # Clean up the qualified names of the tests to not include the test factory. test.__qualname__ = test.__name__