Clear kernel cache when reseting TPU system.

PiperOrigin-RevId: 350838785
Change-Id: I3b49d1e52c1e1171a300a9e218c55da538f64bc1
This commit is contained in:
Xiao Yu 2021-01-08 14:36:39 -08:00 committed by TensorFlower Gardener
parent 5f5041fc9c
commit 52a2da52ad
2 changed files with 13 additions and 0 deletions

View File

@ -661,6 +661,17 @@ class Context(object):
else:
raise ValueError("Context is not initialized.")
def clear_kernel_cache(self):
"""Clear kernel cache and reset all stateful kernels.
Raises:
ValueError: if context is not initialized.
"""
if self._context_handle is not None:
pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
else:
raise ValueError("Context is not initialized.")
def enable_collective_ops(self, server_def):
"""Enable distributed collective ops with an appropriate server_def.

View File

@ -108,6 +108,7 @@ def initialize_tpu_system(cluster_resolver=None):
# Clear out the eager context caches since the memory is invalid now.
logging.info("Clearing out eager caches")
context.context()._clear_caches() # pylint: disable=protected-access
context.context().clear_kernel_cache()
serialized_topology = output.numpy()
elif not ops.executing_eagerly_outside_functions():
@ -196,6 +197,7 @@ def shutdown_tpu_system(cluster_resolver=None):
# Clear out the eager context caches since the memory is invalid now.
logging.info("Clearing out eager caches")
context.context()._clear_caches() # pylint: disable=protected-access
context.context().clear_kernel_cache()
elif not ops.executing_eagerly_outside_functions():
master = cluster_resolver.master()
cluster_spec = cluster_resolver.cluster_spec()