tf.train.Checkpoint: Allow saving dictionaries with replaced dependencies
We already allow reassigning attributes of regular objects (obj.x = v1; obj.x = v2) so doing it for dictionaries too seems reasonable. Lists are somewhat more complicated since e.g. deleting in the middle renames everything after, but potentially we can relax the check there too if it's getting in the way. PiperOrigin-RevId: 247047184
This commit is contained in:
parent
1b27b3cacd
commit
b17b0103a9
@ -665,7 +665,6 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
|
||||
wrapt.ObjectProxy.__init__(self, wrapped_dict)
|
||||
TrackableDataStructure.__init__(self)
|
||||
self._self_non_string_key = False
|
||||
self._self_non_append_mutation = False
|
||||
self._self_external_modification = False
|
||||
self.__wrapped__.update(
|
||||
{key: self._track_value(
|
||||
@ -690,14 +689,12 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
|
||||
# pylint: disable=protected-access
|
||||
def __copy__(self):
|
||||
copied = _DictWrapper(copy.copy(self.__wrapped__))
|
||||
copied._self_non_append_mutation = self._self_non_append_mutation
|
||||
copied._self_external_modification = self._self_external_modification
|
||||
copied._self_non_string_key = self._self_non_string_key
|
||||
return copied
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo))
|
||||
copied._self_non_append_mutation = self._self_non_append_mutation
|
||||
copied._self_external_modification = self._self_external_modification
|
||||
copied._self_non_string_key = self._self_non_string_key
|
||||
return copied
|
||||
@ -725,15 +722,6 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
|
||||
"checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
|
||||
"object; it will be automatically un-wrapped and subsequently "
|
||||
"ignored." % (self,))
|
||||
if self._self_non_append_mutation:
|
||||
raise ValueError(
|
||||
"Unable to save the object %s (a dictionary wrapper constructed "
|
||||
"automatically on attribute assignment). A key mapping to a "
|
||||
"trackable object was overwritten or deleted, which would "
|
||||
"cause problems for restoration.\n\nIf you don't need this "
|
||||
"dictionary checkpointed, wrap it in a "
|
||||
"tf.contrib.checkpoint.NoDependency object; it will be automatically "
|
||||
"un-wrapped and subsequently ignored." % (self,))
|
||||
if self._self_external_modification:
|
||||
raise ValueError(
|
||||
"Unable to save the object %s (a dictionary wrapper constructed "
|
||||
@ -752,7 +740,6 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
|
||||
def _dirty(self):
|
||||
"""Check if there has already been a mutation which prevents saving."""
|
||||
return (self._self_external_modification
|
||||
or self._self_non_append_mutation
|
||||
or self._self_non_string_key)
|
||||
|
||||
def _check_self_external_modification(self):
|
||||
@ -800,39 +787,20 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
|
||||
self._maybe_initialize_trackable()
|
||||
no_dep = isinstance(value, NoDependency)
|
||||
if isinstance(key, six.string_types):
|
||||
existing_dependency = self._lookup_dependency(key)
|
||||
value = self._track_value(value, name=key)
|
||||
else:
|
||||
value = _wrap_or_unwrap(value)
|
||||
existing_dependency = None
|
||||
if not no_dep and isinstance(value, base.Trackable):
|
||||
# Non-string keys are OK as long as we have no reason to add a
|
||||
# dependency on the value (either because the value is not
|
||||
# trackable, or because it was wrapped in a NoDependency object).
|
||||
self._self_non_string_key = True
|
||||
if key in self.__wrapped__:
|
||||
previous_value = self.__wrapped__[key]
|
||||
if previous_value is not value:
|
||||
if ((not no_dep and isinstance(value, base.Trackable))
|
||||
# We don't want to just check that the existing object is
|
||||
# trackable, since it may have been wrapped in a NoDependency
|
||||
# object.
|
||||
or existing_dependency is not None):
|
||||
# A trackable object was replaced under the same key; this means
|
||||
# that restoring would be error-prone, so we'll throw an exception on
|
||||
# save.
|
||||
self._self_non_append_mutation = True
|
||||
self.__wrapped__[key] = value
|
||||
|
||||
self._update_snapshot()
|
||||
|
||||
def __delitem__(self, key):
|
||||
self._check_self_external_modification()
|
||||
existing_value = self[key]
|
||||
if isinstance(existing_value, base.Trackable):
|
||||
# Deleting tracked trackable values means restoring is problematic,
|
||||
# so we'll throw an exception on save.
|
||||
self._self_non_append_mutation = True
|
||||
del self.__wrapped__[key]
|
||||
self._update_snapshot()
|
||||
|
||||
|
@ -663,15 +663,6 @@ class MappingTests(test.TestCase):
|
||||
model.save_weights(save_path)
|
||||
model.load_weights(save_path)
|
||||
|
||||
def testDelNoSave(self):
|
||||
model = training.Model()
|
||||
model.d = {}
|
||||
model.d["a"] = []
|
||||
del model.d["a"]
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
with self.assertRaisesRegexp(ValueError, "overwritten or deleted"):
|
||||
model.save_weights(save_path)
|
||||
|
||||
def testPopNoSave(self):
|
||||
model = training.Model()
|
||||
model.d = {}
|
||||
@ -690,13 +681,12 @@ class MappingTests(test.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, "modified outside the wrapper"):
|
||||
model.save_weights(save_path)
|
||||
|
||||
def testOverwriteNoSave(self):
|
||||
def testOverwriteCanStillSave(self):
|
||||
model = training.Model()
|
||||
model.d = {}
|
||||
model.d["a"] = {}
|
||||
model.d["a"] = {}
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
with self.assertRaisesRegexp(ValueError, "overwritten or deleted"):
|
||||
model.save_weights(save_path)
|
||||
|
||||
def testIter(self):
|
||||
|
Loading…
Reference in New Issue
Block a user