From b17b0103a9b7578280bb4f50dc43f561443063aa Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Tue, 7 May 2019 10:38:48 -0700 Subject: [PATCH] tf.train.Checkpoint: Allow saving dictionaries with replaced dependencies We already allow reassigning attributes of regular objects (obj.x = v1; obj.x = v2) so doing it for dictionaries too seems reasonable. Lists are somewhat more complicated since e.g. deleting in the middle renames everything after, but potentially we can relax the check there too if it's getting in the way. PiperOrigin-RevId: 247047184 --- .../training/tracking/data_structures.py | 32 ------------------- .../training/tracking/data_structures_test.py | 14 ++------ 2 files changed, 2 insertions(+), 44 deletions(-) diff --git a/tensorflow/python/training/tracking/data_structures.py b/tensorflow/python/training/tracking/data_structures.py index 73df6872c27..1695e44bad2 100644 --- a/tensorflow/python/training/tracking/data_structures.py +++ b/tensorflow/python/training/tracking/data_structures.py @@ -665,7 +665,6 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy): wrapt.ObjectProxy.__init__(self, wrapped_dict) TrackableDataStructure.__init__(self) self._self_non_string_key = False - self._self_non_append_mutation = False self._self_external_modification = False self.__wrapped__.update( {key: self._track_value( @@ -690,14 +689,12 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy): # pylint: disable=protected-access def __copy__(self): copied = _DictWrapper(copy.copy(self.__wrapped__)) - copied._self_non_append_mutation = self._self_non_append_mutation copied._self_external_modification = self._self_external_modification copied._self_non_string_key = self._self_non_string_key return copied def __deepcopy__(self, memo): copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo)) - copied._self_non_append_mutation = self._self_non_append_mutation copied._self_external_modification = self._self_external_modification copied._self_non_string_key = self._self_non_string_key return copied @@ -725,15 +722,6 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy): "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency " "object; it will be automatically un-wrapped and subsequently " "ignored." % (self,)) - if self._self_non_append_mutation: - raise ValueError( - "Unable to save the object %s (a dictionary wrapper constructed " - "automatically on attribute assignment). A key mapping to a " - "trackable object was overwritten or deleted, which would " - "cause problems for restoration.\n\nIf you don't need this " - "dictionary checkpointed, wrap it in a " - "tf.contrib.checkpoint.NoDependency object; it will be automatically " - "un-wrapped and subsequently ignored." % (self,)) if self._self_external_modification: raise ValueError( "Unable to save the object %s (a dictionary wrapper constructed " @@ -752,7 +740,6 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy): def _dirty(self): """Check if there has already been a mutation which prevents saving.""" return (self._self_external_modification - or self._self_non_append_mutation or self._self_non_string_key) def _check_self_external_modification(self): @@ -800,39 +787,20 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy): self._maybe_initialize_trackable() no_dep = isinstance(value, NoDependency) if isinstance(key, six.string_types): - existing_dependency = self._lookup_dependency(key) value = self._track_value(value, name=key) else: value = _wrap_or_unwrap(value) - existing_dependency = None if not no_dep and isinstance(value, base.Trackable): # Non-string keys are OK as long as we have no reason to add a # dependency on the value (either because the value is not # trackable, or because it was wrapped in a NoDependency object). self._self_non_string_key = True - if key in self.__wrapped__: - previous_value = self.__wrapped__[key] - if previous_value is not value: - if ((not no_dep and isinstance(value, base.Trackable)) - # We don't want to just check that the existing object is - # trackable, since it may have been wrapped in a NoDependency - # object. - or existing_dependency is not None): - # A trackable object was replaced under the same key; this means - # that restoring would be error-prone, so we'll throw an exception on - # save. - self._self_non_append_mutation = True self.__wrapped__[key] = value self._update_snapshot() def __delitem__(self, key): self._check_self_external_modification() - existing_value = self[key] - if isinstance(existing_value, base.Trackable): - # Deleting tracked trackable values means restoring is problematic, - # so we'll throw an exception on save. - self._self_non_append_mutation = True del self.__wrapped__[key] self._update_snapshot() diff --git a/tensorflow/python/training/tracking/data_structures_test.py b/tensorflow/python/training/tracking/data_structures_test.py index 2746c40e8da..42d75df460d 100644 --- a/tensorflow/python/training/tracking/data_structures_test.py +++ b/tensorflow/python/training/tracking/data_structures_test.py @@ -663,15 +663,6 @@ class MappingTests(test.TestCase): model.save_weights(save_path) model.load_weights(save_path) - def testDelNoSave(self): - model = training.Model() - model.d = {} - model.d["a"] = [] - del model.d["a"] - save_path = os.path.join(self.get_temp_dir(), "ckpt") - with self.assertRaisesRegexp(ValueError, "overwritten or deleted"): - model.save_weights(save_path) - def testPopNoSave(self): model = training.Model() model.d = {} @@ -690,14 +681,13 @@ class MappingTests(test.TestCase): with self.assertRaisesRegexp(ValueError, "modified outside the wrapper"): model.save_weights(save_path) - def testOverwriteNoSave(self): + def testOverwriteCanStillSave(self): model = training.Model() model.d = {} model.d["a"] = {} model.d["a"] = {} save_path = os.path.join(self.get_temp_dir(), "ckpt") - with self.assertRaisesRegexp(ValueError, "overwritten or deleted"): - model.save_weights(save_path) + model.save_weights(save_path) def testIter(self): model = training.Model()