Add option to construct tf.train.Checkpoint with a root object.

PiperOrigin-RevId: 326701488
Change-Id: Ifc01382a513b8977793e68e801e2410cecface3a
This commit is contained in:
Katherine Wu 2020-08-14 11:59:51 -07:00 committed by TensorFlower Gardener
parent 38d3be9f51
commit 3bc9d4420e
7 changed files with 309 additions and 39 deletions
RELEASE.md
tensorflow

View File

@ -157,6 +157,14 @@
* <ADD RELEASE NOTES HERE>
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>
* `tf.train.Checkpoint`:
* Now accepts a `root` argument in the initialization, which generates a
checkpoint with a root object. This allows users to create a `Checkpoint`
object that is compatible with Keras `model.save_weights()` and
`model.load_weights`. The checkpoint is also compatible with the
checkpoint saved in the `variables/` folder in the SavedModel.
* When restoring, `save_path` can be a path to a SavedModel. The function
will automatically find the checkpoint in the SavedModel.
* Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see

View File

@ -262,6 +262,18 @@ def get_or_create_debug_dir(export_dir):
return debug_dir
def get_saved_model_pbtxt_path(export_dir):
return os.path.join(
compat.as_bytes(compat.path_to_str(export_dir)),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
def get_saved_model_pb_path(export_dir):
return os.path.join(
compat.as_bytes(compat.path_to_str(export_dir)),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
def get_debug_dir(export_dir):
"""Returns path to the debug sub-directory in the SavedModel."""
return os.path.join(

View File

@ -146,6 +146,7 @@ py_library(
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/saved_model:utils",
"//tensorflow/python/training/saving:checkpoint_options",
"//tensorflow/python/training/saving:functional_saver",
"//tensorflow/python/training/saving:saveable_object_util",
@ -184,6 +185,7 @@ tf_py_test(
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/saved_model:save",
"//tensorflow/python/training/saving:checkpoint_options",
"@absl_py//absl/testing:parameterized",
"@six_archive//:six",

View File

@ -142,7 +142,7 @@ def _serialize_slot_variables(trackable_objects, node_ids, object_names):
class ObjectGraphView(object):
"""Gathers and serializes an object graph."""
def __init__(self, root, saveables_cache=None):
def __init__(self, root, saveables_cache=None, attached_dependencies=None):
"""Configure the graph view.
Args:
@ -151,16 +151,24 @@ class ObjectGraphView(object):
saveables_cache: A dictionary mapping `Trackable` objects ->
attribute names -> SaveableObjects, used to avoid re-creating
SaveableObjects when graph building.
attached_dependencies: Dependencies to attach to the root object. Used
when saving a Checkpoint with a defined root object.
"""
self._root_ref = root
self._saveables_cache = saveables_cache
self._attached_dependencies = attached_dependencies
def list_dependencies(self, obj):
# pylint: disable=protected-access
obj._maybe_initialize_trackable()
return obj._checkpoint_dependencies
dependencies = obj._checkpoint_dependencies
# pylint: enable=protected-access
if obj is self.root and self._attached_dependencies:
dependencies = dependencies.copy()
dependencies.extend(self._attached_dependencies)
return dependencies
@property
def saveables_cache(self):
"""Maps Trackable objects -> attribute names -> list(SaveableObjects).
@ -173,6 +181,19 @@ class ObjectGraphView(object):
"""
return self._saveables_cache
@property
def attached_dependencies(self):
"""Returns list of dependencies that should be saved in the checkpoint.
These dependencies are not tracked by root, but are in the the checkpoint.
This is defined when the user creates a Checkpoint with both root and kwargs
set.
Returns:
A list of TrackableReferences.
"""
return self._attached_dependencies
@property
def root(self):
if isinstance(self._root_ref, weakref.ref):

View File

@ -40,7 +40,9 @@ from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import utils_impl
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training import saver as v1_saver_lib
@ -1325,6 +1327,30 @@ class TrackableSaver(object):
options=options)
base.CheckpointPosition(
checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
# Attached dependencies are not attached to the root, so should be restored
# separately.
if self._graph_view.attached_dependencies:
for ref in self._graph_view.attached_dependencies:
if ref.name == "root":
# Root dependency is automatically added to attached dependencies --
# this can be ignored since it maps back to the root object.
continue
proto_id = None
# Find proto ID of attached dependency (if it is in the proto).
for proto_ref in object_graph_proto.nodes[0].children:
if proto_ref.local_name == ref.name:
proto_id = proto_ref.node_id
break
if proto_id in checkpoint.object_by_proto_id:
# Object has already been restored. This can happen when there's an
# indirect connection from the attached object to the root.
continue
base.CheckpointPosition(
checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref)
load_status = CheckpointLoadStatus(
checkpoint,
graph_view=self._graph_view,
@ -1358,7 +1384,7 @@ def frozen_saver(root_trackable):
return functional_saver.MultiDeviceSaver(named_saveable_objects)
def saver_with_op_caching(obj):
def saver_with_op_caching(obj, attached_dependencies=None):
"""A TrackableSaver with a SaveableObject cache when graph building."""
if context.executing_eagerly():
saveables_cache = None
@ -1366,7 +1392,19 @@ def saver_with_op_caching(obj):
saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
return TrackableSaver(
graph_view_lib.ObjectGraphView(
weakref.ref(obj), saveables_cache=saveables_cache))
weakref.ref(obj), saveables_cache=saveables_cache,
attached_dependencies=attached_dependencies))
def _assert_trackable(obj):
if not isinstance(
obj, (base.Trackable, def_function.Function)):
raise ValueError(
"`Checkpoint` was expecting a trackable object (an object "
"derived from `TrackableBase`), got {}. If you believe this "
"object should be trackable (i.e. it is part of the "
"TensorFlow Python API and manages state), please open an issue."
.format(obj))
# Mentions graph building / Sessions. The v2 version is below.
@ -1737,15 +1775,32 @@ class CheckpointV1(tracking.AutoTrackable):
@tf_export("train.Checkpoint", v1=[])
class Checkpoint(tracking.AutoTrackable):
"""Groups trackable objects, saving and restoring them.
"""Manages saving/restoring trackable values to disk.
`Checkpoint`'s constructor accepts keyword arguments whose values are types
that contain trackable state, such as `tf.keras.optimizers.Optimizer`
implementations, `tf.Variable`s, `tf.data.Dataset` iterators, `tf.keras.Layer`
implementations, or `tf.keras.Model` implementations. It saves these values
with a checkpoint, and maintains a `save_counter` for numbering checkpoints.
TensorFlow objects may contain trackable state, such as `tf.Variable`s,
`tf.keras.optimizers.Optimizer` implementations, `tf.data.Dataset` iterators,
`tf.keras.Layer` implementations, or `tf.keras.Model` implementations.
These are called **trackable objects**.
Example usage:
A `Checkpoint` object can be constructed to save either a single or group of
trackable objects to a checkpoint file. It maintains a `save_counter` for
numbering checkpoints.
Example:
```python
model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)
# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')
# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)
```
Example 2:
```python
import tensorflow as tf
@ -1805,45 +1860,77 @@ class Checkpoint(tracking.AutoTrackable):
as a single checkpoint. This avoids copying all variables to one worker, but
does require that all workers see a common filesystem.
While `tf.keras.Model.save_weights` and `tf.train.Checkpoint.save` save in the
same format, note that the root of the resulting checkpoint is the object the
save method is attached to. This means saving a `tf.keras.Model` using
`save_weights` and loading into a `tf.train.Checkpoint` with a `Model`
attached (or vice versa) will not match the `Model`'s variables. See the
[guide to training
This function differs slightly from the Keras Model `save_weights` function.
`tf.keras.Model.save_weights` creates a checkpoint file with the name
specified in `filepath`, while `tf.train.Checkpoint` numbers the checkpoints,
using `filepath` as the prefix for the checkpoint file names. Aside from this,
`model.save_weights()` and `tf.train.Checkpoint(model).save()` are equivalent.
See the [guide to training
checkpoints](https://www.tensorflow.org/guide/checkpoint) for
details. Prefer `tf.train.Checkpoint` over `tf.keras.Model.save_weights` for
training checkpoints.
details.
Attributes:
save_counter: Incremented when `save()` is called. Used to number
checkpoints.
"""
def __init__(self, **kwargs):
"""Group objects into a training checkpoint.
def __init__(self, root=None, **kwargs):
"""Creates a training checkpoint for a single or group of objects.
Args:
root: The root object to checkpoint.
**kwargs: Keyword arguments are set as attributes of this object, and are
saved with the checkpoint. Values must be trackable objects.
Raises:
ValueError: If objects in `kwargs` are not trackable.
ValueError: If `root` or the objects in `kwargs` are not trackable. A
`ValueError` is also raised if the `root` object tracks different
objects from the ones listed in attributes in kwargs (e.g.
`root.child = A` and `tf.train.Checkpoint(root, child=B)` are
incompatible).
"""
super(Checkpoint, self).__init__()
for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
setattr(self, k, v)
if not isinstance(
getattr(self, k), (base.Trackable, def_function.Function)):
raise ValueError(
("`Checkpoint` was expecting a trackable object (an object "
"derived from `TrackableBase`), got %s. If you believe this "
"object should be trackable (i.e. it is part of the "
"TensorFlow Python API and manages state), please open an issue.")
% (v,))
saver_root = self
attached_dependencies = None
self._save_counter = None # Created lazily for restore-on-create.
self._save_assign_op = None
self._saver = saver_with_op_caching(self)
if root:
_assert_trackable(root)
saver_root = root
attached_dependencies = []
# All keyword arguments (including root itself) are set as children
# of root.
kwargs["root"] = root
root._maybe_initialize_trackable()
self._save_counter = root._lookup_dependency("save_counter")
self._root = root
for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
setattr(self, k, v)
# Call getattr instead of directly using v because setattr converts
# v to a Trackable data structure when v is a list/dict/tuple.
converted_v = getattr(self, k)
_assert_trackable(converted_v)
if root:
# Make sure that root doesn't already have dependencies with these names
child = root._lookup_dependency(k)
if child is None:
attached_dependencies.append(base.TrackableReference(k, converted_v))
elif child != converted_v:
raise ValueError(
"Cannot create a Checkpoint with keyword argument {name} if "
"root.{name} already exists.".format(name=k))
self._saver = saver_with_op_caching(saver_root, attached_dependencies)
self._attached_dependencies = attached_dependencies
def _maybe_create_save_counter(self):
"""Create a save counter if it does not yet exist."""
@ -1859,6 +1946,15 @@ class Checkpoint(tracking.AutoTrackable):
initializer=0,
dtype=dtypes.int64,
trainable=False))
if self._attached_dependencies is not None:
self._attached_dependencies.append(
base.TrackableReference("save_counter", self._save_counter))
# When loading a checkpoint, the save counter is created after
# the checkpoint has been loaded, so it must be handled in a deferred
# manner.
restore = self.root._deferred_dependencies.get("save_counter") # pylint: disable=protected-access
if restore:
restore[0].restore(self._save_counter)
def write(self, file_prefix, options=None):
"""Writes a training checkpoint.
@ -2074,15 +2170,32 @@ class Checkpoint(tracking.AutoTrackable):
a matching Python object.
Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be
loaded
using this method. Names are used to match variables. Re-encode name-based
checkpoints using `tf.train.Checkpoint.save` as soon as possible.
loaded using this method. Names are used to match variables. Re-encode
name-based checkpoints using `tf.train.Checkpoint.save` as soon as possible.
**Loading from SavedModel checkpoints**
To load values from a SavedModel, just pass the SavedModel directory
to checkpoint.restore:
```python
model = tf.keras.Model(...)
tf.saved_model.save(model, path) # or model.save(path, save_format='tf')
checkpoint = tf.train.Checkpoint(model)
checkpoint.restore(path).expect_partial()
```
This example calls `expect_partial()` on the loaded status, since
SavedModels saved from Keras often generates extra keys in the checkpoint.
Otherwise, the program prints a lot of warnings about unused keys at exit
time.
Args:
save_path: The path to the checkpoint, as returned by `save` or
`tf.train.latest_checkpoint`. If the checkpoint was written by the
name-based `tf.compat.v1.train.Saver`, names are used to match
variables.
variables. This path may also be a SavedModel directory.
options: Optional `tf.train.CheckpointOptions` object.
Returns:
@ -2121,8 +2234,25 @@ class Checkpoint(tracking.AutoTrackable):
restores. Warnings are otherwise printed for unused parts of the
checkpoint file or object when the `Checkpoint` object is deleted
(often at program shutdown).
Raises:
NotFoundError: if the a checkpoint or SavedModel cannot be found at
`save_path`.
"""
orig_save_path = save_path
if save_path is not None and gfile.IsDirectory(save_path) and (
(gfile.Exists(utils_impl.get_saved_model_pb_path(save_path)) or
gfile.Exists(utils_impl.get_saved_model_pbtxt_path(save_path)))):
save_path = utils_impl.get_variables_path(save_path)
try:
status = self.read(save_path, options=options)
except errors_impl.NotFoundError:
raise errors_impl.NotFoundError(
None, None,
"Could not find checkpoint or SavedModel at {}."
.format(orig_save_path))
# Create the save counter now so it gets initialized with other variables
# when graph building. Creating it earlier would lead to errors when using,
# say, train.Saver() to save the model before initializing it.

View File

@ -26,6 +26,7 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
@ -37,6 +38,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import save as saved_model_save
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.saving import checkpoint_options
@ -794,6 +796,101 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
self.assertAllClose(self.evaluate(load_checkpoint.a), [0, 1])
self.assertAllClose(self.evaluate(load_checkpoint.b), {"a": 2, "b": 3})
def _create_trackable(self):
class Model(tracking.AutoTrackable):
def __init__(self):
self.v = variables_lib.Variable(2.)
def __call__(self, x):
return self.v * x
return Model()
def test_initialize_with_root_object(self):
model = self._create_trackable()
input_value = constant_op.constant([[3.]])
expected_output = self.evaluate(model(input_value))
model.deferred_variable = variables_lib.Variable(5.)
checkpoint = trackable_utils.Checkpoint(model)
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
save_path = checkpoint.save(checkpoint_prefix)
new_model = self._create_trackable()
load_checkpoint = trackable_utils.Checkpoint(new_model)
load_checkpoint.restore(save_path)
self.assertAllClose(expected_output, new_model(input_value))
new_model.deferred_variable = variables_lib.Variable(1.)
self.assertEqual(self.evaluate(new_model.deferred_variable), 5)
def test_initialize_with_root_object_and_kwargs(self):
model = self._create_trackable()
model.v.assign(3.)
separate_variable = variables_lib.Variable(5.)
with self.assertRaisesRegex(ValueError, "root.v already exists"):
trackable_utils.Checkpoint(model, v=separate_variable)
checkpoint = trackable_utils.Checkpoint(
model, separate_variable=separate_variable)
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
save_path = checkpoint.save(checkpoint_prefix)
# Case 1: Loading checkpoint with same configuration.
new_model = self._create_trackable()
separate_variable = variables_lib.Variable(1.)
load_checkpoint = trackable_utils.Checkpoint(
new_model, separate_variable=separate_variable)
load_checkpoint.restore(save_path).assert_consumed()
self.assertEqual(self.evaluate(new_model.v), 3)
self.assertEqual(self.evaluate(separate_variable), 5)
self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1)
# Case 2: Loading checkpoint where v and separate_variable are swapped:
# v is not attached to the root, while separate variable is attached to root
new_model = tracking.AutoTrackable()
new_model.separate_variable = variables_lib.Variable(200.)
v = variables_lib.Variable(100.)
load_checkpoint = trackable_utils.Checkpoint(new_model, v=v)
load_checkpoint.restore(save_path).assert_consumed()
self.assertEqual(self.evaluate(v), 3)
self.assertEqual(self.evaluate(new_model.separate_variable), 5)
self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1)
# Case 3: Loading checkpoint where no root object is specified
separate_variable = variables_lib.Variable(200.)
v = variables_lib.Variable(100.)
load_checkpoint = trackable_utils.Checkpoint(
v=v, separate_variable=separate_variable)
load_checkpoint.restore(save_path).assert_consumed()
self.assertEqual(self.evaluate(v), 3)
self.assertEqual(self.evaluate(new_model.separate_variable), 5)
self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1)
def test_checkpoint_saved_model_compatibility(self):
model = self._create_trackable()
input_value = constant_op.constant([[3.]])
expected_output = self.evaluate(model(input_value))
model.deferred_variable = variables_lib.Variable(5.)
saved_model_dir = os.path.join(self.get_temp_dir(), "saved_model")
saved_model_save.save(model, saved_model_dir)
new_model = self._create_trackable()
load_checkpoint = trackable_utils.Checkpoint(new_model)
with self.assertRaisesRegex(errors_impl.NotFoundError,
"Could not find checkpoint or SavedModel"):
load_checkpoint.restore(saved_model_dir + "no").expect_partial()
load_checkpoint.restore(saved_model_dir).expect_partial()
self.assertAllClose(expected_output, new_model(input_value))
new_model.deferred_variable = variables_lib.Variable(1.)
self.assertEqual(self.evaluate(new_model.deferred_variable), 5)
class TemplateTests(parameterized.TestCase, test.TestCase):

View File

@ -10,7 +10,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
argspec: "args=[\'self\', \'root\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
}
member_method {
name: "read"