Add Checkpoint.read() that just loads a checkpoint without adding a save_counter
.
This is symmetrical to write(). Loading a checkpoint written with write() with read() instead of restore() allows to reliably call assert_existing_objects_matched() to check that all objects where read from the checkpoint. PiperOrigin-RevId: 301679428 Change-Id: I4acf3c2b7eb63ad25bb4db163bfca365e18bea6f
This commit is contained in:
parent
bfafc1acef
commit
4f315c18bc
@ -1847,6 +1847,8 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
use by higher level checkpoint management utilities. `save` provides a very
|
use by higher level checkpoint management utilities. `save` provides a very
|
||||||
basic implementation of these features.
|
basic implementation of these features.
|
||||||
|
|
||||||
|
Checkpoints written with `write` must be read with `read`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_prefix: A prefix to use for the checkpoint filenames
|
file_prefix: A prefix to use for the checkpoint filenames
|
||||||
(/path/to/directory/and_a_prefix).
|
(/path/to/directory/and_a_prefix).
|
||||||
@ -1888,7 +1890,7 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
sequentially numbering checkpoints using `save_counter` and updating the
|
sequentially numbering checkpoints using `save_counter` and updating the
|
||||||
metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
|
metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
|
||||||
management, for example garbage collection and custom numbering, may be
|
management, for example garbage collection and custom numbering, may be
|
||||||
provided by other utilities which also wrap `write`
|
provided by other utilities which also wrap `write` and `read`.
|
||||||
(`tf.train.CheckpointManager` for example).
|
(`tf.train.CheckpointManager` for example).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1932,20 +1934,58 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
save_relative_paths=True)
|
save_relative_paths=True)
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
def read(self, save_path):
|
||||||
|
"""Read a training checkpoint written with `write`.
|
||||||
|
|
||||||
|
Reads this `Checkpoint` and any objects it depends on.
|
||||||
|
|
||||||
|
This method is just like `restore()` but does not expect the `save_counter`
|
||||||
|
variable in the checkpoint. It only restores the objects that the checkpoint
|
||||||
|
already depends on.
|
||||||
|
|
||||||
|
The method is primarily intended for use by higher level checkpoint
|
||||||
|
management utilities that use `write()` instead of `save()` and have their
|
||||||
|
own mechanisms to number and track checkpoints.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create a checkpoint with write()
|
||||||
|
ckpt = tf.train.Checkpoint(v=tf.Variable(1.))
|
||||||
|
path = ckpt.write('/tmp/my_checkpoint')
|
||||||
|
|
||||||
|
# Later, load the checkpoint with read()
|
||||||
|
# With restore() assert_consumed() would have failed.
|
||||||
|
checkpoint.read(path).assert_consumed()
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_path: The path to the checkpoint as returned by `write`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A load status object, which can be used to make assertions about the
|
||||||
|
status of a checkpoint restoration. See `restore` for details.
|
||||||
|
"""
|
||||||
|
return self._saver.restore(save_path=save_path)
|
||||||
|
|
||||||
def restore(self, save_path):
|
def restore(self, save_path):
|
||||||
"""Restore a training checkpoint.
|
"""Restore a training checkpoint.
|
||||||
|
|
||||||
Restores this `Checkpoint` and any objects it depends on.
|
Restores this `Checkpoint` and any objects it depends on.
|
||||||
|
|
||||||
Either assigns values immediately if variables to restore have been created
|
This method is intended to be used to load checkpoints created by `save()`.
|
||||||
already, or defers restoration until the variables are created. Dependencies
|
For checkpoints created by `write()` use the `read()` method which does not
|
||||||
added after this call will be matched if they have a corresponding object in
|
expect the `save_counter` variable added by `save()`.
|
||||||
the checkpoint (the restore request will queue in any trackable object
|
|
||||||
waiting for the expected dependency to be added).
|
`restore()` either assigns values immediately if variables to restore have
|
||||||
|
been created already, or defers restoration until the variables are
|
||||||
|
created. Dependencies added after this call will be matched if they have a
|
||||||
|
corresponding object in the checkpoint (the restore request will queue in
|
||||||
|
any trackable object waiting for the expected dependency to be added).
|
||||||
|
|
||||||
To ensure that loading is complete and no more assignments will take place,
|
To ensure that loading is complete and no more assignments will take place,
|
||||||
use the `assert_consumed()` method of the status object returned by
|
use the `assert_consumed()` method of the status object returned by
|
||||||
`restore`:
|
`restore()`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
checkpoint = tf.train.Checkpoint( ... )
|
checkpoint = tf.train.Checkpoint( ... )
|
||||||
@ -2006,7 +2046,7 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
checkpoint file or object when the `Checkpoint` object is deleted
|
checkpoint file or object when the `Checkpoint` object is deleted
|
||||||
(often at program shutdown).
|
(often at program shutdown).
|
||||||
"""
|
"""
|
||||||
status = self._saver.restore(save_path=save_path)
|
status = self.read(save_path)
|
||||||
# Create the save counter now so it gets initialized with other variables
|
# Create the save counter now so it gets initialized with other variables
|
||||||
# when graph building. Creating it earlier would lead to errors when using,
|
# when graph building. Creating it earlier would lead to errors when using,
|
||||||
# say, train.Saver() to save the model before initializing it.
|
# say, train.Saver() to save the model before initializing it.
|
||||||
|
@ -1376,8 +1376,7 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
|||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_write_checkpoint_from_function(self):
|
def test_write_checkpoint_from_function(self):
|
||||||
checkpoint_prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
checkpoint_prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||||
save_checkpoint = trackable_utils.Checkpoint(
|
save_checkpoint = trackable_utils.Checkpoint(v=variables_lib.Variable(1.))
|
||||||
v=variables_lib.Variable(1.))
|
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def _write_checkpoint():
|
def _write_checkpoint():
|
||||||
@ -1386,14 +1385,21 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
|||||||
|
|
||||||
self.evaluate([save_checkpoint.v.initializer])
|
self.evaluate([save_checkpoint.v.initializer])
|
||||||
self.evaluate(_write_checkpoint())
|
self.evaluate(_write_checkpoint())
|
||||||
load_checkpoint = trackable_utils.Checkpoint(
|
load_checkpoint = trackable_utils.Checkpoint(v=variables_lib.Variable(0.))
|
||||||
v=variables_lib.Variable(0.))
|
# Use read() instead of restore() which allows us to check that all
|
||||||
load_checkpoint.restore(checkpoint_prefix).run_restore_ops()
|
# existing objects were loaded.
|
||||||
|
status = load_checkpoint.read(checkpoint_prefix)
|
||||||
|
status.assert_existing_objects_matched()
|
||||||
|
status.assert_consumed()
|
||||||
|
status.run_restore_ops()
|
||||||
self.assertEqual(1., self.evaluate(load_checkpoint.v))
|
self.assertEqual(1., self.evaluate(load_checkpoint.v))
|
||||||
self.evaluate(save_checkpoint.v.assign(3.))
|
self.evaluate(save_checkpoint.v.assign(3.))
|
||||||
self.evaluate(_write_checkpoint())
|
self.evaluate(_write_checkpoint())
|
||||||
self.evaluate(save_checkpoint.v.assign(0.))
|
self.evaluate(save_checkpoint.v.assign(0.))
|
||||||
load_checkpoint.restore(checkpoint_prefix).run_restore_ops()
|
status = load_checkpoint.read(checkpoint_prefix)
|
||||||
|
status.assert_existing_objects_matched()
|
||||||
|
status.assert_consumed()
|
||||||
|
status.run_restore_ops()
|
||||||
self.assertEqual(3., self.evaluate(load_checkpoint.v))
|
self.assertEqual(3., self.evaluate(load_checkpoint.v))
|
||||||
|
|
||||||
def test_inititialize_with_data_structures(self):
|
def test_inititialize_with_data_structures(self):
|
||||||
|
@ -12,6 +12,10 @@ tf_class {
|
|||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "read"
|
||||||
|
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "restore"
|
name: "restore"
|
||||||
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user