From a7467d5d51079d8e5b9bec781e9f26ce8750fcea Mon Sep 17 00:00:00 2001 From: Xinyi Wang Date: Mon, 19 Oct 2020 18:00:45 -0700 Subject: [PATCH] Adding new APIs under tf.distribute: gather and all_gather. `tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` methods are APIs to gather and concatenate `tf.distribute.DistributedValues` object(s) across workers and devices. They are counterparts in cross-replica context and replica context. This methods are implemented for all strategies except ParameterServerStrategy. PiperOrigin-RevId: 337972679 Change-Id: I1d61c96b830683da135d5b4e89da29693c51262c --- RELEASE.md | 1 + .../python/distribute/distribute_lib.py | 382 ++++++++++-------- .../python/distribute/strategy_gather_test.py | 34 +- tensorflow/python/distribute/test_util.py | 2 +- tensorflow/python/distribute/tpu_strategy.py | 9 +- tensorflow/python/keras/engine/training.py | 6 +- ...nsorflow.distribute.-replica-context.pbtxt | 3 +- ...orflow.distribute.-mirrored-strategy.pbtxt | 4 + ...flow.distribute.-one-device-strategy.pbtxt | 4 + ...nsorflow.distribute.-replica-context.pbtxt | 5 + .../v2/tensorflow.distribute.-strategy.pbtxt | 4 + ...ensorflow.distribute.-t-p-u-strategy.pbtxt | 4 + ...perimental.-central-storage-strategy.pbtxt | 4 + ...ntal.-multi-worker-mirrored-strategy.pbtxt | 4 + ...erimental.-parameter-server-strategy.pbtxt | 4 + ...tribute.experimental.-t-p-u-strategy.pbtxt | 4 + 16 files changed, 282 insertions(+), 192 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 28135cc0e34..a05ad11779a 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -767,6 +767,7 @@ stjohnso98, , , , , * Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses. * Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`. * Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization. + * Add `tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` methods to gather and concatenate `tf.distribute.DistributedValues` across workers and devices. ### `tf.keras`: * Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs. diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 4fb68d1ef20..335b5eee310 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -295,7 +295,8 @@ def get_loss_reduction(): # Internal API for validating the current thread mode -def _require_cross_replica_or_default_context_extended(extended): +def _require_cross_replica_or_default_context_extended(extended, + error_message=None): """Verify in cross-replica context.""" context = _get_per_thread_mode() cross_replica = context.cross_replica_context @@ -308,8 +309,10 @@ def _require_cross_replica_or_default_context_extended(extended): if context.strategy is not strategy: _wrong_strategy_scope(strategy, context) assert cross_replica is None - raise RuntimeError("Method requires being in cross-replica context, use " + if not error_message: + error_message = ("Method requires being in cross-replica context, use " "get_replica_context().merge_call()") + raise RuntimeError(error_message) def _wrong_strategy_scope(strategy, context): @@ -1454,107 +1457,6 @@ class StrategyBase(object): denom = math_ops.cast(denom, numer.dtype) return math_ops.truediv(numer, denom) - # TODO(wxinyi): generate docs after it is implemented for all strategies. - # TODO(wxinyi): hide from V1 API - def _gather(self, value, axis): - # pylint: disable=line-too-long, protected-access - """Gather `value` across replicas along `axis` to the current device. - - Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like - object `value`, this API gathers and concatenates `value` along the - `axis`-th dimension. The result is copied to the "current" device - which - would typically be the CPU of the worker on which the program is running. - For `tf.distribute.TPUStrategy`, it is the first TPU host. For multi-client - `MultiWorkerMirroredStrategy`, this is CPU of each worker. - - This API can only be called in the cross-replica context. For a counterpart - in the replica context, see `tf.distribute.ReplicaContext.all_gather`. - - Note: the input `value` on different replicas must have the same rank, and - they must have shapes that are consistent along all dimensions except the - `axis`-th dimension. For example, given a `tf.distribute.DistributedValues` - with tensors of shape `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can - call `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or - `gather(..., axis=2, ...)`. - - - # TODO(wxinyi): convert to testable docstring after implemented for MirroredStrategy - ```python - strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) - local_tensor = tf.constant([[1, 2], [3, 4]]) - distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(local_tensor)) - @tf.function - def run(): - return strategy.gather(distributed_values, axis=0) - run() - # - ``` - - Some more example cases: - - ```python - strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"]) - local_tensor = tf.reshape(tf.range(6), shape=(1,2,3)) - distributed_values = strategy.experimental_distribute_values_from_function(lambda _: local_tensor) - @tf.function - def run(): - return strategy.gather(distributed_values, axis=AXIS) - run() - - # With AXIS=0, the result is - # - # With AXIS=1, the result is - # - # With AXIS=2, the result is - # - - ``` - - Args: - value: a `tf.distribute.DistributedValues` instance, e.g. returned by - `Strategy.run`, to be combined into a single tensor. It can also be a - regular tensor when used with `OneDeviceStrategy` or default strategy. - The underlying tensor constructs can only be dense tensors with non-zero - rank, NOT `tf.IndexedSlices`. - axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the - range [0, rank(value)). - - Returns: - A `Tensor` that's the concatenation of `value` across replicas along - `axis` dimension. - """ - # pylint: enable=line-too-long - _require_cross_replica_or_default_context_extended(self._extended) - dst = device_util.current( - ) or self._extended._default_device or "/device:CPU:0" - if isinstance(value, ops.IndexedSlices): - raise NotImplementedError("gather/all_gather does not support " - "IndexedSlices") - return self._extended._local_results( - self._extended._gather_to(value, dst, axis))[0] - @doc_controls.do_not_doc_inheritable # DEPRECATED def unwrap(self, value): """Returns the list of all local per-replica values contained in `value`. @@ -1787,6 +1689,112 @@ class Strategy(StrategyBase): return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access value_fn) + def gather(self, value, axis): + # pylint: disable=line-too-long, protected-access + """Gather `value` across replicas along `axis` to the current device. + + Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like + object `value`, this API gathers and concatenates `value` across replicas + along the `axis`-th dimension. The result is copied to the "current" device + - which would typically be the CPU of the worker on which the program is + running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For + multi-client `MultiWorkerMirroredStrategy`, this is CPU of each worker. + + This API can only be called in the cross-replica context. For a counterpart + in the replica context, see `tf.distribute.ReplicaContext.all_gather`. + + Note: For all strategies except `tf.distribute.TPUStrategy`, the input + `value` on different replicas must have the same rank, and their shapes must + be the same in all dimensions except the `axis`-th dimension. In other + words, their shapes cannot be different in a dimension `d` where `d` does + not equal to the `axis` argument. For example, given a + `tf.distribute.DistributedValues` with component tensors of shape + `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call + `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or + `gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`, + all tensors must have exactly the same rank and same shape. + + Note: Given a `tf.distribute.DistributedValues` `value`, its component + tensors must have a non-zero rank. Otherwise, consider using + `tf.expand_dims` before gathering them. + + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) + >>> # A DistributedValues with component tensor of shape (2, 1) on each replica + ... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]]))) + >>> @tf.function + ... def run(): + ... return strategy.gather(distributed_values, axis=0) + >>> run() + + + + Consider the following example for more combinations: + + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"]) + >>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3)) + >>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor)) + >>> @tf.function + ... def run(axis): + ... return strategy.gather(distributed_values, axis=axis) + >>> axis=0 + >>> run(axis) + + >>> axis=1 + >>> run(axis) + + >>> axis=2 + >>> run(axis) + + + + Args: + value: a `tf.distribute.DistributedValues` instance, e.g. returned by + `Strategy.run`, to be combined into a single tensor. It can also be a + regular tensor when used with `tf.distribute.OneDeviceStrategy` or the + default strategy. The tensors that constitute the DistributedValues + can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`. + axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the + range [0, rank(value)). + + Returns: + A `Tensor` that's the concatenation of `value` across replicas along + `axis` dimension. + """ + # pylint: enable=line-too-long + error_message = ("tf.distribute.Strategy.gather method requires " + "cross-replica context, use " + "get_replica_context().all_gather() instead") + _require_cross_replica_or_default_context_extended(self._extended, + error_message) + dst = device_util.current( + ) or self._extended._default_device or "/device:CPU:0" + if isinstance(value, ops.IndexedSlices): + raise NotImplementedError("gather does not support IndexedSlices") + return self._extended._local_results( + self._extended._gather_to(value, dst, axis))[0] + # TF v1.x version has additional deprecated APIs @tf_export(v1=["distribute.Strategy"]) @@ -2834,8 +2842,7 @@ class StrategyExtendedV1(StrategyExtendedV2): # It sets the current Strategy for purposes of # `get_strategy()` and `has_strategy()` # and switches the thread mode to a "cross-replica context". -@tf_export("distribute.ReplicaContext") -class ReplicaContext(object): +class ReplicaContextBase(object): """A class with a collection of APIs that can be called in a replica context. You can use `tf.distribute.get_replica_context` to get an instance of @@ -3095,77 +3102,118 @@ class ReplicaContext(object): # to that point that the first result is needed. Most likely this can be # implemented in terms of `merge_call()` and `batch_reduce_to()`. - # TODO(wxinyi): generate docs after it is implemented for all strategies. - def _all_gather(self, value, axis, options=None): + +@tf_export("distribute.ReplicaContext", v1=[]) +class ReplicaContext(ReplicaContextBase): + + __doc__ = ReplicaContextBase.__doc__ + + def all_gather(self, value, axis, options=None): """All-gathers `value` across all replicas along `axis`. - Note: An `all_gather` method can only be called in replica context. To find + Note: An `all_gather` method can only be called in replica context. For a cross-replica context counterpart, see `tf.distribute.Strategy.gather`. All replicas need to participate in the all-gather, otherwise this operation hangs. So if `all_gather` is called in any replica, it must be called in all replicas. - Note: If there're multiple all-gather calls, they need to execute in - the same order on all replicas. Dispatching all-gather based on conditions + Note: If there are multiple `all_gather` calls, they need to be executed in + the same order on all replicas. Dispatching `all_gather` based on conditions is usually error-prone. - # TODO(wxinyi): convert to testable docstring after implemented for MirroredStrategy - ```python - strategy = tf.distribute.MirroredStrategy(["GPU:0", "CPU:0"]) - @tf.function - def gather_value(): - ctx = tf.distribute.get_replica_context() - value = tf.constant([1, 2, 3]) - # all_gather a `tf.distribute.DistributedValues` - return strategy.run(ctx.all_gather(value, axis=0)) - strategy.experimental_local_results(gather_value) - # Result: - # (, - # ) - ``` + For all strategies except `tf.distribute.TPUStrategy`, the input + `value` on different replicas must have the same rank, and their shapes must + be the same in all dimensions except the `axis`-th dimension. In other + words, their shapes cannot be different in a dimension `d` where `d` does + not equal to the `axis` argument. For example, given a + `tf.distribute.DistributedValues` with component tensors of shape + `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call + `all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)` + or `all_gather(..., axis=2, ...)`. However, with + `tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and + same shape. - ```python - strategy = tf.distribute.MirroredStrategy(["GPU:0", "CPU:0"]) - @tf.function - def gather_nest(): - ctx = tf.distribute.get_replica_context() - value_1 = tf.constant([1, 2, 3]) - value_2 = tf.constant([[1, 2], [3, 4]]) - # all_gather a nest of `tf.distribute.DistributedValues` - return ctx.all_gather([value_1, value_2], axis=0) - strategy.experimental_local_results(gather_nest) - # Result: - # ([, , >> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) + >>> @tf.function + ... def gather_value(): + ... ctx = tf.distribute.get_replica_context() + ... local_value = tf.constant([1, 2, 3]) + ... return ctx.all_gather(local_value, axis=0) + >>> result = strategy.run(gather_value) + >>> result + PerReplica:{ + 0: , + 1: + } + >>> strategy.experimental_local_results(result) + (, + ) + + + You can also pass in a nested structure of tensors to all-gather, say, a + list: + + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) + >>> @tf.function + ... def gather_nest(): + ... ctx = tf.distribute.get_replica_context() + ... value_1 = tf.constant([1, 2, 3]) + ... value_2 = tf.constant([[1, 2], [3, 4]]) + ... # all_gather a nest of `tf.distribute.DistributedValues` + ... return ctx.all_gather([value_1, value_2], axis=0) + >>> result = strategy.run(gather_nest) + >>> result + [PerReplica:{ + 0: , + 1: + }, PerReplica:{ + 0: , + 1: + }] + >>> strategy.experimental_local_results(result) + ([PerReplica:{ + 0: , + 1: + }, PerReplica:{ + 0: , + 1: + }],) + + + What if you are all-gathering tensors with different shapes on different + replicas? Consider the following example with two replicas, where you have + `value` as a nested structure consisting of two items to all-gather, `a` and + `b`. + + On Replica 0, `value` is {'a': [0], 'b': [[0, 1]]} + On Replica 1, `value` is {'a': [1], 'b': [[2, 3], [4, 5]]} + + Result for `all_gather` with `axis`=0: (on each of the replicas): {'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]} - Note: an input to be all_gathered must have the same rank on different - replicas, and they must have shapes that are consistent along all dimensions - except the `axis`-th dimension. For example, given a - `tf.distribute.DistributedValues` with tensors of shape `(1, 2, 3)` and - `(1, 3, 3)` on two replicas, you can call `all_gather(..., axis=1, ...)` on - it, but not `all_gather(..., axis=0, ...)` or `all_gather(..., axis=2, ...)`. - - Args: value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts, or a `tf.distribute.DistributedValues` instance. The structure of the @@ -3185,8 +3233,7 @@ class ReplicaContext(object): """ for v in nest.flatten(value): if isinstance(v, ops.IndexedSlices): - raise NotImplementedError("gather/all_gather does not support " - "IndexedSlices") + raise NotImplementedError("all_gather does not support IndexedSlices") if options is None: options = collective_util.Options() @@ -3200,11 +3247,16 @@ class ReplicaContext(object): def grad_wrapper(*xs): ys = self.merge_call(batch_all_gather, args=xs) # The gradient of an all-gather is itself an all-gather. - return ys, lambda *dy_s: self._all_gather(dy_s, axis) + return ys, lambda *dy_s: self.all_gather(dy_s, axis) return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) +@tf_export(v1=["distribute.ReplicaContext"]) +class ReplicaContextV1(ReplicaContextBase): + __doc__ = ReplicaContextBase.__doc__ + + def _batch_reduce_destination(x): """Returns the destinations for batch all-reduce.""" if isinstance(x, ops.Tensor): diff --git a/tensorflow/python/distribute/strategy_gather_test.py b/tensorflow/python/distribute/strategy_gather_test.py index 7cefcf396db..0055f0ea877 100644 --- a/tensorflow/python/distribute/strategy_gather_test.py +++ b/tensorflow/python/distribute/strategy_gather_test.py @@ -74,7 +74,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): lambda _: array_ops.identity(value_on_replica)) def run(): - return strategy._gather(distributed_values, axis=axis) + return strategy.gather(distributed_values, axis=axis) if not pure_eager: run = def_function.function(run) @@ -133,7 +133,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): axis = 0 def run(): - return strategy._gather(distributed_values, axis=axis) + return strategy.gather(distributed_values, axis=axis) if not pure_eager: run = def_function.function(run) @@ -155,7 +155,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): axis = 1 def run(): - return strategy._gather(distributed_values, axis=axis) + return strategy.gather(distributed_values, axis=axis) if not pure_eager: run = def_function.function(run) @@ -183,7 +183,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): axis = 0 def run(): - return strategy._gather(distributed_values, axis=axis) + return strategy.gather(distributed_values, axis=axis) if not pure_eager: run = def_function.function(run) @@ -210,11 +210,11 @@ class GatherTest(test.TestCase, parameterized.TestCase): values=[[1., 2.]], indices=[2], dense_shape=dense_shape) def run(value): - return strategy._gather(value, axis=0) + return strategy.gather(value, axis=0) with self.assertRaisesRegex( NotImplementedError, - r'gather/all_gather does not support IndexedSlices'): + r'gather does not support IndexedSlices'): if pure_eager: run(t0) else: @@ -235,7 +235,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): axis = 0 def run(): - return strategy._gather(distributed_values, axis=axis) + return strategy.gather(distributed_values, axis=axis) if not pure_eager: run = def_function.function(run) @@ -271,7 +271,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): def replica_fn(per_replica_value): ctx = ds_context.get_replica_context() local_value = array_ops.identity(per_replica_value) - return ctx._all_gather(local_value, axis=axis) + return ctx.all_gather(local_value, axis=axis) if not pure_eager: replica_fn = def_function.function(replica_fn) @@ -342,7 +342,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): @def_function.function def replica_fn(per_replica_value): ctx = ds_context.get_replica_context() - return ctx._all_gather(array_ops.identity(per_replica_value), axis=axis) + return ctx.all_gather(array_ops.identity(per_replica_value), axis=axis) result = strategy.experimental_local_results( strategy.run(replica_fn, args=(next(input_iterator),))) @@ -369,7 +369,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): def run(value): value_identity = array_ops.identity(value) ctx = ds_context.get_replica_context() - return ctx._all_gather(value_identity, axis=0) + return ctx.all_gather(value_identity, axis=0) if not pure_eager: run = def_function.function(run) @@ -397,7 +397,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): def run(value): value_identity = array_ops.identity(value) ctx = ds_context.get_replica_context() - return ctx._all_gather(value_identity, axis=1) + return ctx.all_gather(value_identity, axis=1) if not pure_eager: run = def_function.function(run) @@ -436,7 +436,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): value_1 = array_ops.identity(value) value_3 = array_ops.identity(value_2) ctx = ds_context.get_replica_context() - return ctx._all_gather([value_1, value_3], axis=axis) + return ctx.all_gather([value_1, value_3], axis=axis) if not pure_eager: run = def_function.function(run) @@ -455,7 +455,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): def run(): value_identity = array_ops.identity(single_value) ctx = ds_context.get_replica_context() - return ctx._all_gather([value_identity, value_identity], axis=axis) + return ctx.all_gather([value_identity, value_identity], axis=axis) if not pure_eager: run = def_function.function(run) @@ -491,7 +491,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): def run(value): value_identity = array_ops.identity(value) ctx = ds_context.get_replica_context() - return ctx._all_gather(value_identity, axis=0) + return ctx.all_gather(value_identity, axis=0) if not pure_eager: run = def_function.function(run) @@ -519,11 +519,11 @@ class GatherTest(test.TestCase, parameterized.TestCase): def replica_fn(value): ctx = ds_context.get_replica_context() - return ctx._all_gather(value, axis=0) + return ctx.all_gather(value, axis=0) with self.assertRaisesRegex( NotImplementedError, - r'gather/all_gather does not support IndexedSlices'): + r'all_gather does not support IndexedSlices'): if not pure_eager: strategy.run(def_function.function(replica_fn), args=(t0,)) else: @@ -548,7 +548,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): def run(value): value_identity = array_ops.identity(value) ctx = ds_context.get_replica_context() - return ctx._all_gather(value_identity, axis=0) + return ctx.all_gather(value_identity, axis=0) if not pure_eager: run = def_function.function(run) diff --git a/tensorflow/python/distribute/test_util.py b/tensorflow/python/distribute/test_util.py index 82867edb4c2..2f04b67347f 100644 --- a/tensorflow/python/distribute/test_util.py +++ b/tensorflow/python/distribute/test_util.py @@ -58,7 +58,7 @@ def _gather(strategy, value): return array_ops.stack(value._values) assert len(strategy.extended.worker_devices) == len(value._values) inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values] - return strategy._gather(values.PerReplica(inputs), axis=0) + return strategy.gather(values.PerReplica(inputs), axis=0) # pylint: enable=protected-access diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 66eac2d7d93..6f3cc3e79f7 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -1425,12 +1425,11 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): return self.strategy.extended.experimental_logical_device(logical_device_id) # TODO(wxinyi): Investigate whether to use cross_replica_sum to optimize it. - def _all_gather(self, value, axis, options=None): - del options + def all_gather(self, value, axis, experimental_hints=None): + del experimental_hints for v in nest.flatten(value): if isinstance(v, ops.IndexedSlices): - raise NotImplementedError("gather/all_gather does not support " - "IndexedSlices") + raise NotImplementedError("all_gather does not support IndexedSlices") def _all_to_all(value, axis): # The underlying AllToAllOp first do a split of the input value and then @@ -1484,7 +1483,7 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): @custom_gradient.custom_gradient def grad_wrapper(*xs): ys = [_all_to_all(t, axis=axis) for t in xs] - return ys, lambda *dy_s: self._all_gather(dy_s, axis) + return ys, lambda *dy_s: self.all_gather(dy_s, axis) return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index a13d0b01718..986492de4a4 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -2755,7 +2755,7 @@ def _collective_all_reduce_multi_worker(strategy): # for all strategies def _multi_worker_concat(v, strategy): """Order PerReplica objects for CollectiveAllReduceStrategy and concat.""" - replicas = strategy._gather(v, axis=0) # pylint: disable=protected-access + replicas = strategy.gather(v, axis=0) # pylint: disable=protected-access # v might not have the same shape on different replicas if isinstance(v, ds_values.PerReplica): shapes = array_ops.concat([ @@ -2763,10 +2763,10 @@ def _multi_worker_concat(v, strategy): for single_value in v.values ], axis=0) - all_shapes = strategy._gather(shapes, axis=0) # pylint: disable=protected-access + all_shapes = strategy.gather(shapes, axis=0) # pylint: disable=protected-access else: # v is a tensor. This may happen when, say, we have 2x1 multi-worker. - all_shapes = strategy._gather( # pylint: disable=protected-access + all_shapes = strategy.gather( # pylint: disable=protected-access array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), axis=0) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt index 5a9c88dddc0..b364014e55a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.distribute.ReplicaContext" tf_class { - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "devices" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt index 60c0e8d7663..75b0dd33fd3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt @@ -52,6 +52,10 @@ tf_class { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt index 17af12cf279..559ee5e9519 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt @@ -52,6 +52,10 @@ tf_class { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt index 5a9c88dddc0..7379ddc856d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.distribute.ReplicaContext" tf_class { is_instance: "" + is_instance: "" is_instance: "" member { name: "devices" @@ -22,6 +23,10 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "all_gather" + argspec: "args=[\'self\', \'value\', \'axis\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "all_reduce" argspec: "args=[\'self\', \'reduce_op\', \'value\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt index 702cdc98e88..5991d60fd81 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt @@ -51,6 +51,10 @@ tf_class { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt index 0cb06ec3b01..3d8e791613a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt @@ -64,6 +64,10 @@ tf_class { name: "experimental_split_to_logical_devices" argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt index e1b92b44b73..00d7d652a89 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt @@ -52,6 +52,10 @@ tf_class { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt index d0fb8c9a632..9669433cdd8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt @@ -52,6 +52,10 @@ tf_class { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt index e445d7c4dab..3eb65ccfbd2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt @@ -52,6 +52,10 @@ tf_class { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt index 6536eefe414..25c525f8e18 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt @@ -52,6 +52,10 @@ tf_class { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "