Add Checkpoint.read() that just loads a checkpoint without adding a save_counter.

This is symmetrical to write(). Loading a checkpoint written with write() with
read() instead of restore() allows to reliably call
assert_existing_objects_matched() to check that all objects where read from the
checkpoint.

PiperOrigin-RevId: 301679428
Change-Id: I4acf3c2b7eb63ad25bb4db163bfca365e18bea6f
This commit is contained in:
A. Unique TensorFlower 2020-03-18 15:15:03 -07:00 committed by TensorFlower Gardener
parent bfafc1acef
commit 4f315c18bc
3 changed files with 64 additions and 14 deletions

View File

@ -1847,6 +1847,8 @@ class Checkpoint(tracking.AutoTrackable):
use by higher level checkpoint management utilities. `save` provides a very
basic implementation of these features.
Checkpoints written with `write` must be read with `read`.
Args:
file_prefix: A prefix to use for the checkpoint filenames
(/path/to/directory/and_a_prefix).
@ -1888,7 +1890,7 @@ class Checkpoint(tracking.AutoTrackable):
sequentially numbering checkpoints using `save_counter` and updating the
metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
management, for example garbage collection and custom numbering, may be
provided by other utilities which also wrap `write`
provided by other utilities which also wrap `write` and `read`.
(`tf.train.CheckpointManager` for example).
Args:
@ -1932,20 +1934,58 @@ class Checkpoint(tracking.AutoTrackable):
save_relative_paths=True)
return file_path
def read(self, save_path):
"""Read a training checkpoint written with `write`.
Reads this `Checkpoint` and any objects it depends on.
This method is just like `restore()` but does not expect the `save_counter`
variable in the checkpoint. It only restores the objects that the checkpoint
already depends on.
The method is primarily intended for use by higher level checkpoint
management utilities that use `write()` instead of `save()` and have their
own mechanisms to number and track checkpoints.
Example usage:
```python
# Create a checkpoint with write()
ckpt = tf.train.Checkpoint(v=tf.Variable(1.))
path = ckpt.write('/tmp/my_checkpoint')
# Later, load the checkpoint with read()
# With restore() assert_consumed() would have failed.
checkpoint.read(path).assert_consumed()
```
Args:
save_path: The path to the checkpoint as returned by `write`.
Returns:
A load status object, which can be used to make assertions about the
status of a checkpoint restoration. See `restore` for details.
"""
return self._saver.restore(save_path=save_path)
def restore(self, save_path):
"""Restore a training checkpoint.
Restores this `Checkpoint` and any objects it depends on.
Either assigns values immediately if variables to restore have been created
already, or defers restoration until the variables are created. Dependencies
added after this call will be matched if they have a corresponding object in
the checkpoint (the restore request will queue in any trackable object
waiting for the expected dependency to be added).
This method is intended to be used to load checkpoints created by `save()`.
For checkpoints created by `write()` use the `read()` method which does not
expect the `save_counter` variable added by `save()`.
`restore()` either assigns values immediately if variables to restore have
been created already, or defers restoration until the variables are
created. Dependencies added after this call will be matched if they have a
corresponding object in the checkpoint (the restore request will queue in
any trackable object waiting for the expected dependency to be added).
To ensure that loading is complete and no more assignments will take place,
use the `assert_consumed()` method of the status object returned by
`restore`:
`restore()`:
```python
checkpoint = tf.train.Checkpoint( ... )
@ -2006,7 +2046,7 @@ class Checkpoint(tracking.AutoTrackable):
checkpoint file or object when the `Checkpoint` object is deleted
(often at program shutdown).
"""
status = self._saver.restore(save_path=save_path)
status = self.read(save_path)
# Create the save counter now so it gets initialized with other variables
# when graph building. Creating it earlier would lead to errors when using,
# say, train.Saver() to save the model before initializing it.

View File

@ -1376,8 +1376,7 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_write_checkpoint_from_function(self):
checkpoint_prefix = os.path.join(self.get_temp_dir(), "ckpt")
save_checkpoint = trackable_utils.Checkpoint(
v=variables_lib.Variable(1.))
save_checkpoint = trackable_utils.Checkpoint(v=variables_lib.Variable(1.))
@def_function.function
def _write_checkpoint():
@ -1386,14 +1385,21 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
self.evaluate([save_checkpoint.v.initializer])
self.evaluate(_write_checkpoint())
load_checkpoint = trackable_utils.Checkpoint(
v=variables_lib.Variable(0.))
load_checkpoint.restore(checkpoint_prefix).run_restore_ops()
load_checkpoint = trackable_utils.Checkpoint(v=variables_lib.Variable(0.))
# Use read() instead of restore() which allows us to check that all
# existing objects were loaded.
status = load_checkpoint.read(checkpoint_prefix)
status.assert_existing_objects_matched()
status.assert_consumed()
status.run_restore_ops()
self.assertEqual(1., self.evaluate(load_checkpoint.v))
self.evaluate(save_checkpoint.v.assign(3.))
self.evaluate(_write_checkpoint())
self.evaluate(save_checkpoint.v.assign(0.))
load_checkpoint.restore(checkpoint_prefix).run_restore_ops()
status = load_checkpoint.read(checkpoint_prefix)
status.assert_existing_objects_matched()
status.assert_consumed()
status.run_restore_ops()
self.assertEqual(3., self.evaluate(load_checkpoint.v))
def test_inititialize_with_data_structures(self):

View File

@ -12,6 +12,10 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "read"
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "restore"
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None"