diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 4d45bfa7a9b..da99887b0f2 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -1261,8 +1261,17 @@ class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject): def restore(self, restored_tensors, restored_shapes): """Restore the same value into all variables.""" + # To preserve the sum across save and restore, we have to divide the + # total across all devices when restoring a variable that was summed + # when saving. tensor, = restored_tensors - return self._sync_on_read_variable.assign(tensor) + if self._sync_on_read_variable.aggregation == vs.VariableAggregation.SUM: + tensor = math_ops.cast(tensor / len(self._sync_on_read_variable.devices), + self._sync_on_read_variable.dtype) + return control_flow_ops.group( + tuple( + _assign_on_device(v.device, v, tensor) + for v in self._sync_on_read_variable.values)) def _assert_replica_context(strategy):