Add the run function from the revised Distribution Strategy proposal.
PiperOrigin-RevId: 225028975
This commit is contained in:
parent
5741f4b940
commit
90a840fbcb
@ -422,6 +422,42 @@ class DistributionStrategy(object):
|
||||
return self.extended._make_input_fn_iterator( # pylint: disable=protected-access
|
||||
input_fn, replication_mode=replication_mode)
|
||||
|
||||
def experimental_run(self, fn, input_iterator=None):
|
||||
"""Runs ops in `fn` on each replica, with inputs from `input_iterator`.
|
||||
|
||||
When eager execution is enabled, executes ops specified by `fn` on each
|
||||
replica. Otherwise, builds a graph to execute the ops on each replica.
|
||||
|
||||
Each replica will take a single, different input from the inputs provided by
|
||||
one `get_next` call on the input iterator.
|
||||
|
||||
`fn` may call `tf.distribute.get_replica_context()` to access members such
|
||||
as `replica_id_in_sync_group`.
|
||||
|
||||
IMPORTANT: Depending on the `DistributionStrategy` being used, and whether
|
||||
eager execution is enabled, `fn` may be called one or more times (once for
|
||||
each replica).
|
||||
|
||||
Args:
|
||||
fn: function to run. The inputs to the function must match the outputs of
|
||||
`input_iterator.get_next()`. The output must be a `tf.nest` of
|
||||
`Tensor`s.
|
||||
input_iterator: (Optional) input iterator from which the inputs are taken.
|
||||
|
||||
Returns:
|
||||
Merged return value of `fn` across replicas. The structure of the return
|
||||
value is the same as the return value from `fn`. Each element in the
|
||||
structure can either be `PerReplica` (if the values are unsynchronized),
|
||||
`Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
|
||||
single replica).
|
||||
"""
|
||||
with self.scope():
|
||||
if input_iterator is None:
|
||||
return self._extended.call_for_each_replica(fn)
|
||||
else:
|
||||
inputs = input_iterator.get_next()
|
||||
return self._extended.call_for_each_replica(fn, args=(inputs,))
|
||||
|
||||
@doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
|
||||
def broadcast(self, tensor, destinations=None):
|
||||
"""DEPRECATED: use extended.broadcast_to() instead."""
|
||||
|
@ -75,6 +75,10 @@ tf_class {
|
||||
name: "experimental_initialize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "finalize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -74,6 +74,10 @@ tf_class {
|
||||
name: "experimental_initialize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "finalize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -75,6 +75,10 @@ tf_class {
|
||||
name: "experimental_initialize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "finalize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -74,6 +74,10 @@ tf_class {
|
||||
name: "experimental_initialize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "finalize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
x
Reference in New Issue
Block a user