Add tf.train.Checkpoint for reading and writing object-based checkpoints.
Previously exposed as tf.contrib.eager.Checkpoint / tfe.Checkpoint. Spiffies up the documentation a bit, but otherwise just adds the export decorator. Compatible in both directions with tf.train.Saver (object-based checkpoints can be fed to tf.train.Saver, and name-based checkpoints can be fed to tf.train.Checkpoint). PiperOrigin-RevId: 193439442
This commit is contained in:
parent
fddfa9f8dc
commit
5ec3b021fd
@ -38,6 +38,7 @@ from tensorflow.python.training import checkpointable as checkpointable_lib
|
||||
from tensorflow.python.training import optimizer as optimizer_lib
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names.
|
||||
@ -822,30 +823,92 @@ class CheckpointableSaver(object):
|
||||
return load_status
|
||||
|
||||
|
||||
@tf_export("train.Checkpoint")
|
||||
class Checkpoint(checkpointable_lib.Checkpointable):
|
||||
"""A utility class which groups `Checkpointable` objects.
|
||||
"""Groups checkpointable objects, saving and restoring them.
|
||||
|
||||
Accepts arbitrary keyword arguments to its constructor and saves those values
|
||||
with a checkpoint. Maintains a `save_counter` for numbering checkpoints.
|
||||
`Checkpoint`'s constructor accepts keyword arguments whose values are types
|
||||
that contain checkpointable state, such as `tf.train.Optimizer`
|
||||
implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
|
||||
`tf.keras.Model` implementations. It saves these values with a checkpoint, and
|
||||
maintains a `save_counter` for numbering checkpoints.
|
||||
|
||||
Example usage:
|
||||
Example usage when graph building:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.eager as tfe
|
||||
import os
|
||||
|
||||
checkpoint_directory = "/tmp/training_checkpoints"
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
|
||||
root = tfe.Checkpoint(optimizer=optimizer, model=model)
|
||||
root.restore(tf.train.latest_checkpoint(checkpoint_directory))
|
||||
for _ in range(num_training_steps):
|
||||
optimizer.minimize( ... )
|
||||
root.save(file_prefix=checkpoint_prefix)
|
||||
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
|
||||
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
|
||||
train_op = optimizer.minimize( ... )
|
||||
status.assert_consumed() # Optional sanity checks.
|
||||
with tf.Session() as session:
|
||||
# Use the Session to restore variables, or initialize them if
|
||||
# tf.train.latest_checkpoint returned None.
|
||||
status.initialize_or_restore(session)
|
||||
for _ in range(num_training_steps):
|
||||
session.run(train_op)
|
||||
checkpoint.save(file_prefix=checkpoint_prefix)
|
||||
```
|
||||
|
||||
For more manual control over saving, use `tfe.CheckpointableSaver` directly.
|
||||
Example usage with eager execution enabled:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
import os
|
||||
|
||||
tf.enable_eager_execution()
|
||||
|
||||
checkpoint_directory = "/tmp/training_checkpoints"
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
|
||||
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
|
||||
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
|
||||
for _ in range(num_training_steps):
|
||||
optimizer.minimize( ... ) # Variables will be restored on creation.
|
||||
status.assert_consumed() # Optional sanity checks.
|
||||
checkpoint.save(file_prefix=checkpoint_prefix)
|
||||
```
|
||||
|
||||
`Checkpoint.save` and `Checkpoint.restore` write and read object-based
|
||||
checkpoints, in contrast to `tf.train.Saver` which writes and reads
|
||||
`variable.name` based checkpoints. Object-based checkpointing saves a graph of
|
||||
dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s,
|
||||
etc.) with named edges, and this graph is used to match variables when
|
||||
restoring a checkpoint. It can be more robust to changes in the Python
|
||||
program, and helps to support restore-on-create for variables when executing
|
||||
eagerly. Prefer `tf.train.Checkpoint` over `tf.train.Saver` for new code.
|
||||
|
||||
`Checkpoint` objects have dependencies on the objects passed as keyword
|
||||
arguments to their constructors, and each dependency is given a name that is
|
||||
identical to the name of the keyword argument for which it was created.
|
||||
TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
|
||||
dependencies on their variables (e.g. "kernel" and "bias" for
|
||||
`tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing
|
||||
dependencies easy in user-defined classes, since `Model` hooks into attribute
|
||||
assignment. For example:
|
||||
|
||||
```python
|
||||
class Regress(tf.keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(Regress, self).__init__()
|
||||
self.input_transform = tf.keras.layers.Dense(10)
|
||||
# ...
|
||||
|
||||
def call(self, inputs):
|
||||
x = self.input_transform(inputs)
|
||||
# ...
|
||||
```
|
||||
|
||||
This `Model` has a dependency named "input_transform" on its `Dense` layer,
|
||||
which in turn depends on its variables. As a result, saving an instance of
|
||||
`Regress` using `tf.train.Checkpoint` will also save all the variables created
|
||||
by the `Dense` layer.
|
||||
|
||||
Attributes:
|
||||
save_counter: Incremented when `save()` is called. Used to number
|
||||
@ -857,17 +920,19 @@ class Checkpoint(checkpointable_lib.Checkpointable):
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments are set as attributes of this object, and are
|
||||
saved with the checkpoint. Attribute values must derive from
|
||||
`CheckpointableBase`.
|
||||
saved with the checkpoint. Values must be checkpointable objects.
|
||||
Raises:
|
||||
ValueError: If objects in `kwargs` are not Checkpointable.
|
||||
ValueError: If objects in `kwargs` are not checkpointable.
|
||||
"""
|
||||
super(Checkpoint, self).__init__()
|
||||
for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
|
||||
if not isinstance(v, checkpointable_lib.CheckpointableBase):
|
||||
raise ValueError(
|
||||
("`Checkpoint` was expecting an object derived from "
|
||||
"`CheckpointableBase`, got %s.") % (v,))
|
||||
("`Checkpoint` was expecting a checkpointable object (an object "
|
||||
"derived from `CheckpointableBase`), got %s. If you believe this "
|
||||
"object should be checkpointable (i.e. it is part of the "
|
||||
"TensorFlow Python API and manages state), please open an issue.")
|
||||
% (v,))
|
||||
setattr(self, k, v)
|
||||
self._save_counter = None # Created lazily for restore-on-create.
|
||||
self._saver = CheckpointableSaver(weakref.ref(self))
|
||||
@ -893,7 +958,23 @@ class Checkpoint(checkpointable_lib.Checkpointable):
|
||||
return self._save_counter
|
||||
|
||||
def save(self, file_prefix, session=None):
|
||||
"""Save a checkpoint. Wraps `tfe.CheckpointableSaver.save`."""
|
||||
"""Save a training checkpoint.
|
||||
|
||||
The saved checkpoint includes variables created by this object and any
|
||||
checkpointable objects it depends on at the time `Checkpoint.save()` is
|
||||
called.
|
||||
|
||||
Args:
|
||||
file_prefix: A prefix to use for the checkpoint filenames
|
||||
(/path/to/directory/and_a_prefix). Names are generated based on this
|
||||
prefix and `Checkpoint.save_counter`.
|
||||
session: The session to evaluate variables in. Ignored when executing
|
||||
eagerly. If not provided when graph building, the default session is
|
||||
used.
|
||||
|
||||
Returns:
|
||||
The full path to the checkpoint.
|
||||
"""
|
||||
in_graph_mode = not context.executing_eagerly()
|
||||
if in_graph_mode:
|
||||
if session is None:
|
||||
@ -913,7 +994,81 @@ class Checkpoint(checkpointable_lib.Checkpointable):
|
||||
session=session)
|
||||
|
||||
def restore(self, save_path):
|
||||
"""Restore a checkpoint. Wraps `tfe.CheckpointableSaver.restore`."""
|
||||
"""Restore a training checkpoint.
|
||||
|
||||
Restores this `Checkpoint` and any objects it depends on.
|
||||
|
||||
When executing eagerly, either assigns values immediately if variables to
|
||||
restore have been created already, or defers restoration until the variables
|
||||
are created. Dependencies added after this call will be matched if they have
|
||||
a corresponding object in the checkpoint (the restore request will queue in
|
||||
any checkpointable object waiting for the expected dependency to be added).
|
||||
|
||||
When graph building, restoration ops are added to the graph but not run
|
||||
immediately.
|
||||
|
||||
To ensure that loading is complete and no more assignments will take place,
|
||||
use the `assert_consumed()` method of the status object returned by
|
||||
`restore`:
|
||||
|
||||
```python
|
||||
checkpoint = tf.train.Checkpoint( ... )
|
||||
checkpoint.restore(path).assert_consumed()
|
||||
```
|
||||
|
||||
An exception will be raised if any Python objects in the dependency graph
|
||||
were not found in the checkpoint, or if any checkpointed values do not have
|
||||
a matching Python object.
|
||||
|
||||
When graph building, `assert_consumed()` indicates that all of the restore
|
||||
ops that will be created for this checkpoint have been created. They can be
|
||||
run via the `run_restore_ops()` method of the status object:
|
||||
|
||||
```python
|
||||
checkpoint.restore(path).assert_consumed().run_restore_ops()
|
||||
```
|
||||
|
||||
If the checkpoint has not been consumed completely, then the list of restore
|
||||
ops will grow as more objects are added to the dependency graph.
|
||||
|
||||
Name-based `tf.train.Saver` checkpoints can be loaded using this
|
||||
method. There is no deferred loading, and names are used to match
|
||||
variables. No restore ops are created/run until `run_restore_ops()` or
|
||||
`initialize_or_restore()` are called on the returned status object, even
|
||||
when executing eagerly. Re-encode name-based checkpoints using
|
||||
`tf.train.Checkpoint.save` as soon as possible.
|
||||
|
||||
Args:
|
||||
save_path: The path to the checkpoint, as returned by `save` or
|
||||
`tf.train.latest_checkpoint`. If None (as when there is no latest
|
||||
checkpoint for `tf.train.latest_checkpoint` to return), returns an
|
||||
object which may run initializers for objects in the dependency
|
||||
graph. If the checkpoint was written by the name-based `tf.train.Saver`,
|
||||
names are used to match variables.
|
||||
|
||||
Returns:
|
||||
A load status object, which can be used to make assertions about the
|
||||
status of a checkpoint restoration and run initialization/restore ops.
|
||||
|
||||
The returned status object has the following methods:
|
||||
- `assert_consumed()`:
|
||||
Raises an exception if any variables/objects are unmatched: either
|
||||
checkpointed values which don't have a matching Python object or
|
||||
Python objects in the dependency graph with no values in the
|
||||
checkpoint. This method returns the status object, and so may be
|
||||
chained with `initialize_or_restore` or `run_restore_ops`.
|
||||
- `initialize_or_restore(session=None)`:
|
||||
When graph building, runs variable initializers if `save_path` is
|
||||
`None`, but otherwise runs restore operations. If no `session` is
|
||||
explicitly specified, the default session is used. No effect for
|
||||
object-based checkpoints when executing eagerly (variables are
|
||||
initialized or restored eagerly).
|
||||
- `run_restore_ops(session=None)`:
|
||||
When graph building, runs restore operations. If no `session` is
|
||||
explicitly specified, the default session is used. No effect for
|
||||
object-based checkpoints when executing eagerly (restore operations
|
||||
are run eagerly). May only be called when `save_path` is not `None`.
|
||||
"""
|
||||
status = self._saver.restore(save_path=save_path)
|
||||
# Create the save counter now so it gets initialized with other variables
|
||||
# when graph building. Creating it earlier would lead to double
|
||||
|
@ -1824,12 +1824,10 @@ class Saver(object):
|
||||
# This is an object-based checkpoint. We'll print a warning and then do
|
||||
# the restore.
|
||||
logging.warning(
|
||||
# TODO(allenl): Modify instructions for using the object-based saver
|
||||
# once that's in core.
|
||||
"Restoring an object-based checkpoint using a name-based saver. This "
|
||||
"may be somewhat fragile, and will re-build the Saver. Instead, "
|
||||
"consider loading object-based checkpoints using "
|
||||
"tf.contrib.eager.Checkpoint().")
|
||||
"tf.train.Checkpoint().")
|
||||
self._restore_from_object_based_checkpoint(
|
||||
sess=sess, save_path=save_path,
|
||||
object_graph_string=object_graph_string)
|
||||
|
@ -156,6 +156,7 @@ from tensorflow.python.training.basic_session_run_hooks import FinalOpsHook
|
||||
from tensorflow.python.training.basic_session_run_hooks import FeedFnHook
|
||||
from tensorflow.python.training.basic_session_run_hooks import ProfilerHook
|
||||
from tensorflow.python.training.basic_loops import basic_train_loop
|
||||
from tensorflow.python.training.checkpointable_utils import Checkpoint
|
||||
from tensorflow.python.training.checkpoint_utils import init_from_checkpoint
|
||||
from tensorflow.python.training.checkpoint_utils import list_variables
|
||||
from tensorflow.python.training.checkpoint_utils import load_checkpoint
|
||||
|
@ -0,0 +1,23 @@
|
||||
path: "tensorflow.train.Checkpoint"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.checkpointable_utils.Checkpoint\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.checkpointable.Checkpointable\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.checkpointable.CheckpointableBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "save_counter"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "restore"
|
||||
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'file_prefix\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
}
|
@ -20,6 +20,10 @@ tf_module {
|
||||
name: "BytesList"
|
||||
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
|
||||
}
|
||||
member {
|
||||
name: "Checkpoint"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "CheckpointSaverHook"
|
||||
mtype: "<type \'type\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user