Support restoring on creation for Sonnet model with TPUReplicator.

PiperOrigin-RevId: 277523053
Change-Id: I756233199b24f1755eb477a596c9eec2f909a86e
This commit is contained in:
Ruoxin Sang 2019-10-30 09:39:59 -07:00 committed by TensorFlower Gardener
parent a2323c1b1f
commit aff087ab42

View File

@ -1261,8 +1261,17 @@ class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
def restore(self, restored_tensors, restored_shapes): def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into all variables.""" """Restore the same value into all variables."""
# To preserve the sum across save and restore, we have to divide the
# total across all devices when restoring a variable that was summed
# when saving.
tensor, = restored_tensors tensor, = restored_tensors
return self._sync_on_read_variable.assign(tensor) if self._sync_on_read_variable.aggregation == vs.VariableAggregation.SUM:
tensor = math_ops.cast(tensor / len(self._sync_on_read_variable.devices),
self._sync_on_read_variable.dtype)
return control_flow_ops.group(
tuple(
_assign_on_device(v.device, v, tensor)
for v in self._sync_on_read_variable.values))
def _assert_replica_context(strategy): def _assert_replica_context(strategy):