TPUEstimator: support host_call when use_tpu=False.

PiperOrigin-RevId: 184021299
This commit is contained in:
Jonathan Hseu 2018-01-31 11:30:36 -08:00 committed by Michael Case
parent adc9ee7150
commit d418a14176

View File

@ -410,14 +410,13 @@ class TPUEstimatorSpec(
function should not capture any Tensors in `model_fn`.
`host_call` is a tuple of a `function` and a list or dictionary of `tensors`
to pass to that function. `host_call` currently works for train() and
evaluate(). The function's graph is executed on the CPU on every step, so
there is communication overhead when sending tensors from TPU to CPU. To
reduce the overhead, try reducing the size of the tensors. The `tensors` are
concatenated along their major (batch) dimension, and so must be >= rank 1.
The `host_call` is useful for writing summaries with
@{tf.contrib.summary.create_file_writer}. Note that `host_call` does not
currently work if `use_tpu` is set to False.
to pass to that function and returns a list of Tensors. `host_call` currently
works for train() and evaluate(). The Tensors returned by the function is
executed on the CPU on every step, so there is communication overhead when
sending tensors from TPU to CPU. To reduce the overhead, try reducing the
size of the tensors. The `tensors` are concatenated along their major (batch)
dimension, and so must be >= rank 1. The `host_call` is useful for writing
summaries with @{tf.contrib.summary.create_file_writer}.
"""
def __new__(cls,
@ -449,10 +448,18 @@ class TPUEstimatorSpec(
def as_estimator_spec(self):
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
host_calls = {}
if self.eval_metrics is not None:
host_calls['eval_metrics'] = self.eval_metrics
if self.host_call is not None:
host_calls['host_call'] = self.host_call
host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls)
eval_metric_ops = None
if self.eval_metrics is not None:
eval_metric_ops = _OutfeedHostCall.create_cpu_hostcall(
{'eval_metrics': self.eval_metrics})['eval_metrics']
eval_metric_ops = host_call_ret['eval_metrics']
hooks = None
if self.host_call is not None:
hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]
scaffold = self.scaffold_fn() if self.scaffold_fn else None
return model_fn_lib.EstimatorSpec(
mode=self.mode,
@ -461,7 +468,10 @@ class TPUEstimatorSpec(
train_op=self.train_op,
eval_metric_ops=eval_metric_ops,
export_outputs=self.export_outputs,
scaffold=scaffold)
scaffold=scaffold,
training_hooks=hooks,
evaluation_hooks=hooks,
prediction_hooks=hooks)
class _OpQueueContext(object):
@ -1450,6 +1460,34 @@ class _OutfeedHostCall(object):
return ret
class _OutfeedHostCallHook(session_run_hook.SessionRunHook):
"""Hook to run host calls when use_tpu=False."""
def __init__(self, tensors):
self._tensors = tensors
def begin(self):
# We duplicate this code from the TPUInfeedOutfeedSessionHook rather than
# create a separate hook to guarantee execution order, because summaries
# need to be initialized before the outfeed thread starts.
# TODO(jhseu): Make a wrapper hook instead?
self._init_ops = contrib_summary.summary_writer_initializer_op()
# Get all the writer resources from the initializer, so we know what to
# flush.
self._finalize_ops = []
for op in self._init_ops:
self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0]))
def after_create_session(self, session, coord):
session.run(self._init_ops)
def before_run(self, run_context):
return basic_session_run_hooks.SessionRunArgs(self._tensors)
def end(self, session):
session.run(self._finalize_ops)
class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
"""Count examples during runtime."""