tf.train.Checkpoint: Allow saving dictionaries with replaced dependencies

We already allow reassigning attributes of regular objects (obj.x = v1; obj.x = v2) so doing it for dictionaries too seems reasonable.

Lists are somewhat more complicated since e.g. deleting in the middle renames everything after, but potentially we can relax the check there too if it's getting in the way.

PiperOrigin-RevId: 247047184
This commit is contained in:
Allen Lavoie 2019-05-07 10:38:48 -07:00 committed by TensorFlower Gardener
parent 1b27b3cacd
commit b17b0103a9
2 changed files with 2 additions and 44 deletions

View File

@ -665,7 +665,6 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
wrapt.ObjectProxy.__init__(self, wrapped_dict) wrapt.ObjectProxy.__init__(self, wrapped_dict)
TrackableDataStructure.__init__(self) TrackableDataStructure.__init__(self)
self._self_non_string_key = False self._self_non_string_key = False
self._self_non_append_mutation = False
self._self_external_modification = False self._self_external_modification = False
self.__wrapped__.update( self.__wrapped__.update(
{key: self._track_value( {key: self._track_value(
@ -690,14 +689,12 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
# pylint: disable=protected-access # pylint: disable=protected-access
def __copy__(self): def __copy__(self):
copied = _DictWrapper(copy.copy(self.__wrapped__)) copied = _DictWrapper(copy.copy(self.__wrapped__))
copied._self_non_append_mutation = self._self_non_append_mutation
copied._self_external_modification = self._self_external_modification copied._self_external_modification = self._self_external_modification
copied._self_non_string_key = self._self_non_string_key copied._self_non_string_key = self._self_non_string_key
return copied return copied
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo)) copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo))
copied._self_non_append_mutation = self._self_non_append_mutation
copied._self_external_modification = self._self_external_modification copied._self_external_modification = self._self_external_modification
copied._self_non_string_key = self._self_non_string_key copied._self_non_string_key = self._self_non_string_key
return copied return copied
@ -725,15 +722,6 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
"checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency " "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
"object; it will be automatically un-wrapped and subsequently " "object; it will be automatically un-wrapped and subsequently "
"ignored." % (self,)) "ignored." % (self,))
if self._self_non_append_mutation:
raise ValueError(
"Unable to save the object %s (a dictionary wrapper constructed "
"automatically on attribute assignment). A key mapping to a "
"trackable object was overwritten or deleted, which would "
"cause problems for restoration.\n\nIf you don't need this "
"dictionary checkpointed, wrap it in a "
"tf.contrib.checkpoint.NoDependency object; it will be automatically "
"un-wrapped and subsequently ignored." % (self,))
if self._self_external_modification: if self._self_external_modification:
raise ValueError( raise ValueError(
"Unable to save the object %s (a dictionary wrapper constructed " "Unable to save the object %s (a dictionary wrapper constructed "
@ -752,7 +740,6 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
def _dirty(self): def _dirty(self):
"""Check if there has already been a mutation which prevents saving.""" """Check if there has already been a mutation which prevents saving."""
return (self._self_external_modification return (self._self_external_modification
or self._self_non_append_mutation
or self._self_non_string_key) or self._self_non_string_key)
def _check_self_external_modification(self): def _check_self_external_modification(self):
@ -800,39 +787,20 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
self._maybe_initialize_trackable() self._maybe_initialize_trackable()
no_dep = isinstance(value, NoDependency) no_dep = isinstance(value, NoDependency)
if isinstance(key, six.string_types): if isinstance(key, six.string_types):
existing_dependency = self._lookup_dependency(key)
value = self._track_value(value, name=key) value = self._track_value(value, name=key)
else: else:
value = _wrap_or_unwrap(value) value = _wrap_or_unwrap(value)
existing_dependency = None
if not no_dep and isinstance(value, base.Trackable): if not no_dep and isinstance(value, base.Trackable):
# Non-string keys are OK as long as we have no reason to add a # Non-string keys are OK as long as we have no reason to add a
# dependency on the value (either because the value is not # dependency on the value (either because the value is not
# trackable, or because it was wrapped in a NoDependency object). # trackable, or because it was wrapped in a NoDependency object).
self._self_non_string_key = True self._self_non_string_key = True
if key in self.__wrapped__:
previous_value = self.__wrapped__[key]
if previous_value is not value:
if ((not no_dep and isinstance(value, base.Trackable))
# We don't want to just check that the existing object is
# trackable, since it may have been wrapped in a NoDependency
# object.
or existing_dependency is not None):
# A trackable object was replaced under the same key; this means
# that restoring would be error-prone, so we'll throw an exception on
# save.
self._self_non_append_mutation = True
self.__wrapped__[key] = value self.__wrapped__[key] = value
self._update_snapshot() self._update_snapshot()
def __delitem__(self, key): def __delitem__(self, key):
self._check_self_external_modification() self._check_self_external_modification()
existing_value = self[key]
if isinstance(existing_value, base.Trackable):
# Deleting tracked trackable values means restoring is problematic,
# so we'll throw an exception on save.
self._self_non_append_mutation = True
del self.__wrapped__[key] del self.__wrapped__[key]
self._update_snapshot() self._update_snapshot()

View File

@ -663,15 +663,6 @@ class MappingTests(test.TestCase):
model.save_weights(save_path) model.save_weights(save_path)
model.load_weights(save_path) model.load_weights(save_path)
def testDelNoSave(self):
model = training.Model()
model.d = {}
model.d["a"] = []
del model.d["a"]
save_path = os.path.join(self.get_temp_dir(), "ckpt")
with self.assertRaisesRegexp(ValueError, "overwritten or deleted"):
model.save_weights(save_path)
def testPopNoSave(self): def testPopNoSave(self):
model = training.Model() model = training.Model()
model.d = {} model.d = {}
@ -690,14 +681,13 @@ class MappingTests(test.TestCase):
with self.assertRaisesRegexp(ValueError, "modified outside the wrapper"): with self.assertRaisesRegexp(ValueError, "modified outside the wrapper"):
model.save_weights(save_path) model.save_weights(save_path)
def testOverwriteNoSave(self): def testOverwriteCanStillSave(self):
model = training.Model() model = training.Model()
model.d = {} model.d = {}
model.d["a"] = {} model.d["a"] = {}
model.d["a"] = {} model.d["a"] = {}
save_path = os.path.join(self.get_temp_dir(), "ckpt") save_path = os.path.join(self.get_temp_dir(), "ckpt")
with self.assertRaisesRegexp(ValueError, "overwritten or deleted"): model.save_weights(save_path)
model.save_weights(save_path)
def testIter(self): def testIter(self):
model = training.Model() model = training.Model()