[XLA:Python] Cache the backend in xla_client_test.

This is in preparation for removing backend caching logic from xla_client.

PiperOrigin-RevId: 311551914
Change-Id: Ia791dc911bd7d9890dec111b8da69a9c619f061c
This commit is contained in:
Skye Wanderman-Milne 2020-05-14 09:55:16 -07:00 committed by TensorFlower Gardener
parent 5d3c548620
commit 866e01f318
1 changed files with 5 additions and 2 deletions

View File

@ -2029,8 +2029,11 @@ def TestFactory(xla_backend, cloud_tpu=False):
return tests return tests
def InstantiateTests(globals_dict, backend, test_prefix="", **kw): def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw):
for klass in TestFactory(backend, **kw): # Avoid creating a new backend per test (this causes GPU OOM, and is probably
# inefficient).
backend_fn = functools.lru_cache(maxsize=None)(backend_fn)
for klass in TestFactory(backend_fn, **kw):
test = type(test_prefix + klass.__name__, (klass,), {}) test = type(test_prefix + klass.__name__, (klass,), {})
# Clean up the qualified names of the tests to not include the test factory. # Clean up the qualified names of the tests to not include the test factory.
test.__qualname__ = test.__name__ test.__qualname__ = test.__name__