From d66729431da1d47e2adda97be5aa0fa456eb7e26 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Mon, 25 Jan 2021 13:19:07 -0800 Subject: [PATCH] Remove the workaround for collective cancellation issue in keras It's no longer needed PiperOrigin-RevId: 353715404 Change-Id: I32ef068e338f28625c7e0d9a96cc5289e0b8eeb2 --- tensorflow/python/keras/engine/training.py | 31 +++++++++------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index d9137f4a62c..47aabffd7b0 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -2771,24 +2771,19 @@ def _collective_all_reduce_multi_worker(strategy): # for all strategies def _multi_worker_concat(v, strategy): """Order PerReplica objects for CollectiveAllReduceStrategy and concat.""" - replicas = strategy.gather(v, axis=0) # pylint: disable=protected-access - # TODO(b/170435030): We now need to make sure these run after the iterator - # GetNext, so that we don't trigger aborting collective ops in the case of - # EOF. Remove after the issue is fixed. - with ops.control_dependencies([replicas]): - # v might not have the same shape on different replicas - if isinstance(v, ds_values.PerReplica): - shapes = array_ops.concat([ - array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0) - for single_value in v.values - ], - axis=0) - all_shapes = strategy.gather(shapes, axis=0) - else: - # v is a tensor. This may happen when, say, we have 2x1 multi-worker. - all_shapes = strategy.gather( - array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), - axis=0) + replicas = strategy.gather(v, axis=0) + # v might not have the same shape on different replicas + if isinstance(v, ds_values.PerReplica): + shapes = array_ops.concat([ + array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0) + for single_value in v.values + ], + axis=0) + all_shapes = strategy.gather(shapes, axis=0) + else: + # v is a tensor. This may happen when, say, we have 2x1 multi-worker. + all_shapes = strategy.gather( + array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), axis=0) replicas = array_ops.split( replicas,