Support restoring on creation for Sonnet model with TPUReplicator.
PiperOrigin-RevId: 277523053 Change-Id: I756233199b24f1755eb477a596c9eec2f909a86e
This commit is contained in:
parent
a2323c1b1f
commit
aff087ab42
@ -1261,8 +1261,17 @@ class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
|
|||||||
|
|
||||||
def restore(self, restored_tensors, restored_shapes):
|
def restore(self, restored_tensors, restored_shapes):
|
||||||
"""Restore the same value into all variables."""
|
"""Restore the same value into all variables."""
|
||||||
|
# To preserve the sum across save and restore, we have to divide the
|
||||||
|
# total across all devices when restoring a variable that was summed
|
||||||
|
# when saving.
|
||||||
tensor, = restored_tensors
|
tensor, = restored_tensors
|
||||||
return self._sync_on_read_variable.assign(tensor)
|
if self._sync_on_read_variable.aggregation == vs.VariableAggregation.SUM:
|
||||||
|
tensor = math_ops.cast(tensor / len(self._sync_on_read_variable.devices),
|
||||||
|
self._sync_on_read_variable.dtype)
|
||||||
|
return control_flow_ops.group(
|
||||||
|
tuple(
|
||||||
|
_assign_on_device(v.device, v, tensor)
|
||||||
|
for v in self._sync_on_read_variable.values))
|
||||||
|
|
||||||
|
|
||||||
def _assert_replica_context(strategy):
|
def _assert_replica_context(strategy):
|
||||||
|
Loading…
Reference in New Issue
Block a user