Support restoring on creation for Sonnet model with TPUReplicator.
PiperOrigin-RevId: 277523053 Change-Id: I756233199b24f1755eb477a596c9eec2f909a86e
This commit is contained in:
parent
a2323c1b1f
commit
aff087ab42
@ -1261,8 +1261,17 @@ class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
"""Restore the same value into all variables."""
|
||||
# To preserve the sum across save and restore, we have to divide the
|
||||
# total across all devices when restoring a variable that was summed
|
||||
# when saving.
|
||||
tensor, = restored_tensors
|
||||
return self._sync_on_read_variable.assign(tensor)
|
||||
if self._sync_on_read_variable.aggregation == vs.VariableAggregation.SUM:
|
||||
tensor = math_ops.cast(tensor / len(self._sync_on_read_variable.devices),
|
||||
self._sync_on_read_variable.dtype)
|
||||
return control_flow_ops.group(
|
||||
tuple(
|
||||
_assign_on_device(v.device, v, tensor)
|
||||
for v in self._sync_on_read_variable.values))
|
||||
|
||||
|
||||
def _assert_replica_context(strategy):
|
||||
|
Loading…
Reference in New Issue
Block a user