Add DistributionStrategy.experimental_run_v2.

PiperOrigin-RevId: 237071002
This commit is contained in:
Chris Jones 2019-03-06 10:16:40 -08:00 committed by TensorFlower Gardener
parent 20c80123a1
commit 9d703eecbf
14 changed files with 91 additions and 22 deletions

View File

@ -445,13 +445,13 @@ class DistributionStrategy(object):
`fn` may call `tf.distribute.get_replica_context()` to access members such
as `replica_id_in_sync_group`.
IMPORTANT: Depending on the `DistributionStrategy` being used, and whether
eager execution is enabled, `fn` may be called one or more times (once for
each replica).
IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
used, and whether eager execution is enabled, `fn` may be called one or more
times (once for each replica).
Args:
fn: function to run. The inputs to the function must match the outputs of
`input_iterator.get_next()`. The output must be a `tf.nest` of
fn: The function to run. The inputs to the function must match the outputs
of `input_iterator.get_next()`. The output must be a `tf.nest` of
`Tensor`s.
input_iterator: (Optional) input iterator from which the inputs are taken.
@ -463,11 +463,36 @@ class DistributionStrategy(object):
single replica).
"""
with self.scope():
if input_iterator is None:
return self._extended.call_for_each_replica(fn)
else:
inputs = input_iterator.get_next()
return self._extended.call_for_each_replica(fn, args=(inputs,))
args = (input_iterator.get_next(),) if input_iterator is not None else ()
return self.experimental_run_v2(fn, args=args)
def experimental_run_v2(self, fn, args=(), kwargs=None):
"""Runs ops in `fn` on each replica, with the given arguments.
When eager execution is enabled, executes ops specified by `fn` on each
replica. Otherwise, builds a graph to execute the ops on each replica.
`fn` may call `tf.distribute.get_replica_context()` to access members such
as `replica_id_in_sync_group`.
IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
used, and whether eager execution is enabled, `fn` may be called one or more
times (once for each replica).
Args:
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`.
kwargs: (Optional) Keyword arguments to `fn`.
Returns:
Merged return value of `fn` across replicas. The structure of the return
value is the same as the return value from `fn`. Each element in the
structure can either be `PerReplica` (if the values are unsynchronized),
`Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
single replica).
"""
with self.scope():
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
def reduce(self, reduce_op, value):
"""Reduce `value` across replicas.

View File

@ -154,32 +154,28 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
# can use the default implementation.
# This implementation runs a single step. It does not use infeed or outfeed.
def experimental_run(self, fn, input_iterator=None):
def experimental_run_v2(self, fn, args=(), kwargs=None):
"""See base class."""
if context.executing_eagerly() and not ops.inside_function():
raise NotImplementedError(
"Eager mode not supported in TPUStrategy outside TF functions.")
if input_iterator is None:
inputs = []
else:
inputs = input_iterator.get_next()
if kwargs is None:
kwargs = {}
result = [None]
def replicated_fn(replica_id, replica_input):
def replicated_fn(replica_id, replica_args, replica_kwargs):
"""Wraps user function to provide replica ID and `Tensor` inputs."""
with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
if input_iterator is None:
result[0] = fn()
else:
result[0] = fn(replica_input)
result[0] = fn(*replica_args, **replica_kwargs)
return result[0]
replicate_inputs = [] # By replica.
for i in range(self.num_replicas_in_sync):
replicate_inputs.append(
[constant_op.constant(i, dtype=dtypes.int32),
values.select_replica(i, inputs)])
values.select_replica(i, args),
values.select_replica(i, kwargs)])
with self.scope():
replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)

View File

@ -31,6 +31,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -31,6 +31,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -30,6 +30,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -31,6 +31,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -31,6 +31,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -35,6 +35,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -31,6 +31,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -31,6 +31,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -30,6 +30,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -31,6 +31,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -31,6 +31,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -35,6 +35,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "