Add Checkpoint.read() that just loads a checkpoint without adding a save_counter
.
This is symmetrical to write(). Loading a checkpoint written with write() with read() instead of restore() allows to reliably call assert_existing_objects_matched() to check that all objects where read from the checkpoint. PiperOrigin-RevId: 301679428 Change-Id: I4acf3c2b7eb63ad25bb4db163bfca365e18bea6f
This commit is contained in:
parent
bfafc1acef
commit
4f315c18bc
@ -1847,6 +1847,8 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
use by higher level checkpoint management utilities. `save` provides a very
|
||||
basic implementation of these features.
|
||||
|
||||
Checkpoints written with `write` must be read with `read`.
|
||||
|
||||
Args:
|
||||
file_prefix: A prefix to use for the checkpoint filenames
|
||||
(/path/to/directory/and_a_prefix).
|
||||
@ -1888,7 +1890,7 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
sequentially numbering checkpoints using `save_counter` and updating the
|
||||
metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
|
||||
management, for example garbage collection and custom numbering, may be
|
||||
provided by other utilities which also wrap `write`
|
||||
provided by other utilities which also wrap `write` and `read`.
|
||||
(`tf.train.CheckpointManager` for example).
|
||||
|
||||
Args:
|
||||
@ -1932,20 +1934,58 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
save_relative_paths=True)
|
||||
return file_path
|
||||
|
||||
def read(self, save_path):
|
||||
"""Read a training checkpoint written with `write`.
|
||||
|
||||
Reads this `Checkpoint` and any objects it depends on.
|
||||
|
||||
This method is just like `restore()` but does not expect the `save_counter`
|
||||
variable in the checkpoint. It only restores the objects that the checkpoint
|
||||
already depends on.
|
||||
|
||||
The method is primarily intended for use by higher level checkpoint
|
||||
management utilities that use `write()` instead of `save()` and have their
|
||||
own mechanisms to number and track checkpoints.
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
# Create a checkpoint with write()
|
||||
ckpt = tf.train.Checkpoint(v=tf.Variable(1.))
|
||||
path = ckpt.write('/tmp/my_checkpoint')
|
||||
|
||||
# Later, load the checkpoint with read()
|
||||
# With restore() assert_consumed() would have failed.
|
||||
checkpoint.read(path).assert_consumed()
|
||||
```
|
||||
|
||||
Args:
|
||||
save_path: The path to the checkpoint as returned by `write`.
|
||||
|
||||
Returns:
|
||||
A load status object, which can be used to make assertions about the
|
||||
status of a checkpoint restoration. See `restore` for details.
|
||||
"""
|
||||
return self._saver.restore(save_path=save_path)
|
||||
|
||||
def restore(self, save_path):
|
||||
"""Restore a training checkpoint.
|
||||
|
||||
Restores this `Checkpoint` and any objects it depends on.
|
||||
|
||||
Either assigns values immediately if variables to restore have been created
|
||||
already, or defers restoration until the variables are created. Dependencies
|
||||
added after this call will be matched if they have a corresponding object in
|
||||
the checkpoint (the restore request will queue in any trackable object
|
||||
waiting for the expected dependency to be added).
|
||||
This method is intended to be used to load checkpoints created by `save()`.
|
||||
For checkpoints created by `write()` use the `read()` method which does not
|
||||
expect the `save_counter` variable added by `save()`.
|
||||
|
||||
`restore()` either assigns values immediately if variables to restore have
|
||||
been created already, or defers restoration until the variables are
|
||||
created. Dependencies added after this call will be matched if they have a
|
||||
corresponding object in the checkpoint (the restore request will queue in
|
||||
any trackable object waiting for the expected dependency to be added).
|
||||
|
||||
To ensure that loading is complete and no more assignments will take place,
|
||||
use the `assert_consumed()` method of the status object returned by
|
||||
`restore`:
|
||||
`restore()`:
|
||||
|
||||
```python
|
||||
checkpoint = tf.train.Checkpoint( ... )
|
||||
@ -2006,7 +2046,7 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
checkpoint file or object when the `Checkpoint` object is deleted
|
||||
(often at program shutdown).
|
||||
"""
|
||||
status = self._saver.restore(save_path=save_path)
|
||||
status = self.read(save_path)
|
||||
# Create the save counter now so it gets initialized with other variables
|
||||
# when graph building. Creating it earlier would lead to errors when using,
|
||||
# say, train.Saver() to save the model before initializing it.
|
||||
|
@ -1376,8 +1376,7 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_write_checkpoint_from_function(self):
|
||||
checkpoint_prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
save_checkpoint = trackable_utils.Checkpoint(
|
||||
v=variables_lib.Variable(1.))
|
||||
save_checkpoint = trackable_utils.Checkpoint(v=variables_lib.Variable(1.))
|
||||
|
||||
@def_function.function
|
||||
def _write_checkpoint():
|
||||
@ -1386,14 +1385,21 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
|
||||
self.evaluate([save_checkpoint.v.initializer])
|
||||
self.evaluate(_write_checkpoint())
|
||||
load_checkpoint = trackable_utils.Checkpoint(
|
||||
v=variables_lib.Variable(0.))
|
||||
load_checkpoint.restore(checkpoint_prefix).run_restore_ops()
|
||||
load_checkpoint = trackable_utils.Checkpoint(v=variables_lib.Variable(0.))
|
||||
# Use read() instead of restore() which allows us to check that all
|
||||
# existing objects were loaded.
|
||||
status = load_checkpoint.read(checkpoint_prefix)
|
||||
status.assert_existing_objects_matched()
|
||||
status.assert_consumed()
|
||||
status.run_restore_ops()
|
||||
self.assertEqual(1., self.evaluate(load_checkpoint.v))
|
||||
self.evaluate(save_checkpoint.v.assign(3.))
|
||||
self.evaluate(_write_checkpoint())
|
||||
self.evaluate(save_checkpoint.v.assign(0.))
|
||||
load_checkpoint.restore(checkpoint_prefix).run_restore_ops()
|
||||
status = load_checkpoint.read(checkpoint_prefix)
|
||||
status.assert_existing_objects_matched()
|
||||
status.assert_consumed()
|
||||
status.run_restore_ops()
|
||||
self.assertEqual(3., self.evaluate(load_checkpoint.v))
|
||||
|
||||
def test_inititialize_with_data_structures(self):
|
||||
|
@ -12,6 +12,10 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "read"
|
||||
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "restore"
|
||||
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
x
Reference in New Issue
Block a user