From 8bb665ae1c8f2aedd479b5bfe2403ac54e37319e Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Wed, 8 Nov 2017 15:19:12 -0800 Subject: [PATCH] Improve usability of TPUEstimator. 1) Log how many batches to enqueue. The old message is very confusing. 2) If input_pipeline has queue runner, generate a logging (legacy mode) or error out (new mode) 3) If input pipeline has summaries, generate a logging (legacy mode) or error out (new mode) PiperOrigin-RevId: 175073856 --- .../contrib/tpu/python/tpu/tpu_estimator.py | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 5a3b8314291..16d712af9e2 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -535,13 +535,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): session, self._dequeue_ops) def before_run(self, run_context): - logging.info('Enqueue next batch of data to infeed.') - iterations = run_context.session.run(self._iterations_per_loop_var) + + logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) + self._infeed_thd_controller.send_next_batch_signal(iterations) if self._dequeue_ops is not None: # TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop. - logging.info('Dequeue next batch of data from outfeed.') + logging.info( + 'Dequeue next (%d) batch(es) of data from outfeed.', iterations) self._outfeed_thd_controller.send_next_batch_signal(iterations) def end(self, session): @@ -842,6 +844,8 @@ class _InputPipeline(object): # structure is recorded. enqueue_ops = self._invoke_input_fn_and_record_structure() + self._validate_input_pipeline() + def dequeue_fn(): """dequeue_fn is used by TPU to retrieve the tensors.""" values = self._infeed_queue.generate_dequeue_op() @@ -920,6 +924,31 @@ class _InputPipeline(object): else: return enqueue_fn() + def _validate_input_pipeline(self): + # Perform some sanity checks to log user friendly information. We should + # error out to give users better error message. But, if + # _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break + # user code, so, log a warning. + if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): + err_msg = ('Input pipeline contains one or more QueueRunners. ' + 'These are not supported via TPUEstimator. You must convert ' + 'your input pipeline to use `tf.data` instead (see ' + 'https://www.tensorflow.org/programmers_guide/datasets for ' + 'instructions.') + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + raise RuntimeError(err_msg) + else: + logging.warn(err_msg) + elif ops.get_default_graph().get_collection(ops.GraphKeys.SUMMARIES): + # Queue Runner has summary Ops by default. So here we use elif to do + # necessary checks for Dataset input pipeline only. + err_msg = ('Input pipeline contains `tf.summary` operations. ' + 'These are not currently supported.') + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + raise RuntimeError(err_msg) + else: + logging.warn(err_msg) + class _ModelFnWrapper(object): """A `model_fn` wrapper.