Add an experimental_hints to batch all reduce
This contains all performance hints to the API. Currently there's only bytes_per_pack, which splits large batches into multiple packs allows overlapping communication and computation. Currently we can only pack if all Tensors in the batch have known shapes. PiperOrigin-RevId: 297269428 Change-Id: Iaf7d7d3adf7c6cad59aa6079fbcd36b31e92c4b5
This commit is contained in:
parent
26a24de29b
commit
fa08cfd489
tensorflow
python/distribute
BUILDcollective_all_reduce_strategy.pycollective_util.pycross_device_ops.pycross_device_ops_test.pycross_device_utils.pycross_device_utils_test.pydistribute_lib.pydistribute_lib_test.pymirrored_strategy.pyone_device_strategy.pyparameter_server_strategy.pytpu_strategy.py
tools/api/golden
v1
tensorflow.distribute.-cross-device-ops.pbtxttensorflow.distribute.-hierarchical-copy-all-reduce.pbtxttensorflow.distribute.-nccl-all-reduce.pbtxttensorflow.distribute.-reduction-to-one-device.pbtxttensorflow.distribute.-replica-context.pbtxttensorflow.distribute.-strategy-extended.pbtxttensorflow.distribute.experimental.-collective-hints.pbtxttensorflow.distribute.experimental.pbtxt
v2
tensorflow.distribute.-cross-device-ops.pbtxttensorflow.distribute.-hierarchical-copy-all-reduce.pbtxttensorflow.distribute.-nccl-all-reduce.pbtxttensorflow.distribute.-reduction-to-one-device.pbtxttensorflow.distribute.-replica-context.pbtxttensorflow.distribute.-strategy-extended.pbtxttensorflow.distribute.experimental.-collective-hints.pbtxttensorflow.distribute.experimental.pbtxt
@ -63,6 +63,7 @@ py_library(
|
||||
srcs = ["cross_device_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":collective_util",
|
||||
":cross_device_utils",
|
||||
":device_util",
|
||||
":reduce_util",
|
||||
@ -97,6 +98,7 @@ py_library(
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nccl_ops",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
@ -145,6 +147,7 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":collective_util",
|
||||
":device_util",
|
||||
":numpy_dataset",
|
||||
":reduce_util",
|
||||
@ -580,6 +583,15 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "collective_util",
|
||||
srcs = ["collective_util.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "shared_variable_creator",
|
||||
srcs = ["shared_variable_creator.py"],
|
||||
@ -795,7 +807,9 @@ cuda_py_test(
|
||||
name = "cross_device_utils_test",
|
||||
srcs = ["cross_device_utils_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
@ -815,6 +829,7 @@ cuda_py_test(
|
||||
],
|
||||
deps = [
|
||||
":collective_all_reduce_strategy",
|
||||
":collective_util",
|
||||
":mirrored_strategy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
|
@ -95,6 +95,7 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
||||
TFConfigClusterResolver which is instantiated from the TF_CONFIG env
|
||||
var.
|
||||
"""
|
||||
# TODO(b/150151677): consider move communication to CollectiveHints.
|
||||
super(CollectiveAllReduceStrategy, self).__init__(
|
||||
CollectiveAllReduceExtended(
|
||||
self,
|
||||
@ -505,7 +506,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
|
||||
return updated_config
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations):
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
if (isinstance(value, values.Mirrored) and
|
||||
reduce_op == reduce_util.ReduceOp.MEAN):
|
||||
return value
|
||||
@ -526,7 +527,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
return cross_device_ops_lib.reduce_non_distributed_value(
|
||||
reduce_op, value, destinations, len(self.worker_devices))
|
||||
return self._get_cross_device_ops().reduce(
|
||||
reduce_op, value, destinations=destinations)
|
||||
reduce_op,
|
||||
value,
|
||||
destinations=destinations,
|
||||
experimental_hints=experimental_hints)
|
||||
|
||||
def _warn_nccl_no_gpu(self):
|
||||
if ((self._communication ==
|
||||
|
63
tensorflow/python/distribute/collective_util.py
Normal file
63
tensorflow/python/distribute/collective_util.py
Normal file
@ -0,0 +1,63 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for collectives."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("distribute.experimental.CollectiveHints")
|
||||
class Hints(object):
|
||||
"""Hints for collective operations like AllReduce.
|
||||
|
||||
This can be passed to methods like
|
||||
`tf.distribute.get_replica_context().all_reduce()` to optimize collective
|
||||
operation performance. Note that these are only hints, which may or may not
|
||||
change the actual behavior. Some options only apply to certain strategy and
|
||||
are ignored by others.
|
||||
|
||||
One common optimization is to break gradients all-reduce into multiple packs
|
||||
so that weight updates can overlap with gradient all-reduce.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
hints = tf.distribute.experimental.CollectiveHints(
|
||||
bytes_per_pack=50 * 1024 * 1024)
|
||||
grads = tf.distribute.get_replica_context().all_reduce(
|
||||
'sum', grads, experimental_hints=hints)
|
||||
optimizer.apply_gradients(zip(grads, vars), all_reduce_sum_gradients=False)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, bytes_per_pack=0):
|
||||
"""Creates a CollectiveHints.
|
||||
|
||||
Args:
|
||||
bytes_per_pack: A non-negative integer. Breaks collective operations into
|
||||
packs of certain size. If it's zero, the value is determined
|
||||
automatically. This only applies to all-reduce with
|
||||
`MultiWorkerMirroredStrategy` currently.
|
||||
|
||||
Raises:
|
||||
ValueError: When arguments have invalid value.
|
||||
"""
|
||||
if bytes_per_pack < 0:
|
||||
raise ValueError("bytes_per_pack must be non-negative")
|
||||
self.bytes_per_pack = bytes_per_pack
|
@ -19,11 +19,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import enum
|
||||
|
||||
import enum
|
||||
import six
|
||||
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.distribute import collective_util
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
@ -222,7 +223,11 @@ class CrossDeviceOps(object):
|
||||
# Returns 1 by default, the value may be overridden by sub classes.
|
||||
return 1
|
||||
|
||||
def reduce(self, reduce_op, per_replica_value, destinations):
|
||||
def reduce(self,
|
||||
reduce_op,
|
||||
per_replica_value,
|
||||
destinations,
|
||||
experimental_hints=None):
|
||||
"""Reduce `per_replica_value` to `destinations`.
|
||||
|
||||
It runs the reduction operation defined by `reduce_op` and put the
|
||||
@ -231,8 +236,10 @@ class CrossDeviceOps(object):
|
||||
Args:
|
||||
reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
|
||||
per_replica_value will be reduced.
|
||||
per_replica_value: a PerReplica object or a tensor with device set.
|
||||
per_replica_value: A PerReplica object or a tensor with device set.
|
||||
destinations: the reduction destinations.
|
||||
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
|
||||
to perform collective operations.
|
||||
|
||||
Returns:
|
||||
a Mirrored object.
|
||||
@ -254,10 +261,15 @@ class CrossDeviceOps(object):
|
||||
per_replica_value.values,
|
||||
wrap_class=value_lib.Mirrored)
|
||||
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
return self.reduce_implementation(reduce_op, per_replica_value,
|
||||
destinations)
|
||||
destinations, experimental_hints)
|
||||
|
||||
def batch_reduce(self, reduce_op, value_destination_pairs):
|
||||
def batch_reduce(self,
|
||||
reduce_op,
|
||||
value_destination_pairs,
|
||||
experimental_hints=None):
|
||||
"""Reduce PerReplica objects in a batch.
|
||||
|
||||
Reduce each first element in `value_destination_pairs` to each second
|
||||
@ -267,10 +279,12 @@ class CrossDeviceOps(object):
|
||||
fuse several tensors into one or multiple packs before reduction.
|
||||
|
||||
Args:
|
||||
reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
|
||||
the `per_replica_value` will be reduced.
|
||||
value_destination_pairs: a list or a tuple of PerReplica objects
|
||||
(or tensors with device set if there is one device) and destinations.
|
||||
reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how the
|
||||
`per_replica_value` will be reduced.
|
||||
value_destination_pairs: A list or a tuple of PerReplica objects (or
|
||||
tensors with device set if there is one device) and destinations.
|
||||
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
|
||||
to perform collective operations.
|
||||
|
||||
Returns:
|
||||
a list of Mirrored objects.
|
||||
@ -299,7 +313,10 @@ class CrossDeviceOps(object):
|
||||
for v, _ in value_destination_pairs
|
||||
]
|
||||
|
||||
return self.batch_reduce_implementation(reduce_op, value_destination_pairs)
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
|
||||
experimental_hints)
|
||||
|
||||
def broadcast(self, tensor, destinations):
|
||||
"""Broadcast the `tensor` to destinations.
|
||||
@ -315,7 +332,8 @@ class CrossDeviceOps(object):
|
||||
return self.broadcast_implementation(tensor, destinations)
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations):
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
||||
experimental_hints):
|
||||
"""The implementation of reduce of `per_replica_value` to `destinations`.
|
||||
|
||||
Overriding this method is useful for subclass implementers.
|
||||
@ -326,8 +344,10 @@ class CrossDeviceOps(object):
|
||||
Args:
|
||||
reduce_op: An instance `tf.distribute.ReduceOp` that indicates of how
|
||||
per_replica_value will be reduced.
|
||||
per_replica_value: a PerReplica object or a tensor with device set.
|
||||
per_replica_value: A PerReplica object or a tensor with device set.
|
||||
destinations: the reduction destinations.
|
||||
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
|
||||
to perform collective operations.
|
||||
|
||||
Returns:
|
||||
a Mirrored object.
|
||||
@ -340,7 +360,8 @@ class CrossDeviceOps(object):
|
||||
"_reduce method must be implemented in descendants.")
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
"""Implementation of reduce PerReplica objects in a batch.
|
||||
|
||||
Overriding this method is useful for subclass implementers.
|
||||
@ -351,8 +372,10 @@ class CrossDeviceOps(object):
|
||||
Args:
|
||||
reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
|
||||
per_replica_value will be reduced.
|
||||
value_destination_pairs: an iterable of tuples of PerReplica objects
|
||||
value_destination_pairs: An iterable of tuples of PerReplica objects
|
||||
(or tensors with device set if there is one device) and destinations.
|
||||
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
|
||||
to perform collective operations.
|
||||
|
||||
Returns:
|
||||
a list of Mirrored objects.
|
||||
@ -362,7 +385,8 @@ class CrossDeviceOps(object):
|
||||
tuples of PerReplica objects and destinations
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"_batch_reduce method must be implemented in descendants.")
|
||||
"batch_reduce_implementation method must be implemented in descendants."
|
||||
)
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
def broadcast_implementation(self, tensor, destinations):
|
||||
@ -403,7 +427,9 @@ class ReductionToOneDevice(CrossDeviceOps):
|
||||
self.accumulation_fn = accumulation_fn or math_ops.add_n
|
||||
super(ReductionToOneDevice, self).__init__()
|
||||
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations):
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
||||
experimental_hints):
|
||||
del experimental_hints # Unused.
|
||||
if check_destinations(destinations):
|
||||
devices = get_devices_from(destinations)
|
||||
else:
|
||||
@ -416,9 +442,11 @@ class ReductionToOneDevice(CrossDeviceOps):
|
||||
self.accumulation_fn, reduce_op)
|
||||
return self.broadcast(reduced, destinations)
|
||||
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
return [
|
||||
self.reduce_implementation(reduce_op, t, destinations=v)
|
||||
self.reduce_implementation(
|
||||
reduce_op, t, destinations=v, experimental_hints=experimental_hints)
|
||||
for t, v in value_destination_pairs
|
||||
]
|
||||
|
||||
@ -626,21 +654,24 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
|
||||
self._simple_cross_replica_ops = ReductionToOneDevice()
|
||||
super(AllReduceCrossDeviceOps, self).__init__()
|
||||
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations):
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
||||
experimental_hints):
|
||||
del experimental_hints # Unused.
|
||||
if _devices_match(per_replica_value, destinations):
|
||||
return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
|
||||
else:
|
||||
return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
|
||||
destinations)
|
||||
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
if _all_devices_match(value_destination_pairs):
|
||||
return self._batch_all_reduce(reduce_op,
|
||||
[v[0] for v in value_destination_pairs])
|
||||
else:
|
||||
return [
|
||||
self.reduce_implementation(reduce_op, t, destinations=v)
|
||||
for t, v in value_destination_pairs
|
||||
self.reduce_implementation(reduce_op, value, dest, experimental_hints)
|
||||
for value, dest in value_destination_pairs
|
||||
]
|
||||
|
||||
def _batch_all_reduce(self, reduce_op, per_replica_values):
|
||||
@ -904,7 +935,6 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
def __init__(self,
|
||||
num_workers=1,
|
||||
num_gpus_per_worker=0,
|
||||
num_packs=1,
|
||||
collective_keys=None,
|
||||
communication=CollectiveCommunication.AUTO):
|
||||
"""Initializes the object.
|
||||
@ -912,13 +942,11 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
Args:
|
||||
num_workers: number of workers in the between-graph replicated training.
|
||||
num_gpus_per_worker: number of GPUs per worker.
|
||||
num_packs: gradients will be packed into `num_packs` chunks.
|
||||
collective_keys: an optional CollectiveKey object.
|
||||
communication: indicates which collective communication to use.
|
||||
"""
|
||||
self._num_workers = num_workers
|
||||
self._num_gpus_per_worker = num_gpus_per_worker
|
||||
self._num_packs = num_packs
|
||||
self._collective_keys = (collective_keys or
|
||||
cross_device_utils.CollectiveKeys())
|
||||
self._communication = communication
|
||||
@ -928,8 +956,10 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
def _num_between_graph_workers(self):
|
||||
return self._num_workers
|
||||
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations):
|
||||
all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0]
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
||||
experimental_hints):
|
||||
all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value],
|
||||
experimental_hints)[0]
|
||||
devices = get_devices_from(destinations)
|
||||
|
||||
if (isinstance(all_reduced, value_lib.Mirrored) and
|
||||
@ -958,11 +988,13 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access
|
||||
return value_lib.regroup(index, wrap_class=value_lib.Mirrored)
|
||||
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
all_devices_match = _all_devices_match(value_destination_pairs)
|
||||
if all_devices_match:
|
||||
return self._batch_all_reduce(reduce_op,
|
||||
[v[0] for v in value_destination_pairs])
|
||||
[v[0] for v in value_destination_pairs],
|
||||
experimental_hints)
|
||||
else:
|
||||
if not all_devices_match:
|
||||
logging.log_first_n(
|
||||
@ -970,47 +1002,18 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
"destinations are different.", 10)
|
||||
|
||||
return [
|
||||
self.reduce_implementation(reduce_op, t, destinations=v)
|
||||
for t, v in value_destination_pairs
|
||||
self.reduce_implementation(reduce_op, value, dest, experimental_hints)
|
||||
for value, dest in value_destination_pairs
|
||||
]
|
||||
|
||||
def _make_gradient_chunks(self, per_replica_values, num_packs):
|
||||
"""Make `per_replica_values` into chunks."""
|
||||
chunked_by_device = _group_value_by_device(per_replica_values)
|
||||
chunked_by_var = list(zip(*chunked_by_device))
|
||||
# chunked_by_var is chunked by variables and takes the following format:
|
||||
# [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..),
|
||||
# ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..),
|
||||
# ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..),
|
||||
# ...
|
||||
# ]
|
||||
|
||||
# No chunking if number of variables is fewer than number of packs.
|
||||
if len(chunked_by_var) < num_packs:
|
||||
return [chunked_by_var]
|
||||
|
||||
# First n-1 chunks get `chunk_size` grads, last chunk gets leftover grads.
|
||||
# This strategy can cause the last chunk to have larger size compared to the
|
||||
# first n-1 chunks. Alternatively, we can increment chunk_size by 1 to get
|
||||
# slightly larger first n-1 chunks and smaller last chunk.
|
||||
# TODO(ayushd): compare different packing strategies.
|
||||
chunk_size = len(chunked_by_var) // num_packs
|
||||
leftover_size = len(chunked_by_var) - chunk_size * (num_packs - 1)
|
||||
assert leftover_size > 0
|
||||
chunked_gv = [
|
||||
chunked_by_var[x:x + chunk_size]
|
||||
for x in range(0, len(chunked_by_var) - leftover_size, chunk_size)
|
||||
]
|
||||
chunked_gv.append(chunked_by_var[-leftover_size:])
|
||||
|
||||
return chunked_gv
|
||||
|
||||
def _batch_all_reduce(self, reduce_op, per_replica_values):
|
||||
def _batch_all_reduce(self, reduce_op, per_replica_values,
|
||||
experimental_hints):
|
||||
"""All reduce algorithm in a batch."""
|
||||
dense_values, dense_indices, sparse_values, sparse_indices = (
|
||||
cross_device_utils.split_by_sparsity(per_replica_values))
|
||||
if dense_values:
|
||||
dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values)
|
||||
dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values,
|
||||
experimental_hints)
|
||||
else:
|
||||
dense_results = []
|
||||
if sparse_values:
|
||||
@ -1018,83 +1021,84 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
sparse_values)
|
||||
else:
|
||||
sparse_results = []
|
||||
return cross_device_utils.stitch_values(((dense_results, dense_indices),
|
||||
(sparse_results, sparse_indices)))
|
||||
return cross_device_utils.stitch_values(
|
||||
((dense_results, dense_indices), (sparse_results, sparse_indices)))
|
||||
|
||||
def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values):
|
||||
def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values,
|
||||
experimental_hints):
|
||||
"""All-reduce across all workers in a batch."""
|
||||
|
||||
chunked_gv = self._make_gradient_chunks(per_replica_values, self._num_packs)
|
||||
# Actual number of packs may be different from `self._num_packs`. e.g. if
|
||||
# there are fewer tensors than `self._num_packs`.
|
||||
num_actual_packs = len(chunked_gv)
|
||||
|
||||
batch_size = len(per_replica_values)
|
||||
# Pass self._communication to the runtime as a communication hint.
|
||||
communication_hint = self._communication.value
|
||||
communication = self._communication.value
|
||||
# For now, we use NCCL only when batch_size > 1.
|
||||
# TODO(b/132575814): switch to NCCL for all collectives when communication
|
||||
# is NCCL.
|
||||
if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
|
||||
communication_hint = CollectiveCommunication.AUTO.value
|
||||
communication = CollectiveCommunication.AUTO.value
|
||||
|
||||
# Reverse the lists so that there's better chance that values follows
|
||||
# the order in which they are calculated (e.g. when they're gradients), so
|
||||
# as to overlap calculation with communication. However, this may not be
|
||||
# optimal for cases like gradients of complicated non-sequential models.
|
||||
#
|
||||
# Note that we reverse the list before packing so that the first pack won't
|
||||
# be too small, since it's more likely for first few packs to have long
|
||||
# queuing time due to concurrent intense computation.
|
||||
#
|
||||
# TODO(b/147393503): explore solutions for optimal ordering.
|
||||
packs = cross_device_utils.pack_by_size(
|
||||
list(reversed(per_replica_values)), experimental_hints.bytes_per_pack)
|
||||
|
||||
if batch_size > 1:
|
||||
logging.info(
|
||||
"Collective batch_all_reduce: %d all-reduces, num_workers = %d, "
|
||||
"communication_hint = %s, num_packs = %d" %
|
||||
(batch_size, self._num_workers, communication_hint, num_actual_packs))
|
||||
"communication_hint = %s, num_packs = %d", batch_size,
|
||||
self._num_workers, communication, len(packs))
|
||||
else:
|
||||
logging.log_first_n(
|
||||
logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
|
||||
"num_workers = %d, communication_hint = %s, num_packs = %d" %
|
||||
(batch_size, self._num_workers, communication_hint, num_actual_packs),
|
||||
10)
|
||||
(batch_size, self._num_workers, communication, len(packs)), 10)
|
||||
|
||||
def batch_fn():
|
||||
"""Wrapper function around batched all-reduce calls."""
|
||||
reduced_gv_list = []
|
||||
# Reverse the gradient lists so that the gradient grouping roughly follows
|
||||
# the order in which gradients are calculated in backprop. This should
|
||||
# enable overlapping gradient all-reduce with backprop for most models.
|
||||
# However, it is likely that for some complicated non-sequential models
|
||||
# this grouping is not optimal.
|
||||
#
|
||||
# TODO(b/147393503): explore solutions for optimal gradient grouping.
|
||||
for chunk in reversed(chunked_gv):
|
||||
# By placing all CollectiveReduce ops in a chunk under single name
|
||||
# scope, we ensure they will be picked up by the `ScopedAllocator`
|
||||
# grappler optimizer and packed into a single all-reduce.
|
||||
reduced_values = []
|
||||
for pack in packs:
|
||||
# By placing all CollectiveReduce ops in a pack under single name scope,
|
||||
# we ensure they will be picked up by the `ScopedAllocator` grappler
|
||||
# optimizer and packed into a single all-reduce.
|
||||
with ops.name_scope("allreduce"):
|
||||
for grad_and_vars in reversed(chunk):
|
||||
# Gradients for the same variable but from different devices.
|
||||
grads = [g for g, _ in grad_and_vars]
|
||||
for per_replica in pack:
|
||||
# Add control dependencies per device from the last gradients to the
|
||||
# current set, in order to serialize NCCL launches.
|
||||
if (communication_hint == CollectiveCommunication.NCCL.value and
|
||||
reduced_gv_list):
|
||||
control_input_grads = [g for g, _ in reduced_gv_list[-1]]
|
||||
if (communication == CollectiveCommunication.NCCL.value and
|
||||
reduced_values):
|
||||
control_inputs = [g for g in reduced_values[-1]]
|
||||
else:
|
||||
control_input_grads = None
|
||||
collective_reduced = cross_device_utils.build_collective_reduce(
|
||||
grads, self._num_workers, self._collective_keys, "Add", "Id",
|
||||
communication_hint, control_input_grads)
|
||||
result = []
|
||||
for (_, v), g in zip(grad_and_vars, collective_reduced):
|
||||
result.append([g, v])
|
||||
reduced_gv_list.append(result)
|
||||
# Reverse the batch reduced gradients to (approximately) recover the order
|
||||
# in the input per_replica_values.
|
||||
reduced_gv_list.reverse()
|
||||
return reduced_gv_list
|
||||
control_inputs = None
|
||||
reduced_values.append(
|
||||
cross_device_utils.build_collective_reduce(
|
||||
per_replica.values, self._num_workers,
|
||||
self._collective_keys, "Add", "Id", communication,
|
||||
control_inputs))
|
||||
return reduced_values
|
||||
|
||||
if context.executing_eagerly():
|
||||
batch_fn = def_function.function(batch_fn)
|
||||
|
||||
new_device_grads = [list(x) for x in zip(*batch_fn())]
|
||||
return _ungroup_and_make_mirrored(
|
||||
new_device_grads,
|
||||
per_replica_values[0],
|
||||
reduce_op,
|
||||
num_between_graph_workers=self._num_workers)
|
||||
reduced_values = batch_fn()
|
||||
mirrored = []
|
||||
# Reverse the order of reduced value to recover the order in the input.
|
||||
for value in reversed(reduced_values):
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
# Assume each worker has the same number of replicas.
|
||||
num_replicas = len(value) * self._num_workers
|
||||
for i, v in enumerate(value):
|
||||
with ops.device(v.device):
|
||||
value[i] = v / num_replicas
|
||||
mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored))
|
||||
return mirrored
|
||||
|
||||
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values):
|
||||
"""All-reduce IndexedSlices across all workers in a batch."""
|
||||
@ -1106,48 +1110,31 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
|
||||
# Pass self._communication to the runtime as a communication hint.
|
||||
communication_hint = self._communication.value
|
||||
# For now, we use NCCL only when batch_size > 1 and num_packs is 1.
|
||||
# TODO(b/132575814): Enable NCCL if num_packs > 1.
|
||||
# TODO(b/132575814): Switch to NCCL for all collectives when communication
|
||||
# For now, we use NCCL only when batch_size > 1.
|
||||
# TODO(b/132575814): switch to NCCL for all collectives when communication
|
||||
# is NCCL.
|
||||
if self._communication == CollectiveCommunication.NCCL and (
|
||||
len(per_replica_values) == 1 or self._num_packs != 1):
|
||||
if self._communication == CollectiveCommunication.NCCL and len(
|
||||
per_replica_values) == 1:
|
||||
communication_hint = CollectiveCommunication.AUTO.value
|
||||
|
||||
chunked_gv = self._make_gradient_chunks(per_replica_values, self._num_packs)
|
||||
gathered_values = []
|
||||
with ops.name_scope("allreduce"):
|
||||
for per_replica in per_replica_values:
|
||||
gathered_values.append(
|
||||
cross_device_utils.build_collective_gather_indexed_slices(
|
||||
per_replica.values, self._num_workers, self._collective_keys,
|
||||
communication_hint))
|
||||
|
||||
reduced_gv_list = []
|
||||
for chunk in chunked_gv:
|
||||
# By placing all CollectiveReduce ops in a chunk under single name scope,
|
||||
# we ensure they will be picked up by the `ScopedAllocator` grappler
|
||||
# optimizer and packed into a single all-reduce.
|
||||
with ops.name_scope("allreduce"):
|
||||
for grad_and_vars in chunk:
|
||||
grads = [g for g, _ in grad_and_vars]
|
||||
|
||||
# Add control dependencies per device from the last gradients to the
|
||||
# current set, in order to serialize NCCL launches.
|
||||
if (communication_hint == CollectiveCommunication.NCCL.value and
|
||||
reduced_gv_list):
|
||||
control_input_grads = [g for g, _ in reduced_gv_list[-1]]
|
||||
else:
|
||||
control_input_grads = None
|
||||
|
||||
collective_reduced = (
|
||||
cross_device_utils.build_collective_gather_indexed_slices(
|
||||
grads, self._num_workers, self._collective_keys,
|
||||
communication_hint, control_input_grads))
|
||||
result = []
|
||||
for (_, v), g in zip(grad_and_vars, collective_reduced):
|
||||
result.append([g, v])
|
||||
reduced_gv_list.append(result)
|
||||
|
||||
new_device_grads = [list(x) for x in zip(*reduced_gv_list)]
|
||||
return _ungroup_and_make_mirrored(
|
||||
new_device_grads,
|
||||
per_replica_values[0],
|
||||
reduce_op,
|
||||
num_between_graph_workers=self._num_workers)
|
||||
mirrored = []
|
||||
for value in gathered_values:
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
# Assume each worker has the same number of replicas.
|
||||
num_replicas = len(value) * self._num_workers
|
||||
for i, v in enumerate(value):
|
||||
with ops.device(v.device):
|
||||
value[i].values = value[i].values / num_replicas
|
||||
mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored))
|
||||
return mirrored
|
||||
|
||||
|
||||
def choose_the_best(devices, session_config=None):
|
||||
|
@ -24,6 +24,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import collective_util
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
@ -463,8 +464,7 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
num_gpus=0,
|
||||
communication=CollectiveCommunication.AUTO,
|
||||
use_strategy_object=False,
|
||||
local_mode=False,
|
||||
num_packs=1):
|
||||
local_mode=False):
|
||||
collective_keys = cross_device_utils.CollectiveKeys(
|
||||
group_key_start=10 + CollectiveAllReduceTest.collective_key_base,
|
||||
op_instance_key_start=100 + CollectiveAllReduceTest.collective_key_base,
|
||||
@ -487,7 +487,6 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
1,
|
||||
num_gpus,
|
||||
collective_keys=collective_keys,
|
||||
num_packs=num_packs,
|
||||
communication=communication)
|
||||
return collective_all_reduce_ops, devices, ""
|
||||
else:
|
||||
@ -520,7 +519,6 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
NUM_WORKERS,
|
||||
num_gpus,
|
||||
collective_keys=collective_keys,
|
||||
num_packs=num_packs,
|
||||
communication=communication)
|
||||
return (collective_all_reduce_ops, devices,
|
||||
"grpc://" + self._cluster_spec[task_type][task_id])
|
||||
@ -532,15 +530,14 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
communication,
|
||||
use_strategy_object=False,
|
||||
local_mode=False,
|
||||
num_packs=1):
|
||||
hints=None):
|
||||
collective_all_reduce, devices, master_target = self._get_test_objects(
|
||||
task_type,
|
||||
task_id,
|
||||
num_gpus,
|
||||
communication=communication,
|
||||
use_strategy_object=use_strategy_object,
|
||||
local_mode=local_mode,
|
||||
num_packs=num_packs)
|
||||
local_mode=local_mode)
|
||||
if local_mode:
|
||||
num_workers = 1
|
||||
worker_device = None
|
||||
@ -553,17 +550,19 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
if use_strategy_object:
|
||||
with test_object.scope():
|
||||
return test_object.extended.reduce_to(reduce_op, per_replica,
|
||||
destinations)
|
||||
destinations, hints)
|
||||
else:
|
||||
return test_object.reduce(reduce_op, per_replica, destinations)
|
||||
return test_object.reduce(reduce_op, per_replica, destinations, hints)
|
||||
|
||||
def _batch_reduce(test_object, reduce_op, value_destination_pairs):
|
||||
if use_strategy_object:
|
||||
with test_object.scope():
|
||||
return test_object.extended.batch_reduce_to(reduce_op,
|
||||
value_destination_pairs)
|
||||
value_destination_pairs,
|
||||
hints)
|
||||
else:
|
||||
return test_object.batch_reduce(reduce_op, value_destination_pairs)
|
||||
return test_object.batch_reduce(reduce_op, value_destination_pairs,
|
||||
hints)
|
||||
|
||||
with ops.Graph().as_default(), \
|
||||
ops.device(worker_device), \
|
||||
@ -724,16 +723,17 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
mode=["graph"],
|
||||
required_gpus=[0, 1, 2],
|
||||
use_strategy_object=[True, False],
|
||||
num_packs=[1, 2]))
|
||||
bytes_per_pack=[0, 1, 4]))
|
||||
def testReductionDistributed(self, required_gpus, use_strategy_object,
|
||||
num_packs):
|
||||
bytes_per_pack):
|
||||
hints = collective_util.Hints(bytes_per_pack=bytes_per_pack)
|
||||
self._run_between_graph_clients(
|
||||
self._test_reduction,
|
||||
self._cluster_spec,
|
||||
required_gpus,
|
||||
communication=CollectiveCommunication.RING,
|
||||
use_strategy_object=use_strategy_object,
|
||||
num_packs=num_packs)
|
||||
hints=hints)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.ops import collective_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nccl_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
OP_INSTANCE_KEY_START_NUMBER = 100
|
||||
@ -896,6 +897,67 @@ def stitch_values(values_and_indices_list):
|
||||
return result
|
||||
|
||||
|
||||
def per_replica_num_elements(per_replica):
|
||||
"""Returns the static number of elements of one replica.
|
||||
|
||||
Args:
|
||||
per_replica: A PerReplica of Tensor or IndexedSlices.
|
||||
|
||||
Returns:
|
||||
Number of elements. None if some replica has a different or unknown shape.
|
||||
"""
|
||||
|
||||
values = per_replica._values # pylint: disable=protected-access
|
||||
s0 = values[0].shape
|
||||
for v in values:
|
||||
assert not isinstance(v, ops.IndexedSlices)
|
||||
if v.shape != s0:
|
||||
return None
|
||||
return s0.num_elements()
|
||||
|
||||
|
||||
def pack_by_size(per_replica_list, bytes_per_pack):
|
||||
"""Packs `per_replica_list` into chunks of `bytes_per_pack`.
|
||||
|
||||
The method preserves the original order of `per_replica_list`. The packing is
|
||||
best effort, each pack could have more or less bytes than `bytes_per_pack`.
|
||||
It only packs values with known shape. Note that, the usage is different from
|
||||
`cross_device_ops._pack_tensors`, this function is intended to work with the
|
||||
ScopeAllocator style batching used in `CollectiveAllReduce`.
|
||||
|
||||
Args:
|
||||
per_replica_list: A list of PerReplica.
|
||||
bytes_per_pack: Bytes per pack.
|
||||
|
||||
Returns:
|
||||
A list of packs of PerReplica. All values are packed into one pack if
|
||||
`bytes_per_pack` is zero or any of the value has unknown shape.
|
||||
"""
|
||||
|
||||
if bytes_per_pack == 0:
|
||||
return [per_replica_list]
|
||||
packs = []
|
||||
last_pack_size = 0
|
||||
for value in per_replica_list:
|
||||
num_elements = per_replica_num_elements(value)
|
||||
if num_elements is None:
|
||||
# Can't pack values with unknown shape.
|
||||
logging.warning(
|
||||
'not packing values due to the unknown or inconsistent shape of %s',
|
||||
value)
|
||||
return [per_replica_list]
|
||||
size = num_elements * value._primary.dtype.size # pylint: disable=protected-access
|
||||
# Try to keep each pack as close to bytes_per_pack as possible, while each
|
||||
# pack is at least bytes_per_pack large. I.E. we err on the side of having
|
||||
# few but large packs.
|
||||
if not packs or last_pack_size > bytes_per_pack:
|
||||
packs.append([])
|
||||
last_pack_size = 0
|
||||
packs[-1].append(value)
|
||||
last_pack_size += size
|
||||
return packs
|
||||
|
||||
|
||||
def _control_input(inputs, control_inputs, idx):
|
||||
"""Returns the `idx`-th item in control_inputs to be used in ops.control_dependencies.
|
||||
|
||||
|
@ -26,8 +26,11 @@ from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
@ -133,8 +136,86 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertIsInstance(result, ops.IndexedSlices)
|
||||
self._assert_values_equal(t, result)
|
||||
self.assertEqual(device_util.resolve(destination),
|
||||
device_util.resolve(result.device))
|
||||
self.assertEqual(
|
||||
device_util.resolve(destination), device_util.resolve(result.device))
|
||||
|
||||
|
||||
class PackBySizeTest(test.TestCase):
|
||||
|
||||
def assertShape(self, per_replica, shape):
|
||||
for v in per_replica._values: # pylint: disable=protected-access
|
||||
self.assertEqual(v.shape, shape)
|
||||
|
||||
def testPreferLargerPack(self):
|
||||
# Each packs except the last one should be equal or larger than
|
||||
# bytes_per_pack.
|
||||
values = [
|
||||
# size = 2 * 4 * 4 * 4 = 128
|
||||
array_ops.ones([2, 4, 4], dtype=dtypes.float32),
|
||||
# size = 8 * 4 = 32
|
||||
array_ops.ones([8], dtype=dtypes.int32),
|
||||
# size = 10 * 10 * 8 = 800
|
||||
array_ops.ones([10, 10], dtype=dtypes.int64),
|
||||
# size = 1 * 4 = 4
|
||||
array_ops.ones([1], dtype=dtypes.int32),
|
||||
]
|
||||
per_replica_values = [value_lib.PerReplica([v, v]) for v in values]
|
||||
packs = cross_device_utils.pack_by_size(
|
||||
per_replica_values, bytes_per_pack=200)
|
||||
self.assertLen(packs, 2)
|
||||
self.assertLen(packs[0], 3)
|
||||
self.assertShape(packs[0][0], [2, 4, 4])
|
||||
self.assertShape(packs[0][1], [8])
|
||||
self.assertShape(packs[0][2], [10, 10])
|
||||
self.assertLen(packs[1], 1)
|
||||
self.assertShape(packs[1][0], [1])
|
||||
|
||||
def testZeroBytesPerPack(self):
|
||||
values = [
|
||||
array_ops.ones([1], dtype=dtypes.float32),
|
||||
array_ops.ones([2], dtype=dtypes.float32),
|
||||
]
|
||||
per_replica_values = [value_lib.PerReplica([v, v]) for v in values]
|
||||
packs = cross_device_utils.pack_by_size(
|
||||
per_replica_values, bytes_per_pack=0)
|
||||
self.assertLen(packs, 1)
|
||||
self.assertLen(packs[0], 2)
|
||||
self.assertShape(packs[0][0], [1])
|
||||
self.assertShape(packs[0][1], [2])
|
||||
|
||||
def testUnknownShape(self):
|
||||
per_replica_values = [
|
||||
value_lib.PerReplica([
|
||||
array_ops.ones([10, 10], dtype=dtypes.float32),
|
||||
array_ops.ones([10, 10], dtype=dtypes.float32),
|
||||
]),
|
||||
value_lib.PerReplica([
|
||||
array_ops.ones([10, 10], dtype=dtypes.float32),
|
||||
input_layer.Input(
|
||||
shape=(10), batch_size=None, dtype=dtypes.float32),
|
||||
]),
|
||||
]
|
||||
packs = cross_device_utils.pack_by_size(
|
||||
per_replica_values, bytes_per_pack=1)
|
||||
self.assertLen(packs, 1)
|
||||
self.assertEqual(packs[0], per_replica_values)
|
||||
|
||||
def testInconsistentShape(self):
|
||||
per_replica_values = [
|
||||
value_lib.PerReplica([
|
||||
array_ops.ones([10, 10], dtype=dtypes.float32),
|
||||
array_ops.ones([10, 10], dtype=dtypes.float32),
|
||||
]),
|
||||
value_lib.PerReplica([
|
||||
array_ops.ones([10, 10], dtype=dtypes.float32),
|
||||
input_layer.Input(
|
||||
shape=(10), batch_size=None, dtype=dtypes.float32),
|
||||
]),
|
||||
]
|
||||
packs = cross_device_utils.pack_by_size(
|
||||
per_replica_values, bytes_per_pack=1)
|
||||
self.assertLen(packs, 1)
|
||||
self.assertEqual(packs[0], per_replica_values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -108,6 +108,7 @@ import six
|
||||
from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
|
||||
from tensorflow.python.autograph.impl import api as autograph
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import collective_util
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import numpy_dataset
|
||||
@ -1719,10 +1720,10 @@ class StrategyExtendedV2(object):
|
||||
def _reduce(self, reduce_op, value):
|
||||
# Default implementation until we have an implementation for each strategy.
|
||||
return self._local_results(
|
||||
self._reduce_to(reduce_op, value,
|
||||
device_util.current() or "/device:CPU:0"))[0]
|
||||
self.reduce_to(reduce_op, value,
|
||||
device_util.current() or "/device:CPU:0"))[0]
|
||||
|
||||
def reduce_to(self, reduce_op, value, destinations):
|
||||
def reduce_to(self, reduce_op, value, destinations, experimental_hints=None):
|
||||
"""Combine (via e.g. sum or mean) values across replicas.
|
||||
|
||||
Args:
|
||||
@ -1732,6 +1733,8 @@ class StrategyExtendedV2(object):
|
||||
string. The return value will be copied to all destination devices (or
|
||||
all the devices where the `destinations` value resides). To perform an
|
||||
all-reduction, pass `value` to `destinations`.
|
||||
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
|
||||
to perform collective operations.
|
||||
|
||||
Returns:
|
||||
A tensor or value mirrored to `destinations`.
|
||||
@ -1744,18 +1747,25 @@ class StrategyExtendedV2(object):
|
||||
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
|
||||
assert (reduce_op == reduce_util.ReduceOp.SUM or
|
||||
reduce_op == reduce_util.ReduceOp.MEAN)
|
||||
return self._reduce_to(reduce_op, value, destinations)
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
return self._reduce_to(reduce_op, value, destinations, experimental_hints)
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations):
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def batch_reduce_to(self, reduce_op, value_destination_pairs):
|
||||
def batch_reduce_to(self,
|
||||
reduce_op,
|
||||
value_destination_pairs,
|
||||
experimental_hints=None):
|
||||
"""Combine multiple `reduce_to` calls into one for faster execution.
|
||||
|
||||
Args:
|
||||
reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
|
||||
value_destination_pairs: A sequence of (value, destinations)
|
||||
pairs. See `reduce_to()` for a description.
|
||||
value_destination_pairs: A sequence of (value, destinations) pairs. See
|
||||
`reduce_to()` for a description.
|
||||
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
|
||||
to perform collective operations.
|
||||
|
||||
Returns:
|
||||
A list of mirrored values, one per pair in `value_destination_pairs`.
|
||||
@ -1765,11 +1775,16 @@ class StrategyExtendedV2(object):
|
||||
assert not isinstance(reduce_op, variable_scope.VariableAggregation)
|
||||
if isinstance(reduce_op, six.string_types):
|
||||
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
|
||||
return self._batch_reduce_to(reduce_op, value_destination_pairs)
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
return self._batch_reduce_to(reduce_op, value_destination_pairs,
|
||||
experimental_hints)
|
||||
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
return [
|
||||
self.reduce_to(reduce_op, t, destinations=v)
|
||||
self.reduce_to(
|
||||
reduce_op, t, destinations=v, experimental_hints=experimental_hints)
|
||||
for t, v in value_destination_pairs
|
||||
]
|
||||
|
||||
@ -2267,7 +2282,7 @@ class ReplicaContext(object):
|
||||
require_replica_context(self)
|
||||
return (device_util.current(),)
|
||||
|
||||
def all_reduce(self, reduce_op, value):
|
||||
def all_reduce(self, reduce_op, value, experimental_hints=None):
|
||||
"""All-reduces the given `value Tensor` nest across replicas.
|
||||
|
||||
If `all_reduce` is called in any replica, it must be called in all replicas.
|
||||
@ -2289,16 +2304,21 @@ class ReplicaContext(object):
|
||||
reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
|
||||
value: The nested structure of `Tensor`s to all-reduce. The structure must
|
||||
be compatible with `tf.nest`.
|
||||
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
|
||||
to perform collective operations.
|
||||
|
||||
Returns:
|
||||
A `Tensor` nest with the reduced `value`s from each replica.
|
||||
"""
|
||||
if isinstance(reduce_op, six.string_types):
|
||||
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
|
||||
def batch_all_reduce(strategy, *value_flat):
|
||||
return strategy.extended.batch_reduce_to(
|
||||
reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat])
|
||||
reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat],
|
||||
experimental_hints)
|
||||
|
||||
if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:
|
||||
# TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
|
||||
@ -2449,9 +2469,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
|
||||
replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations):
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
# TODO(josh11b): Use destinations?
|
||||
del reduce_op, destinations
|
||||
del reduce_op, destinations, experimental_hints
|
||||
return value
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
|
@ -95,8 +95,8 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
|
||||
def _local_results(self, value):
|
||||
return (value,)
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations):
|
||||
del reduce_op, destinations
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
del reduce_op, destinations, experimental_hints
|
||||
return value
|
||||
|
||||
def _experimental_make_numpy_dataset(self, numpy_input, session):
|
||||
|
@ -788,7 +788,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
def _get_cross_device_ops(self):
|
||||
return self._cross_device_ops or self._inferred_cross_device_ops
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations):
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
if (isinstance(value, values.Mirrored) and
|
||||
reduce_op == reduce_util.ReduceOp.MEAN):
|
||||
return value
|
||||
@ -801,11 +801,16 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
return cross_device_ops_lib.reduce_non_distributed_value(
|
||||
reduce_op, value, destinations, self._num_replicas_in_sync)
|
||||
return self._get_cross_device_ops().reduce(
|
||||
reduce_op, value, destinations=destinations)
|
||||
reduce_op,
|
||||
value,
|
||||
destinations=destinations,
|
||||
experimental_hints=experimental_hints)
|
||||
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
return self._get_cross_device_ops().batch_reduce(reduce_op,
|
||||
value_destination_pairs)
|
||||
value_destination_pairs,
|
||||
experimental_hints)
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
# TODO(josh11b): In eager mode, use one thread per device.
|
||||
|
@ -356,8 +356,8 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
||||
with ops.device(self._device), _OneDeviceReplicaContext(strategy):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations):
|
||||
del reduce_op, destinations
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
del reduce_op, destinations, experimental_hints
|
||||
return value
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
|
@ -466,20 +466,25 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
"Cannot reduce to another worker: %r, current worker is %r" %
|
||||
(d, self._input_workers.worker_devices[0]))
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations):
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
self._verify_destinations_not_different_worker(destinations)
|
||||
if not isinstance(value, values.DistributedValues):
|
||||
# pylint: disable=protected-access
|
||||
return cross_device_ops_lib.reduce_non_distributed_value(
|
||||
reduce_op, value, destinations, self._num_replicas_in_sync)
|
||||
return self._cross_device_ops.reduce(
|
||||
reduce_op, value, destinations=destinations)
|
||||
reduce_op,
|
||||
value,
|
||||
destinations=destinations,
|
||||
experimental_hints=experimental_hints)
|
||||
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
for _, destinations in value_destination_pairs:
|
||||
self._verify_destinations_not_different_worker(destinations)
|
||||
return self._cross_device_ops.batch_reduce(reduce_op,
|
||||
value_destination_pairs)
|
||||
value_destination_pairs,
|
||||
experimental_hints)
|
||||
|
||||
def _select_single_value(self, structured):
|
||||
"""Select any single value in `structured`."""
|
||||
|
@ -659,7 +659,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
tpu_values.TPUSyncOnReadVariable,
|
||||
**kwargs)
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations):
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
if (isinstance(value, values.DistributedValues) or
|
||||
tensor_util.is_tensor(value)
|
||||
) and tpu_values.enclosing_tpu_context() is not None:
|
||||
|
@ -8,11 +8,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -24,10 +24,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -10,11 +10,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -26,10 +26,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -10,11 +10,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -26,10 +26,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -9,11 +9,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -25,10 +25,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "all_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "merge_call"
|
||||
|
@ -37,7 +37,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_to"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast_to"
|
||||
@ -69,7 +69,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_to"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "update"
|
||||
|
@ -0,0 +1,9 @@
|
||||
path: "tensorflow.distribute.experimental.CollectiveHints"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.collective_util.Hints\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'bytes_per_pack\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
||||
}
|
||||
}
|
@ -8,6 +8,10 @@ tf_module {
|
||||
name: "CollectiveCommunication"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "CollectiveHints"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "MultiWorkerMirroredStrategy"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -8,11 +8,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -24,10 +24,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -10,11 +10,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -26,10 +26,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -10,11 +10,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -26,10 +26,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -9,11 +9,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
@ -25,10 +25,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_implementation"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "all_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "merge_call"
|
||||
|
@ -20,7 +20,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_to"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "colocate_vars_with"
|
||||
@ -32,7 +32,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_to"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "update"
|
||||
|
@ -0,0 +1,9 @@
|
||||
path: "tensorflow.distribute.experimental.CollectiveHints"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.collective_util.Hints\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'bytes_per_pack\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
||||
}
|
||||
}
|
@ -8,6 +8,10 @@ tf_module {
|
||||
name: "CollectiveCommunication"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "CollectiveHints"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "MultiWorkerMirroredStrategy"
|
||||
mtype: "<type \'type\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user