Use non batch all_reduce in sync batch norm to avoid potential deadlock

Batch all_reduce may use NCCL, and may cause a deadlock since we don't have ways
to ensure NCCL launch orders are the same on each worker.

For now we switch to a non-batch version so that it uses RING instead.

PiperOrigin-RevId: 336368256
Change-Id: I47d268be6434101a5e90addc64a63fc3cb061074
This commit is contained in:
Ran Chen 2020-10-09 14:40:39 -07:00 committed by TensorFlower Gardener
parent ead4d024f8
commit d995be3625
3 changed files with 16 additions and 6 deletions

View File

@ -1098,9 +1098,11 @@ class CollectiveAllReduce(CrossDeviceOps):
batch_size = len(per_replica_values)
# Pass self._communication to the runtime as a communication hint.
communication = self._communication.value
# For now, we use NCCL only when batch_size > 1.
# For now, we use NCCL only when batch_size > 1 since we don't have a way to
# order NCCL launches. We're hoping that there's only one batched
# all-reduce, which is the gradients.
# TODO(b/132575814): switch to NCCL for all collectives when communication
# is NCCL.
# is NCCL if and only if we can order collectives deterministically.
if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
communication = CollectiveCommunication.AUTO.value

View File

@ -51,7 +51,10 @@ class DistributionStrategyCnnCorrectnessTest(
if self.with_batch_norm == 'regular':
c1 = keras.layers.BatchNormalization(name='bn1')(c1)
elif self.with_batch_norm == 'sync':
c1 = keras.layers.SyncBatchNormalization(name='bn1')(c1)
# Test with parallel batch norms to verify all-reduce works OK.
bn1 = keras.layers.SyncBatchNormalization(name='bn1')(c1)
bn2 = keras.layers.SyncBatchNormalization(name='bn2')(c1)
c1 = keras.layers.Add()([bn1, bn2])
c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1)
logits = keras.layers.Dense(
10, activation='softmax', name='pred')(

View File

@ -170,9 +170,14 @@ class SyncBatchNormalization(normalization.BatchNormalizationBase):
local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes,
keepdims=True)
batch_size = math_ops.cast(array_ops.shape_v2(y)[0], dtypes.float32)
y_sum, y_squared_sum, global_batch_size = (
replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, [
local_sum, local_squared_sum, batch_size]))
# TODO(b/163099951): batch the all-reduces once we sort out the ordering
# issue for NCCL. We don't have a mechanism to launch NCCL in the same
# order in each replica nowadays, so we limit NCCL to batch all-reduces.
y_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, local_sum)
y_squared_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM,
local_squared_sum)
global_batch_size = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM,
batch_size)
axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))]
multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals),