From aff087ab4216e3c80788fd62481325dcaa3dbd0a Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Wed, 30 Oct 2019 09:39:59 -0700 Subject: [PATCH] Support restoring on creation for Sonnet model with TPUReplicator. PiperOrigin-RevId: 277523053 Change-Id: I756233199b24f1755eb477a596c9eec2f909a86e --- tensorflow/python/distribute/values.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 4d45bfa7a9b..da99887b0f2 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -1261,8 +1261,17 @@ class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject): def restore(self, restored_tensors, restored_shapes): """Restore the same value into all variables.""" + # To preserve the sum across save and restore, we have to divide the + # total across all devices when restoring a variable that was summed + # when saving. tensor, = restored_tensors - return self._sync_on_read_variable.assign(tensor) + if self._sync_on_read_variable.aggregation == vs.VariableAggregation.SUM: + tensor = math_ops.cast(tensor / len(self._sync_on_read_variable.devices), + self._sync_on_read_variable.dtype) + return control_flow_ops.group( + tuple( + _assign_on_device(v.device, v, tensor) + for v in self._sync_on_read_variable.values)) def _assert_replica_context(strategy):