[XLA:Python] Cache the backend in xla_client_test.
This is in preparation for removing backend caching logic from xla_client. PiperOrigin-RevId: 311551914 Change-Id: Ia791dc911bd7d9890dec111b8da69a9c619f061c
This commit is contained in:
parent
5d3c548620
commit
866e01f318
|
@ -2029,8 +2029,11 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||||
return tests
|
return tests
|
||||||
|
|
||||||
|
|
||||||
def InstantiateTests(globals_dict, backend, test_prefix="", **kw):
|
def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw):
|
||||||
for klass in TestFactory(backend, **kw):
|
# Avoid creating a new backend per test (this causes GPU OOM, and is probably
|
||||||
|
# inefficient).
|
||||||
|
backend_fn = functools.lru_cache(maxsize=None)(backend_fn)
|
||||||
|
for klass in TestFactory(backend_fn, **kw):
|
||||||
test = type(test_prefix + klass.__name__, (klass,), {})
|
test = type(test_prefix + klass.__name__, (klass,), {})
|
||||||
# Clean up the qualified names of the tests to not include the test factory.
|
# Clean up the qualified names of the tests to not include the test factory.
|
||||||
test.__qualname__ = test.__name__
|
test.__qualname__ = test.__name__
|
||||||
|
|
Loading…
Reference in New Issue