Use non batch all_reduce in sync batch norm to avoid potential deadlock
Batch all_reduce may use NCCL, and may cause a deadlock since we don't have ways to ensure NCCL launch orders are the same on each worker. For now we switch to a non-batch version so that it uses RING instead. PiperOrigin-RevId: 336368256 Change-Id: I47d268be6434101a5e90addc64a63fc3cb061074
This commit is contained in:
parent
ead4d024f8
commit
d995be3625
tensorflow/python
distribute
keras
@ -1098,9 +1098,11 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
batch_size = len(per_replica_values)
|
||||
# Pass self._communication to the runtime as a communication hint.
|
||||
communication = self._communication.value
|
||||
# For now, we use NCCL only when batch_size > 1.
|
||||
# For now, we use NCCL only when batch_size > 1 since we don't have a way to
|
||||
# order NCCL launches. We're hoping that there's only one batched
|
||||
# all-reduce, which is the gradients.
|
||||
# TODO(b/132575814): switch to NCCL for all collectives when communication
|
||||
# is NCCL.
|
||||
# is NCCL if and only if we can order collectives deterministically.
|
||||
if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
|
||||
communication = CollectiveCommunication.AUTO.value
|
||||
|
||||
|
@ -51,7 +51,10 @@ class DistributionStrategyCnnCorrectnessTest(
|
||||
if self.with_batch_norm == 'regular':
|
||||
c1 = keras.layers.BatchNormalization(name='bn1')(c1)
|
||||
elif self.with_batch_norm == 'sync':
|
||||
c1 = keras.layers.SyncBatchNormalization(name='bn1')(c1)
|
||||
# Test with parallel batch norms to verify all-reduce works OK.
|
||||
bn1 = keras.layers.SyncBatchNormalization(name='bn1')(c1)
|
||||
bn2 = keras.layers.SyncBatchNormalization(name='bn2')(c1)
|
||||
c1 = keras.layers.Add()([bn1, bn2])
|
||||
c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1)
|
||||
logits = keras.layers.Dense(
|
||||
10, activation='softmax', name='pred')(
|
||||
|
@ -170,9 +170,14 @@ class SyncBatchNormalization(normalization.BatchNormalizationBase):
|
||||
local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes,
|
||||
keepdims=True)
|
||||
batch_size = math_ops.cast(array_ops.shape_v2(y)[0], dtypes.float32)
|
||||
y_sum, y_squared_sum, global_batch_size = (
|
||||
replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, [
|
||||
local_sum, local_squared_sum, batch_size]))
|
||||
# TODO(b/163099951): batch the all-reduces once we sort out the ordering
|
||||
# issue for NCCL. We don't have a mechanism to launch NCCL in the same
|
||||
# order in each replica nowadays, so we limit NCCL to batch all-reduces.
|
||||
y_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, local_sum)
|
||||
y_squared_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM,
|
||||
local_squared_sum)
|
||||
global_batch_size = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM,
|
||||
batch_size)
|
||||
|
||||
axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))]
|
||||
multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals),
|
||||
|
Loading…
Reference in New Issue
Block a user