Add an experimental_hints to batch all reduce

This contains all performance hints to the API. Currently there's only bytes_per_pack, which splits large batches into multiple packs allows overlapping communication and computation.

Currently we can only pack if all Tensors in the batch have known shapes.

PiperOrigin-RevId: 297269428
Change-Id: Iaf7d7d3adf7c6cad59aa6079fbcd36b31e92c4b5
This commit is contained in:
Ran Chen 2020-02-25 20:30:00 -08:00 committed by TensorFlower Gardener
parent 26a24de29b
commit fa08cfd489
29 changed files with 504 additions and 236 deletions

View File

@ -63,6 +63,7 @@ py_library(
srcs = ["cross_device_ops.py"], srcs = ["cross_device_ops.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":collective_util",
":cross_device_utils", ":cross_device_utils",
":device_util", ":device_util",
":reduce_util", ":reduce_util",
@ -97,6 +98,7 @@ py_library(
"//tensorflow/python:gradients", "//tensorflow/python:gradients",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:nccl_ops", "//tensorflow/python:nccl_ops",
"//tensorflow/python:platform",
], ],
) )
@ -145,6 +147,7 @@ py_library(
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":collective_util",
":device_util", ":device_util",
":numpy_dataset", ":numpy_dataset",
":reduce_util", ":reduce_util",
@ -580,6 +583,15 @@ py_library(
], ],
) )
py_library(
name = "collective_util",
srcs = ["collective_util.py"],
deps = [
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
],
)
py_library( py_library(
name = "shared_variable_creator", name = "shared_variable_creator",
srcs = ["shared_variable_creator.py"], srcs = ["shared_variable_creator.py"],
@ -795,7 +807,9 @@ cuda_py_test(
name = "cross_device_utils_test", name = "cross_device_utils_test",
srcs = ["cross_device_utils_test.py"], srcs = ["cross_device_utils_test.py"],
deps = [ deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op", "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:combinations",
@ -815,6 +829,7 @@ cuda_py_test(
], ],
deps = [ deps = [
":collective_all_reduce_strategy", ":collective_all_reduce_strategy",
":collective_util",
":mirrored_strategy", ":mirrored_strategy",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:constant_op", "//tensorflow/python:constant_op",

View File

@ -95,6 +95,7 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
TFConfigClusterResolver which is instantiated from the TF_CONFIG env TFConfigClusterResolver which is instantiated from the TF_CONFIG env
var. var.
""" """
# TODO(b/150151677): consider move communication to CollectiveHints.
super(CollectiveAllReduceStrategy, self).__init__( super(CollectiveAllReduceStrategy, self).__init__(
CollectiveAllReduceExtended( CollectiveAllReduceExtended(
self, self,
@ -505,7 +506,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
return updated_config return updated_config
def _reduce_to(self, reduce_op, value, destinations): def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
if (isinstance(value, values.Mirrored) and if (isinstance(value, values.Mirrored) and
reduce_op == reduce_util.ReduceOp.MEAN): reduce_op == reduce_util.ReduceOp.MEAN):
return value return value
@ -526,7 +527,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
return cross_device_ops_lib.reduce_non_distributed_value( return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, value, destinations, len(self.worker_devices)) reduce_op, value, destinations, len(self.worker_devices))
return self._get_cross_device_ops().reduce( return self._get_cross_device_ops().reduce(
reduce_op, value, destinations=destinations) reduce_op,
value,
destinations=destinations,
experimental_hints=experimental_hints)
def _warn_nccl_no_gpu(self): def _warn_nccl_no_gpu(self):
if ((self._communication == if ((self._communication ==

View File

@ -0,0 +1,63 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for collectives."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.util.tf_export import tf_export
@tf_export("distribute.experimental.CollectiveHints")
class Hints(object):
"""Hints for collective operations like AllReduce.
This can be passed to methods like
`tf.distribute.get_replica_context().all_reduce()` to optimize collective
operation performance. Note that these are only hints, which may or may not
change the actual behavior. Some options only apply to certain strategy and
are ignored by others.
One common optimization is to break gradients all-reduce into multiple packs
so that weight updates can overlap with gradient all-reduce.
Example:
```python
hints = tf.distribute.experimental.CollectiveHints(
bytes_per_pack=50 * 1024 * 1024)
grads = tf.distribute.get_replica_context().all_reduce(
'sum', grads, experimental_hints=hints)
optimizer.apply_gradients(zip(grads, vars), all_reduce_sum_gradients=False)
```
"""
def __init__(self, bytes_per_pack=0):
"""Creates a CollectiveHints.
Args:
bytes_per_pack: A non-negative integer. Breaks collective operations into
packs of certain size. If it's zero, the value is determined
automatically. This only applies to all-reduce with
`MultiWorkerMirroredStrategy` currently.
Raises:
ValueError: When arguments have invalid value.
"""
if bytes_per_pack < 0:
raise ValueError("bytes_per_pack must be non-negative")
self.bytes_per_pack = bytes_per_pack

View File

@ -19,11 +19,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import enum
import enum
import six import six
from tensorflow.python.client import device_lib from tensorflow.python.client import device_lib
from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
@ -222,7 +223,11 @@ class CrossDeviceOps(object):
# Returns 1 by default, the value may be overridden by sub classes. # Returns 1 by default, the value may be overridden by sub classes.
return 1 return 1
def reduce(self, reduce_op, per_replica_value, destinations): def reduce(self,
reduce_op,
per_replica_value,
destinations,
experimental_hints=None):
"""Reduce `per_replica_value` to `destinations`. """Reduce `per_replica_value` to `destinations`.
It runs the reduction operation defined by `reduce_op` and put the It runs the reduction operation defined by `reduce_op` and put the
@ -231,8 +236,10 @@ class CrossDeviceOps(object):
Args: Args:
reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
per_replica_value will be reduced. per_replica_value will be reduced.
per_replica_value: a PerReplica object or a tensor with device set. per_replica_value: A PerReplica object or a tensor with device set.
destinations: the reduction destinations. destinations: the reduction destinations.
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
to perform collective operations.
Returns: Returns:
a Mirrored object. a Mirrored object.
@ -254,10 +261,15 @@ class CrossDeviceOps(object):
per_replica_value.values, per_replica_value.values,
wrap_class=value_lib.Mirrored) wrap_class=value_lib.Mirrored)
if experimental_hints is None:
experimental_hints = collective_util.Hints()
return self.reduce_implementation(reduce_op, per_replica_value, return self.reduce_implementation(reduce_op, per_replica_value,
destinations) destinations, experimental_hints)
def batch_reduce(self, reduce_op, value_destination_pairs): def batch_reduce(self,
reduce_op,
value_destination_pairs,
experimental_hints=None):
"""Reduce PerReplica objects in a batch. """Reduce PerReplica objects in a batch.
Reduce each first element in `value_destination_pairs` to each second Reduce each first element in `value_destination_pairs` to each second
@ -267,10 +279,12 @@ class CrossDeviceOps(object):
fuse several tensors into one or multiple packs before reduction. fuse several tensors into one or multiple packs before reduction.
Args: Args:
reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how the
the `per_replica_value` will be reduced. `per_replica_value` will be reduced.
value_destination_pairs: a list or a tuple of PerReplica objects value_destination_pairs: A list or a tuple of PerReplica objects (or
(or tensors with device set if there is one device) and destinations. tensors with device set if there is one device) and destinations.
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
to perform collective operations.
Returns: Returns:
a list of Mirrored objects. a list of Mirrored objects.
@ -299,7 +313,10 @@ class CrossDeviceOps(object):
for v, _ in value_destination_pairs for v, _ in value_destination_pairs
] ]
return self.batch_reduce_implementation(reduce_op, value_destination_pairs) if experimental_hints is None:
experimental_hints = collective_util.Hints()
return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
experimental_hints)
def broadcast(self, tensor, destinations): def broadcast(self, tensor, destinations):
"""Broadcast the `tensor` to destinations. """Broadcast the `tensor` to destinations.
@ -315,7 +332,8 @@ class CrossDeviceOps(object):
return self.broadcast_implementation(tensor, destinations) return self.broadcast_implementation(tensor, destinations)
@doc_controls.for_subclass_implementers @doc_controls.for_subclass_implementers
def reduce_implementation(self, reduce_op, per_replica_value, destinations): def reduce_implementation(self, reduce_op, per_replica_value, destinations,
experimental_hints):
"""The implementation of reduce of `per_replica_value` to `destinations`. """The implementation of reduce of `per_replica_value` to `destinations`.
Overriding this method is useful for subclass implementers. Overriding this method is useful for subclass implementers.
@ -326,8 +344,10 @@ class CrossDeviceOps(object):
Args: Args:
reduce_op: An instance `tf.distribute.ReduceOp` that indicates of how reduce_op: An instance `tf.distribute.ReduceOp` that indicates of how
per_replica_value will be reduced. per_replica_value will be reduced.
per_replica_value: a PerReplica object or a tensor with device set. per_replica_value: A PerReplica object or a tensor with device set.
destinations: the reduction destinations. destinations: the reduction destinations.
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
to perform collective operations.
Returns: Returns:
a Mirrored object. a Mirrored object.
@ -340,7 +360,8 @@ class CrossDeviceOps(object):
"_reduce method must be implemented in descendants.") "_reduce method must be implemented in descendants.")
@doc_controls.for_subclass_implementers @doc_controls.for_subclass_implementers
def batch_reduce_implementation(self, reduce_op, value_destination_pairs): def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
experimental_hints):
"""Implementation of reduce PerReplica objects in a batch. """Implementation of reduce PerReplica objects in a batch.
Overriding this method is useful for subclass implementers. Overriding this method is useful for subclass implementers.
@ -351,8 +372,10 @@ class CrossDeviceOps(object):
Args: Args:
reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
per_replica_value will be reduced. per_replica_value will be reduced.
value_destination_pairs: an iterable of tuples of PerReplica objects value_destination_pairs: An iterable of tuples of PerReplica objects
(or tensors with device set if there is one device) and destinations. (or tensors with device set if there is one device) and destinations.
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
to perform collective operations.
Returns: Returns:
a list of Mirrored objects. a list of Mirrored objects.
@ -362,7 +385,8 @@ class CrossDeviceOps(object):
tuples of PerReplica objects and destinations tuples of PerReplica objects and destinations
""" """
raise NotImplementedError( raise NotImplementedError(
"_batch_reduce method must be implemented in descendants.") "batch_reduce_implementation method must be implemented in descendants."
)
@doc_controls.for_subclass_implementers @doc_controls.for_subclass_implementers
def broadcast_implementation(self, tensor, destinations): def broadcast_implementation(self, tensor, destinations):
@ -403,7 +427,9 @@ class ReductionToOneDevice(CrossDeviceOps):
self.accumulation_fn = accumulation_fn or math_ops.add_n self.accumulation_fn = accumulation_fn or math_ops.add_n
super(ReductionToOneDevice, self).__init__() super(ReductionToOneDevice, self).__init__()
def reduce_implementation(self, reduce_op, per_replica_value, destinations): def reduce_implementation(self, reduce_op, per_replica_value, destinations,
experimental_hints):
del experimental_hints # Unused.
if check_destinations(destinations): if check_destinations(destinations):
devices = get_devices_from(destinations) devices = get_devices_from(destinations)
else: else:
@ -416,9 +442,11 @@ class ReductionToOneDevice(CrossDeviceOps):
self.accumulation_fn, reduce_op) self.accumulation_fn, reduce_op)
return self.broadcast(reduced, destinations) return self.broadcast(reduced, destinations)
def batch_reduce_implementation(self, reduce_op, value_destination_pairs): def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
experimental_hints):
return [ return [
self.reduce_implementation(reduce_op, t, destinations=v) self.reduce_implementation(
reduce_op, t, destinations=v, experimental_hints=experimental_hints)
for t, v in value_destination_pairs for t, v in value_destination_pairs
] ]
@ -626,21 +654,24 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
self._simple_cross_replica_ops = ReductionToOneDevice() self._simple_cross_replica_ops = ReductionToOneDevice()
super(AllReduceCrossDeviceOps, self).__init__() super(AllReduceCrossDeviceOps, self).__init__()
def reduce_implementation(self, reduce_op, per_replica_value, destinations): def reduce_implementation(self, reduce_op, per_replica_value, destinations,
experimental_hints):
del experimental_hints # Unused.
if _devices_match(per_replica_value, destinations): if _devices_match(per_replica_value, destinations):
return self._batch_all_reduce(reduce_op, [per_replica_value])[0] return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
else: else:
return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value, return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
destinations) destinations)
def batch_reduce_implementation(self, reduce_op, value_destination_pairs): def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
experimental_hints):
if _all_devices_match(value_destination_pairs): if _all_devices_match(value_destination_pairs):
return self._batch_all_reduce(reduce_op, return self._batch_all_reduce(reduce_op,
[v[0] for v in value_destination_pairs]) [v[0] for v in value_destination_pairs])
else: else:
return [ return [
self.reduce_implementation(reduce_op, t, destinations=v) self.reduce_implementation(reduce_op, value, dest, experimental_hints)
for t, v in value_destination_pairs for value, dest in value_destination_pairs
] ]
def _batch_all_reduce(self, reduce_op, per_replica_values): def _batch_all_reduce(self, reduce_op, per_replica_values):
@ -904,7 +935,6 @@ class CollectiveAllReduce(CrossDeviceOps):
def __init__(self, def __init__(self,
num_workers=1, num_workers=1,
num_gpus_per_worker=0, num_gpus_per_worker=0,
num_packs=1,
collective_keys=None, collective_keys=None,
communication=CollectiveCommunication.AUTO): communication=CollectiveCommunication.AUTO):
"""Initializes the object. """Initializes the object.
@ -912,13 +942,11 @@ class CollectiveAllReduce(CrossDeviceOps):
Args: Args:
num_workers: number of workers in the between-graph replicated training. num_workers: number of workers in the between-graph replicated training.
num_gpus_per_worker: number of GPUs per worker. num_gpus_per_worker: number of GPUs per worker.
num_packs: gradients will be packed into `num_packs` chunks.
collective_keys: an optional CollectiveKey object. collective_keys: an optional CollectiveKey object.
communication: indicates which collective communication to use. communication: indicates which collective communication to use.
""" """
self._num_workers = num_workers self._num_workers = num_workers
self._num_gpus_per_worker = num_gpus_per_worker self._num_gpus_per_worker = num_gpus_per_worker
self._num_packs = num_packs
self._collective_keys = (collective_keys or self._collective_keys = (collective_keys or
cross_device_utils.CollectiveKeys()) cross_device_utils.CollectiveKeys())
self._communication = communication self._communication = communication
@ -928,8 +956,10 @@ class CollectiveAllReduce(CrossDeviceOps):
def _num_between_graph_workers(self): def _num_between_graph_workers(self):
return self._num_workers return self._num_workers
def reduce_implementation(self, reduce_op, per_replica_value, destinations): def reduce_implementation(self, reduce_op, per_replica_value, destinations,
all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0] experimental_hints):
all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value],
experimental_hints)[0]
devices = get_devices_from(destinations) devices = get_devices_from(destinations)
if (isinstance(all_reduced, value_lib.Mirrored) and if (isinstance(all_reduced, value_lib.Mirrored) and
@ -958,11 +988,13 @@ class CollectiveAllReduce(CrossDeviceOps):
index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access
return value_lib.regroup(index, wrap_class=value_lib.Mirrored) return value_lib.regroup(index, wrap_class=value_lib.Mirrored)
def batch_reduce_implementation(self, reduce_op, value_destination_pairs): def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
experimental_hints):
all_devices_match = _all_devices_match(value_destination_pairs) all_devices_match = _all_devices_match(value_destination_pairs)
if all_devices_match: if all_devices_match:
return self._batch_all_reduce(reduce_op, return self._batch_all_reduce(reduce_op,
[v[0] for v in value_destination_pairs]) [v[0] for v in value_destination_pairs],
experimental_hints)
else: else:
if not all_devices_match: if not all_devices_match:
logging.log_first_n( logging.log_first_n(
@ -970,47 +1002,18 @@ class CollectiveAllReduce(CrossDeviceOps):
"destinations are different.", 10) "destinations are different.", 10)
return [ return [
self.reduce_implementation(reduce_op, t, destinations=v) self.reduce_implementation(reduce_op, value, dest, experimental_hints)
for t, v in value_destination_pairs for value, dest in value_destination_pairs
] ]
def _make_gradient_chunks(self, per_replica_values, num_packs): def _batch_all_reduce(self, reduce_op, per_replica_values,
"""Make `per_replica_values` into chunks.""" experimental_hints):
chunked_by_device = _group_value_by_device(per_replica_values)
chunked_by_var = list(zip(*chunked_by_device))
# chunked_by_var is chunked by variables and takes the following format:
# [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..),
# ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..),
# ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..),
# ...
# ]
# No chunking if number of variables is fewer than number of packs.
if len(chunked_by_var) < num_packs:
return [chunked_by_var]
# First n-1 chunks get `chunk_size` grads, last chunk gets leftover grads.
# This strategy can cause the last chunk to have larger size compared to the
# first n-1 chunks. Alternatively, we can increment chunk_size by 1 to get
# slightly larger first n-1 chunks and smaller last chunk.
# TODO(ayushd): compare different packing strategies.
chunk_size = len(chunked_by_var) // num_packs
leftover_size = len(chunked_by_var) - chunk_size * (num_packs - 1)
assert leftover_size > 0
chunked_gv = [
chunked_by_var[x:x + chunk_size]
for x in range(0, len(chunked_by_var) - leftover_size, chunk_size)
]
chunked_gv.append(chunked_by_var[-leftover_size:])
return chunked_gv
def _batch_all_reduce(self, reduce_op, per_replica_values):
"""All reduce algorithm in a batch.""" """All reduce algorithm in a batch."""
dense_values, dense_indices, sparse_values, sparse_indices = ( dense_values, dense_indices, sparse_values, sparse_indices = (
cross_device_utils.split_by_sparsity(per_replica_values)) cross_device_utils.split_by_sparsity(per_replica_values))
if dense_values: if dense_values:
dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values) dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values,
experimental_hints)
else: else:
dense_results = [] dense_results = []
if sparse_values: if sparse_values:
@ -1018,83 +1021,84 @@ class CollectiveAllReduce(CrossDeviceOps):
sparse_values) sparse_values)
else: else:
sparse_results = [] sparse_results = []
return cross_device_utils.stitch_values(((dense_results, dense_indices), return cross_device_utils.stitch_values(
(sparse_results, sparse_indices))) ((dense_results, dense_indices), (sparse_results, sparse_indices)))
def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values): def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values,
experimental_hints):
"""All-reduce across all workers in a batch.""" """All-reduce across all workers in a batch."""
chunked_gv = self._make_gradient_chunks(per_replica_values, self._num_packs)
# Actual number of packs may be different from `self._num_packs`. e.g. if
# there are fewer tensors than `self._num_packs`.
num_actual_packs = len(chunked_gv)
batch_size = len(per_replica_values) batch_size = len(per_replica_values)
# Pass self._communication to the runtime as a communication hint. # Pass self._communication to the runtime as a communication hint.
communication_hint = self._communication.value communication = self._communication.value
# For now, we use NCCL only when batch_size > 1. # For now, we use NCCL only when batch_size > 1.
# TODO(b/132575814): switch to NCCL for all collectives when communication # TODO(b/132575814): switch to NCCL for all collectives when communication
# is NCCL. # is NCCL.
if self._communication == CollectiveCommunication.NCCL and batch_size == 1: if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
communication_hint = CollectiveCommunication.AUTO.value communication = CollectiveCommunication.AUTO.value
# Reverse the lists so that there's better chance that values follows
# the order in which they are calculated (e.g. when they're gradients), so
# as to overlap calculation with communication. However, this may not be
# optimal for cases like gradients of complicated non-sequential models.
#
# Note that we reverse the list before packing so that the first pack won't
# be too small, since it's more likely for first few packs to have long
# queuing time due to concurrent intense computation.
#
# TODO(b/147393503): explore solutions for optimal ordering.
packs = cross_device_utils.pack_by_size(
list(reversed(per_replica_values)), experimental_hints.bytes_per_pack)
if batch_size > 1: if batch_size > 1:
logging.info( logging.info(
"Collective batch_all_reduce: %d all-reduces, num_workers = %d, " "Collective batch_all_reduce: %d all-reduces, num_workers = %d, "
"communication_hint = %s, num_packs = %d" % "communication_hint = %s, num_packs = %d", batch_size,
(batch_size, self._num_workers, communication_hint, num_actual_packs)) self._num_workers, communication, len(packs))
else: else:
logging.log_first_n( logging.log_first_n(
logging.INFO, "Collective batch_all_reduce: %d all-reduces, " logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
"num_workers = %d, communication_hint = %s, num_packs = %d" % "num_workers = %d, communication_hint = %s, num_packs = %d" %
(batch_size, self._num_workers, communication_hint, num_actual_packs), (batch_size, self._num_workers, communication, len(packs)), 10)
10)
def batch_fn(): def batch_fn():
"""Wrapper function around batched all-reduce calls.""" """Wrapper function around batched all-reduce calls."""
reduced_gv_list = [] reduced_values = []
# Reverse the gradient lists so that the gradient grouping roughly follows for pack in packs:
# the order in which gradients are calculated in backprop. This should # By placing all CollectiveReduce ops in a pack under single name scope,
# enable overlapping gradient all-reduce with backprop for most models. # we ensure they will be picked up by the `ScopedAllocator` grappler
# However, it is likely that for some complicated non-sequential models # optimizer and packed into a single all-reduce.
# this grouping is not optimal.
#
# TODO(b/147393503): explore solutions for optimal gradient grouping.
for chunk in reversed(chunked_gv):
# By placing all CollectiveReduce ops in a chunk under single name
# scope, we ensure they will be picked up by the `ScopedAllocator`
# grappler optimizer and packed into a single all-reduce.
with ops.name_scope("allreduce"): with ops.name_scope("allreduce"):
for grad_and_vars in reversed(chunk): for per_replica in pack:
# Gradients for the same variable but from different devices.
grads = [g for g, _ in grad_and_vars]
# Add control dependencies per device from the last gradients to the # Add control dependencies per device from the last gradients to the
# current set, in order to serialize NCCL launches. # current set, in order to serialize NCCL launches.
if (communication_hint == CollectiveCommunication.NCCL.value and if (communication == CollectiveCommunication.NCCL.value and
reduced_gv_list): reduced_values):
control_input_grads = [g for g, _ in reduced_gv_list[-1]] control_inputs = [g for g in reduced_values[-1]]
else: else:
control_input_grads = None control_inputs = None
collective_reduced = cross_device_utils.build_collective_reduce( reduced_values.append(
grads, self._num_workers, self._collective_keys, "Add", "Id", cross_device_utils.build_collective_reduce(
communication_hint, control_input_grads) per_replica.values, self._num_workers,
result = [] self._collective_keys, "Add", "Id", communication,
for (_, v), g in zip(grad_and_vars, collective_reduced): control_inputs))
result.append([g, v]) return reduced_values
reduced_gv_list.append(result)
# Reverse the batch reduced gradients to (approximately) recover the order
# in the input per_replica_values.
reduced_gv_list.reverse()
return reduced_gv_list
if context.executing_eagerly(): if context.executing_eagerly():
batch_fn = def_function.function(batch_fn) batch_fn = def_function.function(batch_fn)
new_device_grads = [list(x) for x in zip(*batch_fn())] reduced_values = batch_fn()
return _ungroup_and_make_mirrored( mirrored = []
new_device_grads, # Reverse the order of reduced value to recover the order in the input.
per_replica_values[0], for value in reversed(reduced_values):
reduce_op, if reduce_op == reduce_util.ReduceOp.MEAN:
num_between_graph_workers=self._num_workers) # Assume each worker has the same number of replicas.
num_replicas = len(value) * self._num_workers
for i, v in enumerate(value):
with ops.device(v.device):
value[i] = v / num_replicas
mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored))
return mirrored
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values): def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values):
"""All-reduce IndexedSlices across all workers in a batch.""" """All-reduce IndexedSlices across all workers in a batch."""
@ -1106,48 +1110,31 @@ class CollectiveAllReduce(CrossDeviceOps):
# Pass self._communication to the runtime as a communication hint. # Pass self._communication to the runtime as a communication hint.
communication_hint = self._communication.value communication_hint = self._communication.value
# For now, we use NCCL only when batch_size > 1 and num_packs is 1. # For now, we use NCCL only when batch_size > 1.
# TODO(b/132575814): Enable NCCL if num_packs > 1. # TODO(b/132575814): switch to NCCL for all collectives when communication
# TODO(b/132575814): Switch to NCCL for all collectives when communication
# is NCCL. # is NCCL.
if self._communication == CollectiveCommunication.NCCL and ( if self._communication == CollectiveCommunication.NCCL and len(
len(per_replica_values) == 1 or self._num_packs != 1): per_replica_values) == 1:
communication_hint = CollectiveCommunication.AUTO.value communication_hint = CollectiveCommunication.AUTO.value
chunked_gv = self._make_gradient_chunks(per_replica_values, self._num_packs) gathered_values = []
with ops.name_scope("allreduce"):
for per_replica in per_replica_values:
gathered_values.append(
cross_device_utils.build_collective_gather_indexed_slices(
per_replica.values, self._num_workers, self._collective_keys,
communication_hint))
reduced_gv_list = [] mirrored = []
for chunk in chunked_gv: for value in gathered_values:
# By placing all CollectiveReduce ops in a chunk under single name scope, if reduce_op == reduce_util.ReduceOp.MEAN:
# we ensure they will be picked up by the `ScopedAllocator` grappler # Assume each worker has the same number of replicas.
# optimizer and packed into a single all-reduce. num_replicas = len(value) * self._num_workers
with ops.name_scope("allreduce"): for i, v in enumerate(value):
for grad_and_vars in chunk: with ops.device(v.device):
grads = [g for g, _ in grad_and_vars] value[i].values = value[i].values / num_replicas
mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored))
# Add control dependencies per device from the last gradients to the return mirrored
# current set, in order to serialize NCCL launches.
if (communication_hint == CollectiveCommunication.NCCL.value and
reduced_gv_list):
control_input_grads = [g for g, _ in reduced_gv_list[-1]]
else:
control_input_grads = None
collective_reduced = (
cross_device_utils.build_collective_gather_indexed_slices(
grads, self._num_workers, self._collective_keys,
communication_hint, control_input_grads))
result = []
for (_, v), g in zip(grad_and_vars, collective_reduced):
result.append([g, v])
reduced_gv_list.append(result)
new_device_grads = [list(x) for x in zip(*reduced_gv_list)]
return _ungroup_and_make_mirrored(
new_device_grads,
per_replica_values[0],
reduce_op,
num_between_graph_workers=self._num_workers)
def choose_the_best(devices, session_config=None): def choose_the_best(devices, session_config=None):

View File

@ -24,6 +24,7 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import cross_device_utils
@ -463,8 +464,7 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
num_gpus=0, num_gpus=0,
communication=CollectiveCommunication.AUTO, communication=CollectiveCommunication.AUTO,
use_strategy_object=False, use_strategy_object=False,
local_mode=False, local_mode=False):
num_packs=1):
collective_keys = cross_device_utils.CollectiveKeys( collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=10 + CollectiveAllReduceTest.collective_key_base, group_key_start=10 + CollectiveAllReduceTest.collective_key_base,
op_instance_key_start=100 + CollectiveAllReduceTest.collective_key_base, op_instance_key_start=100 + CollectiveAllReduceTest.collective_key_base,
@ -487,7 +487,6 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
1, 1,
num_gpus, num_gpus,
collective_keys=collective_keys, collective_keys=collective_keys,
num_packs=num_packs,
communication=communication) communication=communication)
return collective_all_reduce_ops, devices, "" return collective_all_reduce_ops, devices, ""
else: else:
@ -520,7 +519,6 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
NUM_WORKERS, NUM_WORKERS,
num_gpus, num_gpus,
collective_keys=collective_keys, collective_keys=collective_keys,
num_packs=num_packs,
communication=communication) communication=communication)
return (collective_all_reduce_ops, devices, return (collective_all_reduce_ops, devices,
"grpc://" + self._cluster_spec[task_type][task_id]) "grpc://" + self._cluster_spec[task_type][task_id])
@ -532,15 +530,14 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
communication, communication,
use_strategy_object=False, use_strategy_object=False,
local_mode=False, local_mode=False,
num_packs=1): hints=None):
collective_all_reduce, devices, master_target = self._get_test_objects( collective_all_reduce, devices, master_target = self._get_test_objects(
task_type, task_type,
task_id, task_id,
num_gpus, num_gpus,
communication=communication, communication=communication,
use_strategy_object=use_strategy_object, use_strategy_object=use_strategy_object,
local_mode=local_mode, local_mode=local_mode)
num_packs=num_packs)
if local_mode: if local_mode:
num_workers = 1 num_workers = 1
worker_device = None worker_device = None
@ -553,17 +550,19 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
if use_strategy_object: if use_strategy_object:
with test_object.scope(): with test_object.scope():
return test_object.extended.reduce_to(reduce_op, per_replica, return test_object.extended.reduce_to(reduce_op, per_replica,
destinations) destinations, hints)
else: else:
return test_object.reduce(reduce_op, per_replica, destinations) return test_object.reduce(reduce_op, per_replica, destinations, hints)
def _batch_reduce(test_object, reduce_op, value_destination_pairs): def _batch_reduce(test_object, reduce_op, value_destination_pairs):
if use_strategy_object: if use_strategy_object:
with test_object.scope(): with test_object.scope():
return test_object.extended.batch_reduce_to(reduce_op, return test_object.extended.batch_reduce_to(reduce_op,
value_destination_pairs) value_destination_pairs,
hints)
else: else:
return test_object.batch_reduce(reduce_op, value_destination_pairs) return test_object.batch_reduce(reduce_op, value_destination_pairs,
hints)
with ops.Graph().as_default(), \ with ops.Graph().as_default(), \
ops.device(worker_device), \ ops.device(worker_device), \
@ -724,16 +723,17 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
mode=["graph"], mode=["graph"],
required_gpus=[0, 1, 2], required_gpus=[0, 1, 2],
use_strategy_object=[True, False], use_strategy_object=[True, False],
num_packs=[1, 2])) bytes_per_pack=[0, 1, 4]))
def testReductionDistributed(self, required_gpus, use_strategy_object, def testReductionDistributed(self, required_gpus, use_strategy_object,
num_packs): bytes_per_pack):
hints = collective_util.Hints(bytes_per_pack=bytes_per_pack)
self._run_between_graph_clients( self._run_between_graph_clients(
self._test_reduction, self._test_reduction,
self._cluster_spec, self._cluster_spec,
required_gpus, required_gpus,
communication=CollectiveCommunication.RING, communication=CollectiveCommunication.RING,
use_strategy_object=use_strategy_object, use_strategy_object=use_strategy_object,
num_packs=num_packs) hints=hints)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(

View File

@ -33,6 +33,7 @@ from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nccl_ops from tensorflow.python.ops import nccl_ops
from tensorflow.python.platform import tf_logging as logging
OP_INSTANCE_KEY_START_NUMBER = 100 OP_INSTANCE_KEY_START_NUMBER = 100
@ -896,6 +897,67 @@ def stitch_values(values_and_indices_list):
return result return result
def per_replica_num_elements(per_replica):
"""Returns the static number of elements of one replica.
Args:
per_replica: A PerReplica of Tensor or IndexedSlices.
Returns:
Number of elements. None if some replica has a different or unknown shape.
"""
values = per_replica._values # pylint: disable=protected-access
s0 = values[0].shape
for v in values:
assert not isinstance(v, ops.IndexedSlices)
if v.shape != s0:
return None
return s0.num_elements()
def pack_by_size(per_replica_list, bytes_per_pack):
"""Packs `per_replica_list` into chunks of `bytes_per_pack`.
The method preserves the original order of `per_replica_list`. The packing is
best effort, each pack could have more or less bytes than `bytes_per_pack`.
It only packs values with known shape. Note that, the usage is different from
`cross_device_ops._pack_tensors`, this function is intended to work with the
ScopeAllocator style batching used in `CollectiveAllReduce`.
Args:
per_replica_list: A list of PerReplica.
bytes_per_pack: Bytes per pack.
Returns:
A list of packs of PerReplica. All values are packed into one pack if
`bytes_per_pack` is zero or any of the value has unknown shape.
"""
if bytes_per_pack == 0:
return [per_replica_list]
packs = []
last_pack_size = 0
for value in per_replica_list:
num_elements = per_replica_num_elements(value)
if num_elements is None:
# Can't pack values with unknown shape.
logging.warning(
'not packing values due to the unknown or inconsistent shape of %s',
value)
return [per_replica_list]
size = num_elements * value._primary.dtype.size # pylint: disable=protected-access
# Try to keep each pack as close to bytes_per_pack as possible, while each
# pack is at least bytes_per_pack large. I.E. we err on the side of having
# few but large packs.
if not packs or last_pack_size > bytes_per_pack:
packs.append([])
last_pack_size = 0
packs[-1].append(value)
last_pack_size += size
return packs
def _control_input(inputs, control_inputs, idx): def _control_input(inputs, control_inputs, idx):
"""Returns the `idx`-th item in control_inputs to be used in ops.control_dependencies. """Returns the `idx`-th item in control_inputs to be used in ops.control_dependencies.

View File

@ -26,8 +26,11 @@ from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import values as value_lib from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import input_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -133,8 +136,86 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
self.assertIsInstance(result, ops.IndexedSlices) self.assertIsInstance(result, ops.IndexedSlices)
self._assert_values_equal(t, result) self._assert_values_equal(t, result)
self.assertEqual(device_util.resolve(destination), self.assertEqual(
device_util.resolve(result.device)) device_util.resolve(destination), device_util.resolve(result.device))
class PackBySizeTest(test.TestCase):
def assertShape(self, per_replica, shape):
for v in per_replica._values: # pylint: disable=protected-access
self.assertEqual(v.shape, shape)
def testPreferLargerPack(self):
# Each packs except the last one should be equal or larger than
# bytes_per_pack.
values = [
# size = 2 * 4 * 4 * 4 = 128
array_ops.ones([2, 4, 4], dtype=dtypes.float32),
# size = 8 * 4 = 32
array_ops.ones([8], dtype=dtypes.int32),
# size = 10 * 10 * 8 = 800
array_ops.ones([10, 10], dtype=dtypes.int64),
# size = 1 * 4 = 4
array_ops.ones([1], dtype=dtypes.int32),
]
per_replica_values = [value_lib.PerReplica([v, v]) for v in values]
packs = cross_device_utils.pack_by_size(
per_replica_values, bytes_per_pack=200)
self.assertLen(packs, 2)
self.assertLen(packs[0], 3)
self.assertShape(packs[0][0], [2, 4, 4])
self.assertShape(packs[0][1], [8])
self.assertShape(packs[0][2], [10, 10])
self.assertLen(packs[1], 1)
self.assertShape(packs[1][0], [1])
def testZeroBytesPerPack(self):
values = [
array_ops.ones([1], dtype=dtypes.float32),
array_ops.ones([2], dtype=dtypes.float32),
]
per_replica_values = [value_lib.PerReplica([v, v]) for v in values]
packs = cross_device_utils.pack_by_size(
per_replica_values, bytes_per_pack=0)
self.assertLen(packs, 1)
self.assertLen(packs[0], 2)
self.assertShape(packs[0][0], [1])
self.assertShape(packs[0][1], [2])
def testUnknownShape(self):
per_replica_values = [
value_lib.PerReplica([
array_ops.ones([10, 10], dtype=dtypes.float32),
array_ops.ones([10, 10], dtype=dtypes.float32),
]),
value_lib.PerReplica([
array_ops.ones([10, 10], dtype=dtypes.float32),
input_layer.Input(
shape=(10), batch_size=None, dtype=dtypes.float32),
]),
]
packs = cross_device_utils.pack_by_size(
per_replica_values, bytes_per_pack=1)
self.assertLen(packs, 1)
self.assertEqual(packs[0], per_replica_values)
def testInconsistentShape(self):
per_replica_values = [
value_lib.PerReplica([
array_ops.ones([10, 10], dtype=dtypes.float32),
array_ops.ones([10, 10], dtype=dtypes.float32),
]),
value_lib.PerReplica([
array_ops.ones([10, 10], dtype=dtypes.float32),
input_layer.Input(
shape=(10), batch_size=None, dtype=dtypes.float32),
]),
]
packs = cross_device_utils.pack_by_size(
per_replica_values, bytes_per_pack=1)
self.assertLen(packs, 1)
self.assertEqual(packs[0], per_replica_values)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -108,6 +108,7 @@ import six
from tensorflow.python.autograph.core import ag_ctx as autograph_ctx from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import device_util from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import numpy_dataset
@ -1719,10 +1720,10 @@ class StrategyExtendedV2(object):
def _reduce(self, reduce_op, value): def _reduce(self, reduce_op, value):
# Default implementation until we have an implementation for each strategy. # Default implementation until we have an implementation for each strategy.
return self._local_results( return self._local_results(
self._reduce_to(reduce_op, value, self.reduce_to(reduce_op, value,
device_util.current() or "/device:CPU:0"))[0] device_util.current() or "/device:CPU:0"))[0]
def reduce_to(self, reduce_op, value, destinations): def reduce_to(self, reduce_op, value, destinations, experimental_hints=None):
"""Combine (via e.g. sum or mean) values across replicas. """Combine (via e.g. sum or mean) values across replicas.
Args: Args:
@ -1732,6 +1733,8 @@ class StrategyExtendedV2(object):
string. The return value will be copied to all destination devices (or string. The return value will be copied to all destination devices (or
all the devices where the `destinations` value resides). To perform an all the devices where the `destinations` value resides). To perform an
all-reduction, pass `value` to `destinations`. all-reduction, pass `value` to `destinations`.
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
to perform collective operations.
Returns: Returns:
A tensor or value mirrored to `destinations`. A tensor or value mirrored to `destinations`.
@ -1744,18 +1747,25 @@ class StrategyExtendedV2(object):
reduce_op = reduce_util.ReduceOp(reduce_op.upper()) reduce_op = reduce_util.ReduceOp(reduce_op.upper())
assert (reduce_op == reduce_util.ReduceOp.SUM or assert (reduce_op == reduce_util.ReduceOp.SUM or
reduce_op == reduce_util.ReduceOp.MEAN) reduce_op == reduce_util.ReduceOp.MEAN)
return self._reduce_to(reduce_op, value, destinations) if experimental_hints is None:
experimental_hints = collective_util.Hints()
return self._reduce_to(reduce_op, value, destinations, experimental_hints)
def _reduce_to(self, reduce_op, value, destinations): def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
raise NotImplementedError("must be implemented in descendants") raise NotImplementedError("must be implemented in descendants")
def batch_reduce_to(self, reduce_op, value_destination_pairs): def batch_reduce_to(self,
reduce_op,
value_destination_pairs,
experimental_hints=None):
"""Combine multiple `reduce_to` calls into one for faster execution. """Combine multiple `reduce_to` calls into one for faster execution.
Args: Args:
reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
value_destination_pairs: A sequence of (value, destinations) value_destination_pairs: A sequence of (value, destinations) pairs. See
pairs. See `reduce_to()` for a description. `reduce_to()` for a description.
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
to perform collective operations.
Returns: Returns:
A list of mirrored values, one per pair in `value_destination_pairs`. A list of mirrored values, one per pair in `value_destination_pairs`.
@ -1765,11 +1775,16 @@ class StrategyExtendedV2(object):
assert not isinstance(reduce_op, variable_scope.VariableAggregation) assert not isinstance(reduce_op, variable_scope.VariableAggregation)
if isinstance(reduce_op, six.string_types): if isinstance(reduce_op, six.string_types):
reduce_op = reduce_util.ReduceOp(reduce_op.upper()) reduce_op = reduce_util.ReduceOp(reduce_op.upper())
return self._batch_reduce_to(reduce_op, value_destination_pairs) if experimental_hints is None:
experimental_hints = collective_util.Hints()
return self._batch_reduce_to(reduce_op, value_destination_pairs,
experimental_hints)
def _batch_reduce_to(self, reduce_op, value_destination_pairs): def _batch_reduce_to(self, reduce_op, value_destination_pairs,
experimental_hints):
return [ return [
self.reduce_to(reduce_op, t, destinations=v) self.reduce_to(
reduce_op, t, destinations=v, experimental_hints=experimental_hints)
for t, v in value_destination_pairs for t, v in value_destination_pairs
] ]
@ -2267,7 +2282,7 @@ class ReplicaContext(object):
require_replica_context(self) require_replica_context(self)
return (device_util.current(),) return (device_util.current(),)
def all_reduce(self, reduce_op, value): def all_reduce(self, reduce_op, value, experimental_hints=None):
"""All-reduces the given `value Tensor` nest across replicas. """All-reduces the given `value Tensor` nest across replicas.
If `all_reduce` is called in any replica, it must be called in all replicas. If `all_reduce` is called in any replica, it must be called in all replicas.
@ -2289,16 +2304,21 @@ class ReplicaContext(object):
reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
value: The nested structure of `Tensor`s to all-reduce. The structure must value: The nested structure of `Tensor`s to all-reduce. The structure must
be compatible with `tf.nest`. be compatible with `tf.nest`.
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
to perform collective operations.
Returns: Returns:
A `Tensor` nest with the reduced `value`s from each replica. A `Tensor` nest with the reduced `value`s from each replica.
""" """
if isinstance(reduce_op, six.string_types): if isinstance(reduce_op, six.string_types):
reduce_op = reduce_util.ReduceOp(reduce_op.upper()) reduce_op = reduce_util.ReduceOp(reduce_op.upper())
if experimental_hints is None:
experimental_hints = collective_util.Hints()
def batch_all_reduce(strategy, *value_flat): def batch_all_reduce(strategy, *value_flat):
return strategy.extended.batch_reduce_to( return strategy.extended.batch_reduce_to(
reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat]) reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat],
experimental_hints)
if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]: if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:
# TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
@ -2449,9 +2469,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)): replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
return fn(*args, **kwargs) return fn(*args, **kwargs)
def _reduce_to(self, reduce_op, value, destinations): def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
# TODO(josh11b): Use destinations? # TODO(josh11b): Use destinations?
del reduce_op, destinations del reduce_op, destinations, experimental_hints
return value return value
def _update(self, var, fn, args, kwargs, group): def _update(self, var, fn, args, kwargs, group):

View File

@ -95,8 +95,8 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
def _local_results(self, value): def _local_results(self, value):
return (value,) return (value,)
def _reduce_to(self, reduce_op, value, destinations): def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
del reduce_op, destinations del reduce_op, destinations, experimental_hints
return value return value
def _experimental_make_numpy_dataset(self, numpy_input, session): def _experimental_make_numpy_dataset(self, numpy_input, session):

View File

@ -788,7 +788,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
def _get_cross_device_ops(self): def _get_cross_device_ops(self):
return self._cross_device_ops or self._inferred_cross_device_ops return self._cross_device_ops or self._inferred_cross_device_ops
def _reduce_to(self, reduce_op, value, destinations): def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
if (isinstance(value, values.Mirrored) and if (isinstance(value, values.Mirrored) and
reduce_op == reduce_util.ReduceOp.MEAN): reduce_op == reduce_util.ReduceOp.MEAN):
return value return value
@ -801,11 +801,16 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
return cross_device_ops_lib.reduce_non_distributed_value( return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, value, destinations, self._num_replicas_in_sync) reduce_op, value, destinations, self._num_replicas_in_sync)
return self._get_cross_device_ops().reduce( return self._get_cross_device_ops().reduce(
reduce_op, value, destinations=destinations) reduce_op,
value,
destinations=destinations,
experimental_hints=experimental_hints)
def _batch_reduce_to(self, reduce_op, value_destination_pairs): def _batch_reduce_to(self, reduce_op, value_destination_pairs,
experimental_hints):
return self._get_cross_device_ops().batch_reduce(reduce_op, return self._get_cross_device_ops().batch_reduce(reduce_op,
value_destination_pairs) value_destination_pairs,
experimental_hints)
def _update(self, var, fn, args, kwargs, group): def _update(self, var, fn, args, kwargs, group):
# TODO(josh11b): In eager mode, use one thread per device. # TODO(josh11b): In eager mode, use one thread per device.

View File

@ -356,8 +356,8 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
with ops.device(self._device), _OneDeviceReplicaContext(strategy): with ops.device(self._device), _OneDeviceReplicaContext(strategy):
return fn(*args, **kwargs) return fn(*args, **kwargs)
def _reduce_to(self, reduce_op, value, destinations): def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
del reduce_op, destinations del reduce_op, destinations, experimental_hints
return value return value
def _update(self, var, fn, args, kwargs, group): def _update(self, var, fn, args, kwargs, group):

View File

@ -466,20 +466,25 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
"Cannot reduce to another worker: %r, current worker is %r" % "Cannot reduce to another worker: %r, current worker is %r" %
(d, self._input_workers.worker_devices[0])) (d, self._input_workers.worker_devices[0]))
def _reduce_to(self, reduce_op, value, destinations): def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
self._verify_destinations_not_different_worker(destinations) self._verify_destinations_not_different_worker(destinations)
if not isinstance(value, values.DistributedValues): if not isinstance(value, values.DistributedValues):
# pylint: disable=protected-access # pylint: disable=protected-access
return cross_device_ops_lib.reduce_non_distributed_value( return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, value, destinations, self._num_replicas_in_sync) reduce_op, value, destinations, self._num_replicas_in_sync)
return self._cross_device_ops.reduce( return self._cross_device_ops.reduce(
reduce_op, value, destinations=destinations) reduce_op,
value,
destinations=destinations,
experimental_hints=experimental_hints)
def _batch_reduce_to(self, reduce_op, value_destination_pairs): def _batch_reduce_to(self, reduce_op, value_destination_pairs,
experimental_hints):
for _, destinations in value_destination_pairs: for _, destinations in value_destination_pairs:
self._verify_destinations_not_different_worker(destinations) self._verify_destinations_not_different_worker(destinations)
return self._cross_device_ops.batch_reduce(reduce_op, return self._cross_device_ops.batch_reduce(reduce_op,
value_destination_pairs) value_destination_pairs,
experimental_hints)
def _select_single_value(self, structured): def _select_single_value(self, structured):
"""Select any single value in `structured`.""" """Select any single value in `structured`."""

View File

@ -659,7 +659,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
tpu_values.TPUSyncOnReadVariable, tpu_values.TPUSyncOnReadVariable,
**kwargs) **kwargs)
def _reduce_to(self, reduce_op, value, destinations): def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
if (isinstance(value, values.DistributedValues) or if (isinstance(value, values.DistributedValues) or
tensor_util.is_tensor(value) tensor_util.is_tensor(value)
) and tpu_values.enclosing_tpu_context() is not None: ) and tpu_values.enclosing_tpu_context() is not None:

View File

@ -8,11 +8,11 @@ tf_class {
} }
member_method { member_method {
name: "batch_reduce" name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "batch_reduce_implementation" name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
member_method { member_method {
name: "broadcast" name: "broadcast"
@ -24,10 +24,10 @@ tf_class {
} }
member_method { member_method {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "reduce_implementation" name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
} }

View File

@ -10,11 +10,11 @@ tf_class {
} }
member_method { member_method {
name: "batch_reduce" name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "batch_reduce_implementation" name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
member_method { member_method {
name: "broadcast" name: "broadcast"
@ -26,10 +26,10 @@ tf_class {
} }
member_method { member_method {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "reduce_implementation" name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
} }

View File

@ -10,11 +10,11 @@ tf_class {
} }
member_method { member_method {
name: "batch_reduce" name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "batch_reduce_implementation" name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
member_method { member_method {
name: "broadcast" name: "broadcast"
@ -26,10 +26,10 @@ tf_class {
} }
member_method { member_method {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "reduce_implementation" name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
} }

View File

@ -9,11 +9,11 @@ tf_class {
} }
member_method { member_method {
name: "batch_reduce" name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "batch_reduce_implementation" name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
member_method { member_method {
name: "broadcast" name: "broadcast"
@ -25,10 +25,10 @@ tf_class {
} }
member_method { member_method {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "reduce_implementation" name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
} }

View File

@ -24,7 +24,7 @@ tf_class {
} }
member_method { member_method {
name: "all_reduce" name: "all_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "merge_call" name: "merge_call"

View File

@ -37,7 +37,7 @@ tf_class {
} }
member_method { member_method {
name: "batch_reduce_to" name: "batch_reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "broadcast_to" name: "broadcast_to"
@ -69,7 +69,7 @@ tf_class {
} }
member_method { member_method {
name: "reduce_to" name: "reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "update" name: "update"

View File

@ -0,0 +1,9 @@
path: "tensorflow.distribute.experimental.CollectiveHints"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.collective_util.Hints\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'bytes_per_pack\'], varargs=None, keywords=None, defaults=[\'0\'], "
}
}

View File

@ -8,6 +8,10 @@ tf_module {
name: "CollectiveCommunication" name: "CollectiveCommunication"
mtype: "<class \'enum.EnumMeta\'>" mtype: "<class \'enum.EnumMeta\'>"
} }
member {
name: "CollectiveHints"
mtype: "<type \'type\'>"
}
member { member {
name: "MultiWorkerMirroredStrategy" name: "MultiWorkerMirroredStrategy"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"

View File

@ -8,11 +8,11 @@ tf_class {
} }
member_method { member_method {
name: "batch_reduce" name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "batch_reduce_implementation" name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
member_method { member_method {
name: "broadcast" name: "broadcast"
@ -24,10 +24,10 @@ tf_class {
} }
member_method { member_method {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "reduce_implementation" name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
} }

View File

@ -10,11 +10,11 @@ tf_class {
} }
member_method { member_method {
name: "batch_reduce" name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "batch_reduce_implementation" name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
member_method { member_method {
name: "broadcast" name: "broadcast"
@ -26,10 +26,10 @@ tf_class {
} }
member_method { member_method {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "reduce_implementation" name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
} }

View File

@ -10,11 +10,11 @@ tf_class {
} }
member_method { member_method {
name: "batch_reduce" name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "batch_reduce_implementation" name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
member_method { member_method {
name: "broadcast" name: "broadcast"
@ -26,10 +26,10 @@ tf_class {
} }
member_method { member_method {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "reduce_implementation" name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
} }

View File

@ -9,11 +9,11 @@ tf_class {
} }
member_method { member_method {
name: "batch_reduce" name: "batch_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "batch_reduce_implementation" name: "batch_reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
member_method { member_method {
name: "broadcast" name: "broadcast"
@ -25,10 +25,10 @@ tf_class {
} }
member_method { member_method {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "reduce_implementation" name: "reduce_implementation"
argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
} }
} }

View File

@ -24,7 +24,7 @@ tf_class {
} }
member_method { member_method {
name: "all_reduce" name: "all_reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "merge_call" name: "merge_call"

View File

@ -20,7 +20,7 @@ tf_class {
} }
member_method { member_method {
name: "batch_reduce_to" name: "batch_reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "colocate_vars_with" name: "colocate_vars_with"
@ -32,7 +32,7 @@ tf_class {
} }
member_method { member_method {
name: "reduce_to" name: "reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "update" name: "update"

View File

@ -0,0 +1,9 @@
path: "tensorflow.distribute.experimental.CollectiveHints"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.collective_util.Hints\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'bytes_per_pack\'], varargs=None, keywords=None, defaults=[\'0\'], "
}
}

View File

@ -8,6 +8,10 @@ tf_module {
name: "CollectiveCommunication" name: "CollectiveCommunication"
mtype: "<class \'enum.EnumMeta\'>" mtype: "<class \'enum.EnumMeta\'>"
} }
member {
name: "CollectiveHints"
mtype: "<type \'type\'>"
}
member { member {
name: "MultiWorkerMirroredStrategy" name: "MultiWorkerMirroredStrategy"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"