TPUEstimator: support host_call when use_tpu=False.
PiperOrigin-RevId: 184021299
This commit is contained in:
parent
adc9ee7150
commit
d418a14176
@ -410,14 +410,13 @@ class TPUEstimatorSpec(
|
|||||||
function should not capture any Tensors in `model_fn`.
|
function should not capture any Tensors in `model_fn`.
|
||||||
|
|
||||||
`host_call` is a tuple of a `function` and a list or dictionary of `tensors`
|
`host_call` is a tuple of a `function` and a list or dictionary of `tensors`
|
||||||
to pass to that function. `host_call` currently works for train() and
|
to pass to that function and returns a list of Tensors. `host_call` currently
|
||||||
evaluate(). The function's graph is executed on the CPU on every step, so
|
works for train() and evaluate(). The Tensors returned by the function is
|
||||||
there is communication overhead when sending tensors from TPU to CPU. To
|
executed on the CPU on every step, so there is communication overhead when
|
||||||
reduce the overhead, try reducing the size of the tensors. The `tensors` are
|
sending tensors from TPU to CPU. To reduce the overhead, try reducing the
|
||||||
concatenated along their major (batch) dimension, and so must be >= rank 1.
|
size of the tensors. The `tensors` are concatenated along their major (batch)
|
||||||
The `host_call` is useful for writing summaries with
|
dimension, and so must be >= rank 1. The `host_call` is useful for writing
|
||||||
@{tf.contrib.summary.create_file_writer}. Note that `host_call` does not
|
summaries with @{tf.contrib.summary.create_file_writer}.
|
||||||
currently work if `use_tpu` is set to False.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls,
|
def __new__(cls,
|
||||||
@ -449,10 +448,18 @@ class TPUEstimatorSpec(
|
|||||||
|
|
||||||
def as_estimator_spec(self):
|
def as_estimator_spec(self):
|
||||||
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
|
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
|
||||||
|
host_calls = {}
|
||||||
|
if self.eval_metrics is not None:
|
||||||
|
host_calls['eval_metrics'] = self.eval_metrics
|
||||||
|
if self.host_call is not None:
|
||||||
|
host_calls['host_call'] = self.host_call
|
||||||
|
host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls)
|
||||||
eval_metric_ops = None
|
eval_metric_ops = None
|
||||||
if self.eval_metrics is not None:
|
if self.eval_metrics is not None:
|
||||||
eval_metric_ops = _OutfeedHostCall.create_cpu_hostcall(
|
eval_metric_ops = host_call_ret['eval_metrics']
|
||||||
{'eval_metrics': self.eval_metrics})['eval_metrics']
|
hooks = None
|
||||||
|
if self.host_call is not None:
|
||||||
|
hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]
|
||||||
scaffold = self.scaffold_fn() if self.scaffold_fn else None
|
scaffold = self.scaffold_fn() if self.scaffold_fn else None
|
||||||
return model_fn_lib.EstimatorSpec(
|
return model_fn_lib.EstimatorSpec(
|
||||||
mode=self.mode,
|
mode=self.mode,
|
||||||
@ -461,7 +468,10 @@ class TPUEstimatorSpec(
|
|||||||
train_op=self.train_op,
|
train_op=self.train_op,
|
||||||
eval_metric_ops=eval_metric_ops,
|
eval_metric_ops=eval_metric_ops,
|
||||||
export_outputs=self.export_outputs,
|
export_outputs=self.export_outputs,
|
||||||
scaffold=scaffold)
|
scaffold=scaffold,
|
||||||
|
training_hooks=hooks,
|
||||||
|
evaluation_hooks=hooks,
|
||||||
|
prediction_hooks=hooks)
|
||||||
|
|
||||||
|
|
||||||
class _OpQueueContext(object):
|
class _OpQueueContext(object):
|
||||||
@ -1450,6 +1460,34 @@ class _OutfeedHostCall(object):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class _OutfeedHostCallHook(session_run_hook.SessionRunHook):
|
||||||
|
"""Hook to run host calls when use_tpu=False."""
|
||||||
|
|
||||||
|
def __init__(self, tensors):
|
||||||
|
self._tensors = tensors
|
||||||
|
|
||||||
|
def begin(self):
|
||||||
|
# We duplicate this code from the TPUInfeedOutfeedSessionHook rather than
|
||||||
|
# create a separate hook to guarantee execution order, because summaries
|
||||||
|
# need to be initialized before the outfeed thread starts.
|
||||||
|
# TODO(jhseu): Make a wrapper hook instead?
|
||||||
|
self._init_ops = contrib_summary.summary_writer_initializer_op()
|
||||||
|
# Get all the writer resources from the initializer, so we know what to
|
||||||
|
# flush.
|
||||||
|
self._finalize_ops = []
|
||||||
|
for op in self._init_ops:
|
||||||
|
self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0]))
|
||||||
|
|
||||||
|
def after_create_session(self, session, coord):
|
||||||
|
session.run(self._init_ops)
|
||||||
|
|
||||||
|
def before_run(self, run_context):
|
||||||
|
return basic_session_run_hooks.SessionRunArgs(self._tensors)
|
||||||
|
|
||||||
|
def end(self, session):
|
||||||
|
session.run(self._finalize_ops)
|
||||||
|
|
||||||
|
|
||||||
class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
|
class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
|
||||||
"""Count examples during runtime."""
|
"""Count examples during runtime."""
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user