[XLA:Python] Cache the backend in xla_client_test.
This is in preparation for removing backend caching logic from xla_client. PiperOrigin-RevId: 311551914 Change-Id: Ia791dc911bd7d9890dec111b8da69a9c619f061c
This commit is contained in:
parent
5d3c548620
commit
866e01f318
|
@ -2029,8 +2029,11 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||
return tests
|
||||
|
||||
|
||||
def InstantiateTests(globals_dict, backend, test_prefix="", **kw):
|
||||
for klass in TestFactory(backend, **kw):
|
||||
def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw):
|
||||
# Avoid creating a new backend per test (this causes GPU OOM, and is probably
|
||||
# inefficient).
|
||||
backend_fn = functools.lru_cache(maxsize=None)(backend_fn)
|
||||
for klass in TestFactory(backend_fn, **kw):
|
||||
test = type(test_prefix + klass.__name__, (klass,), {})
|
||||
# Clean up the qualified names of the tests to not include the test factory.
|
||||
test.__qualname__ = test.__name__
|
||||
|
|
Loading…
Reference in New Issue