Shortcut cross_device_ops reduce and batch_reduce method if there is only one input in PerReplica object.
PiperOrigin-RevId: 249860947
This commit is contained in:
parent
7eb6a3fefc
commit
1188b9e764
@ -232,6 +232,11 @@ class CrossDeviceOps(object):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _num_between_graph_workers(self):
|
||||||
|
# Returns 1 by default, the value may be overridden by sub classes.
|
||||||
|
return 1
|
||||||
|
|
||||||
def reduce(self, reduce_op, per_replica_value, destinations):
|
def reduce(self, reduce_op, per_replica_value, destinations):
|
||||||
"""Reduce `per_replica_value` to `destinations`.
|
"""Reduce `per_replica_value` to `destinations`.
|
||||||
|
|
||||||
@ -255,6 +260,14 @@ class CrossDeviceOps(object):
|
|||||||
per_replica_value = _make_tensor_into_per_replica(per_replica_value)
|
per_replica_value = _make_tensor_into_per_replica(per_replica_value)
|
||||||
|
|
||||||
validate_destinations(destinations)
|
validate_destinations(destinations)
|
||||||
|
|
||||||
|
# Shortcut if `per_replica_value` only contains one value.
|
||||||
|
if self._num_between_graph_workers == 1 and len(
|
||||||
|
per_replica_value.values) == 1 and _devices_match(
|
||||||
|
per_replica_value, destinations):
|
||||||
|
return value_lib.Mirrored(per_replica_value.device_map,
|
||||||
|
per_replica_value.values)
|
||||||
|
|
||||||
return self.reduce_implementation(reduce_op, per_replica_value,
|
return self.reduce_implementation(reduce_op, per_replica_value,
|
||||||
destinations)
|
destinations)
|
||||||
|
|
||||||
@ -288,6 +301,15 @@ class CrossDeviceOps(object):
|
|||||||
for _, d in value_destination_pairs:
|
for _, d in value_destination_pairs:
|
||||||
validate_destinations(d)
|
validate_destinations(d)
|
||||||
|
|
||||||
|
# Shortcut all PerReplica objects only contain one value.
|
||||||
|
if self._num_between_graph_workers == 1 and _all_devices_match(
|
||||||
|
value_destination_pairs) and len(
|
||||||
|
value_destination_pairs[0][0].values) == 1:
|
||||||
|
return [
|
||||||
|
value_lib.Mirrored(v.device_map, v.values)
|
||||||
|
for v, _ in value_destination_pairs
|
||||||
|
]
|
||||||
|
|
||||||
return self.batch_reduce_implementation(reduce_op, value_destination_pairs)
|
return self.batch_reduce_implementation(reduce_op, value_destination_pairs)
|
||||||
|
|
||||||
def broadcast(self, tensor, destinations):
|
def broadcast(self, tensor, destinations):
|
||||||
@ -974,6 +996,10 @@ class CollectiveAllReduce(CrossDeviceOps):
|
|||||||
cross_device_utils.CollectiveKeys())
|
cross_device_utils.CollectiveKeys())
|
||||||
super(CollectiveAllReduce, self).__init__()
|
super(CollectiveAllReduce, self).__init__()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _num_between_graph_workers(self):
|
||||||
|
return self._num_workers
|
||||||
|
|
||||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations):
|
def reduce_implementation(self, reduce_op, per_replica_value, destinations):
|
||||||
all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0]
|
all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0]
|
||||||
device_map, logical_device = get_device_map_from(destinations)
|
device_map, logical_device = get_device_map_from(destinations)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user