TPUEstimator: support host_call when use_tpu=False.

PiperOrigin-RevId: 184021299
This commit is contained in:
Jonathan Hseu 2018-01-31 11:30:36 -08:00 committed by Michael Case
parent adc9ee7150
commit d418a14176

View File

@ -410,14 +410,13 @@ class TPUEstimatorSpec(
function should not capture any Tensors in `model_fn`. function should not capture any Tensors in `model_fn`.
`host_call` is a tuple of a `function` and a list or dictionary of `tensors` `host_call` is a tuple of a `function` and a list or dictionary of `tensors`
to pass to that function. `host_call` currently works for train() and to pass to that function and returns a list of Tensors. `host_call` currently
evaluate(). The function's graph is executed on the CPU on every step, so works for train() and evaluate(). The Tensors returned by the function is
there is communication overhead when sending tensors from TPU to CPU. To executed on the CPU on every step, so there is communication overhead when
reduce the overhead, try reducing the size of the tensors. The `tensors` are sending tensors from TPU to CPU. To reduce the overhead, try reducing the
concatenated along their major (batch) dimension, and so must be >= rank 1. size of the tensors. The `tensors` are concatenated along their major (batch)
The `host_call` is useful for writing summaries with dimension, and so must be >= rank 1. The `host_call` is useful for writing
@{tf.contrib.summary.create_file_writer}. Note that `host_call` does not summaries with @{tf.contrib.summary.create_file_writer}.
currently work if `use_tpu` is set to False.
""" """
def __new__(cls, def __new__(cls,
@ -449,10 +448,18 @@ class TPUEstimatorSpec(
def as_estimator_spec(self): def as_estimator_spec(self):
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" """Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
host_calls = {}
if self.eval_metrics is not None:
host_calls['eval_metrics'] = self.eval_metrics
if self.host_call is not None:
host_calls['host_call'] = self.host_call
host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls)
eval_metric_ops = None eval_metric_ops = None
if self.eval_metrics is not None: if self.eval_metrics is not None:
eval_metric_ops = _OutfeedHostCall.create_cpu_hostcall( eval_metric_ops = host_call_ret['eval_metrics']
{'eval_metrics': self.eval_metrics})['eval_metrics'] hooks = None
if self.host_call is not None:
hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]
scaffold = self.scaffold_fn() if self.scaffold_fn else None scaffold = self.scaffold_fn() if self.scaffold_fn else None
return model_fn_lib.EstimatorSpec( return model_fn_lib.EstimatorSpec(
mode=self.mode, mode=self.mode,
@ -461,7 +468,10 @@ class TPUEstimatorSpec(
train_op=self.train_op, train_op=self.train_op,
eval_metric_ops=eval_metric_ops, eval_metric_ops=eval_metric_ops,
export_outputs=self.export_outputs, export_outputs=self.export_outputs,
scaffold=scaffold) scaffold=scaffold,
training_hooks=hooks,
evaluation_hooks=hooks,
prediction_hooks=hooks)
class _OpQueueContext(object): class _OpQueueContext(object):
@ -1450,6 +1460,34 @@ class _OutfeedHostCall(object):
return ret return ret
class _OutfeedHostCallHook(session_run_hook.SessionRunHook):
"""Hook to run host calls when use_tpu=False."""
def __init__(self, tensors):
self._tensors = tensors
def begin(self):
# We duplicate this code from the TPUInfeedOutfeedSessionHook rather than
# create a separate hook to guarantee execution order, because summaries
# need to be initialized before the outfeed thread starts.
# TODO(jhseu): Make a wrapper hook instead?
self._init_ops = contrib_summary.summary_writer_initializer_op()
# Get all the writer resources from the initializer, so we know what to
# flush.
self._finalize_ops = []
for op in self._init_ops:
self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0]))
def after_create_session(self, session, coord):
session.run(self._init_ops)
def before_run(self, run_context):
return basic_session_run_hooks.SessionRunArgs(self._tensors)
def end(self, session):
session.run(self._finalize_ops)
class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
"""Count examples during runtime.""" """Count examples during runtime."""