Refactor collective utils to be of one replica

We used to have build_collective_* that launches collective on all replicas. However each replica is independent and it can simplify the utils if it only launches on one replica.

This is mainly to prepare for switching to Collective V2 ops and reduce the use of merge_call in MirroredStrategy/MultiWorkerMirroredStrategy. For the former, it will be easier after this change. For the latter, we can directly call CollectiveReplicaLauncher.all_reduce from ReplicaContext.all_reduce without the need of merge_call in MS/MWMS.

Note that the original CollecticeKeys is outdated. Now instance key is scoped under each group key, and is no longer global.

PiperOrigin-RevId: 337007588
Change-Id: I111efad0ecbe829260c8e4d2f555c6b2dc1d4c05
This commit is contained in:
Ran Chen 2020-10-13 19:57:19 -07:00 committed by TensorFlower Gardener
parent c451bbccd9
commit 746af76e4f
9 changed files with 321 additions and 476 deletions

View File

@ -96,13 +96,14 @@ py_library(
":values", ":values",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:collective_ops", "//tensorflow/python:collective_ops",
"//tensorflow/python:device", "//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:nccl_ops", "//tensorflow/python:nccl_ops",
"//tensorflow/python:platform", "//tensorflow/python:platform",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
], ],
) )

View File

@ -101,6 +101,9 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
# TODO(anjalisridhar): Update our guides with examples showing how we can use # TODO(anjalisridhar): Update our guides with examples showing how we can use
# the cluster_resolver argument. # the cluster_resolver argument.
# The starting number for collective keys. This should only be set in tests.
_collective_key_base = 0
def __init__( def __init__(
self, self,
communication=cross_device_ops_lib.CollectiveCommunication.AUTO, communication=cross_device_ops_lib.CollectiveCommunication.AUTO,
@ -362,7 +365,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
else: else:
local_devices = (self._worker_device,) local_devices = (self._worker_device,)
self._collective_keys = cross_device_utils.CollectiveKeys() self._collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=1 + CollectiveAllReduceStrategy._collective_key_base) # pylint: disable=protected-access
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=local_devices, devices=local_devices,
group_size=len(local_devices) * self._num_workers, group_size=len(local_devices) * self._num_workers,
@ -428,7 +432,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
group_key = self._collective_keys.get_group_key([device]) group_key = self._collective_keys.get_group_key([device])
group_size = self._num_workers group_size = self._num_workers
collective_instance_key = ( collective_instance_key = (
self._collective_keys.get_variable_instance_key()) self._collective_keys.get_instance_key(group_key, device))
with ops.device(device): with ops.device(device):
initial_value = kwargs["initial_value"] initial_value = kwargs["initial_value"]

View File

@ -29,7 +29,6 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib
from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
@ -59,6 +58,8 @@ from tensorflow.python.platform import test
from tensorflow.python.training.server_lib import ClusterSpec from tensorflow.python.training.server_lib import ClusterSpec
CollectiveAllReduceStrategy = (
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
CollectiveAllReduceExtended = ( CollectiveAllReduceExtended = (
collective_all_reduce_strategy.CollectiveAllReduceExtended) collective_all_reduce_strategy.CollectiveAllReduceExtended)
@ -93,14 +94,10 @@ def create_test_objects(cluster_spec=None,
class CollectiveAllReduceStrategyTestBase( class CollectiveAllReduceStrategyTestBase(
multi_worker_test_base.MultiWorkerTestBase): multi_worker_test_base.MultiWorkerTestBase):
collective_key_base = 0
def setUp(self): def setUp(self):
# We use a different key_base for each test so that collective keys won't be # We use a different key_base for each test so that collective keys won't be
# reused. # reused.
# TODO(yuefengz, ayushd): enable it to reuse collective keys in different CollectiveAllReduceStrategy._collective_key_base += 100000
# tests.
CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
super(CollectiveAllReduceStrategyTestBase, self).setUp() super(CollectiveAllReduceStrategyTestBase, self).setUp()
def _get_test_object(self, task_type, task_id, num_gpus=0): def _get_test_object(self, task_type, task_id, num_gpus=0):
@ -109,18 +106,6 @@ class CollectiveAllReduceStrategyTestBase(
task_type=task_type, task_type=task_type,
task_id=task_id, task_id=task_id,
num_gpus=num_gpus) num_gpus=num_gpus)
collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=10 +
CollectiveAllReduceStrategyTestBase.collective_key_base,
op_instance_key_start=100 +
CollectiveAllReduceStrategyTestBase.collective_key_base,
variable_instance_key_start=10000 +
CollectiveAllReduceStrategyTestBase.collective_key_base)
strategy.extended._collective_keys = collective_keys
strategy.extended._cross_device_ops._collective_keys = collective_keys
strategy.extended._host_cross_device_ops._collective_keys = collective_keys
return strategy, target, session_config return strategy, target, session_config
def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):

View File

@ -36,7 +36,7 @@ from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute import values as value_lib from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import executor from tensorflow.python.eager import executor as executor_lib
from tensorflow.python.framework import kernels from tensorflow.python.framework import kernels
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
@ -144,8 +144,8 @@ def _normalize_value_destination_pairs(value_destination_pairs):
def _validate_value_destination_pairs(value_destination_pairs): def _validate_value_destination_pairs(value_destination_pairs):
"""Validates value_destination_pairs are valid."""
# TODO(yuefengz): raise exceptions instead of returning False. # TODO(yuefengz): raise exceptions instead of returning False.
# pylint: disable=g-missing-docstring
if not value_destination_pairs: return False if not value_destination_pairs: return False
if not isinstance(value_destination_pairs, (list, tuple)): return False if not isinstance(value_destination_pairs, (list, tuple)): return False
if not all(isinstance(pair, tuple) for pair in value_destination_pairs): if not all(isinstance(pair, tuple) for pair in value_destination_pairs):
@ -197,7 +197,7 @@ def simple_broadcast(value, destinations, always_mirrored=False):
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
reduce_op): reduce_op):
# pylint: disable=g-missing-docstring """Reduces the value by accumulation_fn and reduce_op."""
all_values = per_replica_value.values all_values = per_replica_value.values
if not all_values: if not all_values:
raise ValueError("`per_replica_value` must be non-empty") raise ValueError("`per_replica_value` must be non-empty")
@ -1019,7 +1019,6 @@ class CollectiveAllReduce(CrossDeviceOps):
if group_size % len(devices) > 0: if group_size % len(devices) > 0:
raise ValueError("group_size must be divisible by the number of devices.") raise ValueError("group_size must be divisible by the number of devices.")
self._devices = tuple(device_util.canonicalize(d) for d in devices)
self._group_size = group_size self._group_size = group_size
self._collective_keys = (collective_keys or self._collective_keys = (collective_keys or
cross_device_utils.CollectiveKeys()) cross_device_utils.CollectiveKeys())
@ -1039,14 +1038,21 @@ class CollectiveAllReduce(CrossDeviceOps):
# This deadlocks since neither collective is able to finish. # This deadlocks since neither collective is able to finish.
self._lock = threading.Lock() self._lock = threading.Lock()
self._devices = tuple(device_util.canonicalize(d) for d in devices)
group_key = self._collective_keys.get_group_key(self._devices)
# Collective ops requires all devices to participate and is blocking. In # Collective ops requires all devices to participate and is blocking. In
# eager, we need one async executor for each device to be able to launch # eager, we need one async executor for each device to be able to launch
# them altogether. Note that async doesn't imply concurrency. Within an # them altogether. Note that async doesn't imply concurrency. Within an
# async executor operations are still executed sequentially. In graph or # async executor operations are still executed sequentially. In graph or
# function building, the executors are not used. # function building, the executors are not used.
self._executors = [] self._executors = []
for _ in range(len(devices)): self._launchers = []
self._executors.append(executor.new_executor(enable_async=True)) for device in self._devices:
executor = executor_lib.new_executor(enable_async=True)
self._executors.append(executor)
launcher = cross_device_utils.CollectiveReplicaLauncher(
group_key, group_size, self._collective_keys, device, executor)
self._launchers.append(launcher)
super(CollectiveAllReduce, self).__init__() super(CollectiveAllReduce, self).__init__()
@ -1148,63 +1154,39 @@ class CollectiveAllReduce(CrossDeviceOps):
# queuing time due to concurrent intense computation. # queuing time due to concurrent intense computation.
# #
# TODO(b/147393503): explore solutions for optimal ordering. # TODO(b/147393503): explore solutions for optimal ordering.
packs = cross_device_utils.pack_by_size( values_by_device = [[] for _ in range(len(self._devices))]
list(reversed(per_replica_values)), experimental_hints.bytes_per_pack) for per_replica in reversed(per_replica_values):
for i in range(len(self._devices)):
values_by_device[i].append(per_replica.values[i])
if batch_size > 1: outputs_by_device = []
logging.info(
"Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
"group_size = %d, communication_hint = %s, num_packs = %d",
batch_size, len(self._devices), self._group_size, communication,
len(packs))
else:
logging.log_first_n(
logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
"num_devices = %d, group_size = %d, communication_hint = %s, "
"num_packs = %d" % (batch_size, len(
self._devices), self._group_size, communication, len(packs)), 10)
reduced_values = []
with self._lock: with self._lock:
for pack in packs: for i in range(len(self._devices)):
# By placing all CollectiveReduce ops in a pack under single name scope, packs = cross_device_utils.group_by_size(
# we ensure they will be picked up by the `ScopedAllocator` grappler values_by_device[i], experimental_hints.bytes_per_pack)
# optimizer and packed into a single all-reduce. if not context.executing_eagerly() and i == 0:
with ops.name_scope("allreduce"): logging.info(
for per_replica in pack: "Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
# Add control dependencies per device from the last gradients to the "group_size = %d, communication_hint = %s, num_packs = %d",
# current set, in order to serialize NCCL launches. batch_size, len(self._launchers), self._group_size, communication,
if (communication == CollectiveCommunication.NCCL.value and len(packs))
reduced_values): outputs_by_device.append(self._launchers[i].batch_all_reduce(
control_inputs = list(reduced_values[-1]) packs, communication, experimental_hints.timeout_seconds))
else:
control_inputs = None
reduced_values.append(
cross_device_utils.build_collective_reduce(
per_replica.values,
self._devices,
self._group_size,
self._collective_keys,
"Add",
"Id",
communication,
control_inputs,
executors=self._executors,
timeout=experimental_hints.timeout_seconds))
for e in self._executors: for e in self._executors:
e.wait() e.wait()
mirrored = [] mirrored = []
# Reverse the order of reduced value to recover the order in the input. for values in zip(*outputs_by_device):
for value in reversed(reduced_values):
if reduce_op == reduce_util.ReduceOp.MEAN: if reduce_op == reduce_util.ReduceOp.MEAN:
for i, v in enumerate(value): values = list(values)
for i, v in enumerate(values):
with ops.device(v.device): with ops.device(v.device):
value[i] = v / self._group_size values[i] = v / self._group_size
mirrored.append( mirrored.append(
distribute_utils.regroup(value, wrap_class=value_lib.Mirrored)) distribute_utils.regroup(values, wrap_class=value_lib.Mirrored))
return mirrored # Reverse the order of reduced value to recover the order in the input.
return list(reversed(mirrored))
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values, def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values,
experimental_hints): experimental_hints):
@ -1217,24 +1199,16 @@ class CollectiveAllReduce(CrossDeviceOps):
# Pass self._communication to the runtime as a communication hint. # Pass self._communication to the runtime as a communication hint.
communication_hint = self._communication.value communication_hint = self._communication.value
# For now, we use NCCL only when batch_size > 1.
# TODO(b/132575814): switch to NCCL for all collectives when communication
# is NCCL.
if self._communication == CollectiveCommunication.NCCL and len(
per_replica_values) == 1:
communication_hint = CollectiveCommunication.AUTO.value
gathered_values = [] gathered_values = []
with self._lock, ops.name_scope("allreduce"): with self._lock:
for per_replica in per_replica_values: for per_replica in per_replica_values:
gathered_values.append( outputs = []
cross_device_utils.build_collective_gather_indexed_slices( for i in range(len(self._devices)):
per_replica.values, outputs.append(self._launchers[i].all_reduce_indexed_slices(
self._devices, per_replica.values[i], communication_hint,
self._group_size, experimental_hints.timeout_seconds))
self._collective_keys, gathered_values.append(outputs)
communication_hint,
timeout=experimental_hints.timeout_seconds))
mirrored = [] mirrored = []
for value in gathered_values: for value in gathered_values:
@ -1281,11 +1255,6 @@ class CollectiveAllReduce(CrossDeviceOps):
batch_size = len(per_replica_values) batch_size = len(per_replica_values)
# Pass self._communication to the runtime as a communication hint. # Pass self._communication to the runtime as a communication hint.
communication = self._communication.value communication = self._communication.value
# For now, we use NCCL only when batch_size > 1.
# TODO(b/132575814): switch to NCCL for all collectives when communication
# is NCCL.
if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
communication = CollectiveCommunication.AUTO.value
logging.log_first_n( logging.log_first_n(
logging.INFO, "Collective batch_all_gather: %d all-gathers, " logging.INFO, "Collective batch_all_gather: %d all-gathers, "
@ -1296,21 +1265,12 @@ class CollectiveAllReduce(CrossDeviceOps):
gathered_values = [] gathered_values = []
with self._lock, ops.name_scope("allgather"): with self._lock, ops.name_scope("allgather"):
for per_replica in per_replica_values: for per_replica in per_replica_values:
if (communication == CollectiveCommunication.NCCL.value and outputs = []
gathered_values): for i in range(len(self._devices)):
control_inputs = list(gathered_values[-1]) outputs.append(self._launchers[i].all_gather(
else: per_replica.values[i], axis, communication,
control_inputs = None experimental_hints.timeout_seconds))
gathered_values.append( gathered_values.append(outputs)
cross_device_utils.build_collective_gather(
per_replica.values,
self._devices,
self._group_size,
self._collective_keys,
axis,
communication,
control_inputs,
timeout=experimental_hints.timeout_seconds))
return gathered_values return gathered_values
if context.executing_eagerly(): if context.executing_eagerly():
@ -1319,8 +1279,7 @@ class CollectiveAllReduce(CrossDeviceOps):
gathered_values = compute_gathered_values() gathered_values = compute_gathered_values()
mirrored = [] mirrored = []
# Reverse the order of gathered value to recover the order in the input. for value in gathered_values:
for value in reversed(gathered_values):
mirrored.append( mirrored.append(
distribute_utils.regroup(value, wrap_class=value_lib.Mirrored)) distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
return mirrored return mirrored

View File

@ -24,7 +24,6 @@ import threading
from tensorflow.python.distribute import values as value_lib from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -34,7 +33,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nccl_ops from tensorflow.python.ops import nccl_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
OP_INSTANCE_KEY_START_NUMBER = 100 INSTANCE_KEY_START_NUMBER = 100
def aggregate_gradients_using_nccl(replica_grads): def aggregate_gradients_using_nccl(replica_grads):
@ -181,69 +180,66 @@ class CollectiveKeys(object):
*Instance key*: an integer key to identify the set of same counterpart of *Instance key*: an integer key to identify the set of same counterpart of
tensors on different devices in a device group that need to be all-reduced. tensors on different devices in a device group that need to be all-reduced.
"Graph key": an integer key that is unique key graph. This is used to support
multiple graphs per client session. It must be non-zero and set in the
`config` argument of each call to `session.run`.
This class is thread safe. This class is thread safe.
""" """
def __init__(self, def __init__(self, group_key_start=1):
group_key_start=1,
op_instance_key_start=OP_INSTANCE_KEY_START_NUMBER,
variable_instance_key_start=1000000):
"""Initializes the object. """Initializes the object.
Args: Args:
group_key_start: the starting integer of group key. group_key_start: the starting integer of group key.
op_instance_key_start: the starting integer of instance key for ops.
variable_instance_key_start: the starting integer of instance key for
variables.
""" """
self._group_key = group_key_start self._group_key = group_key_start
self._group_key_table = {} self._group_key_table = {}
self._instance_key_table = {}
assert op_instance_key_start != variable_instance_key_start
self._op_instance_key = op_instance_key_start
self._variable_instance_key = variable_instance_key_start
self._lock = threading.Lock() self._lock = threading.Lock()
def get_group_key(self, devices): def get_group_key(self, devices):
"""Returns a group key for the set of devices. """Returns a group key for the set of devices.
Args: Args:
devices: list of strings naming devices in a collective group. devices: a list of canonical device strings in a collective group.
Returns: Returns:
int key uniquely identifying the set of device names. int key uniquely identifying the set of device names.
""" """
parsed = [pydev.DeviceSpec.from_string(d) for d in devices] key_id = hash(tuple(sorted(devices)))
# In the between-graph replicated training, different workers need to get
# the same device key. So we remove the task_type and task_id from the
# devices.
# TODO(yuefengz): in the in-graph replicated training, we need to include
# task_type and task_id.
names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed])
key_id = ','.join(names)
with self._lock: with self._lock:
if key_id not in self._group_key_table: if key_id not in self._group_key_table:
new_key = self._group_key new_key = self._group_key
self._group_key += 1 self._group_key += 1
self._group_key_table[key_id] = new_key self._group_key_table[key_id] = new_key
self._instance_key_table[new_key] = {}
for device in devices:
self._instance_key_table[new_key][device] = INSTANCE_KEY_START_NUMBER
return self._group_key_table[key_id] return self._group_key_table[key_id]
def get_op_instance_key(self): def get_instance_key(self, group_key, device):
"""Returns a new instance key for use in defining a collective op.""" """Returns a new instance key for use in defining a collective op.
with self._lock:
v = self._op_instance_key
self._op_instance_key += 1
return v
def get_variable_instance_key(self): You should call this once per each collective op of a collective instance.
"""Returns a new instance key for use in creating a Variable."""
Args:
group_key: the group key returned by get_group_key(). You should not
assign the group key yourself.
device: a canonical device string. It should be the device this collective
op is on.
Returns:
a new instance key.
Raises:
ValueError: when the group key is invalid or the device is not in the
group.
"""
with self._lock: with self._lock:
v = self._variable_instance_key group = self._instance_key_table.get(group_key, None)
self._variable_instance_key += 1 if group is None:
raise ValueError('group {} not found'.format(group_key))
if device not in group:
raise ValueError('{} not in group {}'.format(device, group_key))
v = group[device]
group[device] += 1
return v return v
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
@ -252,135 +248,146 @@ class CollectiveKeys(object):
copied = CollectiveKeys() copied = CollectiveKeys()
copied._group_key = self._group_key copied._group_key = self._group_key
copied._group_key_table = copy.deepcopy(self._group_key_table, memo) copied._group_key_table = copy.deepcopy(self._group_key_table, memo)
copied._op_instance_key = self._op_instance_key copied._instance_key_table = copy.deepcopy(self._instance_key_table, memo)
copied._variable_instance_key = self._variable_instance_key
return copied return copied
def build_collective_reduce(input_tensors, class CollectiveReplicaLauncher(object):
devices, """Launch collectives on one replica."""
group_size,
collective_keys,
reduction_op='Add',
unary_op='Id',
communication_hint='AUTO',
control_inputs=None,
executors=None,
timeout=None):
"""Build a subgraph that does one full all-reduce, using the collective Op.
If called in eager mode, it's required to supply a list of async executors for def __init__(self,
each input Tensor. group_key,
group_size,
collective_keys,
device,
executor=None):
if executor and not executor.is_async():
raise ValueError('executor must be async')
self._group_key = group_key
self._group_size = group_size
self._collective_keys = collective_keys
self._device = device
self._executor = executor
Args: def _executor_scope(self):
input_tensors: tensors within a single worker graph that are to be reduced if context.executing_eagerly() and not self._executor:
together; must be one per device. raise ValueError('collectives requires a async executor in eager mode')
devices: a list of device strings to run the collective on.
group_size: total number of devices globally that will be doing this same
reduction. The reduction will actually include the corresponding tensors
at all these workers.
collective_keys: a CollectiveKeys object.
reduction_op: string naming the reduction op.
unary_op: string naming the unary final op.
communication_hint: string providing hint to runtime for choosing collective
implementation.
control_inputs: if not None, add control edges between control_inputs and
(index-wise) corresponding collective_reduce tensors
executors: a list of async executor. Required for eager execution.
timeout: a float or None. The timeout in seconds.
Returns:
An array of final tensors, one per device, computed by the full reduction.
Raises:
ValueError: There must be at least two tensors over all the workers.
"""
if context.executing_eagerly():
if (not executors or len(executors) != len(input_tensors) or
not all(e.is_async() for e in executors)):
raise ValueError(
'collectives requires async executors for each device in eager mode')
if len(input_tensors) != len(devices):
raise ValueError('collective requires one input tensor for each device, '
'len(input_tensors) = %d, len(devices) = %d' %
(len(input_tensors), len(devices)))
if group_size < 2:
return input_tensors
group_key = collective_keys.get_group_key(devices)
instance_key = collective_keys.get_op_instance_key()
subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec
out_tensors = []
for idx, input_tensor in enumerate(input_tensors):
if context.executing_eagerly(): if context.executing_eagerly():
executor_scope = context.executor_scope(executors[idx]) return context.executor_scope(self._executor)
else: return ops.NullContextmanager()
executor_scope = ops.NullContextmanager()
with executor_scope, \ def _control_input(self, control_input):
ops.device(devices[idx]), \ if control_input is not None:
ops.control_dependencies( return ops.control_dependencies([control_input])
_control_input(devices, control_inputs, idx)): return ops.NullContextmanager()
out_tensor = collective_ops.all_reduce(
def all_reduce(self,
input_tensor,
control_input=None,
communication_hint='AUTO',
timeout=0):
"""All-reduce a dense tensor.
This can be called in eager mode if a async executor is supplied when
creating the launcher.
Args:
input_tensor: a dense tensor. It must have the same shape on all replicas.
control_input: if not None, add control edges between control_input and
the all-reduce.
communication_hint: string providing hint to runtime for choosing
collective implementation.
timeout: a float. The timeout in seconds.
Returns:
The reduced tensor.
"""
instance_key = self._collective_keys.get_instance_key(
self._group_key, self._device)
with self._executor_scope(), \
ops.device(self._device), \
self._control_input(control_input):
return collective_ops.all_reduce(
input_tensor, input_tensor,
group_size, self._group_size,
group_key, self._group_key,
instance_key, instance_key,
reduction_op, communication_hint=communication_hint,
unary_op,
subdiv_offsets,
communication_hint,
timeout=timeout) timeout=timeout)
out_tensors.append(out_tensor)
return out_tensors
def batch_all_reduce(self,
input_tensor_packs,
communication_hint='AUTO',
timeout=0):
"""Batch all-reduce dense tensors.
def build_collective_gather(input_tensors, This takes a list of batches of tensors. Using multiple batches have the
devices, benefit that it doesn't need to wait for all inputs to be ready to start the
group_size, all-reduce.
collective_keys,
axis,
communication_hint='AUTO',
control_inputs=None,
timeout=None):
"""Build a subgraph that does one full all-gather, using the collective Op.
This method must be called in graph mode or inside a tf.function. This can be called in eager mode if a async executor is supplied when
creating the launcher.
Args: Args:
input_tensors: tensors within a single worker graph that are to be gathered input_tensor_packs: a list of lists of dense tensors.
together; must be one per device. Input tensors cannot have rank 0. communication_hint: string providing hint to runtime for choosing
devices: a list of device strings to run the collective on. collective implementation.
group_size: total number of devices globally that will be doing this same timeout: a float. The timeout in seconds.
gathering. The gathering will actually include the corresponding tensors
at all these workers.
collective_keys: a CollectiveKeys object.
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
range [0, rank(value)).
communication_hint: string providing hint to runtime for choosing collective
implementation. Available options are `AUTO`, `NCCL`, and `RING`.
control_inputs: if not None, add control edges between control_inputs and
(index-wise) corresponding collective_gather tensors
timeout: a float or None. The timeout in seconds.
Returns: Returns:
An array of final tensors, one per device, computed by the full gather. A flat list of reduced tensors.
""" """
if len(input_tensors) != len(devices): outputs = []
raise ValueError( for pack in input_tensor_packs:
'collective requires one input tensor for each device, %d != %d' % # By placing all CollectiveReduce ops in a batch under single name scope,
(len(input_tensors), len(devices))) # we ensure they will be picked up by the `ScopedAllocator` grappler
# optimizer and packed into a single all-reduce.
with ops.name_scope('allreduce'):
# TODO(b/169168846): inserts a parallel all_gather to verify packings
# are the same on each replica.
for input_tensor in pack:
if communication_hint == 'NCCL' and outputs:
control_input = outputs[-1]
else:
control_input = None
outputs.append(
self.all_reduce(input_tensor, control_input, communication_hint,
timeout))
return outputs
if group_size < 2: def all_gather(self,
return input_tensors input_tensor,
group_key = collective_keys.get_group_key(devices) axis,
instance_key_tensor = collective_keys.get_op_instance_key() communication_hint='AUTO',
instance_key_shape = collective_keys.get_op_instance_key() timeout=0):
"""All-gather a dense tensor.
out_tensors = [] This method must be called inside a tf.function.
for idx, input_tensor in enumerate(input_tensors):
with ops.device(devices[idx]), ops.control_dependencies( Args:
_control_input(devices, control_inputs, idx)): input_tensor: a dense tensor. It must have the same rank on all replicas,
and dimensions other than `axis` need to be the same as well.
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
range [0, rank(value)).
communication_hint: string providing hint to runtime for choosing
collective implementation. Available options are `AUTO`, `NCCL`, and
`RING`.
timeout: a float. The timeout in seconds.
Returns:
The gathered Tensor.
Raises:
RuntimeError: if called in eager mode.
"""
if context.executing_eagerly():
raise RuntimeError('all_gather in eager mode is not supported')
instance_key_tensor = self._collective_keys.get_instance_key(
self._group_key, self._device)
instance_key_shape = self._collective_keys.get_instance_key(
self._group_key, self._device)
with ops.device(self._device):
# 1. Transpose # 1. Transpose
# E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
# we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
@ -394,8 +401,8 @@ def build_collective_gather(input_tensors,
# 2. Pad # 2. Pad
gathered_shape = collective_ops.all_gather( gathered_shape = collective_ops.all_gather(
array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0), array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0),
group_size, self._group_size,
group_key, self._group_key,
instance_key_shape, instance_key_shape,
communication_hint, communication_hint,
timeout=timeout) timeout=timeout)
@ -406,8 +413,8 @@ def build_collective_gather(input_tensors,
# 3. Gather # 3. Gather
gather_padded_out_tensor = collective_ops.all_gather( gather_padded_out_tensor = collective_ops.all_gather(
padded_input_tensor, padded_input_tensor,
group_size, self._group_size,
group_key, self._group_key,
instance_key_tensor, instance_key_tensor,
communication_hint, communication_hint,
timeout=timeout) timeout=timeout)
@ -424,93 +431,56 @@ def build_collective_gather(input_tensors,
(math_ops.range(1, axis + 1), [0], (math_ops.range(1, axis + 1), [0],
math_ops.range(axis + 1, array_ops.rank(input_tensor_t))), math_ops.range(axis + 1, array_ops.rank(input_tensor_t))),
axis=0) axis=0)
out_tensor = array_ops.transpose(out_tensor_t, perm=perm_after) return array_ops.transpose(out_tensor_t, perm=perm_after)
out_tensors.append(out_tensor)
return out_tensors
def all_reduce_indexed_slices(self,
input_slices,
communication_hint='AUTO',
timeout=0):
"""All-reduce an IndexedSlices.
def _pad_util(input_tensor, full_axis_dim): This method must be called inside a tf.function.
"""Pad the `input_tensor`'s first dimension to be `full_axis_dim`."""
missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0]
tensor_rank = array_ops.rank(input_tensor)
paddings_axis = [[0, missing_axis_dim]]
paddings = array_ops.concat([
paddings_axis,
array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32)
],
axis=0)
padded_input_tensor = array_ops.pad(input_tensor, paddings)
return padded_input_tensor
Args:
input_slices: an IndexedSlices.
communication_hint: string providing hint to runtime for choosing
collective implementation.
timeout: a float. The timeout in seconds.
def build_collective_gather_indexed_slices(input_slices_list, Returns:
devices, The reduced IndexedSlices.
group_size,
collective_keys,
communication_hint='AUTO',
control_inputs=None,
timeout=None):
"""Build a subgraph that all-gathers IndexedSlices using the collective Op.
This method must be called in graph mode or inside a tf.function. Raises:
RuntimeError: if called in eager mode.
"""
if context.executing_eagerly():
raise RuntimeError(
'all_reduce_indexed_slices in eager mode is not supported')
Args: gather_length_key = self._collective_keys.get_instance_key(
input_slices_list: a list of IndexedSlices within a single worker graph that self._group_key, self._device)
are to be gathered together; must be one per device. gather_indices_key = self._collective_keys.get_instance_key(
devices: a list of device strings to run the collective on. self._group_key, self._device)
group_size: total number of devices globally that will be doing this same gather_values_key = self._collective_keys.get_instance_key(
gathering. The gathering will actually include the corresponding tensors self._group_key, self._device)
at all these workers. reduce_densified_key = self._collective_keys.get_instance_key(
collective_keys: a CollectiveKeys object. self._group_key, self._device)
communication_hint: string providing hint to runtime for choosing collective
implementation.
control_inputs: if not None, add control edges between control_inputs and
(index-wise) corresponding collective_reduce tensors
timeout: a float or None. The timeout in seconds.
Returns: # Current CollectiveAllGather implementations require input IndexedSlices to
An array of final IndexedSlices, one per device, computed by the full # have consistent length across the board, we handle the reduction of
gather. # IndexedSlices as follows:
# 1. Gather the lengths of IndexedSlices from all participants.
Raises: # 2. If they have consistent length, apply all_gather.
ValueError: if control_inputs is not None and doesn't match the length and # 3. Otherwise convert IndexedSlices to dense tensors and apply
devices of inputs. # all_reduce.
""" with ops.device(self._device):
assert not context.executing_eagerly(), (
'build_collective_gather_indexed_slices can only be called in graph mode'
' or inside tf.function')
if len(input_slices_list) != len(devices):
raise ValueError(
'collective requires one input IndexedSlice for each device, %d != %d' %
(len(input_slices_list), len(devices)))
if group_size < 2:
return input_slices_list
group_key = collective_keys.get_group_key(devices)
gather_length_key = collective_keys.get_op_instance_key()
gather_indices_key = collective_keys.get_op_instance_key()
gather_values_key = collective_keys.get_op_instance_key()
reduce_densified_key = collective_keys.get_op_instance_key()
# Current CollectiveAllGather implementations require input IndexedSlices to
# have consistent length across the board, we handle the reduction of
# IndexedSlices as follows:
# 1. Gather the lengths of IndexedSlices from all participants.
# 2. If they have consistent length, apply all_gather.
# 3. Otherwise convert IndexedSlices to dense tensors and apply
# all_reduce.
out_slices_list = []
for idx, input_slices in enumerate(input_slices_list):
# pylint: disable = cell-var-from-loop
with ops.device(devices[idx]):
def all_gather(): def all_gather():
"""Use all_gather to aggregate `IndexedSlices`.""" """Use all_gather to aggregate `IndexedSlices`."""
all_values = collective_ops.all_gather( all_values = collective_ops.all_gather(
input_slices.values, input_slices.values,
group_size, self._group_size,
group_key, self._group_key,
gather_values_key, gather_values_key,
communication_hint, communication_hint,
timeout=timeout) timeout=timeout)
@ -519,8 +489,8 @@ def build_collective_gather_indexed_slices(input_slices_list,
with ops.control_dependencies(control): with ops.control_dependencies(control):
all_indices = collective_ops.all_gather( all_indices = collective_ops.all_gather(
input_slices.indices, input_slices.indices,
group_size, self._group_size,
group_key, self._group_key,
gather_indices_key, gather_indices_key,
communication_hint, communication_hint,
timeout=timeout) timeout=timeout)
@ -534,8 +504,8 @@ def build_collective_gather_indexed_slices(input_slices_list,
densified = ops.convert_to_tensor(input_slices) densified = ops.convert_to_tensor(input_slices)
reduced = collective_ops.all_reduce( reduced = collective_ops.all_reduce(
densified, densified,
group_size, self._group_size,
group_key, self._group_key,
reduce_densified_key, reduce_densified_key,
'Add', 'Add',
'Id', [0], 'Id', [0],
@ -550,23 +520,18 @@ def build_collective_gather_indexed_slices(input_slices_list,
dense_shape=input_slices.dense_shape) dense_shape=input_slices.dense_shape)
length = array_ops.shape(input_slices.indices) length = array_ops.shape(input_slices.indices)
with ops.control_dependencies( all_lengths = collective_ops.all_gather(
_control_input(input_slices, control_inputs, idx)): length,
all_lengths = collective_ops.all_gather( self._group_size,
length, self._group_key,
group_size, gather_length_key,
group_key, communication_hint,
gather_length_key, timeout=timeout)
communication_hint, return control_flow_ops.cond(
timeout=timeout)
out_slices = control_flow_ops.cond(
math_ops.equal( math_ops.equal(
math_ops.reduce_max(all_lengths), math_ops.reduce_max(all_lengths),
math_ops.reduce_min(all_lengths)), all_gather, math_ops.reduce_min(all_lengths)), all_gather,
densify_and_all_reduce) densify_and_all_reduce)
out_slices_list.append(out_slices)
# pylint: enable=cell-var-from-loop
return out_slices_list
def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
@ -653,56 +618,35 @@ def stitch_values(values_and_indices_list):
return result return result
def per_replica_num_elements(per_replica): def group_by_size(input_tensors, bytes_per_pack):
"""Returns the static number of elements of one replica. """Groups `input_tensors` into chunks of `bytes_per_pack`.
Args: The method preserves the original order of `input_tensors`. The grouping is
per_replica: A PerReplica of Tensor or IndexedSlices.
Returns:
Number of elements. None if some replica has a different or unknown shape.
"""
values = per_replica._values # pylint: disable=protected-access
s0 = values[0].shape
for v in values:
assert not isinstance(v, ops.IndexedSlices)
if v.shape != s0:
return None
return s0.num_elements()
def pack_by_size(per_replica_list, bytes_per_pack):
"""Packs `per_replica_list` into chunks of `bytes_per_pack`.
The method preserves the original order of `per_replica_list`. The packing is
best effort, each pack could have more or less bytes than `bytes_per_pack`. best effort, each pack could have more or less bytes than `bytes_per_pack`.
It only packs values with known shape. Note that, the usage is different from It only groups values with known shape.
`cross_device_ops._pack_tensors`, this function is intended to work with the
ScopeAllocator style batching used in `CollectiveAllReduce`.
Args: Args:
per_replica_list: A list of PerReplica. input_tensors: a list of Tensor.
bytes_per_pack: Bytes per pack. bytes_per_pack: an integer.
Returns: Returns:
A list of packs of PerReplica. All values are packed into one pack if A list of packs of Tensor. All values are grouped into one pack if
`bytes_per_pack` is zero or any of the value has unknown shape. `bytes_per_pack` is zero or any of the value has unknown shape.
""" """
if bytes_per_pack == 0: if bytes_per_pack == 0:
return [per_replica_list] return [input_tensors]
packs = [] packs = []
last_pack_size = 0 last_pack_size = 0
for value in per_replica_list: for value in input_tensors:
num_elements = per_replica_num_elements(value) num_elements = value.shape.num_elements()
if num_elements is None: if num_elements is None:
# Can't pack values with unknown shape. # Can't pack values with unknown shape.
logging.warning( logging.warning(
'not packing values due to the unknown or inconsistent shape of %s', 'not packing values due to the unknown or inconsistent shape of %s',
value) value)
return [per_replica_list] return [input_tensors]
size = num_elements * value._primary.dtype.size # pylint: disable=protected-access size = num_elements * value.dtype.size
# Try to keep each pack as close to bytes_per_pack as possible, while each # Try to keep each pack as close to bytes_per_pack as possible, while each
# pack is at least bytes_per_pack large. I.E. we err on the side of having # pack is at least bytes_per_pack large. I.E. we err on the side of having
# few but large packs. # few but large packs.
@ -714,24 +658,15 @@ def pack_by_size(per_replica_list, bytes_per_pack):
return packs return packs
def _control_input(devices, control_inputs, idx): def _pad_util(input_tensor, full_axis_dim):
"""Returns the `idx`-th item in control_inputs to be used in ops.control_dependencies. """Pad the `input_tensor`'s first dimension to be `full_axis_dim`."""
missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0]
This is a helper function for building collective ops. tensor_rank = array_ops.rank(input_tensor)
paddings_axis = [[0, missing_axis_dim]]
Args: paddings = array_ops.concat([
devices: a list of device strings the collective run on. paddings_axis,
control_inputs: a list or None. array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32)
idx: the index into `inputs` and `control_inputs`. ],
axis=0)
Returns: padded_input_tensor = array_ops.pad(input_tensor, paddings)
A one item list of the `idx`-th element of `control_inputs`, or an empty return padded_input_tensor
list if `control_inputs` is None.
"""
if control_inputs is None:
return []
if len(control_inputs) != len(devices):
raise ValueError(
'control_inputs must match the length of the devices, %s != %s' %
(len(control_inputs), len(devices)))
return [control_inputs[idx]]

View File

@ -23,7 +23,6 @@ from absl.testing import parameterized
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -114,11 +113,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
device_util.resolve(destination), device_util.resolve(result.device)) device_util.resolve(destination), device_util.resolve(result.device))
class PackBySizeTest(test.TestCase): class GroupBySizeTest(test.TestCase):
def assertShape(self, per_replica, shape):
for v in per_replica._values: # pylint: disable=protected-access
self.assertEqual(v.shape, shape)
def testPreferLargerPack(self): def testPreferLargerPack(self):
# Each packs except the last one should be equal or larger than # Each packs except the last one should be equal or larger than
@ -133,49 +128,38 @@ class PackBySizeTest(test.TestCase):
# size = 1 * 4 = 4 # size = 1 * 4 = 4
array_ops.ones([1], dtype=dtypes.int32), array_ops.ones([1], dtype=dtypes.int32),
] ]
per_replica_values = [value_lib.PerReplica([v, v]) for v in values] packs = cross_device_utils.group_by_size(values, bytes_per_pack=200)
packs = cross_device_utils.pack_by_size(
per_replica_values, bytes_per_pack=200)
self.assertLen(packs, 2) self.assertLen(packs, 2)
self.assertLen(packs[0], 3) self.assertLen(packs[0], 3)
self.assertShape(packs[0][0], [2, 4, 4]) self.assertEqual(packs[0][0].shape, [2, 4, 4])
self.assertShape(packs[0][1], [8]) self.assertEqual(packs[0][1].shape, [8])
self.assertShape(packs[0][2], [10, 10]) self.assertEqual(packs[0][2].shape, [10, 10])
self.assertLen(packs[1], 1) self.assertLen(packs[1], 1)
self.assertShape(packs[1][0], [1]) self.assertEqual(packs[1][0].shape, [1])
def testZeroBytesPerPack(self): def testZeroBytesPerPack(self):
values = [ values = [
array_ops.ones([1], dtype=dtypes.float32), array_ops.ones([1], dtype=dtypes.float32),
array_ops.ones([2], dtype=dtypes.float32), array_ops.ones([2], dtype=dtypes.float32),
] ]
per_replica_values = [value_lib.PerReplica([v, v]) for v in values] packs = cross_device_utils.group_by_size(values, bytes_per_pack=0)
packs = cross_device_utils.pack_by_size(
per_replica_values, bytes_per_pack=0)
self.assertLen(packs, 1) self.assertLen(packs, 1)
self.assertLen(packs[0], 2) self.assertLen(packs[0], 2)
self.assertShape(packs[0][0], [1]) self.assertEqual(packs[0][0].shape, [1])
self.assertShape(packs[0][1], [2]) self.assertEqual(packs[0][1].shape, [2])
def testUnknownShape(self): def testUnknownShape(self):
def create_placeholder(shape, dtype): def create_placeholder(shape, dtype):
with ops.Graph().as_default(): with ops.Graph().as_default():
return array_ops.placeholder(dtype=dtype, shape=shape) return array_ops.placeholder(dtype=dtype, shape=shape)
per_replica_values = [ values = [
value_lib.PerReplica([ array_ops.ones([10, 10], dtype=dtypes.float32),
array_ops.ones([10, 10], dtype=dtypes.float32), create_placeholder([None, 10], dtype=dtypes.float32),
array_ops.ones([10, 10], dtype=dtypes.float32),
]),
value_lib.PerReplica([
array_ops.ones([10, 10], dtype=dtypes.float32),
create_placeholder([None, 10], dtype=dtypes.float32),
]),
] ]
packs = cross_device_utils.pack_by_size( packs = cross_device_utils.group_by_size(values, bytes_per_pack=1)
per_replica_values, bytes_per_pack=1)
self.assertLen(packs, 1) self.assertLen(packs, 1)
self.assertEqual(packs[0], per_replica_values) self.assertEqual(packs[0], values)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -27,7 +27,7 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import cluster_resolver from tensorflow.python.distribute import cluster_resolver
from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
@ -451,6 +451,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
# Reusing keys is not supported well. So we have to give a different # Reusing keys is not supported well. So we have to give a different
# collective key base for different tests. # collective key base for different tests.
CollectiveAllReduceTest.collective_key_base += 100000 CollectiveAllReduceTest.collective_key_base += 100000
mwms_lib.CollectiveAllReduceStrategy._collective_key_base = (
CollectiveAllReduceTest.collective_key_base)
def _get_test_objects(self, def _get_test_objects(self,
task_type, task_type,
@ -460,10 +462,7 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
use_strategy_object=False, use_strategy_object=False,
local_mode=False): local_mode=False):
collective_keys = cross_device_utils.CollectiveKeys( collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=10 + CollectiveAllReduceTest.collective_key_base, group_key_start=10 + CollectiveAllReduceTest.collective_key_base)
op_instance_key_start=100 + CollectiveAllReduceTest.collective_key_base,
variable_instance_key_start=10000 +
CollectiveAllReduceTest.collective_key_base)
if local_mode: if local_mode:
if num_gpus: if num_gpus:
devices = ["/device:GPU:%d" % i for i in range(num_gpus)] devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
@ -471,13 +470,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
devices = ["/device:CPU:0"] devices = ["/device:CPU:0"]
if use_strategy_object: if use_strategy_object:
strategy = ( strategy = (mwms_lib.CollectiveAllReduceStrategy
collective_all_reduce_strategy.CollectiveAllReduceStrategy ._from_local_devices(devices, communication=communication)) # pylint: disable=protected-access
._from_local_devices(devices, communication=communication)) # pylint: disable=protected-access
strategy.extended._collective_keys = collective_keys
strategy.extended._cross_device_ops._collective_keys = collective_keys
strategy.extended._host_cross_device_ops._collective_keys = (
collective_keys)
return strategy, devices, "" return strategy, devices, ""
else: else:
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
@ -507,10 +501,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
task_type=task_type, task_type=task_type,
task_id=task_id, task_id=task_id,
num_accelerators={"GPU": num_gpus}) num_accelerators={"GPU": num_gpus})
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( strategy = mwms_lib.CollectiveAllReduceStrategy(
cluster_resolver=resolver, communication=communication) cluster_resolver=resolver, communication=communication)
strategy.extended._collective_keys = collective_keys
strategy.extended._cross_device_ops._collective_keys = collective_keys
return (strategy, devices, return (strategy, devices,
"grpc://" + self._cluster_spec[task_type][task_id]) "grpc://" + self._cluster_spec[task_type][task_id])
else: else:

View File

@ -689,7 +689,7 @@ cuda_py_test(
name = "multi_worker_test", name = "multi_worker_test",
srcs = ["multi_worker_test.py"], srcs = ["multi_worker_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 32, shard_count = 2,
tags = [ tags = [
"multi_and_single_gpu", "multi_and_single_gpu",
"no_oss", # TODO(b/130369494): Investigate why it times out on OSS. "no_oss", # TODO(b/130369494): Investigate why it times out on OSS.

View File

@ -24,9 +24,8 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import v2_compat from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
from tensorflow.python.distribute import combinations as ds_combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import multi_worker_util
@ -80,7 +79,7 @@ def create_test_objects(cluster_spec=None,
ClusterSpec({}), num_accelerators={'GPU': num_gpus}) ClusterSpec({}), num_accelerators={'GPU': num_gpus})
target = '' target = ''
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( strategy = mwms_lib.CollectiveAllReduceStrategy(
cluster_resolver=cluster_resolver) cluster_resolver=cluster_resolver)
sess_config = strategy.update_config_proto(sess_config) sess_config = strategy.update_config_proto(sess_config)
@ -95,9 +94,7 @@ class CollectiveAllReduceStrategyTestBase(
def setUp(self): def setUp(self):
# We use a different key_base for each test so that collective keys won't be # We use a different key_base for each test so that collective keys won't be
# reused. # reused.
# TODO(yuefengz, ayushd): enable it to reuse collective keys in different mwms_lib.CollectiveAllReduceStrategy._collective_key_base += 100000
# tests.
CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
super(CollectiveAllReduceStrategyTestBase, self).setUp() super(CollectiveAllReduceStrategyTestBase, self).setUp()
def _get_test_object(self, task_type, task_id, num_gpus=0): def _get_test_object(self, task_type, task_id, num_gpus=0):
@ -106,18 +103,6 @@ class CollectiveAllReduceStrategyTestBase(
task_type=task_type, task_type=task_type,
task_id=task_id, task_id=task_id,
num_gpus=num_gpus) num_gpus=num_gpus)
collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=10 +
CollectiveAllReduceStrategyTestBase.collective_key_base,
op_instance_key_start=100 +
CollectiveAllReduceStrategyTestBase.collective_key_base,
variable_instance_key_start=10000 +
CollectiveAllReduceStrategyTestBase.collective_key_base)
strategy.extended._collective_keys = collective_keys
strategy.extended._cross_device_ops._collective_keys = collective_keys
strategy.extended._host_cross_device_ops._collective_keys = collective_keys
return strategy, target, session_config return strategy, target, session_config
def _test_complex_model(self, task_type, task_id, num_gpus): def _test_complex_model(self, task_type, task_id, num_gpus):