Adding new APIs under tf.distribute: gather and all_gather.
`tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` methods are APIs to gather and concatenate `tf.distribute.DistributedValues` object(s) across workers and devices. They are counterparts in cross-replica context and replica context. This methods are implemented for all strategies except ParameterServerStrategy. PiperOrigin-RevId: 337972679 Change-Id: I1d61c96b830683da135d5b4e89da29693c51262c
This commit is contained in:
parent
eccb7ec454
commit
a7467d5d51
RELEASE.md
tensorflow
python
distribute
keras/engine
tools/api/golden
v1
v2
tensorflow.distribute.-mirrored-strategy.pbtxttensorflow.distribute.-one-device-strategy.pbtxttensorflow.distribute.-replica-context.pbtxttensorflow.distribute.-strategy.pbtxttensorflow.distribute.-t-p-u-strategy.pbtxttensorflow.distribute.experimental.-central-storage-strategy.pbtxttensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxttensorflow.distribute.experimental.-parameter-server-strategy.pbtxttensorflow.distribute.experimental.-t-p-u-strategy.pbtxt
@ -767,6 +767,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
* Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses.
|
||||
* Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
|
||||
* Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization.
|
||||
* Add `tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` methods to gather and concatenate `tf.distribute.DistributedValues` across workers and devices.
|
||||
|
||||
### `tf.keras`:
|
||||
* Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs.
|
||||
|
@ -295,7 +295,8 @@ def get_loss_reduction():
|
||||
# Internal API for validating the current thread mode
|
||||
|
||||
|
||||
def _require_cross_replica_or_default_context_extended(extended):
|
||||
def _require_cross_replica_or_default_context_extended(extended,
|
||||
error_message=None):
|
||||
"""Verify in cross-replica context."""
|
||||
context = _get_per_thread_mode()
|
||||
cross_replica = context.cross_replica_context
|
||||
@ -308,8 +309,10 @@ def _require_cross_replica_or_default_context_extended(extended):
|
||||
if context.strategy is not strategy:
|
||||
_wrong_strategy_scope(strategy, context)
|
||||
assert cross_replica is None
|
||||
raise RuntimeError("Method requires being in cross-replica context, use "
|
||||
if not error_message:
|
||||
error_message = ("Method requires being in cross-replica context, use "
|
||||
"get_replica_context().merge_call()")
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
|
||||
def _wrong_strategy_scope(strategy, context):
|
||||
@ -1454,107 +1457,6 @@ class StrategyBase(object):
|
||||
denom = math_ops.cast(denom, numer.dtype)
|
||||
return math_ops.truediv(numer, denom)
|
||||
|
||||
# TODO(wxinyi): generate docs after it is implemented for all strategies.
|
||||
# TODO(wxinyi): hide from V1 API
|
||||
def _gather(self, value, axis):
|
||||
# pylint: disable=line-too-long, protected-access
|
||||
"""Gather `value` across replicas along `axis` to the current device.
|
||||
|
||||
Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like
|
||||
object `value`, this API gathers and concatenates `value` along the
|
||||
`axis`-th dimension. The result is copied to the "current" device - which
|
||||
would typically be the CPU of the worker on which the program is running.
|
||||
For `tf.distribute.TPUStrategy`, it is the first TPU host. For multi-client
|
||||
`MultiWorkerMirroredStrategy`, this is CPU of each worker.
|
||||
|
||||
This API can only be called in the cross-replica context. For a counterpart
|
||||
in the replica context, see `tf.distribute.ReplicaContext.all_gather`.
|
||||
|
||||
Note: the input `value` on different replicas must have the same rank, and
|
||||
they must have shapes that are consistent along all dimensions except the
|
||||
`axis`-th dimension. For example, given a `tf.distribute.DistributedValues`
|
||||
with tensors of shape `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can
|
||||
call `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or
|
||||
`gather(..., axis=2, ...)`.
|
||||
|
||||
|
||||
# TODO(wxinyi): convert to testable docstring after implemented for MirroredStrategy
|
||||
```python
|
||||
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
||||
local_tensor = tf.constant([[1, 2], [3, 4]])
|
||||
distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(local_tensor))
|
||||
@tf.function
|
||||
def run():
|
||||
return strategy.gather(distributed_values, axis=0)
|
||||
run()
|
||||
# <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
||||
# array([[1, 2],
|
||||
# [3, 4],
|
||||
# [1, 2],
|
||||
# [3, 4]], dtype=int32)>
|
||||
```
|
||||
|
||||
Some more example cases:
|
||||
|
||||
```python
|
||||
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
|
||||
local_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
|
||||
distributed_values = strategy.experimental_distribute_values_from_function(lambda _: local_tensor)
|
||||
@tf.function
|
||||
def run():
|
||||
return strategy.gather(distributed_values, axis=AXIS)
|
||||
run()
|
||||
|
||||
# With AXIS=0, the result is
|
||||
# <tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy=
|
||||
# array([[[0, 1, 2],
|
||||
# [3, 4, 5]],
|
||||
# [[0, 1, 2],
|
||||
# [3, 4, 5]],
|
||||
# [[0, 1, 2],
|
||||
# [3, 4, 5]],
|
||||
# [[0, 1, 2],
|
||||
# [3, 4, 5]]], dtype=int32)>
|
||||
# With AXIS=1, the result is
|
||||
# <tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy=
|
||||
# array([[[0, 1, 2],
|
||||
# [3, 4, 5],
|
||||
# [0, 1, 2],
|
||||
# [3, 4, 5],
|
||||
# [0, 1, 2],
|
||||
# [3, 4, 5],
|
||||
# [0, 1, 2],
|
||||
# [3, 4, 5]]], dtype=int32)>
|
||||
# With AXIS=2, the result is
|
||||
# <tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy=
|
||||
# array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
|
||||
# [3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>
|
||||
|
||||
```
|
||||
|
||||
Args:
|
||||
value: a `tf.distribute.DistributedValues` instance, e.g. returned by
|
||||
`Strategy.run`, to be combined into a single tensor. It can also be a
|
||||
regular tensor when used with `OneDeviceStrategy` or default strategy.
|
||||
The underlying tensor constructs can only be dense tensors with non-zero
|
||||
rank, NOT `tf.IndexedSlices`.
|
||||
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
|
||||
range [0, rank(value)).
|
||||
|
||||
Returns:
|
||||
A `Tensor` that's the concatenation of `value` across replicas along
|
||||
`axis` dimension.
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
_require_cross_replica_or_default_context_extended(self._extended)
|
||||
dst = device_util.current(
|
||||
) or self._extended._default_device or "/device:CPU:0"
|
||||
if isinstance(value, ops.IndexedSlices):
|
||||
raise NotImplementedError("gather/all_gather does not support "
|
||||
"IndexedSlices")
|
||||
return self._extended._local_results(
|
||||
self._extended._gather_to(value, dst, axis))[0]
|
||||
|
||||
@doc_controls.do_not_doc_inheritable # DEPRECATED
|
||||
def unwrap(self, value):
|
||||
"""Returns the list of all local per-replica values contained in `value`.
|
||||
@ -1787,6 +1689,112 @@ class Strategy(StrategyBase):
|
||||
return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access
|
||||
value_fn)
|
||||
|
||||
def gather(self, value, axis):
|
||||
# pylint: disable=line-too-long, protected-access
|
||||
"""Gather `value` across replicas along `axis` to the current device.
|
||||
|
||||
Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like
|
||||
object `value`, this API gathers and concatenates `value` across replicas
|
||||
along the `axis`-th dimension. The result is copied to the "current" device
|
||||
- which would typically be the CPU of the worker on which the program is
|
||||
running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For
|
||||
multi-client `MultiWorkerMirroredStrategy`, this is CPU of each worker.
|
||||
|
||||
This API can only be called in the cross-replica context. For a counterpart
|
||||
in the replica context, see `tf.distribute.ReplicaContext.all_gather`.
|
||||
|
||||
Note: For all strategies except `tf.distribute.TPUStrategy`, the input
|
||||
`value` on different replicas must have the same rank, and their shapes must
|
||||
be the same in all dimensions except the `axis`-th dimension. In other
|
||||
words, their shapes cannot be different in a dimension `d` where `d` does
|
||||
not equal to the `axis` argument. For example, given a
|
||||
`tf.distribute.DistributedValues` with component tensors of shape
|
||||
`(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
|
||||
`gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or
|
||||
`gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`,
|
||||
all tensors must have exactly the same rank and same shape.
|
||||
|
||||
Note: Given a `tf.distribute.DistributedValues` `value`, its component
|
||||
tensors must have a non-zero rank. Otherwise, consider using
|
||||
`tf.expand_dims` before gathering them.
|
||||
|
||||
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
||||
>>> # A DistributedValues with component tensor of shape (2, 1) on each replica
|
||||
... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]])))
|
||||
>>> @tf.function
|
||||
... def run():
|
||||
... return strategy.gather(distributed_values, axis=0)
|
||||
>>> run()
|
||||
<tf.Tensor: shape=(4, 1), dtype=int32, numpy=
|
||||
array([[1],
|
||||
[2],
|
||||
[1],
|
||||
[2]], dtype=int32)>
|
||||
|
||||
|
||||
Consider the following example for more combinations:
|
||||
|
||||
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
|
||||
>>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
|
||||
>>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor))
|
||||
>>> @tf.function
|
||||
... def run(axis):
|
||||
... return strategy.gather(distributed_values, axis=axis)
|
||||
>>> axis=0
|
||||
>>> run(axis)
|
||||
<tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy=
|
||||
array([[[0, 1, 2],
|
||||
[3, 4, 5]],
|
||||
[[0, 1, 2],
|
||||
[3, 4, 5]],
|
||||
[[0, 1, 2],
|
||||
[3, 4, 5]],
|
||||
[[0, 1, 2],
|
||||
[3, 4, 5]]], dtype=int32)>
|
||||
>>> axis=1
|
||||
>>> run(axis)
|
||||
<tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy=
|
||||
array([[[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[0, 1, 2],
|
||||
[3, 4, 5]]], dtype=int32)>
|
||||
>>> axis=2
|
||||
>>> run(axis)
|
||||
<tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy=
|
||||
array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
|
||||
[3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>
|
||||
|
||||
|
||||
Args:
|
||||
value: a `tf.distribute.DistributedValues` instance, e.g. returned by
|
||||
`Strategy.run`, to be combined into a single tensor. It can also be a
|
||||
regular tensor when used with `tf.distribute.OneDeviceStrategy` or the
|
||||
default strategy. The tensors that constitute the DistributedValues
|
||||
can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`.
|
||||
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
|
||||
range [0, rank(value)).
|
||||
|
||||
Returns:
|
||||
A `Tensor` that's the concatenation of `value` across replicas along
|
||||
`axis` dimension.
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
error_message = ("tf.distribute.Strategy.gather method requires "
|
||||
"cross-replica context, use "
|
||||
"get_replica_context().all_gather() instead")
|
||||
_require_cross_replica_or_default_context_extended(self._extended,
|
||||
error_message)
|
||||
dst = device_util.current(
|
||||
) or self._extended._default_device or "/device:CPU:0"
|
||||
if isinstance(value, ops.IndexedSlices):
|
||||
raise NotImplementedError("gather does not support IndexedSlices")
|
||||
return self._extended._local_results(
|
||||
self._extended._gather_to(value, dst, axis))[0]
|
||||
|
||||
|
||||
# TF v1.x version has additional deprecated APIs
|
||||
@tf_export(v1=["distribute.Strategy"])
|
||||
@ -2834,8 +2842,7 @@ class StrategyExtendedV1(StrategyExtendedV2):
|
||||
# It sets the current Strategy for purposes of
|
||||
# `get_strategy()` and `has_strategy()`
|
||||
# and switches the thread mode to a "cross-replica context".
|
||||
@tf_export("distribute.ReplicaContext")
|
||||
class ReplicaContext(object):
|
||||
class ReplicaContextBase(object):
|
||||
"""A class with a collection of APIs that can be called in a replica context.
|
||||
|
||||
You can use `tf.distribute.get_replica_context` to get an instance of
|
||||
@ -3095,77 +3102,118 @@ class ReplicaContext(object):
|
||||
# to that point that the first result is needed. Most likely this can be
|
||||
# implemented in terms of `merge_call()` and `batch_reduce_to()`.
|
||||
|
||||
# TODO(wxinyi): generate docs after it is implemented for all strategies.
|
||||
def _all_gather(self, value, axis, options=None):
|
||||
|
||||
@tf_export("distribute.ReplicaContext", v1=[])
|
||||
class ReplicaContext(ReplicaContextBase):
|
||||
|
||||
__doc__ = ReplicaContextBase.__doc__
|
||||
|
||||
def all_gather(self, value, axis, options=None):
|
||||
"""All-gathers `value` across all replicas along `axis`.
|
||||
|
||||
Note: An `all_gather` method can only be called in replica context. To find
|
||||
Note: An `all_gather` method can only be called in replica context. For
|
||||
a cross-replica context counterpart, see `tf.distribute.Strategy.gather`.
|
||||
All replicas need to participate in the all-gather, otherwise this
|
||||
operation hangs. So if `all_gather` is called in any replica, it must be
|
||||
called in all replicas.
|
||||
|
||||
Note: If there're multiple all-gather calls, they need to execute in
|
||||
the same order on all replicas. Dispatching all-gather based on conditions
|
||||
Note: If there are multiple `all_gather` calls, they need to be executed in
|
||||
the same order on all replicas. Dispatching `all_gather` based on conditions
|
||||
is usually error-prone.
|
||||
|
||||
# TODO(wxinyi): convert to testable docstring after implemented for MirroredStrategy
|
||||
```python
|
||||
strategy = tf.distribute.MirroredStrategy(["GPU:0", "CPU:0"])
|
||||
@tf.function
|
||||
def gather_value():
|
||||
ctx = tf.distribute.get_replica_context()
|
||||
value = tf.constant([1, 2, 3])
|
||||
# all_gather a `tf.distribute.DistributedValues`
|
||||
return strategy.run(ctx.all_gather(value, axis=0))
|
||||
strategy.experimental_local_results(gather_value)
|
||||
# Result:
|
||||
# (<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
|
||||
# dtype=int32)>,
|
||||
# <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
|
||||
# dtype=int32)>)
|
||||
```
|
||||
For all strategies except `tf.distribute.TPUStrategy`, the input
|
||||
`value` on different replicas must have the same rank, and their shapes must
|
||||
be the same in all dimensions except the `axis`-th dimension. In other
|
||||
words, their shapes cannot be different in a dimension `d` where `d` does
|
||||
not equal to the `axis` argument. For example, given a
|
||||
`tf.distribute.DistributedValues` with component tensors of shape
|
||||
`(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
|
||||
`all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)`
|
||||
or `all_gather(..., axis=2, ...)`. However, with
|
||||
`tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and
|
||||
same shape.
|
||||
|
||||
```python
|
||||
strategy = tf.distribute.MirroredStrategy(["GPU:0", "CPU:0"])
|
||||
@tf.function
|
||||
def gather_nest():
|
||||
ctx = tf.distribute.get_replica_context()
|
||||
value_1 = tf.constant([1, 2, 3])
|
||||
value_2 = tf.constant([[1, 2], [3, 4]])
|
||||
# all_gather a nest of `tf.distribute.DistributedValues`
|
||||
return ctx.all_gather([value_1, value_2], axis=0)
|
||||
strategy.experimental_local_results(gather_nest)
|
||||
# Result:
|
||||
# ([<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
|
||||
# dtype=int32)>, <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
||||
# array([[1, 2],
|
||||
# [3, 4],
|
||||
# [1, 2],
|
||||
# [3, 4]], dtype=int32)],
|
||||
# [<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
|
||||
# dtype=int32)>, <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
||||
# array([[1, 2],
|
||||
# [3, 4],
|
||||
# [1, 2],
|
||||
# [3, 4]], dtype=int32)])
|
||||
```
|
||||
Note: The input `value` must have a non-zero rank. Otherwise, consider using
|
||||
`tf.expand_dims` before gathering them.
|
||||
|
||||
Example with two replicas:
|
||||
Replica 0 `value`: {'a': [0], 'b': [[0, 1]]}
|
||||
Replica 1 `value`: {'a': [1], 'b': [[2, 3], [4, 5]]}
|
||||
You can pass in a single tensor to all-gather:
|
||||
|
||||
Result for `all_gather` with `axis`=0: (on all replicas):
|
||||
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
||||
>>> @tf.function
|
||||
... def gather_value():
|
||||
... ctx = tf.distribute.get_replica_context()
|
||||
... local_value = tf.constant([1, 2, 3])
|
||||
... return ctx.all_gather(local_value, axis=0)
|
||||
>>> result = strategy.run(gather_value)
|
||||
>>> result
|
||||
PerReplica:{
|
||||
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
|
||||
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
|
||||
}
|
||||
>>> strategy.experimental_local_results(result)
|
||||
(<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
|
||||
dtype=int32)>,
|
||||
<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
|
||||
dtype=int32)>)
|
||||
|
||||
|
||||
You can also pass in a nested structure of tensors to all-gather, say, a
|
||||
list:
|
||||
|
||||
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
||||
>>> @tf.function
|
||||
... def gather_nest():
|
||||
... ctx = tf.distribute.get_replica_context()
|
||||
... value_1 = tf.constant([1, 2, 3])
|
||||
... value_2 = tf.constant([[1, 2], [3, 4]])
|
||||
... # all_gather a nest of `tf.distribute.DistributedValues`
|
||||
... return ctx.all_gather([value_1, value_2], axis=0)
|
||||
>>> result = strategy.run(gather_nest)
|
||||
>>> result
|
||||
[PerReplica:{
|
||||
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
|
||||
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
|
||||
}, PerReplica:{
|
||||
0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
||||
array([[1, 2],
|
||||
[3, 4],
|
||||
[1, 2],
|
||||
[3, 4]], dtype=int32)>,
|
||||
1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
||||
array([[1, 2],
|
||||
[3, 4],
|
||||
[1, 2],
|
||||
[3, 4]], dtype=int32)>
|
||||
}]
|
||||
>>> strategy.experimental_local_results(result)
|
||||
([PerReplica:{
|
||||
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
|
||||
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
|
||||
}, PerReplica:{
|
||||
0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
||||
array([[1, 2],
|
||||
[3, 4],
|
||||
[1, 2],
|
||||
[3, 4]], dtype=int32)>,
|
||||
1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
||||
array([[1, 2],
|
||||
[3, 4],
|
||||
[1, 2],
|
||||
[3, 4]], dtype=int32)>
|
||||
}],)
|
||||
|
||||
|
||||
What if you are all-gathering tensors with different shapes on different
|
||||
replicas? Consider the following example with two replicas, where you have
|
||||
`value` as a nested structure consisting of two items to all-gather, `a` and
|
||||
`b`.
|
||||
|
||||
On Replica 0, `value` is {'a': [0], 'b': [[0, 1]]}
|
||||
On Replica 1, `value` is {'a': [1], 'b': [[2, 3], [4, 5]]}
|
||||
|
||||
Result for `all_gather` with `axis`=0: (on each of the replicas):
|
||||
{'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}
|
||||
|
||||
Note: an input to be all_gathered must have the same rank on different
|
||||
replicas, and they must have shapes that are consistent along all dimensions
|
||||
except the `axis`-th dimension. For example, given a
|
||||
`tf.distribute.DistributedValues` with tensors of shape `(1, 2, 3)` and
|
||||
`(1, 3, 3)` on two replicas, you can call `all_gather(..., axis=1, ...)` on
|
||||
it, but not `all_gather(..., axis=0, ...)` or `all_gather(..., axis=2, ...)`.
|
||||
|
||||
|
||||
Args:
|
||||
value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts,
|
||||
or a `tf.distribute.DistributedValues` instance. The structure of the
|
||||
@ -3185,8 +3233,7 @@ class ReplicaContext(object):
|
||||
"""
|
||||
for v in nest.flatten(value):
|
||||
if isinstance(v, ops.IndexedSlices):
|
||||
raise NotImplementedError("gather/all_gather does not support "
|
||||
"IndexedSlices")
|
||||
raise NotImplementedError("all_gather does not support IndexedSlices")
|
||||
|
||||
if options is None:
|
||||
options = collective_util.Options()
|
||||
@ -3200,11 +3247,16 @@ class ReplicaContext(object):
|
||||
def grad_wrapper(*xs):
|
||||
ys = self.merge_call(batch_all_gather, args=xs)
|
||||
# The gradient of an all-gather is itself an all-gather.
|
||||
return ys, lambda *dy_s: self._all_gather(dy_s, axis)
|
||||
return ys, lambda *dy_s: self.all_gather(dy_s, axis)
|
||||
|
||||
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
|
||||
|
||||
|
||||
@tf_export(v1=["distribute.ReplicaContext"])
|
||||
class ReplicaContextV1(ReplicaContextBase):
|
||||
__doc__ = ReplicaContextBase.__doc__
|
||||
|
||||
|
||||
def _batch_reduce_destination(x):
|
||||
"""Returns the destinations for batch all-reduce."""
|
||||
if isinstance(x, ops.Tensor):
|
||||
|
@ -74,7 +74,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
lambda _: array_ops.identity(value_on_replica))
|
||||
|
||||
def run():
|
||||
return strategy._gather(distributed_values, axis=axis)
|
||||
return strategy.gather(distributed_values, axis=axis)
|
||||
|
||||
if not pure_eager:
|
||||
run = def_function.function(run)
|
||||
@ -133,7 +133,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
axis = 0
|
||||
|
||||
def run():
|
||||
return strategy._gather(distributed_values, axis=axis)
|
||||
return strategy.gather(distributed_values, axis=axis)
|
||||
|
||||
if not pure_eager:
|
||||
run = def_function.function(run)
|
||||
@ -155,7 +155,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
axis = 1
|
||||
|
||||
def run():
|
||||
return strategy._gather(distributed_values, axis=axis)
|
||||
return strategy.gather(distributed_values, axis=axis)
|
||||
|
||||
if not pure_eager:
|
||||
run = def_function.function(run)
|
||||
@ -183,7 +183,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
axis = 0
|
||||
|
||||
def run():
|
||||
return strategy._gather(distributed_values, axis=axis)
|
||||
return strategy.gather(distributed_values, axis=axis)
|
||||
|
||||
if not pure_eager:
|
||||
run = def_function.function(run)
|
||||
@ -210,11 +210,11 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
values=[[1., 2.]], indices=[2], dense_shape=dense_shape)
|
||||
|
||||
def run(value):
|
||||
return strategy._gather(value, axis=0)
|
||||
return strategy.gather(value, axis=0)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
r'gather/all_gather does not support IndexedSlices'):
|
||||
r'gather does not support IndexedSlices'):
|
||||
if pure_eager:
|
||||
run(t0)
|
||||
else:
|
||||
@ -235,7 +235,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
axis = 0
|
||||
|
||||
def run():
|
||||
return strategy._gather(distributed_values, axis=axis)
|
||||
return strategy.gather(distributed_values, axis=axis)
|
||||
|
||||
if not pure_eager:
|
||||
run = def_function.function(run)
|
||||
@ -271,7 +271,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
def replica_fn(per_replica_value):
|
||||
ctx = ds_context.get_replica_context()
|
||||
local_value = array_ops.identity(per_replica_value)
|
||||
return ctx._all_gather(local_value, axis=axis)
|
||||
return ctx.all_gather(local_value, axis=axis)
|
||||
|
||||
if not pure_eager:
|
||||
replica_fn = def_function.function(replica_fn)
|
||||
@ -342,7 +342,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
@def_function.function
|
||||
def replica_fn(per_replica_value):
|
||||
ctx = ds_context.get_replica_context()
|
||||
return ctx._all_gather(array_ops.identity(per_replica_value), axis=axis)
|
||||
return ctx.all_gather(array_ops.identity(per_replica_value), axis=axis)
|
||||
|
||||
result = strategy.experimental_local_results(
|
||||
strategy.run(replica_fn, args=(next(input_iterator),)))
|
||||
@ -369,7 +369,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
def run(value):
|
||||
value_identity = array_ops.identity(value)
|
||||
ctx = ds_context.get_replica_context()
|
||||
return ctx._all_gather(value_identity, axis=0)
|
||||
return ctx.all_gather(value_identity, axis=0)
|
||||
|
||||
if not pure_eager:
|
||||
run = def_function.function(run)
|
||||
@ -397,7 +397,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
def run(value):
|
||||
value_identity = array_ops.identity(value)
|
||||
ctx = ds_context.get_replica_context()
|
||||
return ctx._all_gather(value_identity, axis=1)
|
||||
return ctx.all_gather(value_identity, axis=1)
|
||||
|
||||
if not pure_eager:
|
||||
run = def_function.function(run)
|
||||
@ -436,7 +436,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
value_1 = array_ops.identity(value)
|
||||
value_3 = array_ops.identity(value_2)
|
||||
ctx = ds_context.get_replica_context()
|
||||
return ctx._all_gather([value_1, value_3], axis=axis)
|
||||
return ctx.all_gather([value_1, value_3], axis=axis)
|
||||
|
||||
if not pure_eager:
|
||||
run = def_function.function(run)
|
||||
@ -455,7 +455,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
def run():
|
||||
value_identity = array_ops.identity(single_value)
|
||||
ctx = ds_context.get_replica_context()
|
||||
return ctx._all_gather([value_identity, value_identity], axis=axis)
|
||||
return ctx.all_gather([value_identity, value_identity], axis=axis)
|
||||
|
||||
if not pure_eager:
|
||||
run = def_function.function(run)
|
||||
@ -491,7 +491,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
def run(value):
|
||||
value_identity = array_ops.identity(value)
|
||||
ctx = ds_context.get_replica_context()
|
||||
return ctx._all_gather(value_identity, axis=0)
|
||||
return ctx.all_gather(value_identity, axis=0)
|
||||
|
||||
if not pure_eager:
|
||||
run = def_function.function(run)
|
||||
@ -519,11 +519,11 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def replica_fn(value):
|
||||
ctx = ds_context.get_replica_context()
|
||||
return ctx._all_gather(value, axis=0)
|
||||
return ctx.all_gather(value, axis=0)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
r'gather/all_gather does not support IndexedSlices'):
|
||||
r'all_gather does not support IndexedSlices'):
|
||||
if not pure_eager:
|
||||
strategy.run(def_function.function(replica_fn), args=(t0,))
|
||||
else:
|
||||
@ -548,7 +548,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
def run(value):
|
||||
value_identity = array_ops.identity(value)
|
||||
ctx = ds_context.get_replica_context()
|
||||
return ctx._all_gather(value_identity, axis=0)
|
||||
return ctx.all_gather(value_identity, axis=0)
|
||||
|
||||
if not pure_eager:
|
||||
run = def_function.function(run)
|
||||
|
@ -58,7 +58,7 @@ def _gather(strategy, value):
|
||||
return array_ops.stack(value._values)
|
||||
assert len(strategy.extended.worker_devices) == len(value._values)
|
||||
inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
|
||||
return strategy._gather(values.PerReplica(inputs), axis=0)
|
||||
return strategy.gather(values.PerReplica(inputs), axis=0)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
|
@ -1425,12 +1425,11 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
||||
return self.strategy.extended.experimental_logical_device(logical_device_id)
|
||||
|
||||
# TODO(wxinyi): Investigate whether to use cross_replica_sum to optimize it.
|
||||
def _all_gather(self, value, axis, options=None):
|
||||
del options
|
||||
def all_gather(self, value, axis, experimental_hints=None):
|
||||
del experimental_hints
|
||||
for v in nest.flatten(value):
|
||||
if isinstance(v, ops.IndexedSlices):
|
||||
raise NotImplementedError("gather/all_gather does not support "
|
||||
"IndexedSlices")
|
||||
raise NotImplementedError("all_gather does not support IndexedSlices")
|
||||
|
||||
def _all_to_all(value, axis):
|
||||
# The underlying AllToAllOp first do a split of the input value and then
|
||||
@ -1484,7 +1483,7 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
||||
@custom_gradient.custom_gradient
|
||||
def grad_wrapper(*xs):
|
||||
ys = [_all_to_all(t, axis=axis) for t in xs]
|
||||
return ys, lambda *dy_s: self._all_gather(dy_s, axis)
|
||||
return ys, lambda *dy_s: self.all_gather(dy_s, axis)
|
||||
|
||||
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
|
||||
|
||||
|
@ -2755,7 +2755,7 @@ def _collective_all_reduce_multi_worker(strategy):
|
||||
# for all strategies
|
||||
def _multi_worker_concat(v, strategy):
|
||||
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
|
||||
replicas = strategy._gather(v, axis=0) # pylint: disable=protected-access
|
||||
replicas = strategy.gather(v, axis=0) # pylint: disable=protected-access
|
||||
# v might not have the same shape on different replicas
|
||||
if isinstance(v, ds_values.PerReplica):
|
||||
shapes = array_ops.concat([
|
||||
@ -2763,10 +2763,10 @@ def _multi_worker_concat(v, strategy):
|
||||
for single_value in v.values
|
||||
],
|
||||
axis=0)
|
||||
all_shapes = strategy._gather(shapes, axis=0) # pylint: disable=protected-access
|
||||
all_shapes = strategy.gather(shapes, axis=0) # pylint: disable=protected-access
|
||||
else:
|
||||
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
|
||||
all_shapes = strategy._gather( # pylint: disable=protected-access
|
||||
all_shapes = strategy.gather( # pylint: disable=protected-access
|
||||
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
|
||||
axis=0)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.distribute.ReplicaContext"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContext\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContextV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContextBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "devices"
|
||||
|
@ -52,6 +52,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "gather"
|
||||
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -52,6 +52,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "gather"
|
||||
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.distribute.ReplicaContext"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContext\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContextBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "devices"
|
||||
@ -22,6 +23,10 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "all_gather"
|
||||
argspec: "args=[\'self\', \'value\', \'axis\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "all_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -51,6 +51,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "gather"
|
||||
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -64,6 +64,10 @@ tf_class {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "gather"
|
||||
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -52,6 +52,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "gather"
|
||||
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -52,6 +52,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "gather"
|
||||
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -52,6 +52,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "gather"
|
||||
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -52,6 +52,10 @@ tf_class {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "gather"
|
||||
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user