From 90a840fbcb0d5db6049de261061c48061d345678 Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky Date: Tue, 11 Dec 2018 10:37:05 -0800 Subject: [PATCH] Add the run function from the revised Distribution Strategy proposal. PiperOrigin-RevId: 225028975 --- .../python/distribute/distribute_lib.py | 36 +++++++++++++++++++ ...orflow.distribute.-mirrored-strategy.pbtxt | 4 +++ .../v1/tensorflow.distribute.-strategy.pbtxt | 4 +++ ...orflow.distribute.-mirrored-strategy.pbtxt | 4 +++ .../v2/tensorflow.distribute.-strategy.pbtxt | 4 +++ 5 files changed, 52 insertions(+) diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 87bf510ec54..60bb75ded00 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -422,6 +422,42 @@ class DistributionStrategy(object): return self.extended._make_input_fn_iterator( # pylint: disable=protected-access input_fn, replication_mode=replication_mode) + def experimental_run(self, fn, input_iterator=None): + """Runs ops in `fn` on each replica, with inputs from `input_iterator`. + + When eager execution is enabled, executes ops specified by `fn` on each + replica. Otherwise, builds a graph to execute the ops on each replica. + + Each replica will take a single, different input from the inputs provided by + one `get_next` call on the input iterator. + + `fn` may call `tf.distribute.get_replica_context()` to access members such + as `replica_id_in_sync_group`. + + IMPORTANT: Depending on the `DistributionStrategy` being used, and whether + eager execution is enabled, `fn` may be called one or more times (once for + each replica). + + Args: + fn: function to run. The inputs to the function must match the outputs of + `input_iterator.get_next()`. The output must be a `tf.nest` of + `Tensor`s. + input_iterator: (Optional) input iterator from which the inputs are taken. + + Returns: + Merged return value of `fn` across replicas. The structure of the return + value is the same as the return value from `fn`. Each element in the + structure can either be `PerReplica` (if the values are unsynchronized), + `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a + single replica). + """ + with self.scope(): + if input_iterator is None: + return self._extended.call_for_each_replica(fn) + else: + inputs = input_iterator.get_next() + return self._extended.call_for_each_replica(fn, args=(inputs,)) + @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended` def broadcast(self, tensor, destinations=None): """DEPRECATED: use extended.broadcast_to() instead.""" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt index a613e2d3d1d..81224f00a4a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt @@ -75,6 +75,10 @@ tf_class { name: "experimental_initialize" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_run" + argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "finalize" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt index 9eb73d2c0d9..63b6584caf0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt @@ -74,6 +74,10 @@ tf_class { name: "experimental_initialize" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_run" + argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "finalize" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt index a613e2d3d1d..81224f00a4a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt @@ -75,6 +75,10 @@ tf_class { name: "experimental_initialize" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_run" + argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "finalize" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt index 9eb73d2c0d9..63b6584caf0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt @@ -74,6 +74,10 @@ tf_class { name: "experimental_initialize" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_run" + argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "finalize" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"