Adding new APIs under tf.distribute: gather and all_gather.

`tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` methods are APIs to gather and concatenate `tf.distribute.DistributedValues` object(s) across workers and devices. They are counterparts in cross-replica context and replica context. This methods are implemented for all strategies except ParameterServerStrategy.

PiperOrigin-RevId: 337972679
Change-Id: I1d61c96b830683da135d5b4e89da29693c51262c
This commit is contained in:
Xinyi Wang 2020-10-19 18:00:45 -07:00 committed by TensorFlower Gardener
parent eccb7ec454
commit a7467d5d51
16 changed files with 282 additions and 192 deletions

View File

@ -767,6 +767,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
* Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses.
* Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
* Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization.
* Add `tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` methods to gather and concatenate `tf.distribute.DistributedValues` across workers and devices.
### `tf.keras`:
* Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs.

View File

@ -295,7 +295,8 @@ def get_loss_reduction():
# Internal API for validating the current thread mode
def _require_cross_replica_or_default_context_extended(extended):
def _require_cross_replica_or_default_context_extended(extended,
error_message=None):
"""Verify in cross-replica context."""
context = _get_per_thread_mode()
cross_replica = context.cross_replica_context
@ -308,8 +309,10 @@ def _require_cross_replica_or_default_context_extended(extended):
if context.strategy is not strategy:
_wrong_strategy_scope(strategy, context)
assert cross_replica is None
raise RuntimeError("Method requires being in cross-replica context, use "
if not error_message:
error_message = ("Method requires being in cross-replica context, use "
"get_replica_context().merge_call()")
raise RuntimeError(error_message)
def _wrong_strategy_scope(strategy, context):
@ -1454,107 +1457,6 @@ class StrategyBase(object):
denom = math_ops.cast(denom, numer.dtype)
return math_ops.truediv(numer, denom)
# TODO(wxinyi): generate docs after it is implemented for all strategies.
# TODO(wxinyi): hide from V1 API
def _gather(self, value, axis):
# pylint: disable=line-too-long, protected-access
"""Gather `value` across replicas along `axis` to the current device.
Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like
object `value`, this API gathers and concatenates `value` along the
`axis`-th dimension. The result is copied to the "current" device - which
would typically be the CPU of the worker on which the program is running.
For `tf.distribute.TPUStrategy`, it is the first TPU host. For multi-client
`MultiWorkerMirroredStrategy`, this is CPU of each worker.
This API can only be called in the cross-replica context. For a counterpart
in the replica context, see `tf.distribute.ReplicaContext.all_gather`.
Note: the input `value` on different replicas must have the same rank, and
they must have shapes that are consistent along all dimensions except the
`axis`-th dimension. For example, given a `tf.distribute.DistributedValues`
with tensors of shape `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can
call `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or
`gather(..., axis=2, ...)`.
# TODO(wxinyi): convert to testable docstring after implemented for MirroredStrategy
```python
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
local_tensor = tf.constant([[1, 2], [3, 4]])
distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(local_tensor))
@tf.function
def run():
return strategy.gather(distributed_values, axis=0)
run()
# <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
# array([[1, 2],
# [3, 4],
# [1, 2],
# [3, 4]], dtype=int32)>
```
Some more example cases:
```python
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
local_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
distributed_values = strategy.experimental_distribute_values_from_function(lambda _: local_tensor)
@tf.function
def run():
return strategy.gather(distributed_values, axis=AXIS)
run()
# With AXIS=0, the result is
# <tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy=
# array([[[0, 1, 2],
# [3, 4, 5]],
# [[0, 1, 2],
# [3, 4, 5]],
# [[0, 1, 2],
# [3, 4, 5]],
# [[0, 1, 2],
# [3, 4, 5]]], dtype=int32)>
# With AXIS=1, the result is
# <tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy=
# array([[[0, 1, 2],
# [3, 4, 5],
# [0, 1, 2],
# [3, 4, 5],
# [0, 1, 2],
# [3, 4, 5],
# [0, 1, 2],
# [3, 4, 5]]], dtype=int32)>
# With AXIS=2, the result is
# <tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy=
# array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
# [3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>
```
Args:
value: a `tf.distribute.DistributedValues` instance, e.g. returned by
`Strategy.run`, to be combined into a single tensor. It can also be a
regular tensor when used with `OneDeviceStrategy` or default strategy.
The underlying tensor constructs can only be dense tensors with non-zero
rank, NOT `tf.IndexedSlices`.
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
range [0, rank(value)).
Returns:
A `Tensor` that's the concatenation of `value` across replicas along
`axis` dimension.
"""
# pylint: enable=line-too-long
_require_cross_replica_or_default_context_extended(self._extended)
dst = device_util.current(
) or self._extended._default_device or "/device:CPU:0"
if isinstance(value, ops.IndexedSlices):
raise NotImplementedError("gather/all_gather does not support "
"IndexedSlices")
return self._extended._local_results(
self._extended._gather_to(value, dst, axis))[0]
@doc_controls.do_not_doc_inheritable # DEPRECATED
def unwrap(self, value):
"""Returns the list of all local per-replica values contained in `value`.
@ -1787,6 +1689,112 @@ class Strategy(StrategyBase):
return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access
value_fn)
def gather(self, value, axis):
# pylint: disable=line-too-long, protected-access
"""Gather `value` across replicas along `axis` to the current device.
Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like
object `value`, this API gathers and concatenates `value` across replicas
along the `axis`-th dimension. The result is copied to the "current" device
- which would typically be the CPU of the worker on which the program is
running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For
multi-client `MultiWorkerMirroredStrategy`, this is CPU of each worker.
This API can only be called in the cross-replica context. For a counterpart
in the replica context, see `tf.distribute.ReplicaContext.all_gather`.
Note: For all strategies except `tf.distribute.TPUStrategy`, the input
`value` on different replicas must have the same rank, and their shapes must
be the same in all dimensions except the `axis`-th dimension. In other
words, their shapes cannot be different in a dimension `d` where `d` does
not equal to the `axis` argument. For example, given a
`tf.distribute.DistributedValues` with component tensors of shape
`(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
`gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or
`gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`,
all tensors must have exactly the same rank and same shape.
Note: Given a `tf.distribute.DistributedValues` `value`, its component
tensors must have a non-zero rank. Otherwise, consider using
`tf.expand_dims` before gathering them.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> # A DistributedValues with component tensor of shape (2, 1) on each replica
... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]])))
>>> @tf.function
... def run():
... return strategy.gather(distributed_values, axis=0)
>>> run()
<tf.Tensor: shape=(4, 1), dtype=int32, numpy=
array([[1],
[2],
[1],
[2]], dtype=int32)>
Consider the following example for more combinations:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
>>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
>>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor))
>>> @tf.function
... def run(axis):
... return strategy.gather(distributed_values, axis=axis)
>>> axis=0
>>> run(axis)
<tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy=
array([[[0, 1, 2],
[3, 4, 5]],
[[0, 1, 2],
[3, 4, 5]],
[[0, 1, 2],
[3, 4, 5]],
[[0, 1, 2],
[3, 4, 5]]], dtype=int32)>
>>> axis=1
>>> run(axis)
<tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy=
array([[[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5]]], dtype=int32)>
>>> axis=2
>>> run(axis)
<tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy=
array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>
Args:
value: a `tf.distribute.DistributedValues` instance, e.g. returned by
`Strategy.run`, to be combined into a single tensor. It can also be a
regular tensor when used with `tf.distribute.OneDeviceStrategy` or the
default strategy. The tensors that constitute the DistributedValues
can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`.
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
range [0, rank(value)).
Returns:
A `Tensor` that's the concatenation of `value` across replicas along
`axis` dimension.
"""
# pylint: enable=line-too-long
error_message = ("tf.distribute.Strategy.gather method requires "
"cross-replica context, use "
"get_replica_context().all_gather() instead")
_require_cross_replica_or_default_context_extended(self._extended,
error_message)
dst = device_util.current(
) or self._extended._default_device or "/device:CPU:0"
if isinstance(value, ops.IndexedSlices):
raise NotImplementedError("gather does not support IndexedSlices")
return self._extended._local_results(
self._extended._gather_to(value, dst, axis))[0]
# TF v1.x version has additional deprecated APIs
@tf_export(v1=["distribute.Strategy"])
@ -2834,8 +2842,7 @@ class StrategyExtendedV1(StrategyExtendedV2):
# It sets the current Strategy for purposes of
# `get_strategy()` and `has_strategy()`
# and switches the thread mode to a "cross-replica context".
@tf_export("distribute.ReplicaContext")
class ReplicaContext(object):
class ReplicaContextBase(object):
"""A class with a collection of APIs that can be called in a replica context.
You can use `tf.distribute.get_replica_context` to get an instance of
@ -3095,77 +3102,118 @@ class ReplicaContext(object):
# to that point that the first result is needed. Most likely this can be
# implemented in terms of `merge_call()` and `batch_reduce_to()`.
# TODO(wxinyi): generate docs after it is implemented for all strategies.
def _all_gather(self, value, axis, options=None):
@tf_export("distribute.ReplicaContext", v1=[])
class ReplicaContext(ReplicaContextBase):
__doc__ = ReplicaContextBase.__doc__
def all_gather(self, value, axis, options=None):
"""All-gathers `value` across all replicas along `axis`.
Note: An `all_gather` method can only be called in replica context. To find
Note: An `all_gather` method can only be called in replica context. For
a cross-replica context counterpart, see `tf.distribute.Strategy.gather`.
All replicas need to participate in the all-gather, otherwise this
operation hangs. So if `all_gather` is called in any replica, it must be
called in all replicas.
Note: If there're multiple all-gather calls, they need to execute in
the same order on all replicas. Dispatching all-gather based on conditions
Note: If there are multiple `all_gather` calls, they need to be executed in
the same order on all replicas. Dispatching `all_gather` based on conditions
is usually error-prone.
# TODO(wxinyi): convert to testable docstring after implemented for MirroredStrategy
```python
strategy = tf.distribute.MirroredStrategy(["GPU:0", "CPU:0"])
@tf.function
def gather_value():
ctx = tf.distribute.get_replica_context()
value = tf.constant([1, 2, 3])
# all_gather a `tf.distribute.DistributedValues`
return strategy.run(ctx.all_gather(value, axis=0))
strategy.experimental_local_results(gather_value)
# Result:
# (<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
# dtype=int32)>,
# <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
# dtype=int32)>)
```
For all strategies except `tf.distribute.TPUStrategy`, the input
`value` on different replicas must have the same rank, and their shapes must
be the same in all dimensions except the `axis`-th dimension. In other
words, their shapes cannot be different in a dimension `d` where `d` does
not equal to the `axis` argument. For example, given a
`tf.distribute.DistributedValues` with component tensors of shape
`(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
`all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)`
or `all_gather(..., axis=2, ...)`. However, with
`tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and
same shape.
```python
strategy = tf.distribute.MirroredStrategy(["GPU:0", "CPU:0"])
@tf.function
def gather_nest():
ctx = tf.distribute.get_replica_context()
value_1 = tf.constant([1, 2, 3])
value_2 = tf.constant([[1, 2], [3, 4]])
# all_gather a nest of `tf.distribute.DistributedValues`
return ctx.all_gather([value_1, value_2], axis=0)
strategy.experimental_local_results(gather_nest)
# Result:
# ([<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
# dtype=int32)>, <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
# array([[1, 2],
# [3, 4],
# [1, 2],
# [3, 4]], dtype=int32)],
# [<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
# dtype=int32)>, <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
# array([[1, 2],
# [3, 4],
# [1, 2],
# [3, 4]], dtype=int32)])
```
Note: The input `value` must have a non-zero rank. Otherwise, consider using
`tf.expand_dims` before gathering them.
Example with two replicas:
Replica 0 `value`: {'a': [0], 'b': [[0, 1]]}
Replica 1 `value`: {'a': [1], 'b': [[2, 3], [4, 5]]}
You can pass in a single tensor to all-gather:
Result for `all_gather` with `axis`=0: (on all replicas):
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> @tf.function
... def gather_value():
... ctx = tf.distribute.get_replica_context()
... local_value = tf.constant([1, 2, 3])
... return ctx.all_gather(local_value, axis=0)
>>> result = strategy.run(gather_value)
>>> result
PerReplica:{
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
}
>>> strategy.experimental_local_results(result)
(<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
dtype=int32)>,
<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
dtype=int32)>)
You can also pass in a nested structure of tensors to all-gather, say, a
list:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> @tf.function
... def gather_nest():
... ctx = tf.distribute.get_replica_context()
... value_1 = tf.constant([1, 2, 3])
... value_2 = tf.constant([[1, 2], [3, 4]])
... # all_gather a nest of `tf.distribute.DistributedValues`
... return ctx.all_gather([value_1, value_2], axis=0)
>>> result = strategy.run(gather_nest)
>>> result
[PerReplica:{
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
}, PerReplica:{
0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]], dtype=int32)>,
1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]], dtype=int32)>
}]
>>> strategy.experimental_local_results(result)
([PerReplica:{
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
}, PerReplica:{
0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]], dtype=int32)>,
1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]], dtype=int32)>
}],)
What if you are all-gathering tensors with different shapes on different
replicas? Consider the following example with two replicas, where you have
`value` as a nested structure consisting of two items to all-gather, `a` and
`b`.
On Replica 0, `value` is {'a': [0], 'b': [[0, 1]]}
On Replica 1, `value` is {'a': [1], 'b': [[2, 3], [4, 5]]}
Result for `all_gather` with `axis`=0: (on each of the replicas):
{'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}
Note: an input to be all_gathered must have the same rank on different
replicas, and they must have shapes that are consistent along all dimensions
except the `axis`-th dimension. For example, given a
`tf.distribute.DistributedValues` with tensors of shape `(1, 2, 3)` and
`(1, 3, 3)` on two replicas, you can call `all_gather(..., axis=1, ...)` on
it, but not `all_gather(..., axis=0, ...)` or `all_gather(..., axis=2, ...)`.
Args:
value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts,
or a `tf.distribute.DistributedValues` instance. The structure of the
@ -3185,8 +3233,7 @@ class ReplicaContext(object):
"""
for v in nest.flatten(value):
if isinstance(v, ops.IndexedSlices):
raise NotImplementedError("gather/all_gather does not support "
"IndexedSlices")
raise NotImplementedError("all_gather does not support IndexedSlices")
if options is None:
options = collective_util.Options()
@ -3200,11 +3247,16 @@ class ReplicaContext(object):
def grad_wrapper(*xs):
ys = self.merge_call(batch_all_gather, args=xs)
# The gradient of an all-gather is itself an all-gather.
return ys, lambda *dy_s: self._all_gather(dy_s, axis)
return ys, lambda *dy_s: self.all_gather(dy_s, axis)
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
@tf_export(v1=["distribute.ReplicaContext"])
class ReplicaContextV1(ReplicaContextBase):
__doc__ = ReplicaContextBase.__doc__
def _batch_reduce_destination(x):
"""Returns the destinations for batch all-reduce."""
if isinstance(x, ops.Tensor):

View File

@ -74,7 +74,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
lambda _: array_ops.identity(value_on_replica))
def run():
return strategy._gather(distributed_values, axis=axis)
return strategy.gather(distributed_values, axis=axis)
if not pure_eager:
run = def_function.function(run)
@ -133,7 +133,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
axis = 0
def run():
return strategy._gather(distributed_values, axis=axis)
return strategy.gather(distributed_values, axis=axis)
if not pure_eager:
run = def_function.function(run)
@ -155,7 +155,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
axis = 1
def run():
return strategy._gather(distributed_values, axis=axis)
return strategy.gather(distributed_values, axis=axis)
if not pure_eager:
run = def_function.function(run)
@ -183,7 +183,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
axis = 0
def run():
return strategy._gather(distributed_values, axis=axis)
return strategy.gather(distributed_values, axis=axis)
if not pure_eager:
run = def_function.function(run)
@ -210,11 +210,11 @@ class GatherTest(test.TestCase, parameterized.TestCase):
values=[[1., 2.]], indices=[2], dense_shape=dense_shape)
def run(value):
return strategy._gather(value, axis=0)
return strategy.gather(value, axis=0)
with self.assertRaisesRegex(
NotImplementedError,
r'gather/all_gather does not support IndexedSlices'):
r'gather does not support IndexedSlices'):
if pure_eager:
run(t0)
else:
@ -235,7 +235,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
axis = 0
def run():
return strategy._gather(distributed_values, axis=axis)
return strategy.gather(distributed_values, axis=axis)
if not pure_eager:
run = def_function.function(run)
@ -271,7 +271,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
def replica_fn(per_replica_value):
ctx = ds_context.get_replica_context()
local_value = array_ops.identity(per_replica_value)
return ctx._all_gather(local_value, axis=axis)
return ctx.all_gather(local_value, axis=axis)
if not pure_eager:
replica_fn = def_function.function(replica_fn)
@ -342,7 +342,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
@def_function.function
def replica_fn(per_replica_value):
ctx = ds_context.get_replica_context()
return ctx._all_gather(array_ops.identity(per_replica_value), axis=axis)
return ctx.all_gather(array_ops.identity(per_replica_value), axis=axis)
result = strategy.experimental_local_results(
strategy.run(replica_fn, args=(next(input_iterator),)))
@ -369,7 +369,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
def run(value):
value_identity = array_ops.identity(value)
ctx = ds_context.get_replica_context()
return ctx._all_gather(value_identity, axis=0)
return ctx.all_gather(value_identity, axis=0)
if not pure_eager:
run = def_function.function(run)
@ -397,7 +397,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
def run(value):
value_identity = array_ops.identity(value)
ctx = ds_context.get_replica_context()
return ctx._all_gather(value_identity, axis=1)
return ctx.all_gather(value_identity, axis=1)
if not pure_eager:
run = def_function.function(run)
@ -436,7 +436,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
value_1 = array_ops.identity(value)
value_3 = array_ops.identity(value_2)
ctx = ds_context.get_replica_context()
return ctx._all_gather([value_1, value_3], axis=axis)
return ctx.all_gather([value_1, value_3], axis=axis)
if not pure_eager:
run = def_function.function(run)
@ -455,7 +455,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
def run():
value_identity = array_ops.identity(single_value)
ctx = ds_context.get_replica_context()
return ctx._all_gather([value_identity, value_identity], axis=axis)
return ctx.all_gather([value_identity, value_identity], axis=axis)
if not pure_eager:
run = def_function.function(run)
@ -491,7 +491,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
def run(value):
value_identity = array_ops.identity(value)
ctx = ds_context.get_replica_context()
return ctx._all_gather(value_identity, axis=0)
return ctx.all_gather(value_identity, axis=0)
if not pure_eager:
run = def_function.function(run)
@ -519,11 +519,11 @@ class GatherTest(test.TestCase, parameterized.TestCase):
def replica_fn(value):
ctx = ds_context.get_replica_context()
return ctx._all_gather(value, axis=0)
return ctx.all_gather(value, axis=0)
with self.assertRaisesRegex(
NotImplementedError,
r'gather/all_gather does not support IndexedSlices'):
r'all_gather does not support IndexedSlices'):
if not pure_eager:
strategy.run(def_function.function(replica_fn), args=(t0,))
else:
@ -548,7 +548,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
def run(value):
value_identity = array_ops.identity(value)
ctx = ds_context.get_replica_context()
return ctx._all_gather(value_identity, axis=0)
return ctx.all_gather(value_identity, axis=0)
if not pure_eager:
run = def_function.function(run)

View File

@ -58,7 +58,7 @@ def _gather(strategy, value):
return array_ops.stack(value._values)
assert len(strategy.extended.worker_devices) == len(value._values)
inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
return strategy._gather(values.PerReplica(inputs), axis=0)
return strategy.gather(values.PerReplica(inputs), axis=0)
# pylint: enable=protected-access

View File

@ -1425,12 +1425,11 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
return self.strategy.extended.experimental_logical_device(logical_device_id)
# TODO(wxinyi): Investigate whether to use cross_replica_sum to optimize it.
def _all_gather(self, value, axis, options=None):
del options
def all_gather(self, value, axis, experimental_hints=None):
del experimental_hints
for v in nest.flatten(value):
if isinstance(v, ops.IndexedSlices):
raise NotImplementedError("gather/all_gather does not support "
"IndexedSlices")
raise NotImplementedError("all_gather does not support IndexedSlices")
def _all_to_all(value, axis):
# The underlying AllToAllOp first do a split of the input value and then
@ -1484,7 +1483,7 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
@custom_gradient.custom_gradient
def grad_wrapper(*xs):
ys = [_all_to_all(t, axis=axis) for t in xs]
return ys, lambda *dy_s: self._all_gather(dy_s, axis)
return ys, lambda *dy_s: self.all_gather(dy_s, axis)
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))

View File

@ -2755,7 +2755,7 @@ def _collective_all_reduce_multi_worker(strategy):
# for all strategies
def _multi_worker_concat(v, strategy):
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
replicas = strategy._gather(v, axis=0) # pylint: disable=protected-access
replicas = strategy.gather(v, axis=0) # pylint: disable=protected-access
# v might not have the same shape on different replicas
if isinstance(v, ds_values.PerReplica):
shapes = array_ops.concat([
@ -2763,10 +2763,10 @@ def _multi_worker_concat(v, strategy):
for single_value in v.values
],
axis=0)
all_shapes = strategy._gather(shapes, axis=0) # pylint: disable=protected-access
all_shapes = strategy.gather(shapes, axis=0) # pylint: disable=protected-access
else:
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
all_shapes = strategy._gather( # pylint: disable=protected-access
all_shapes = strategy.gather( # pylint: disable=protected-access
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
axis=0)

View File

@ -1,6 +1,7 @@
path: "tensorflow.distribute.ReplicaContext"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContext\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContextV1\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContextBase\'>"
is_instance: "<type \'object\'>"
member {
name: "devices"

View File

@ -52,6 +52,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "gather"
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -52,6 +52,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "gather"
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1,6 +1,7 @@
path: "tensorflow.distribute.ReplicaContext"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContext\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContextBase\'>"
is_instance: "<type \'object\'>"
member {
name: "devices"
@ -22,6 +23,10 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "all_gather"
argspec: "args=[\'self\', \'value\', \'axis\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "all_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -51,6 +51,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "gather"
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -64,6 +64,10 @@ tf_class {
name: "experimental_split_to_logical_devices"
argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "gather"
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -52,6 +52,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "gather"
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -52,6 +52,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "gather"
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -52,6 +52,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "gather"
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -52,6 +52,10 @@ tf_class {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "gather"
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "