Fix TPU saved_model.save
with SyncOnReadVariable.
PiperOrigin-RevId: 304032507 Change-Id: I93abd43d43351b50eff544fbf4e2d5ef3a26fe17
This commit is contained in:
parent
b5f5af4c9d
commit
44fe00d030
@ -672,11 +672,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
reduce_op, value, destinations, self._num_replicas_in_sync)
|
||||
|
||||
# TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
|
||||
# Always performs the reduction on the TPU host.
|
||||
with ops.device(self._host_device):
|
||||
output = math_ops.add_n(value.values)
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
output *= (1. / len(value.values))
|
||||
output = math_ops.add_n(value.values)
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
output *= (1. / len(value.values))
|
||||
|
||||
devices = cross_device_ops_lib.get_devices_from(destinations)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user