diff --git a/tensorflow/python/training/tracking/util.py b/tensorflow/python/training/tracking/util.py index e4138864bd3..eeaa2a541c5 100644 --- a/tensorflow/python/training/tracking/util.py +++ b/tensorflow/python/training/tracking/util.py @@ -1847,6 +1847,8 @@ class Checkpoint(tracking.AutoTrackable): use by higher level checkpoint management utilities. `save` provides a very basic implementation of these features. + Checkpoints written with `write` must be read with `read`. + Args: file_prefix: A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix). @@ -1888,7 +1890,7 @@ class Checkpoint(tracking.AutoTrackable): sequentially numbering checkpoints using `save_counter` and updating the metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint management, for example garbage collection and custom numbering, may be - provided by other utilities which also wrap `write` + provided by other utilities which also wrap `write` and `read`. (`tf.train.CheckpointManager` for example). Args: @@ -1932,20 +1934,58 @@ class Checkpoint(tracking.AutoTrackable): save_relative_paths=True) return file_path + def read(self, save_path): + """Read a training checkpoint written with `write`. + + Reads this `Checkpoint` and any objects it depends on. + + This method is just like `restore()` but does not expect the `save_counter` + variable in the checkpoint. It only restores the objects that the checkpoint + already depends on. + + The method is primarily intended for use by higher level checkpoint + management utilities that use `write()` instead of `save()` and have their + own mechanisms to number and track checkpoints. + + Example usage: + + ```python + # Create a checkpoint with write() + ckpt = tf.train.Checkpoint(v=tf.Variable(1.)) + path = ckpt.write('/tmp/my_checkpoint') + + # Later, load the checkpoint with read() + # With restore() assert_consumed() would have failed. + checkpoint.read(path).assert_consumed() + ``` + + Args: + save_path: The path to the checkpoint as returned by `write`. + + Returns: + A load status object, which can be used to make assertions about the + status of a checkpoint restoration. See `restore` for details. + """ + return self._saver.restore(save_path=save_path) + def restore(self, save_path): """Restore a training checkpoint. Restores this `Checkpoint` and any objects it depends on. - Either assigns values immediately if variables to restore have been created - already, or defers restoration until the variables are created. Dependencies - added after this call will be matched if they have a corresponding object in - the checkpoint (the restore request will queue in any trackable object - waiting for the expected dependency to be added). + This method is intended to be used to load checkpoints created by `save()`. + For checkpoints created by `write()` use the `read()` method which does not + expect the `save_counter` variable added by `save()`. + + `restore()` either assigns values immediately if variables to restore have + been created already, or defers restoration until the variables are + created. Dependencies added after this call will be matched if they have a + corresponding object in the checkpoint (the restore request will queue in + any trackable object waiting for the expected dependency to be added). To ensure that loading is complete and no more assignments will take place, use the `assert_consumed()` method of the status object returned by - `restore`: + `restore()`: ```python checkpoint = tf.train.Checkpoint( ... ) @@ -2006,7 +2046,7 @@ class Checkpoint(tracking.AutoTrackable): checkpoint file or object when the `Checkpoint` object is deleted (often at program shutdown). """ - status = self._saver.restore(save_path=save_path) + status = self.read(save_path) # Create the save counter now so it gets initialized with other variables # when graph building. Creating it earlier would lead to errors when using, # say, train.Saver() to save the model before initializing it. diff --git a/tensorflow/python/training/tracking/util_test.py b/tensorflow/python/training/tracking/util_test.py index 6e57d690726..e63baa60003 100644 --- a/tensorflow/python/training/tracking/util_test.py +++ b/tensorflow/python/training/tracking/util_test.py @@ -1376,8 +1376,7 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase): @test_util.run_in_graph_and_eager_modes def test_write_checkpoint_from_function(self): checkpoint_prefix = os.path.join(self.get_temp_dir(), "ckpt") - save_checkpoint = trackable_utils.Checkpoint( - v=variables_lib.Variable(1.)) + save_checkpoint = trackable_utils.Checkpoint(v=variables_lib.Variable(1.)) @def_function.function def _write_checkpoint(): @@ -1386,14 +1385,21 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase): self.evaluate([save_checkpoint.v.initializer]) self.evaluate(_write_checkpoint()) - load_checkpoint = trackable_utils.Checkpoint( - v=variables_lib.Variable(0.)) - load_checkpoint.restore(checkpoint_prefix).run_restore_ops() + load_checkpoint = trackable_utils.Checkpoint(v=variables_lib.Variable(0.)) + # Use read() instead of restore() which allows us to check that all + # existing objects were loaded. + status = load_checkpoint.read(checkpoint_prefix) + status.assert_existing_objects_matched() + status.assert_consumed() + status.run_restore_ops() self.assertEqual(1., self.evaluate(load_checkpoint.v)) self.evaluate(save_checkpoint.v.assign(3.)) self.evaluate(_write_checkpoint()) self.evaluate(save_checkpoint.v.assign(0.)) - load_checkpoint.restore(checkpoint_prefix).run_restore_ops() + status = load_checkpoint.read(checkpoint_prefix) + status.assert_existing_objects_matched() + status.assert_consumed() + status.run_restore_ops() self.assertEqual(3., self.evaluate(load_checkpoint.v)) def test_inititialize_with_data_structures(self): diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt index deb93d7adca..d7e93a0f937 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt @@ -12,6 +12,10 @@ tf_class { name: "__init__" argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" } + member_method { + name: "read" + argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "restore" argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None"