Add DistributionStrategy.experimental_run_v2.
PiperOrigin-RevId: 237071002
This commit is contained in:
parent
20c80123a1
commit
9d703eecbf
@ -445,13 +445,13 @@ class DistributionStrategy(object):
|
|||||||
`fn` may call `tf.distribute.get_replica_context()` to access members such
|
`fn` may call `tf.distribute.get_replica_context()` to access members such
|
||||||
as `replica_id_in_sync_group`.
|
as `replica_id_in_sync_group`.
|
||||||
|
|
||||||
IMPORTANT: Depending on the `DistributionStrategy` being used, and whether
|
IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
|
||||||
eager execution is enabled, `fn` may be called one or more times (once for
|
used, and whether eager execution is enabled, `fn` may be called one or more
|
||||||
each replica).
|
times (once for each replica).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fn: function to run. The inputs to the function must match the outputs of
|
fn: The function to run. The inputs to the function must match the outputs
|
||||||
`input_iterator.get_next()`. The output must be a `tf.nest` of
|
of `input_iterator.get_next()`. The output must be a `tf.nest` of
|
||||||
`Tensor`s.
|
`Tensor`s.
|
||||||
input_iterator: (Optional) input iterator from which the inputs are taken.
|
input_iterator: (Optional) input iterator from which the inputs are taken.
|
||||||
|
|
||||||
@ -463,11 +463,36 @@ class DistributionStrategy(object):
|
|||||||
single replica).
|
single replica).
|
||||||
"""
|
"""
|
||||||
with self.scope():
|
with self.scope():
|
||||||
if input_iterator is None:
|
args = (input_iterator.get_next(),) if input_iterator is not None else ()
|
||||||
return self._extended.call_for_each_replica(fn)
|
return self.experimental_run_v2(fn, args=args)
|
||||||
else:
|
|
||||||
inputs = input_iterator.get_next()
|
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
||||||
return self._extended.call_for_each_replica(fn, args=(inputs,))
|
"""Runs ops in `fn` on each replica, with the given arguments.
|
||||||
|
|
||||||
|
When eager execution is enabled, executes ops specified by `fn` on each
|
||||||
|
replica. Otherwise, builds a graph to execute the ops on each replica.
|
||||||
|
|
||||||
|
`fn` may call `tf.distribute.get_replica_context()` to access members such
|
||||||
|
as `replica_id_in_sync_group`.
|
||||||
|
|
||||||
|
IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
|
||||||
|
used, and whether eager execution is enabled, `fn` may be called one or more
|
||||||
|
times (once for each replica).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
|
||||||
|
args: (Optional) Positional arguments to `fn`.
|
||||||
|
kwargs: (Optional) Keyword arguments to `fn`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged return value of `fn` across replicas. The structure of the return
|
||||||
|
value is the same as the return value from `fn`. Each element in the
|
||||||
|
structure can either be `PerReplica` (if the values are unsynchronized),
|
||||||
|
`Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
|
||||||
|
single replica).
|
||||||
|
"""
|
||||||
|
with self.scope():
|
||||||
|
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
|
||||||
|
|
||||||
def reduce(self, reduce_op, value):
|
def reduce(self, reduce_op, value):
|
||||||
"""Reduce `value` across replicas.
|
"""Reduce `value` across replicas.
|
||||||
|
|||||||
@ -154,32 +154,28 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
|
|||||||
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
|
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
|
||||||
# can use the default implementation.
|
# can use the default implementation.
|
||||||
# This implementation runs a single step. It does not use infeed or outfeed.
|
# This implementation runs a single step. It does not use infeed or outfeed.
|
||||||
def experimental_run(self, fn, input_iterator=None):
|
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
if context.executing_eagerly() and not ops.inside_function():
|
if context.executing_eagerly() and not ops.inside_function():
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Eager mode not supported in TPUStrategy outside TF functions.")
|
"Eager mode not supported in TPUStrategy outside TF functions.")
|
||||||
|
|
||||||
if input_iterator is None:
|
if kwargs is None:
|
||||||
inputs = []
|
kwargs = {}
|
||||||
else:
|
|
||||||
inputs = input_iterator.get_next()
|
|
||||||
|
|
||||||
result = [None]
|
result = [None]
|
||||||
def replicated_fn(replica_id, replica_input):
|
def replicated_fn(replica_id, replica_args, replica_kwargs):
|
||||||
"""Wraps user function to provide replica ID and `Tensor` inputs."""
|
"""Wraps user function to provide replica ID and `Tensor` inputs."""
|
||||||
with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
|
with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
|
||||||
if input_iterator is None:
|
result[0] = fn(*replica_args, **replica_kwargs)
|
||||||
result[0] = fn()
|
|
||||||
else:
|
|
||||||
result[0] = fn(replica_input)
|
|
||||||
return result[0]
|
return result[0]
|
||||||
|
|
||||||
replicate_inputs = [] # By replica.
|
replicate_inputs = [] # By replica.
|
||||||
for i in range(self.num_replicas_in_sync):
|
for i in range(self.num_replicas_in_sync):
|
||||||
replicate_inputs.append(
|
replicate_inputs.append(
|
||||||
[constant_op.constant(i, dtype=dtypes.int32),
|
[constant_op.constant(i, dtype=dtypes.int32),
|
||||||
values.select_replica(i, inputs)])
|
values.select_replica(i, args),
|
||||||
|
values.select_replica(i, kwargs)])
|
||||||
|
|
||||||
with self.scope():
|
with self.scope():
|
||||||
replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)
|
replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)
|
||||||
|
|||||||
@ -31,6 +31,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_v2"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|||||||
@ -31,6 +31,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_v2"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|||||||
@ -30,6 +30,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_v2"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|||||||
@ -31,6 +31,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_v2"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|||||||
@ -31,6 +31,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_v2"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|||||||
@ -35,6 +35,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_v2"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|||||||
@ -31,6 +31,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_v2"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|||||||
@ -31,6 +31,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_v2"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|||||||
@ -30,6 +30,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_v2"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|||||||
@ -31,6 +31,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_v2"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|||||||
@ -31,6 +31,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_v2"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|||||||
@ -35,6 +35,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_v2"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user