TPUEstimator: Log status every 60 seconds to indicate progress.

PiperOrigin-RevId: 238738425
This commit is contained in:
Russell Power 2019-03-15 17:12:23 -07:00 committed by TensorFlower Gardener
parent b3991e03ab
commit 2a2c8a9679

View File

@ -253,6 +253,18 @@ def _extract_key_names(tensor_or_dict):
return []
class PeriodicLogger(object):
def __init__(self, seconds):
self._log_every_n_seconds = seconds
self._last_log_time = 0
def log(self, msg, *args, **kw):
if time.time() - self._last_log_time > self._log_every_n_seconds:
self._last_log_time = time.time()
logging.info(msg, *args, **kw)
class _SIGNAL(object):
"""Signal used to control the thread of infeed/outfeed.
@ -460,8 +472,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
self._initial_infeed_sleep_secs = (
ctx.config.tpu_config.initial_infeed_sleep_secs)
self._feed_error = None
self._finished = False
# When using model parallelism, the TPU is pre-initialized at startup to
# fetch mesh information. We skip re-initializing it here to avoid
# suspected issues due to the mesh layout changing on the second
@ -505,11 +515,13 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
def _run_outfeed(self, queue_ctx, session):
logging.info('Starting outfeed thread controller.')
status_logger = PeriodicLogger(seconds=60)
with self._rendezvous.catch_errors(source='outfeed', session=session):
for count, steps in enumerate(queue_ctx.read_iteration_counts()):
for i in xrange(steps):
logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
session.run(self._dequeue_ops)
status_logger.log('Outfeed finished for iteration (%d, %d)', count, i)
logging.info('Outfeed thread finished, shutting down.')
def _create_infeed_controller(self, name, target, args):
@ -557,8 +569,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
shutdown_timeout=watchdog_timeout)
def before_run(self, run_context):
self._feed_error = None
iterations = run_context.session.run(self._iterations_per_loop_var)
logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
@ -569,7 +579,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
self._outfeed_controller.send_next_batch_signal(iterations)
def end(self, session):
self._finished = True
logging.info('Stop infeed thread controller')
self._infeed_controller.join()
self._rendezvous.record_done('infeed')
@ -1490,6 +1499,14 @@ class _ModelFnWrapper(object):
and estimator_spec.host_call is not None):
host_call.record({'host_call': estimator_spec.host_call})
host_call_outfeed_ops = host_call.create_enqueue_op()
else:
# Create a dummy outfeed for the loss to track execution progress
host_call.record({
'host_call': (lambda loss_t: loss_t,
[array_ops.reshape(loss, [1])])
})
host_call_outfeed_ops = host_call.create_enqueue_op()
with ops.control_dependencies(host_call_outfeed_ops):
return array_ops.identity(loss)
@ -2856,8 +2873,6 @@ class TPUEstimator(estimator_lib.Estimator):
tpu_init_ops.extend(embedding_variables_and_ops.load_ops())
host_ops = host_call.create_tpu_hostcall()
if host_ops is None:
host_ops = []
shutdown_hooks = []
shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE',