Add the run function from the revised Distribution Strategy proposal.
PiperOrigin-RevId: 225028975
This commit is contained in:
parent
5741f4b940
commit
90a840fbcb
@ -422,6 +422,42 @@ class DistributionStrategy(object):
|
|||||||
return self.extended._make_input_fn_iterator( # pylint: disable=protected-access
|
return self.extended._make_input_fn_iterator( # pylint: disable=protected-access
|
||||||
input_fn, replication_mode=replication_mode)
|
input_fn, replication_mode=replication_mode)
|
||||||
|
|
||||||
|
def experimental_run(self, fn, input_iterator=None):
|
||||||
|
"""Runs ops in `fn` on each replica, with inputs from `input_iterator`.
|
||||||
|
|
||||||
|
When eager execution is enabled, executes ops specified by `fn` on each
|
||||||
|
replica. Otherwise, builds a graph to execute the ops on each replica.
|
||||||
|
|
||||||
|
Each replica will take a single, different input from the inputs provided by
|
||||||
|
one `get_next` call on the input iterator.
|
||||||
|
|
||||||
|
`fn` may call `tf.distribute.get_replica_context()` to access members such
|
||||||
|
as `replica_id_in_sync_group`.
|
||||||
|
|
||||||
|
IMPORTANT: Depending on the `DistributionStrategy` being used, and whether
|
||||||
|
eager execution is enabled, `fn` may be called one or more times (once for
|
||||||
|
each replica).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: function to run. The inputs to the function must match the outputs of
|
||||||
|
`input_iterator.get_next()`. The output must be a `tf.nest` of
|
||||||
|
`Tensor`s.
|
||||||
|
input_iterator: (Optional) input iterator from which the inputs are taken.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged return value of `fn` across replicas. The structure of the return
|
||||||
|
value is the same as the return value from `fn`. Each element in the
|
||||||
|
structure can either be `PerReplica` (if the values are unsynchronized),
|
||||||
|
`Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
|
||||||
|
single replica).
|
||||||
|
"""
|
||||||
|
with self.scope():
|
||||||
|
if input_iterator is None:
|
||||||
|
return self._extended.call_for_each_replica(fn)
|
||||||
|
else:
|
||||||
|
inputs = input_iterator.get_next()
|
||||||
|
return self._extended.call_for_each_replica(fn, args=(inputs,))
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
|
@doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
|
||||||
def broadcast(self, tensor, destinations=None):
|
def broadcast(self, tensor, destinations=None):
|
||||||
"""DEPRECATED: use extended.broadcast_to() instead."""
|
"""DEPRECATED: use extended.broadcast_to() instead."""
|
||||||
|
@ -75,6 +75,10 @@ tf_class {
|
|||||||
name: "experimental_initialize"
|
name: "experimental_initialize"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "finalize"
|
name: "finalize"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -74,6 +74,10 @@ tf_class {
|
|||||||
name: "experimental_initialize"
|
name: "experimental_initialize"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "finalize"
|
name: "finalize"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -75,6 +75,10 @@ tf_class {
|
|||||||
name: "experimental_initialize"
|
name: "experimental_initialize"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "finalize"
|
name: "finalize"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -74,6 +74,10 @@ tf_class {
|
|||||||
name: "experimental_initialize"
|
name: "experimental_initialize"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "finalize"
|
name: "finalize"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user