Improve usability of TPUEstimator.
1) Log how many batches to enqueue. The old message is very confusing. 2) If input_pipeline has queue runner, generate a logging (legacy mode) or error out (new mode) 3) If input pipeline has summaries, generate a logging (legacy mode) or error out (new mode) PiperOrigin-RevId: 175073856
This commit is contained in:
parent
6488286b26
commit
8bb665ae1c
@ -535,13 +535,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
|
||||
session, self._dequeue_ops)
|
||||
|
||||
def before_run(self, run_context):
|
||||
logging.info('Enqueue next batch of data to infeed.')
|
||||
|
||||
iterations = run_context.session.run(self._iterations_per_loop_var)
|
||||
|
||||
logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
|
||||
|
||||
self._infeed_thd_controller.send_next_batch_signal(iterations)
|
||||
if self._dequeue_ops is not None:
|
||||
# TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop.
|
||||
logging.info('Dequeue next batch of data from outfeed.')
|
||||
logging.info(
|
||||
'Dequeue next (%d) batch(es) of data from outfeed.', iterations)
|
||||
self._outfeed_thd_controller.send_next_batch_signal(iterations)
|
||||
|
||||
def end(self, session):
|
||||
@ -842,6 +844,8 @@ class _InputPipeline(object):
|
||||
# structure is recorded.
|
||||
enqueue_ops = self._invoke_input_fn_and_record_structure()
|
||||
|
||||
self._validate_input_pipeline()
|
||||
|
||||
def dequeue_fn():
|
||||
"""dequeue_fn is used by TPU to retrieve the tensors."""
|
||||
values = self._infeed_queue.generate_dequeue_op()
|
||||
@ -920,6 +924,31 @@ class _InputPipeline(object):
|
||||
else:
|
||||
return enqueue_fn()
|
||||
|
||||
def _validate_input_pipeline(self):
|
||||
# Perform some sanity checks to log user friendly information. We should
|
||||
# error out to give users better error message. But, if
|
||||
# _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break
|
||||
# user code, so, log a warning.
|
||||
if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
|
||||
err_msg = ('Input pipeline contains one or more QueueRunners. '
|
||||
'These are not supported via TPUEstimator. You must convert '
|
||||
'your input pipeline to use `tf.data` instead (see '
|
||||
'https://www.tensorflow.org/programmers_guide/datasets for '
|
||||
'instructions.')
|
||||
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
|
||||
raise RuntimeError(err_msg)
|
||||
else:
|
||||
logging.warn(err_msg)
|
||||
elif ops.get_default_graph().get_collection(ops.GraphKeys.SUMMARIES):
|
||||
# Queue Runner has summary Ops by default. So here we use elif to do
|
||||
# necessary checks for Dataset input pipeline only.
|
||||
err_msg = ('Input pipeline contains `tf.summary` operations. '
|
||||
'These are not currently supported.')
|
||||
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
|
||||
raise RuntimeError(err_msg)
|
||||
else:
|
||||
logging.warn(err_msg)
|
||||
|
||||
|
||||
class _ModelFnWrapper(object):
|
||||
"""A `model_fn` wrapper.
|
||||
|
Loading…
Reference in New Issue
Block a user