Adding new APIs under tf.distribute: gather and all_gather.
`tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` methods are APIs to gather and concatenate `tf.distribute.DistributedValues` object(s) across workers and devices. They are counterparts in cross-replica context and replica context. This methods are implemented for all strategies except ParameterServerStrategy. PiperOrigin-RevId: 337972679 Change-Id: I1d61c96b830683da135d5b4e89da29693c51262c
This commit is contained in:
parent
eccb7ec454
commit
a7467d5d51
@ -767,6 +767,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
|||||||
* Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses.
|
* Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses.
|
||||||
* Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
|
* Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
|
||||||
* Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization.
|
* Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization.
|
||||||
|
* Add `tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` methods to gather and concatenate `tf.distribute.DistributedValues` across workers and devices.
|
||||||
|
|
||||||
### `tf.keras`:
|
### `tf.keras`:
|
||||||
* Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs.
|
* Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs.
|
||||||
|
@ -295,7 +295,8 @@ def get_loss_reduction():
|
|||||||
# Internal API for validating the current thread mode
|
# Internal API for validating the current thread mode
|
||||||
|
|
||||||
|
|
||||||
def _require_cross_replica_or_default_context_extended(extended):
|
def _require_cross_replica_or_default_context_extended(extended,
|
||||||
|
error_message=None):
|
||||||
"""Verify in cross-replica context."""
|
"""Verify in cross-replica context."""
|
||||||
context = _get_per_thread_mode()
|
context = _get_per_thread_mode()
|
||||||
cross_replica = context.cross_replica_context
|
cross_replica = context.cross_replica_context
|
||||||
@ -308,8 +309,10 @@ def _require_cross_replica_or_default_context_extended(extended):
|
|||||||
if context.strategy is not strategy:
|
if context.strategy is not strategy:
|
||||||
_wrong_strategy_scope(strategy, context)
|
_wrong_strategy_scope(strategy, context)
|
||||||
assert cross_replica is None
|
assert cross_replica is None
|
||||||
raise RuntimeError("Method requires being in cross-replica context, use "
|
if not error_message:
|
||||||
|
error_message = ("Method requires being in cross-replica context, use "
|
||||||
"get_replica_context().merge_call()")
|
"get_replica_context().merge_call()")
|
||||||
|
raise RuntimeError(error_message)
|
||||||
|
|
||||||
|
|
||||||
def _wrong_strategy_scope(strategy, context):
|
def _wrong_strategy_scope(strategy, context):
|
||||||
@ -1454,107 +1457,6 @@ class StrategyBase(object):
|
|||||||
denom = math_ops.cast(denom, numer.dtype)
|
denom = math_ops.cast(denom, numer.dtype)
|
||||||
return math_ops.truediv(numer, denom)
|
return math_ops.truediv(numer, denom)
|
||||||
|
|
||||||
# TODO(wxinyi): generate docs after it is implemented for all strategies.
|
|
||||||
# TODO(wxinyi): hide from V1 API
|
|
||||||
def _gather(self, value, axis):
|
|
||||||
# pylint: disable=line-too-long, protected-access
|
|
||||||
"""Gather `value` across replicas along `axis` to the current device.
|
|
||||||
|
|
||||||
Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like
|
|
||||||
object `value`, this API gathers and concatenates `value` along the
|
|
||||||
`axis`-th dimension. The result is copied to the "current" device - which
|
|
||||||
would typically be the CPU of the worker on which the program is running.
|
|
||||||
For `tf.distribute.TPUStrategy`, it is the first TPU host. For multi-client
|
|
||||||
`MultiWorkerMirroredStrategy`, this is CPU of each worker.
|
|
||||||
|
|
||||||
This API can only be called in the cross-replica context. For a counterpart
|
|
||||||
in the replica context, see `tf.distribute.ReplicaContext.all_gather`.
|
|
||||||
|
|
||||||
Note: the input `value` on different replicas must have the same rank, and
|
|
||||||
they must have shapes that are consistent along all dimensions except the
|
|
||||||
`axis`-th dimension. For example, given a `tf.distribute.DistributedValues`
|
|
||||||
with tensors of shape `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can
|
|
||||||
call `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or
|
|
||||||
`gather(..., axis=2, ...)`.
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(wxinyi): convert to testable docstring after implemented for MirroredStrategy
|
|
||||||
```python
|
|
||||||
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
|
||||||
local_tensor = tf.constant([[1, 2], [3, 4]])
|
|
||||||
distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(local_tensor))
|
|
||||||
@tf.function
|
|
||||||
def run():
|
|
||||||
return strategy.gather(distributed_values, axis=0)
|
|
||||||
run()
|
|
||||||
# <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
|
||||||
# array([[1, 2],
|
|
||||||
# [3, 4],
|
|
||||||
# [1, 2],
|
|
||||||
# [3, 4]], dtype=int32)>
|
|
||||||
```
|
|
||||||
|
|
||||||
Some more example cases:
|
|
||||||
|
|
||||||
```python
|
|
||||||
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
|
|
||||||
local_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
|
|
||||||
distributed_values = strategy.experimental_distribute_values_from_function(lambda _: local_tensor)
|
|
||||||
@tf.function
|
|
||||||
def run():
|
|
||||||
return strategy.gather(distributed_values, axis=AXIS)
|
|
||||||
run()
|
|
||||||
|
|
||||||
# With AXIS=0, the result is
|
|
||||||
# <tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy=
|
|
||||||
# array([[[0, 1, 2],
|
|
||||||
# [3, 4, 5]],
|
|
||||||
# [[0, 1, 2],
|
|
||||||
# [3, 4, 5]],
|
|
||||||
# [[0, 1, 2],
|
|
||||||
# [3, 4, 5]],
|
|
||||||
# [[0, 1, 2],
|
|
||||||
# [3, 4, 5]]], dtype=int32)>
|
|
||||||
# With AXIS=1, the result is
|
|
||||||
# <tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy=
|
|
||||||
# array([[[0, 1, 2],
|
|
||||||
# [3, 4, 5],
|
|
||||||
# [0, 1, 2],
|
|
||||||
# [3, 4, 5],
|
|
||||||
# [0, 1, 2],
|
|
||||||
# [3, 4, 5],
|
|
||||||
# [0, 1, 2],
|
|
||||||
# [3, 4, 5]]], dtype=int32)>
|
|
||||||
# With AXIS=2, the result is
|
|
||||||
# <tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy=
|
|
||||||
# array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
|
|
||||||
# [3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: a `tf.distribute.DistributedValues` instance, e.g. returned by
|
|
||||||
`Strategy.run`, to be combined into a single tensor. It can also be a
|
|
||||||
regular tensor when used with `OneDeviceStrategy` or default strategy.
|
|
||||||
The underlying tensor constructs can only be dense tensors with non-zero
|
|
||||||
rank, NOT `tf.IndexedSlices`.
|
|
||||||
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
|
|
||||||
range [0, rank(value)).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A `Tensor` that's the concatenation of `value` across replicas along
|
|
||||||
`axis` dimension.
|
|
||||||
"""
|
|
||||||
# pylint: enable=line-too-long
|
|
||||||
_require_cross_replica_or_default_context_extended(self._extended)
|
|
||||||
dst = device_util.current(
|
|
||||||
) or self._extended._default_device or "/device:CPU:0"
|
|
||||||
if isinstance(value, ops.IndexedSlices):
|
|
||||||
raise NotImplementedError("gather/all_gather does not support "
|
|
||||||
"IndexedSlices")
|
|
||||||
return self._extended._local_results(
|
|
||||||
self._extended._gather_to(value, dst, axis))[0]
|
|
||||||
|
|
||||||
@doc_controls.do_not_doc_inheritable # DEPRECATED
|
@doc_controls.do_not_doc_inheritable # DEPRECATED
|
||||||
def unwrap(self, value):
|
def unwrap(self, value):
|
||||||
"""Returns the list of all local per-replica values contained in `value`.
|
"""Returns the list of all local per-replica values contained in `value`.
|
||||||
@ -1787,6 +1689,112 @@ class Strategy(StrategyBase):
|
|||||||
return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access
|
return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access
|
||||||
value_fn)
|
value_fn)
|
||||||
|
|
||||||
|
def gather(self, value, axis):
|
||||||
|
# pylint: disable=line-too-long, protected-access
|
||||||
|
"""Gather `value` across replicas along `axis` to the current device.
|
||||||
|
|
||||||
|
Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like
|
||||||
|
object `value`, this API gathers and concatenates `value` across replicas
|
||||||
|
along the `axis`-th dimension. The result is copied to the "current" device
|
||||||
|
- which would typically be the CPU of the worker on which the program is
|
||||||
|
running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For
|
||||||
|
multi-client `MultiWorkerMirroredStrategy`, this is CPU of each worker.
|
||||||
|
|
||||||
|
This API can only be called in the cross-replica context. For a counterpart
|
||||||
|
in the replica context, see `tf.distribute.ReplicaContext.all_gather`.
|
||||||
|
|
||||||
|
Note: For all strategies except `tf.distribute.TPUStrategy`, the input
|
||||||
|
`value` on different replicas must have the same rank, and their shapes must
|
||||||
|
be the same in all dimensions except the `axis`-th dimension. In other
|
||||||
|
words, their shapes cannot be different in a dimension `d` where `d` does
|
||||||
|
not equal to the `axis` argument. For example, given a
|
||||||
|
`tf.distribute.DistributedValues` with component tensors of shape
|
||||||
|
`(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
|
||||||
|
`gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or
|
||||||
|
`gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`,
|
||||||
|
all tensors must have exactly the same rank and same shape.
|
||||||
|
|
||||||
|
Note: Given a `tf.distribute.DistributedValues` `value`, its component
|
||||||
|
tensors must have a non-zero rank. Otherwise, consider using
|
||||||
|
`tf.expand_dims` before gathering them.
|
||||||
|
|
||||||
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
||||||
|
>>> # A DistributedValues with component tensor of shape (2, 1) on each replica
|
||||||
|
... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]])))
|
||||||
|
>>> @tf.function
|
||||||
|
... def run():
|
||||||
|
... return strategy.gather(distributed_values, axis=0)
|
||||||
|
>>> run()
|
||||||
|
<tf.Tensor: shape=(4, 1), dtype=int32, numpy=
|
||||||
|
array([[1],
|
||||||
|
[2],
|
||||||
|
[1],
|
||||||
|
[2]], dtype=int32)>
|
||||||
|
|
||||||
|
|
||||||
|
Consider the following example for more combinations:
|
||||||
|
|
||||||
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
|
||||||
|
>>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
|
||||||
|
>>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor))
|
||||||
|
>>> @tf.function
|
||||||
|
... def run(axis):
|
||||||
|
... return strategy.gather(distributed_values, axis=axis)
|
||||||
|
>>> axis=0
|
||||||
|
>>> run(axis)
|
||||||
|
<tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy=
|
||||||
|
array([[[0, 1, 2],
|
||||||
|
[3, 4, 5]],
|
||||||
|
[[0, 1, 2],
|
||||||
|
[3, 4, 5]],
|
||||||
|
[[0, 1, 2],
|
||||||
|
[3, 4, 5]],
|
||||||
|
[[0, 1, 2],
|
||||||
|
[3, 4, 5]]], dtype=int32)>
|
||||||
|
>>> axis=1
|
||||||
|
>>> run(axis)
|
||||||
|
<tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy=
|
||||||
|
array([[[0, 1, 2],
|
||||||
|
[3, 4, 5],
|
||||||
|
[0, 1, 2],
|
||||||
|
[3, 4, 5],
|
||||||
|
[0, 1, 2],
|
||||||
|
[3, 4, 5],
|
||||||
|
[0, 1, 2],
|
||||||
|
[3, 4, 5]]], dtype=int32)>
|
||||||
|
>>> axis=2
|
||||||
|
>>> run(axis)
|
||||||
|
<tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy=
|
||||||
|
array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
|
||||||
|
[3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: a `tf.distribute.DistributedValues` instance, e.g. returned by
|
||||||
|
`Strategy.run`, to be combined into a single tensor. It can also be a
|
||||||
|
regular tensor when used with `tf.distribute.OneDeviceStrategy` or the
|
||||||
|
default strategy. The tensors that constitute the DistributedValues
|
||||||
|
can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`.
|
||||||
|
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
|
||||||
|
range [0, rank(value)).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that's the concatenation of `value` across replicas along
|
||||||
|
`axis` dimension.
|
||||||
|
"""
|
||||||
|
# pylint: enable=line-too-long
|
||||||
|
error_message = ("tf.distribute.Strategy.gather method requires "
|
||||||
|
"cross-replica context, use "
|
||||||
|
"get_replica_context().all_gather() instead")
|
||||||
|
_require_cross_replica_or_default_context_extended(self._extended,
|
||||||
|
error_message)
|
||||||
|
dst = device_util.current(
|
||||||
|
) or self._extended._default_device or "/device:CPU:0"
|
||||||
|
if isinstance(value, ops.IndexedSlices):
|
||||||
|
raise NotImplementedError("gather does not support IndexedSlices")
|
||||||
|
return self._extended._local_results(
|
||||||
|
self._extended._gather_to(value, dst, axis))[0]
|
||||||
|
|
||||||
|
|
||||||
# TF v1.x version has additional deprecated APIs
|
# TF v1.x version has additional deprecated APIs
|
||||||
@tf_export(v1=["distribute.Strategy"])
|
@tf_export(v1=["distribute.Strategy"])
|
||||||
@ -2834,8 +2842,7 @@ class StrategyExtendedV1(StrategyExtendedV2):
|
|||||||
# It sets the current Strategy for purposes of
|
# It sets the current Strategy for purposes of
|
||||||
# `get_strategy()` and `has_strategy()`
|
# `get_strategy()` and `has_strategy()`
|
||||||
# and switches the thread mode to a "cross-replica context".
|
# and switches the thread mode to a "cross-replica context".
|
||||||
@tf_export("distribute.ReplicaContext")
|
class ReplicaContextBase(object):
|
||||||
class ReplicaContext(object):
|
|
||||||
"""A class with a collection of APIs that can be called in a replica context.
|
"""A class with a collection of APIs that can be called in a replica context.
|
||||||
|
|
||||||
You can use `tf.distribute.get_replica_context` to get an instance of
|
You can use `tf.distribute.get_replica_context` to get an instance of
|
||||||
@ -3095,77 +3102,118 @@ class ReplicaContext(object):
|
|||||||
# to that point that the first result is needed. Most likely this can be
|
# to that point that the first result is needed. Most likely this can be
|
||||||
# implemented in terms of `merge_call()` and `batch_reduce_to()`.
|
# implemented in terms of `merge_call()` and `batch_reduce_to()`.
|
||||||
|
|
||||||
# TODO(wxinyi): generate docs after it is implemented for all strategies.
|
|
||||||
def _all_gather(self, value, axis, options=None):
|
@tf_export("distribute.ReplicaContext", v1=[])
|
||||||
|
class ReplicaContext(ReplicaContextBase):
|
||||||
|
|
||||||
|
__doc__ = ReplicaContextBase.__doc__
|
||||||
|
|
||||||
|
def all_gather(self, value, axis, options=None):
|
||||||
"""All-gathers `value` across all replicas along `axis`.
|
"""All-gathers `value` across all replicas along `axis`.
|
||||||
|
|
||||||
Note: An `all_gather` method can only be called in replica context. To find
|
Note: An `all_gather` method can only be called in replica context. For
|
||||||
a cross-replica context counterpart, see `tf.distribute.Strategy.gather`.
|
a cross-replica context counterpart, see `tf.distribute.Strategy.gather`.
|
||||||
All replicas need to participate in the all-gather, otherwise this
|
All replicas need to participate in the all-gather, otherwise this
|
||||||
operation hangs. So if `all_gather` is called in any replica, it must be
|
operation hangs. So if `all_gather` is called in any replica, it must be
|
||||||
called in all replicas.
|
called in all replicas.
|
||||||
|
|
||||||
Note: If there're multiple all-gather calls, they need to execute in
|
Note: If there are multiple `all_gather` calls, they need to be executed in
|
||||||
the same order on all replicas. Dispatching all-gather based on conditions
|
the same order on all replicas. Dispatching `all_gather` based on conditions
|
||||||
is usually error-prone.
|
is usually error-prone.
|
||||||
|
|
||||||
# TODO(wxinyi): convert to testable docstring after implemented for MirroredStrategy
|
For all strategies except `tf.distribute.TPUStrategy`, the input
|
||||||
```python
|
`value` on different replicas must have the same rank, and their shapes must
|
||||||
strategy = tf.distribute.MirroredStrategy(["GPU:0", "CPU:0"])
|
be the same in all dimensions except the `axis`-th dimension. In other
|
||||||
@tf.function
|
words, their shapes cannot be different in a dimension `d` where `d` does
|
||||||
def gather_value():
|
not equal to the `axis` argument. For example, given a
|
||||||
ctx = tf.distribute.get_replica_context()
|
`tf.distribute.DistributedValues` with component tensors of shape
|
||||||
value = tf.constant([1, 2, 3])
|
`(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
|
||||||
# all_gather a `tf.distribute.DistributedValues`
|
`all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)`
|
||||||
return strategy.run(ctx.all_gather(value, axis=0))
|
or `all_gather(..., axis=2, ...)`. However, with
|
||||||
strategy.experimental_local_results(gather_value)
|
`tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and
|
||||||
# Result:
|
same shape.
|
||||||
# (<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
|
|
||||||
# dtype=int32)>,
|
|
||||||
# <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
|
|
||||||
# dtype=int32)>)
|
|
||||||
```
|
|
||||||
|
|
||||||
```python
|
Note: The input `value` must have a non-zero rank. Otherwise, consider using
|
||||||
strategy = tf.distribute.MirroredStrategy(["GPU:0", "CPU:0"])
|
`tf.expand_dims` before gathering them.
|
||||||
@tf.function
|
|
||||||
def gather_nest():
|
|
||||||
ctx = tf.distribute.get_replica_context()
|
|
||||||
value_1 = tf.constant([1, 2, 3])
|
|
||||||
value_2 = tf.constant([[1, 2], [3, 4]])
|
|
||||||
# all_gather a nest of `tf.distribute.DistributedValues`
|
|
||||||
return ctx.all_gather([value_1, value_2], axis=0)
|
|
||||||
strategy.experimental_local_results(gather_nest)
|
|
||||||
# Result:
|
|
||||||
# ([<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
|
|
||||||
# dtype=int32)>, <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
|
||||||
# array([[1, 2],
|
|
||||||
# [3, 4],
|
|
||||||
# [1, 2],
|
|
||||||
# [3, 4]], dtype=int32)],
|
|
||||||
# [<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
|
|
||||||
# dtype=int32)>, <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
|
||||||
# array([[1, 2],
|
|
||||||
# [3, 4],
|
|
||||||
# [1, 2],
|
|
||||||
# [3, 4]], dtype=int32)])
|
|
||||||
```
|
|
||||||
|
|
||||||
Example with two replicas:
|
You can pass in a single tensor to all-gather:
|
||||||
Replica 0 `value`: {'a': [0], 'b': [[0, 1]]}
|
|
||||||
Replica 1 `value`: {'a': [1], 'b': [[2, 3], [4, 5]]}
|
|
||||||
|
|
||||||
Result for `all_gather` with `axis`=0: (on all replicas):
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
||||||
|
>>> @tf.function
|
||||||
|
... def gather_value():
|
||||||
|
... ctx = tf.distribute.get_replica_context()
|
||||||
|
... local_value = tf.constant([1, 2, 3])
|
||||||
|
... return ctx.all_gather(local_value, axis=0)
|
||||||
|
>>> result = strategy.run(gather_value)
|
||||||
|
>>> result
|
||||||
|
PerReplica:{
|
||||||
|
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
|
||||||
|
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
|
||||||
|
}
|
||||||
|
>>> strategy.experimental_local_results(result)
|
||||||
|
(<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
|
||||||
|
dtype=int32)>,
|
||||||
|
<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
|
||||||
|
dtype=int32)>)
|
||||||
|
|
||||||
|
|
||||||
|
You can also pass in a nested structure of tensors to all-gather, say, a
|
||||||
|
list:
|
||||||
|
|
||||||
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
||||||
|
>>> @tf.function
|
||||||
|
... def gather_nest():
|
||||||
|
... ctx = tf.distribute.get_replica_context()
|
||||||
|
... value_1 = tf.constant([1, 2, 3])
|
||||||
|
... value_2 = tf.constant([[1, 2], [3, 4]])
|
||||||
|
... # all_gather a nest of `tf.distribute.DistributedValues`
|
||||||
|
... return ctx.all_gather([value_1, value_2], axis=0)
|
||||||
|
>>> result = strategy.run(gather_nest)
|
||||||
|
>>> result
|
||||||
|
[PerReplica:{
|
||||||
|
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
|
||||||
|
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
|
||||||
|
}, PerReplica:{
|
||||||
|
0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
||||||
|
array([[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
[1, 2],
|
||||||
|
[3, 4]], dtype=int32)>,
|
||||||
|
1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
||||||
|
array([[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
[1, 2],
|
||||||
|
[3, 4]], dtype=int32)>
|
||||||
|
}]
|
||||||
|
>>> strategy.experimental_local_results(result)
|
||||||
|
([PerReplica:{
|
||||||
|
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
|
||||||
|
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
|
||||||
|
}, PerReplica:{
|
||||||
|
0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
||||||
|
array([[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
[1, 2],
|
||||||
|
[3, 4]], dtype=int32)>,
|
||||||
|
1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
||||||
|
array([[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
[1, 2],
|
||||||
|
[3, 4]], dtype=int32)>
|
||||||
|
}],)
|
||||||
|
|
||||||
|
|
||||||
|
What if you are all-gathering tensors with different shapes on different
|
||||||
|
replicas? Consider the following example with two replicas, where you have
|
||||||
|
`value` as a nested structure consisting of two items to all-gather, `a` and
|
||||||
|
`b`.
|
||||||
|
|
||||||
|
On Replica 0, `value` is {'a': [0], 'b': [[0, 1]]}
|
||||||
|
On Replica 1, `value` is {'a': [1], 'b': [[2, 3], [4, 5]]}
|
||||||
|
|
||||||
|
Result for `all_gather` with `axis`=0: (on each of the replicas):
|
||||||
{'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}
|
{'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}
|
||||||
|
|
||||||
Note: an input to be all_gathered must have the same rank on different
|
|
||||||
replicas, and they must have shapes that are consistent along all dimensions
|
|
||||||
except the `axis`-th dimension. For example, given a
|
|
||||||
`tf.distribute.DistributedValues` with tensors of shape `(1, 2, 3)` and
|
|
||||||
`(1, 3, 3)` on two replicas, you can call `all_gather(..., axis=1, ...)` on
|
|
||||||
it, but not `all_gather(..., axis=0, ...)` or `all_gather(..., axis=2, ...)`.
|
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts,
|
value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts,
|
||||||
or a `tf.distribute.DistributedValues` instance. The structure of the
|
or a `tf.distribute.DistributedValues` instance. The structure of the
|
||||||
@ -3185,8 +3233,7 @@ class ReplicaContext(object):
|
|||||||
"""
|
"""
|
||||||
for v in nest.flatten(value):
|
for v in nest.flatten(value):
|
||||||
if isinstance(v, ops.IndexedSlices):
|
if isinstance(v, ops.IndexedSlices):
|
||||||
raise NotImplementedError("gather/all_gather does not support "
|
raise NotImplementedError("all_gather does not support IndexedSlices")
|
||||||
"IndexedSlices")
|
|
||||||
|
|
||||||
if options is None:
|
if options is None:
|
||||||
options = collective_util.Options()
|
options = collective_util.Options()
|
||||||
@ -3200,11 +3247,16 @@ class ReplicaContext(object):
|
|||||||
def grad_wrapper(*xs):
|
def grad_wrapper(*xs):
|
||||||
ys = self.merge_call(batch_all_gather, args=xs)
|
ys = self.merge_call(batch_all_gather, args=xs)
|
||||||
# The gradient of an all-gather is itself an all-gather.
|
# The gradient of an all-gather is itself an all-gather.
|
||||||
return ys, lambda *dy_s: self._all_gather(dy_s, axis)
|
return ys, lambda *dy_s: self.all_gather(dy_s, axis)
|
||||||
|
|
||||||
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
|
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export(v1=["distribute.ReplicaContext"])
|
||||||
|
class ReplicaContextV1(ReplicaContextBase):
|
||||||
|
__doc__ = ReplicaContextBase.__doc__
|
||||||
|
|
||||||
|
|
||||||
def _batch_reduce_destination(x):
|
def _batch_reduce_destination(x):
|
||||||
"""Returns the destinations for batch all-reduce."""
|
"""Returns the destinations for batch all-reduce."""
|
||||||
if isinstance(x, ops.Tensor):
|
if isinstance(x, ops.Tensor):
|
||||||
|
@ -74,7 +74,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
lambda _: array_ops.identity(value_on_replica))
|
lambda _: array_ops.identity(value_on_replica))
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
return strategy._gather(distributed_values, axis=axis)
|
return strategy.gather(distributed_values, axis=axis)
|
||||||
|
|
||||||
if not pure_eager:
|
if not pure_eager:
|
||||||
run = def_function.function(run)
|
run = def_function.function(run)
|
||||||
@ -133,7 +133,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
axis = 0
|
axis = 0
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
return strategy._gather(distributed_values, axis=axis)
|
return strategy.gather(distributed_values, axis=axis)
|
||||||
|
|
||||||
if not pure_eager:
|
if not pure_eager:
|
||||||
run = def_function.function(run)
|
run = def_function.function(run)
|
||||||
@ -155,7 +155,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
axis = 1
|
axis = 1
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
return strategy._gather(distributed_values, axis=axis)
|
return strategy.gather(distributed_values, axis=axis)
|
||||||
|
|
||||||
if not pure_eager:
|
if not pure_eager:
|
||||||
run = def_function.function(run)
|
run = def_function.function(run)
|
||||||
@ -183,7 +183,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
axis = 0
|
axis = 0
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
return strategy._gather(distributed_values, axis=axis)
|
return strategy.gather(distributed_values, axis=axis)
|
||||||
|
|
||||||
if not pure_eager:
|
if not pure_eager:
|
||||||
run = def_function.function(run)
|
run = def_function.function(run)
|
||||||
@ -210,11 +210,11 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
values=[[1., 2.]], indices=[2], dense_shape=dense_shape)
|
values=[[1., 2.]], indices=[2], dense_shape=dense_shape)
|
||||||
|
|
||||||
def run(value):
|
def run(value):
|
||||||
return strategy._gather(value, axis=0)
|
return strategy.gather(value, axis=0)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
NotImplementedError,
|
NotImplementedError,
|
||||||
r'gather/all_gather does not support IndexedSlices'):
|
r'gather does not support IndexedSlices'):
|
||||||
if pure_eager:
|
if pure_eager:
|
||||||
run(t0)
|
run(t0)
|
||||||
else:
|
else:
|
||||||
@ -235,7 +235,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
axis = 0
|
axis = 0
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
return strategy._gather(distributed_values, axis=axis)
|
return strategy.gather(distributed_values, axis=axis)
|
||||||
|
|
||||||
if not pure_eager:
|
if not pure_eager:
|
||||||
run = def_function.function(run)
|
run = def_function.function(run)
|
||||||
@ -271,7 +271,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
def replica_fn(per_replica_value):
|
def replica_fn(per_replica_value):
|
||||||
ctx = ds_context.get_replica_context()
|
ctx = ds_context.get_replica_context()
|
||||||
local_value = array_ops.identity(per_replica_value)
|
local_value = array_ops.identity(per_replica_value)
|
||||||
return ctx._all_gather(local_value, axis=axis)
|
return ctx.all_gather(local_value, axis=axis)
|
||||||
|
|
||||||
if not pure_eager:
|
if not pure_eager:
|
||||||
replica_fn = def_function.function(replica_fn)
|
replica_fn = def_function.function(replica_fn)
|
||||||
@ -342,7 +342,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
@def_function.function
|
@def_function.function
|
||||||
def replica_fn(per_replica_value):
|
def replica_fn(per_replica_value):
|
||||||
ctx = ds_context.get_replica_context()
|
ctx = ds_context.get_replica_context()
|
||||||
return ctx._all_gather(array_ops.identity(per_replica_value), axis=axis)
|
return ctx.all_gather(array_ops.identity(per_replica_value), axis=axis)
|
||||||
|
|
||||||
result = strategy.experimental_local_results(
|
result = strategy.experimental_local_results(
|
||||||
strategy.run(replica_fn, args=(next(input_iterator),)))
|
strategy.run(replica_fn, args=(next(input_iterator),)))
|
||||||
@ -369,7 +369,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
def run(value):
|
def run(value):
|
||||||
value_identity = array_ops.identity(value)
|
value_identity = array_ops.identity(value)
|
||||||
ctx = ds_context.get_replica_context()
|
ctx = ds_context.get_replica_context()
|
||||||
return ctx._all_gather(value_identity, axis=0)
|
return ctx.all_gather(value_identity, axis=0)
|
||||||
|
|
||||||
if not pure_eager:
|
if not pure_eager:
|
||||||
run = def_function.function(run)
|
run = def_function.function(run)
|
||||||
@ -397,7 +397,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
def run(value):
|
def run(value):
|
||||||
value_identity = array_ops.identity(value)
|
value_identity = array_ops.identity(value)
|
||||||
ctx = ds_context.get_replica_context()
|
ctx = ds_context.get_replica_context()
|
||||||
return ctx._all_gather(value_identity, axis=1)
|
return ctx.all_gather(value_identity, axis=1)
|
||||||
|
|
||||||
if not pure_eager:
|
if not pure_eager:
|
||||||
run = def_function.function(run)
|
run = def_function.function(run)
|
||||||
@ -436,7 +436,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
value_1 = array_ops.identity(value)
|
value_1 = array_ops.identity(value)
|
||||||
value_3 = array_ops.identity(value_2)
|
value_3 = array_ops.identity(value_2)
|
||||||
ctx = ds_context.get_replica_context()
|
ctx = ds_context.get_replica_context()
|
||||||
return ctx._all_gather([value_1, value_3], axis=axis)
|
return ctx.all_gather([value_1, value_3], axis=axis)
|
||||||
|
|
||||||
if not pure_eager:
|
if not pure_eager:
|
||||||
run = def_function.function(run)
|
run = def_function.function(run)
|
||||||
@ -455,7 +455,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
def run():
|
def run():
|
||||||
value_identity = array_ops.identity(single_value)
|
value_identity = array_ops.identity(single_value)
|
||||||
ctx = ds_context.get_replica_context()
|
ctx = ds_context.get_replica_context()
|
||||||
return ctx._all_gather([value_identity, value_identity], axis=axis)
|
return ctx.all_gather([value_identity, value_identity], axis=axis)
|
||||||
|
|
||||||
if not pure_eager:
|
if not pure_eager:
|
||||||
run = def_function.function(run)
|
run = def_function.function(run)
|
||||||
@ -491,7 +491,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
def run(value):
|
def run(value):
|
||||||
value_identity = array_ops.identity(value)
|
value_identity = array_ops.identity(value)
|
||||||
ctx = ds_context.get_replica_context()
|
ctx = ds_context.get_replica_context()
|
||||||
return ctx._all_gather(value_identity, axis=0)
|
return ctx.all_gather(value_identity, axis=0)
|
||||||
|
|
||||||
if not pure_eager:
|
if not pure_eager:
|
||||||
run = def_function.function(run)
|
run = def_function.function(run)
|
||||||
@ -519,11 +519,11 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
def replica_fn(value):
|
def replica_fn(value):
|
||||||
ctx = ds_context.get_replica_context()
|
ctx = ds_context.get_replica_context()
|
||||||
return ctx._all_gather(value, axis=0)
|
return ctx.all_gather(value, axis=0)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
NotImplementedError,
|
NotImplementedError,
|
||||||
r'gather/all_gather does not support IndexedSlices'):
|
r'all_gather does not support IndexedSlices'):
|
||||||
if not pure_eager:
|
if not pure_eager:
|
||||||
strategy.run(def_function.function(replica_fn), args=(t0,))
|
strategy.run(def_function.function(replica_fn), args=(t0,))
|
||||||
else:
|
else:
|
||||||
@ -548,7 +548,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
def run(value):
|
def run(value):
|
||||||
value_identity = array_ops.identity(value)
|
value_identity = array_ops.identity(value)
|
||||||
ctx = ds_context.get_replica_context()
|
ctx = ds_context.get_replica_context()
|
||||||
return ctx._all_gather(value_identity, axis=0)
|
return ctx.all_gather(value_identity, axis=0)
|
||||||
|
|
||||||
if not pure_eager:
|
if not pure_eager:
|
||||||
run = def_function.function(run)
|
run = def_function.function(run)
|
||||||
|
@ -58,7 +58,7 @@ def _gather(strategy, value):
|
|||||||
return array_ops.stack(value._values)
|
return array_ops.stack(value._values)
|
||||||
assert len(strategy.extended.worker_devices) == len(value._values)
|
assert len(strategy.extended.worker_devices) == len(value._values)
|
||||||
inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
|
inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
|
||||||
return strategy._gather(values.PerReplica(inputs), axis=0)
|
return strategy.gather(values.PerReplica(inputs), axis=0)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
@ -1425,12 +1425,11 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
|||||||
return self.strategy.extended.experimental_logical_device(logical_device_id)
|
return self.strategy.extended.experimental_logical_device(logical_device_id)
|
||||||
|
|
||||||
# TODO(wxinyi): Investigate whether to use cross_replica_sum to optimize it.
|
# TODO(wxinyi): Investigate whether to use cross_replica_sum to optimize it.
|
||||||
def _all_gather(self, value, axis, options=None):
|
def all_gather(self, value, axis, experimental_hints=None):
|
||||||
del options
|
del experimental_hints
|
||||||
for v in nest.flatten(value):
|
for v in nest.flatten(value):
|
||||||
if isinstance(v, ops.IndexedSlices):
|
if isinstance(v, ops.IndexedSlices):
|
||||||
raise NotImplementedError("gather/all_gather does not support "
|
raise NotImplementedError("all_gather does not support IndexedSlices")
|
||||||
"IndexedSlices")
|
|
||||||
|
|
||||||
def _all_to_all(value, axis):
|
def _all_to_all(value, axis):
|
||||||
# The underlying AllToAllOp first do a split of the input value and then
|
# The underlying AllToAllOp first do a split of the input value and then
|
||||||
@ -1484,7 +1483,7 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
|||||||
@custom_gradient.custom_gradient
|
@custom_gradient.custom_gradient
|
||||||
def grad_wrapper(*xs):
|
def grad_wrapper(*xs):
|
||||||
ys = [_all_to_all(t, axis=axis) for t in xs]
|
ys = [_all_to_all(t, axis=axis) for t in xs]
|
||||||
return ys, lambda *dy_s: self._all_gather(dy_s, axis)
|
return ys, lambda *dy_s: self.all_gather(dy_s, axis)
|
||||||
|
|
||||||
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
|
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
|
||||||
|
|
||||||
|
@ -2755,7 +2755,7 @@ def _collective_all_reduce_multi_worker(strategy):
|
|||||||
# for all strategies
|
# for all strategies
|
||||||
def _multi_worker_concat(v, strategy):
|
def _multi_worker_concat(v, strategy):
|
||||||
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
|
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
|
||||||
replicas = strategy._gather(v, axis=0) # pylint: disable=protected-access
|
replicas = strategy.gather(v, axis=0) # pylint: disable=protected-access
|
||||||
# v might not have the same shape on different replicas
|
# v might not have the same shape on different replicas
|
||||||
if isinstance(v, ds_values.PerReplica):
|
if isinstance(v, ds_values.PerReplica):
|
||||||
shapes = array_ops.concat([
|
shapes = array_ops.concat([
|
||||||
@ -2763,10 +2763,10 @@ def _multi_worker_concat(v, strategy):
|
|||||||
for single_value in v.values
|
for single_value in v.values
|
||||||
],
|
],
|
||||||
axis=0)
|
axis=0)
|
||||||
all_shapes = strategy._gather(shapes, axis=0) # pylint: disable=protected-access
|
all_shapes = strategy.gather(shapes, axis=0) # pylint: disable=protected-access
|
||||||
else:
|
else:
|
||||||
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
|
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
|
||||||
all_shapes = strategy._gather( # pylint: disable=protected-access
|
all_shapes = strategy.gather( # pylint: disable=protected-access
|
||||||
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
|
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
|
||||||
axis=0)
|
axis=0)
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
path: "tensorflow.distribute.ReplicaContext"
|
path: "tensorflow.distribute.ReplicaContext"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContext\'>"
|
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContextV1\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContextBase\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
member {
|
||||||
name: "devices"
|
name: "devices"
|
||||||
|
@ -52,6 +52,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "gather"
|
||||||
|
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -52,6 +52,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "gather"
|
||||||
|
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
path: "tensorflow.distribute.ReplicaContext"
|
path: "tensorflow.distribute.ReplicaContext"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContext\'>"
|
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContext\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContextBase\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
member {
|
||||||
name: "devices"
|
name: "devices"
|
||||||
@ -22,6 +23,10 @@ tf_class {
|
|||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "all_gather"
|
||||||
|
argspec: "args=[\'self\', \'value\', \'axis\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "all_reduce"
|
name: "all_reduce"
|
||||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -51,6 +51,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "gather"
|
||||||
|
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -64,6 +64,10 @@ tf_class {
|
|||||||
name: "experimental_split_to_logical_devices"
|
name: "experimental_split_to_logical_devices"
|
||||||
argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "gather"
|
||||||
|
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -52,6 +52,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "gather"
|
||||||
|
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -52,6 +52,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "gather"
|
||||||
|
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -52,6 +52,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "gather"
|
||||||
|
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -52,6 +52,10 @@ tf_class {
|
|||||||
name: "experimental_run"
|
name: "experimental_run"
|
||||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "gather"
|
||||||
|
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group"
|
name: "group"
|
||||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user