diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 23960bb0300..c7008533f3a 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -410,14 +410,13 @@ class TPUEstimatorSpec( function should not capture any Tensors in `model_fn`. `host_call` is a tuple of a `function` and a list or dictionary of `tensors` - to pass to that function. `host_call` currently works for train() and - evaluate(). The function's graph is executed on the CPU on every step, so - there is communication overhead when sending tensors from TPU to CPU. To - reduce the overhead, try reducing the size of the tensors. The `tensors` are - concatenated along their major (batch) dimension, and so must be >= rank 1. - The `host_call` is useful for writing summaries with - @{tf.contrib.summary.create_file_writer}. Note that `host_call` does not - currently work if `use_tpu` is set to False. + to pass to that function and returns a list of Tensors. `host_call` currently + works for train() and evaluate(). The Tensors returned by the function is + executed on the CPU on every step, so there is communication overhead when + sending tensors from TPU to CPU. To reduce the overhead, try reducing the + size of the tensors. The `tensors` are concatenated along their major (batch) + dimension, and so must be >= rank 1. The `host_call` is useful for writing + summaries with @{tf.contrib.summary.create_file_writer}. """ def __new__(cls, @@ -449,10 +448,18 @@ class TPUEstimatorSpec( def as_estimator_spec(self): """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" + host_calls = {} + if self.eval_metrics is not None: + host_calls['eval_metrics'] = self.eval_metrics + if self.host_call is not None: + host_calls['host_call'] = self.host_call + host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls) eval_metric_ops = None if self.eval_metrics is not None: - eval_metric_ops = _OutfeedHostCall.create_cpu_hostcall( - {'eval_metrics': self.eval_metrics})['eval_metrics'] + eval_metric_ops = host_call_ret['eval_metrics'] + hooks = None + if self.host_call is not None: + hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec( mode=self.mode, @@ -461,7 +468,10 @@ class TPUEstimatorSpec( train_op=self.train_op, eval_metric_ops=eval_metric_ops, export_outputs=self.export_outputs, - scaffold=scaffold) + scaffold=scaffold, + training_hooks=hooks, + evaluation_hooks=hooks, + prediction_hooks=hooks) class _OpQueueContext(object): @@ -1450,6 +1460,34 @@ class _OutfeedHostCall(object): return ret +class _OutfeedHostCallHook(session_run_hook.SessionRunHook): + """Hook to run host calls when use_tpu=False.""" + + def __init__(self, tensors): + self._tensors = tensors + + def begin(self): + # We duplicate this code from the TPUInfeedOutfeedSessionHook rather than + # create a separate hook to guarantee execution order, because summaries + # need to be initialized before the outfeed thread starts. + # TODO(jhseu): Make a wrapper hook instead? + self._init_ops = contrib_summary.summary_writer_initializer_op() + # Get all the writer resources from the initializer, so we know what to + # flush. + self._finalize_ops = [] + for op in self._init_ops: + self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) + + def after_create_session(self, session, coord): + session.run(self._init_ops) + + def before_run(self, run_context): + return basic_session_run_hooks.SessionRunArgs(self._tensors) + + def end(self, session): + session.run(self._finalize_ops) + + class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): """Count examples during runtime."""