Add the run function from the revised Distribution Strategy proposal.

PiperOrigin-RevId: 225028975
This commit is contained in:
Peter Buchlovsky 2018-12-11 10:37:05 -08:00 committed by TensorFlower Gardener
parent 5741f4b940
commit 90a840fbcb
5 changed files with 52 additions and 0 deletions

View File

@ -422,6 +422,42 @@ class DistributionStrategy(object):
return self.extended._make_input_fn_iterator( # pylint: disable=protected-access
input_fn, replication_mode=replication_mode)
def experimental_run(self, fn, input_iterator=None):
"""Runs ops in `fn` on each replica, with inputs from `input_iterator`.
When eager execution is enabled, executes ops specified by `fn` on each
replica. Otherwise, builds a graph to execute the ops on each replica.
Each replica will take a single, different input from the inputs provided by
one `get_next` call on the input iterator.
`fn` may call `tf.distribute.get_replica_context()` to access members such
as `replica_id_in_sync_group`.
IMPORTANT: Depending on the `DistributionStrategy` being used, and whether
eager execution is enabled, `fn` may be called one or more times (once for
each replica).
Args:
fn: function to run. The inputs to the function must match the outputs of
`input_iterator.get_next()`. The output must be a `tf.nest` of
`Tensor`s.
input_iterator: (Optional) input iterator from which the inputs are taken.
Returns:
Merged return value of `fn` across replicas. The structure of the return
value is the same as the return value from `fn`. Each element in the
structure can either be `PerReplica` (if the values are unsynchronized),
`Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
single replica).
"""
with self.scope():
if input_iterator is None:
return self._extended.call_for_each_replica(fn)
else:
inputs = input_iterator.get_next()
return self._extended.call_for_each_replica(fn, args=(inputs,))
@doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
def broadcast(self, tensor, destinations=None):
"""DEPRECATED: use extended.broadcast_to() instead."""

View File

@ -75,6 +75,10 @@ tf_class {
name: "experimental_initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "finalize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -74,6 +74,10 @@ tf_class {
name: "experimental_initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "finalize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -75,6 +75,10 @@ tf_class {
name: "experimental_initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "finalize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -74,6 +74,10 @@ tf_class {
name: "experimental_initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "finalize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"