TPUEstimator: support host_call when use_tpu=False.
PiperOrigin-RevId: 184021299
This commit is contained in:
parent
adc9ee7150
commit
d418a14176
@ -410,14 +410,13 @@ class TPUEstimatorSpec(
|
||||
function should not capture any Tensors in `model_fn`.
|
||||
|
||||
`host_call` is a tuple of a `function` and a list or dictionary of `tensors`
|
||||
to pass to that function. `host_call` currently works for train() and
|
||||
evaluate(). The function's graph is executed on the CPU on every step, so
|
||||
there is communication overhead when sending tensors from TPU to CPU. To
|
||||
reduce the overhead, try reducing the size of the tensors. The `tensors` are
|
||||
concatenated along their major (batch) dimension, and so must be >= rank 1.
|
||||
The `host_call` is useful for writing summaries with
|
||||
@{tf.contrib.summary.create_file_writer}. Note that `host_call` does not
|
||||
currently work if `use_tpu` is set to False.
|
||||
to pass to that function and returns a list of Tensors. `host_call` currently
|
||||
works for train() and evaluate(). The Tensors returned by the function is
|
||||
executed on the CPU on every step, so there is communication overhead when
|
||||
sending tensors from TPU to CPU. To reduce the overhead, try reducing the
|
||||
size of the tensors. The `tensors` are concatenated along their major (batch)
|
||||
dimension, and so must be >= rank 1. The `host_call` is useful for writing
|
||||
summaries with @{tf.contrib.summary.create_file_writer}.
|
||||
"""
|
||||
|
||||
def __new__(cls,
|
||||
@ -449,10 +448,18 @@ class TPUEstimatorSpec(
|
||||
|
||||
def as_estimator_spec(self):
|
||||
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
|
||||
host_calls = {}
|
||||
if self.eval_metrics is not None:
|
||||
host_calls['eval_metrics'] = self.eval_metrics
|
||||
if self.host_call is not None:
|
||||
host_calls['host_call'] = self.host_call
|
||||
host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls)
|
||||
eval_metric_ops = None
|
||||
if self.eval_metrics is not None:
|
||||
eval_metric_ops = _OutfeedHostCall.create_cpu_hostcall(
|
||||
{'eval_metrics': self.eval_metrics})['eval_metrics']
|
||||
eval_metric_ops = host_call_ret['eval_metrics']
|
||||
hooks = None
|
||||
if self.host_call is not None:
|
||||
hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]
|
||||
scaffold = self.scaffold_fn() if self.scaffold_fn else None
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=self.mode,
|
||||
@ -461,7 +468,10 @@ class TPUEstimatorSpec(
|
||||
train_op=self.train_op,
|
||||
eval_metric_ops=eval_metric_ops,
|
||||
export_outputs=self.export_outputs,
|
||||
scaffold=scaffold)
|
||||
scaffold=scaffold,
|
||||
training_hooks=hooks,
|
||||
evaluation_hooks=hooks,
|
||||
prediction_hooks=hooks)
|
||||
|
||||
|
||||
class _OpQueueContext(object):
|
||||
@ -1450,6 +1460,34 @@ class _OutfeedHostCall(object):
|
||||
return ret
|
||||
|
||||
|
||||
class _OutfeedHostCallHook(session_run_hook.SessionRunHook):
|
||||
"""Hook to run host calls when use_tpu=False."""
|
||||
|
||||
def __init__(self, tensors):
|
||||
self._tensors = tensors
|
||||
|
||||
def begin(self):
|
||||
# We duplicate this code from the TPUInfeedOutfeedSessionHook rather than
|
||||
# create a separate hook to guarantee execution order, because summaries
|
||||
# need to be initialized before the outfeed thread starts.
|
||||
# TODO(jhseu): Make a wrapper hook instead?
|
||||
self._init_ops = contrib_summary.summary_writer_initializer_op()
|
||||
# Get all the writer resources from the initializer, so we know what to
|
||||
# flush.
|
||||
self._finalize_ops = []
|
||||
for op in self._init_ops:
|
||||
self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0]))
|
||||
|
||||
def after_create_session(self, session, coord):
|
||||
session.run(self._init_ops)
|
||||
|
||||
def before_run(self, run_context):
|
||||
return basic_session_run_hooks.SessionRunArgs(self._tensors)
|
||||
|
||||
def end(self, session):
|
||||
session.run(self._finalize_ops)
|
||||
|
||||
|
||||
class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
|
||||
"""Count examples during runtime."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user