TPUEstimator: Log status every 60 seconds to indicate progress.
PiperOrigin-RevId: 238738425
This commit is contained in:
parent
b3991e03ab
commit
2a2c8a9679
@ -253,6 +253,18 @@ def _extract_key_names(tensor_or_dict):
|
||||
return []
|
||||
|
||||
|
||||
class PeriodicLogger(object):
|
||||
|
||||
def __init__(self, seconds):
|
||||
self._log_every_n_seconds = seconds
|
||||
self._last_log_time = 0
|
||||
|
||||
def log(self, msg, *args, **kw):
|
||||
if time.time() - self._last_log_time > self._log_every_n_seconds:
|
||||
self._last_log_time = time.time()
|
||||
logging.info(msg, *args, **kw)
|
||||
|
||||
|
||||
class _SIGNAL(object):
|
||||
"""Signal used to control the thread of infeed/outfeed.
|
||||
|
||||
@ -460,8 +472,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
|
||||
self._initial_infeed_sleep_secs = (
|
||||
ctx.config.tpu_config.initial_infeed_sleep_secs)
|
||||
|
||||
self._feed_error = None
|
||||
self._finished = False
|
||||
# When using model parallelism, the TPU is pre-initialized at startup to
|
||||
# fetch mesh information. We skip re-initializing it here to avoid
|
||||
# suspected issues due to the mesh layout changing on the second
|
||||
@ -505,11 +515,13 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
|
||||
|
||||
def _run_outfeed(self, queue_ctx, session):
|
||||
logging.info('Starting outfeed thread controller.')
|
||||
status_logger = PeriodicLogger(seconds=60)
|
||||
with self._rendezvous.catch_errors(source='outfeed', session=session):
|
||||
for count, steps in enumerate(queue_ctx.read_iteration_counts()):
|
||||
for i in xrange(steps):
|
||||
logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
|
||||
session.run(self._dequeue_ops)
|
||||
status_logger.log('Outfeed finished for iteration (%d, %d)', count, i)
|
||||
logging.info('Outfeed thread finished, shutting down.')
|
||||
|
||||
def _create_infeed_controller(self, name, target, args):
|
||||
@ -557,8 +569,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
|
||||
shutdown_timeout=watchdog_timeout)
|
||||
|
||||
def before_run(self, run_context):
|
||||
self._feed_error = None
|
||||
|
||||
iterations = run_context.session.run(self._iterations_per_loop_var)
|
||||
|
||||
logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
|
||||
@ -569,7 +579,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
|
||||
self._outfeed_controller.send_next_batch_signal(iterations)
|
||||
|
||||
def end(self, session):
|
||||
self._finished = True
|
||||
logging.info('Stop infeed thread controller')
|
||||
self._infeed_controller.join()
|
||||
self._rendezvous.record_done('infeed')
|
||||
@ -1490,6 +1499,14 @@ class _ModelFnWrapper(object):
|
||||
and estimator_spec.host_call is not None):
|
||||
host_call.record({'host_call': estimator_spec.host_call})
|
||||
host_call_outfeed_ops = host_call.create_enqueue_op()
|
||||
else:
|
||||
# Create a dummy outfeed for the loss to track execution progress
|
||||
host_call.record({
|
||||
'host_call': (lambda loss_t: loss_t,
|
||||
[array_ops.reshape(loss, [1])])
|
||||
})
|
||||
host_call_outfeed_ops = host_call.create_enqueue_op()
|
||||
|
||||
with ops.control_dependencies(host_call_outfeed_ops):
|
||||
return array_ops.identity(loss)
|
||||
|
||||
@ -2856,8 +2873,6 @@ class TPUEstimator(estimator_lib.Estimator):
|
||||
tpu_init_ops.extend(embedding_variables_and_ops.load_ops())
|
||||
|
||||
host_ops = host_call.create_tpu_hostcall()
|
||||
if host_ops is None:
|
||||
host_ops = []
|
||||
|
||||
shutdown_hooks = []
|
||||
shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE',
|
||||
|
Loading…
x
Reference in New Issue
Block a user