Add DistributionStrategy.experimental_run_v2.
PiperOrigin-RevId: 237071002
This commit is contained in:
parent
20c80123a1
commit
9d703eecbf
@ -445,13 +445,13 @@ class DistributionStrategy(object):
|
||||
`fn` may call `tf.distribute.get_replica_context()` to access members such
|
||||
as `replica_id_in_sync_group`.
|
||||
|
||||
IMPORTANT: Depending on the `DistributionStrategy` being used, and whether
|
||||
eager execution is enabled, `fn` may be called one or more times (once for
|
||||
each replica).
|
||||
IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
|
||||
used, and whether eager execution is enabled, `fn` may be called one or more
|
||||
times (once for each replica).
|
||||
|
||||
Args:
|
||||
fn: function to run. The inputs to the function must match the outputs of
|
||||
`input_iterator.get_next()`. The output must be a `tf.nest` of
|
||||
fn: The function to run. The inputs to the function must match the outputs
|
||||
of `input_iterator.get_next()`. The output must be a `tf.nest` of
|
||||
`Tensor`s.
|
||||
input_iterator: (Optional) input iterator from which the inputs are taken.
|
||||
|
||||
@ -463,11 +463,36 @@ class DistributionStrategy(object):
|
||||
single replica).
|
||||
"""
|
||||
with self.scope():
|
||||
if input_iterator is None:
|
||||
return self._extended.call_for_each_replica(fn)
|
||||
else:
|
||||
inputs = input_iterator.get_next()
|
||||
return self._extended.call_for_each_replica(fn, args=(inputs,))
|
||||
args = (input_iterator.get_next(),) if input_iterator is not None else ()
|
||||
return self.experimental_run_v2(fn, args=args)
|
||||
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
||||
"""Runs ops in `fn` on each replica, with the given arguments.
|
||||
|
||||
When eager execution is enabled, executes ops specified by `fn` on each
|
||||
replica. Otherwise, builds a graph to execute the ops on each replica.
|
||||
|
||||
`fn` may call `tf.distribute.get_replica_context()` to access members such
|
||||
as `replica_id_in_sync_group`.
|
||||
|
||||
IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
|
||||
used, and whether eager execution is enabled, `fn` may be called one or more
|
||||
times (once for each replica).
|
||||
|
||||
Args:
|
||||
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
|
||||
args: (Optional) Positional arguments to `fn`.
|
||||
kwargs: (Optional) Keyword arguments to `fn`.
|
||||
|
||||
Returns:
|
||||
Merged return value of `fn` across replicas. The structure of the return
|
||||
value is the same as the return value from `fn`. Each element in the
|
||||
structure can either be `PerReplica` (if the values are unsynchronized),
|
||||
`Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
|
||||
single replica).
|
||||
"""
|
||||
with self.scope():
|
||||
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
|
||||
|
||||
def reduce(self, reduce_op, value):
|
||||
"""Reduce `value` across replicas.
|
||||
|
||||
@ -154,32 +154,28 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
|
||||
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
|
||||
# can use the default implementation.
|
||||
# This implementation runs a single step. It does not use infeed or outfeed.
|
||||
def experimental_run(self, fn, input_iterator=None):
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
||||
"""See base class."""
|
||||
if context.executing_eagerly() and not ops.inside_function():
|
||||
raise NotImplementedError(
|
||||
"Eager mode not supported in TPUStrategy outside TF functions.")
|
||||
|
||||
if input_iterator is None:
|
||||
inputs = []
|
||||
else:
|
||||
inputs = input_iterator.get_next()
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
result = [None]
|
||||
def replicated_fn(replica_id, replica_input):
|
||||
def replicated_fn(replica_id, replica_args, replica_kwargs):
|
||||
"""Wraps user function to provide replica ID and `Tensor` inputs."""
|
||||
with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
|
||||
if input_iterator is None:
|
||||
result[0] = fn()
|
||||
else:
|
||||
result[0] = fn(replica_input)
|
||||
result[0] = fn(*replica_args, **replica_kwargs)
|
||||
return result[0]
|
||||
|
||||
replicate_inputs = [] # By replica.
|
||||
for i in range(self.num_replicas_in_sync):
|
||||
replicate_inputs.append(
|
||||
[constant_op.constant(i, dtype=dtypes.int32),
|
||||
values.select_replica(i, inputs)])
|
||||
values.select_replica(i, args),
|
||||
values.select_replica(i, kwargs)])
|
||||
|
||||
with self.scope():
|
||||
replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)
|
||||
|
||||
@ -31,6 +31,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
||||
@ -31,6 +31,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
||||
@ -30,6 +30,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
||||
@ -31,6 +31,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
||||
@ -31,6 +31,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
||||
@ -35,6 +35,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
||||
@ -31,6 +31,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
||||
@ -31,6 +31,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
||||
@ -30,6 +30,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
||||
@ -31,6 +31,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
||||
@ -31,6 +31,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
||||
@ -35,6 +35,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user