Fix TPU saved_model.save with SyncOnReadVariable.

PiperOrigin-RevId: 304032507
Change-Id: I93abd43d43351b50eff544fbf4e2d5ef3a26fe17
This commit is contained in:
Ruoxin Sang 2020-03-31 13:19:02 -07:00 committed by TensorFlower Gardener
parent b5f5af4c9d
commit 44fe00d030

View File

@ -672,11 +672,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
reduce_op, value, destinations, self._num_replicas_in_sync)
# TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
# Always performs the reduction on the TPU host.
with ops.device(self._host_device):
output = math_ops.add_n(value.values)
if reduce_op == reduce_util.ReduceOp.MEAN:
output *= (1. / len(value.values))
output = math_ops.add_n(value.values)
if reduce_op == reduce_util.ReduceOp.MEAN:
output *= (1. / len(value.values))
devices = cross_device_ops_lib.get_devices_from(destinations)