Remove the workaround for collective cancellation issue in keras
It's no longer needed PiperOrigin-RevId: 353715404 Change-Id: I32ef068e338f28625c7e0d9a96cc5289e0b8eeb2
This commit is contained in:
parent
a1668cbcd9
commit
d66729431d
@ -2771,24 +2771,19 @@ def _collective_all_reduce_multi_worker(strategy):
|
||||
# for all strategies
|
||||
def _multi_worker_concat(v, strategy):
|
||||
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
|
||||
replicas = strategy.gather(v, axis=0) # pylint: disable=protected-access
|
||||
# TODO(b/170435030): We now need to make sure these run after the iterator
|
||||
# GetNext, so that we don't trigger aborting collective ops in the case of
|
||||
# EOF. Remove after the issue is fixed.
|
||||
with ops.control_dependencies([replicas]):
|
||||
# v might not have the same shape on different replicas
|
||||
if isinstance(v, ds_values.PerReplica):
|
||||
shapes = array_ops.concat([
|
||||
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
|
||||
for single_value in v.values
|
||||
],
|
||||
axis=0)
|
||||
all_shapes = strategy.gather(shapes, axis=0)
|
||||
else:
|
||||
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
|
||||
all_shapes = strategy.gather(
|
||||
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
|
||||
axis=0)
|
||||
replicas = strategy.gather(v, axis=0)
|
||||
# v might not have the same shape on different replicas
|
||||
if isinstance(v, ds_values.PerReplica):
|
||||
shapes = array_ops.concat([
|
||||
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
|
||||
for single_value in v.values
|
||||
],
|
||||
axis=0)
|
||||
all_shapes = strategy.gather(shapes, axis=0)
|
||||
else:
|
||||
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
|
||||
all_shapes = strategy.gather(
|
||||
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), axis=0)
|
||||
|
||||
replicas = array_ops.split(
|
||||
replicas,
|
||||
|
Loading…
Reference in New Issue
Block a user