Improve usability of TPUEstimator.

1) Log how many batches to enqueue. The old message is very confusing.
2) If input_pipeline has queue runner, generate a logging (legacy mode) or error out (new mode)
3) If input pipeline has summaries, generate a logging (legacy mode) or error out (new mode)

PiperOrigin-RevId: 175073856
This commit is contained in:
Jianwei Xie 2017-11-08 15:19:12 -08:00 committed by TensorFlower Gardener
parent 6488286b26
commit 8bb665ae1c

View File

@ -535,13 +535,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
session, self._dequeue_ops)
def before_run(self, run_context):
logging.info('Enqueue next batch of data to infeed.')
iterations = run_context.session.run(self._iterations_per_loop_var)
logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
self._infeed_thd_controller.send_next_batch_signal(iterations)
if self._dequeue_ops is not None:
# TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop.
logging.info('Dequeue next batch of data from outfeed.')
logging.info(
'Dequeue next (%d) batch(es) of data from outfeed.', iterations)
self._outfeed_thd_controller.send_next_batch_signal(iterations)
def end(self, session):
@ -842,6 +844,8 @@ class _InputPipeline(object):
# structure is recorded.
enqueue_ops = self._invoke_input_fn_and_record_structure()
self._validate_input_pipeline()
def dequeue_fn():
"""dequeue_fn is used by TPU to retrieve the tensors."""
values = self._infeed_queue.generate_dequeue_op()
@ -920,6 +924,31 @@ class _InputPipeline(object):
else:
return enqueue_fn()
def _validate_input_pipeline(self):
# Perform some sanity checks to log user friendly information. We should
# error out to give users better error message. But, if
# _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break
# user code, so, log a warning.
if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
err_msg = ('Input pipeline contains one or more QueueRunners. '
'These are not supported via TPUEstimator. You must convert '
'your input pipeline to use `tf.data` instead (see '
'https://www.tensorflow.org/programmers_guide/datasets for '
'instructions.')
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
raise RuntimeError(err_msg)
else:
logging.warn(err_msg)
elif ops.get_default_graph().get_collection(ops.GraphKeys.SUMMARIES):
# Queue Runner has summary Ops by default. So here we use elif to do
# necessary checks for Dataset input pipeline only.
err_msg = ('Input pipeline contains `tf.summary` operations. '
'These are not currently supported.')
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
raise RuntimeError(err_msg)
else:
logging.warn(err_msg)
class _ModelFnWrapper(object):
"""A `model_fn` wrapper.