Remove the workaround for collective cancellation issue in keras

It's no longer needed

PiperOrigin-RevId: 353715404
Change-Id: I32ef068e338f28625c7e0d9a96cc5289e0b8eeb2
This commit is contained in:
Ran Chen 2021-01-25 13:19:07 -08:00 committed by TensorFlower Gardener
parent a1668cbcd9
commit d66729431d

View File

@ -2771,24 +2771,19 @@ def _collective_all_reduce_multi_worker(strategy):
# for all strategies
def _multi_worker_concat(v, strategy):
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
replicas = strategy.gather(v, axis=0) # pylint: disable=protected-access
# TODO(b/170435030): We now need to make sure these run after the iterator
# GetNext, so that we don't trigger aborting collective ops in the case of
# EOF. Remove after the issue is fixed.
with ops.control_dependencies([replicas]):
# v might not have the same shape on different replicas
if isinstance(v, ds_values.PerReplica):
shapes = array_ops.concat([
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
for single_value in v.values
],
axis=0)
all_shapes = strategy.gather(shapes, axis=0)
else:
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
all_shapes = strategy.gather(
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
axis=0)
replicas = strategy.gather(v, axis=0)
# v might not have the same shape on different replicas
if isinstance(v, ds_values.PerReplica):
shapes = array_ops.concat([
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
for single_value in v.values
],
axis=0)
all_shapes = strategy.gather(shapes, axis=0)
else:
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
all_shapes = strategy.gather(
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), axis=0)
replicas = array_ops.split(
replicas,