diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 3c095469927..6e51b84a1d1 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import atexit import collections import contextlib import copy @@ -327,6 +328,11 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): self._logical_device_stack = [0] + if context.executing_eagerly(): + # In async remote eager, we want to sync the exectors before exiting the + # program. + atexit.register(context.async_wait) + # TODO(bfontain): Remove once a proper dataset API exists for prefetching # a dataset to multiple devices exists. # If value is true, this forces prefetch of data to the host's memeory rather