tf.train.Checkpoint: Allow saving dictionaries with replaced dependencies

We already allow reassigning attributes of regular objects (obj.x = v1; obj.x = v2) so doing it for dictionaries too seems reasonable.

Lists are somewhat more complicated since e.g. deleting in the middle renames everything after, but potentially we can relax the check there too if it's getting in the way.

PiperOrigin-RevId: 247047184
This commit is contained in:
Allen Lavoie 2019-05-07 10:38:48 -07:00 committed by TensorFlower Gardener
parent 1b27b3cacd
commit b17b0103a9
2 changed files with 2 additions and 44 deletions

View File

@ -665,7 +665,6 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
wrapt.ObjectProxy.__init__(self, wrapped_dict)
TrackableDataStructure.__init__(self)
self._self_non_string_key = False
self._self_non_append_mutation = False
self._self_external_modification = False
self.__wrapped__.update(
{key: self._track_value(
@ -690,14 +689,12 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
# pylint: disable=protected-access
def __copy__(self):
copied = _DictWrapper(copy.copy(self.__wrapped__))
copied._self_non_append_mutation = self._self_non_append_mutation
copied._self_external_modification = self._self_external_modification
copied._self_non_string_key = self._self_non_string_key
return copied
def __deepcopy__(self, memo):
copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo))
copied._self_non_append_mutation = self._self_non_append_mutation
copied._self_external_modification = self._self_external_modification
copied._self_non_string_key = self._self_non_string_key
return copied
@ -725,15 +722,6 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
"checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
"object; it will be automatically un-wrapped and subsequently "
"ignored." % (self,))
if self._self_non_append_mutation:
raise ValueError(
"Unable to save the object %s (a dictionary wrapper constructed "
"automatically on attribute assignment). A key mapping to a "
"trackable object was overwritten or deleted, which would "
"cause problems for restoration.\n\nIf you don't need this "
"dictionary checkpointed, wrap it in a "
"tf.contrib.checkpoint.NoDependency object; it will be automatically "
"un-wrapped and subsequently ignored." % (self,))
if self._self_external_modification:
raise ValueError(
"Unable to save the object %s (a dictionary wrapper constructed "
@ -752,7 +740,6 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
def _dirty(self):
"""Check if there has already been a mutation which prevents saving."""
return (self._self_external_modification
or self._self_non_append_mutation
or self._self_non_string_key)
def _check_self_external_modification(self):
@ -800,39 +787,20 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
self._maybe_initialize_trackable()
no_dep = isinstance(value, NoDependency)
if isinstance(key, six.string_types):
existing_dependency = self._lookup_dependency(key)
value = self._track_value(value, name=key)
else:
value = _wrap_or_unwrap(value)
existing_dependency = None
if not no_dep and isinstance(value, base.Trackable):
# Non-string keys are OK as long as we have no reason to add a
# dependency on the value (either because the value is not
# trackable, or because it was wrapped in a NoDependency object).
self._self_non_string_key = True
if key in self.__wrapped__:
previous_value = self.__wrapped__[key]
if previous_value is not value:
if ((not no_dep and isinstance(value, base.Trackable))
# We don't want to just check that the existing object is
# trackable, since it may have been wrapped in a NoDependency
# object.
or existing_dependency is not None):
# A trackable object was replaced under the same key; this means
# that restoring would be error-prone, so we'll throw an exception on
# save.
self._self_non_append_mutation = True
self.__wrapped__[key] = value
self._update_snapshot()
def __delitem__(self, key):
self._check_self_external_modification()
existing_value = self[key]
if isinstance(existing_value, base.Trackable):
# Deleting tracked trackable values means restoring is problematic,
# so we'll throw an exception on save.
self._self_non_append_mutation = True
del self.__wrapped__[key]
self._update_snapshot()

View File

@ -663,15 +663,6 @@ class MappingTests(test.TestCase):
model.save_weights(save_path)
model.load_weights(save_path)
def testDelNoSave(self):
model = training.Model()
model.d = {}
model.d["a"] = []
del model.d["a"]
save_path = os.path.join(self.get_temp_dir(), "ckpt")
with self.assertRaisesRegexp(ValueError, "overwritten or deleted"):
model.save_weights(save_path)
def testPopNoSave(self):
model = training.Model()
model.d = {}
@ -690,14 +681,13 @@ class MappingTests(test.TestCase):
with self.assertRaisesRegexp(ValueError, "modified outside the wrapper"):
model.save_weights(save_path)
def testOverwriteNoSave(self):
def testOverwriteCanStillSave(self):
model = training.Model()
model.d = {}
model.d["a"] = {}
model.d["a"] = {}
save_path = os.path.join(self.get_temp_dir(), "ckpt")
with self.assertRaisesRegexp(ValueError, "overwritten or deleted"):
model.save_weights(save_path)
model.save_weights(save_path)
def testIter(self):
model = training.Model()