Add tf.train.Checkpoint for reading and writing object-based checkpoints.

Previously exposed as tf.contrib.eager.Checkpoint / tfe.Checkpoint.

Spiffies up the documentation a bit, but otherwise just adds the export decorator.

Compatible in both directions with tf.train.Saver (object-based checkpoints can be fed to tf.train.Saver, and name-based checkpoints can be fed to tf.train.Checkpoint).

PiperOrigin-RevId: 193439442
This commit is contained in:
Allen Lavoie 2018-04-18 16:48:17 -07:00 committed by TensorFlower Gardener
parent fddfa9f8dc
commit 5ec3b021fd
5 changed files with 202 additions and 21 deletions

View File

@ -38,6 +38,7 @@ from tensorflow.python.training import checkpointable as checkpointable_lib
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names.
@ -822,30 +823,92 @@ class CheckpointableSaver(object):
return load_status
@tf_export("train.Checkpoint")
class Checkpoint(checkpointable_lib.Checkpointable):
"""A utility class which groups `Checkpointable` objects.
"""Groups checkpointable objects, saving and restoring them.
Accepts arbitrary keyword arguments to its constructor and saves those values
with a checkpoint. Maintains a `save_counter` for numbering checkpoints.
`Checkpoint`'s constructor accepts keyword arguments whose values are types
that contain checkpointable state, such as `tf.train.Optimizer`
implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
`tf.keras.Model` implementations. It saves these values with a checkpoint, and
maintains a `save_counter` for numbering checkpoints.
Example usage:
Example usage when graph building:
```python
import tensorflow as tf
import tensorflow.contrib.eager as tfe
import os
checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
root = tfe.Checkpoint(optimizer=optimizer, model=model)
root.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
optimizer.minimize( ... )
root.save(file_prefix=checkpoint_prefix)
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
train_op = optimizer.minimize( ... )
status.assert_consumed() # Optional sanity checks.
with tf.Session() as session:
# Use the Session to restore variables, or initialize them if
# tf.train.latest_checkpoint returned None.
status.initialize_or_restore(session)
for _ in range(num_training_steps):
session.run(train_op)
checkpoint.save(file_prefix=checkpoint_prefix)
```
For more manual control over saving, use `tfe.CheckpointableSaver` directly.
Example usage with eager execution enabled:
```python
import tensorflow as tf
import os
tf.enable_eager_execution()
checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
optimizer.minimize( ... ) # Variables will be restored on creation.
status.assert_consumed() # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)
```
`Checkpoint.save` and `Checkpoint.restore` write and read object-based
checkpoints, in contrast to `tf.train.Saver` which writes and reads
`variable.name` based checkpoints. Object-based checkpointing saves a graph of
dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s,
etc.) with named edges, and this graph is used to match variables when
restoring a checkpoint. It can be more robust to changes in the Python
program, and helps to support restore-on-create for variables when executing
eagerly. Prefer `tf.train.Checkpoint` over `tf.train.Saver` for new code.
`Checkpoint` objects have dependencies on the objects passed as keyword
arguments to their constructors, and each dependency is given a name that is
identical to the name of the keyword argument for which it was created.
TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
dependencies on their variables (e.g. "kernel" and "bias" for
`tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing
dependencies easy in user-defined classes, since `Model` hooks into attribute
assignment. For example:
```python
class Regress(tf.keras.Model):
def __init__(self):
super(Regress, self).__init__()
self.input_transform = tf.keras.layers.Dense(10)
# ...
def call(self, inputs):
x = self.input_transform(inputs)
# ...
```
This `Model` has a dependency named "input_transform" on its `Dense` layer,
which in turn depends on its variables. As a result, saving an instance of
`Regress` using `tf.train.Checkpoint` will also save all the variables created
by the `Dense` layer.
Attributes:
save_counter: Incremented when `save()` is called. Used to number
@ -857,17 +920,19 @@ class Checkpoint(checkpointable_lib.Checkpointable):
Args:
**kwargs: Keyword arguments are set as attributes of this object, and are
saved with the checkpoint. Attribute values must derive from
`CheckpointableBase`.
saved with the checkpoint. Values must be checkpointable objects.
Raises:
ValueError: If objects in `kwargs` are not Checkpointable.
ValueError: If objects in `kwargs` are not checkpointable.
"""
super(Checkpoint, self).__init__()
for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
if not isinstance(v, checkpointable_lib.CheckpointableBase):
raise ValueError(
("`Checkpoint` was expecting an object derived from "
"`CheckpointableBase`, got %s.") % (v,))
("`Checkpoint` was expecting a checkpointable object (an object "
"derived from `CheckpointableBase`), got %s. If you believe this "
"object should be checkpointable (i.e. it is part of the "
"TensorFlow Python API and manages state), please open an issue.")
% (v,))
setattr(self, k, v)
self._save_counter = None # Created lazily for restore-on-create.
self._saver = CheckpointableSaver(weakref.ref(self))
@ -893,7 +958,23 @@ class Checkpoint(checkpointable_lib.Checkpointable):
return self._save_counter
def save(self, file_prefix, session=None):
"""Save a checkpoint. Wraps `tfe.CheckpointableSaver.save`."""
"""Save a training checkpoint.
The saved checkpoint includes variables created by this object and any
checkpointable objects it depends on at the time `Checkpoint.save()` is
called.
Args:
file_prefix: A prefix to use for the checkpoint filenames
(/path/to/directory/and_a_prefix). Names are generated based on this
prefix and `Checkpoint.save_counter`.
session: The session to evaluate variables in. Ignored when executing
eagerly. If not provided when graph building, the default session is
used.
Returns:
The full path to the checkpoint.
"""
in_graph_mode = not context.executing_eagerly()
if in_graph_mode:
if session is None:
@ -913,7 +994,81 @@ class Checkpoint(checkpointable_lib.Checkpointable):
session=session)
def restore(self, save_path):
"""Restore a checkpoint. Wraps `tfe.CheckpointableSaver.restore`."""
"""Restore a training checkpoint.
Restores this `Checkpoint` and any objects it depends on.
When executing eagerly, either assigns values immediately if variables to
restore have been created already, or defers restoration until the variables
are created. Dependencies added after this call will be matched if they have
a corresponding object in the checkpoint (the restore request will queue in
any checkpointable object waiting for the expected dependency to be added).
When graph building, restoration ops are added to the graph but not run
immediately.
To ensure that loading is complete and no more assignments will take place,
use the `assert_consumed()` method of the status object returned by
`restore`:
```python
checkpoint = tf.train.Checkpoint( ... )
checkpoint.restore(path).assert_consumed()
```
An exception will be raised if any Python objects in the dependency graph
were not found in the checkpoint, or if any checkpointed values do not have
a matching Python object.
When graph building, `assert_consumed()` indicates that all of the restore
ops that will be created for this checkpoint have been created. They can be
run via the `run_restore_ops()` method of the status object:
```python
checkpoint.restore(path).assert_consumed().run_restore_ops()
```
If the checkpoint has not been consumed completely, then the list of restore
ops will grow as more objects are added to the dependency graph.
Name-based `tf.train.Saver` checkpoints can be loaded using this
method. There is no deferred loading, and names are used to match
variables. No restore ops are created/run until `run_restore_ops()` or
`initialize_or_restore()` are called on the returned status object, even
when executing eagerly. Re-encode name-based checkpoints using
`tf.train.Checkpoint.save` as soon as possible.
Args:
save_path: The path to the checkpoint, as returned by `save` or
`tf.train.latest_checkpoint`. If None (as when there is no latest
checkpoint for `tf.train.latest_checkpoint` to return), returns an
object which may run initializers for objects in the dependency
graph. If the checkpoint was written by the name-based `tf.train.Saver`,
names are used to match variables.
Returns:
A load status object, which can be used to make assertions about the
status of a checkpoint restoration and run initialization/restore ops.
The returned status object has the following methods:
- `assert_consumed()`:
Raises an exception if any variables/objects are unmatched: either
checkpointed values which don't have a matching Python object or
Python objects in the dependency graph with no values in the
checkpoint. This method returns the status object, and so may be
chained with `initialize_or_restore` or `run_restore_ops`.
- `initialize_or_restore(session=None)`:
When graph building, runs variable initializers if `save_path` is
`None`, but otherwise runs restore operations. If no `session` is
explicitly specified, the default session is used. No effect for
object-based checkpoints when executing eagerly (variables are
initialized or restored eagerly).
- `run_restore_ops(session=None)`:
When graph building, runs restore operations. If no `session` is
explicitly specified, the default session is used. No effect for
object-based checkpoints when executing eagerly (restore operations
are run eagerly). May only be called when `save_path` is not `None`.
"""
status = self._saver.restore(save_path=save_path)
# Create the save counter now so it gets initialized with other variables
# when graph building. Creating it earlier would lead to double

View File

@ -1824,12 +1824,10 @@ class Saver(object):
# This is an object-based checkpoint. We'll print a warning and then do
# the restore.
logging.warning(
# TODO(allenl): Modify instructions for using the object-based saver
# once that's in core.
"Restoring an object-based checkpoint using a name-based saver. This "
"may be somewhat fragile, and will re-build the Saver. Instead, "
"consider loading object-based checkpoints using "
"tf.contrib.eager.Checkpoint().")
"tf.train.Checkpoint().")
self._restore_from_object_based_checkpoint(
sess=sess, save_path=save_path,
object_graph_string=object_graph_string)

View File

@ -156,6 +156,7 @@ from tensorflow.python.training.basic_session_run_hooks import FinalOpsHook
from tensorflow.python.training.basic_session_run_hooks import FeedFnHook
from tensorflow.python.training.basic_session_run_hooks import ProfilerHook
from tensorflow.python.training.basic_loops import basic_train_loop
from tensorflow.python.training.checkpointable_utils import Checkpoint
from tensorflow.python.training.checkpoint_utils import init_from_checkpoint
from tensorflow.python.training.checkpoint_utils import list_variables
from tensorflow.python.training.checkpoint_utils import load_checkpoint

View File

@ -0,0 +1,23 @@
path: "tensorflow.train.Checkpoint"
tf_class {
is_instance: "<class \'tensorflow.python.training.checkpointable_utils.Checkpoint\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.Checkpointable\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "save_counter"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "restore"
argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "save"
argspec: "args=[\'self\', \'file_prefix\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -20,6 +20,10 @@ tf_module {
name: "BytesList"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
member {
name: "Checkpoint"
mtype: "<type \'type\'>"
}
member {
name: "CheckpointSaverHook"
mtype: "<type \'type\'>"