Add option to construct tf.train.Checkpoint with a root object.
PiperOrigin-RevId: 326701488 Change-Id: Ifc01382a513b8977793e68e801e2410cecface3a
This commit is contained in:
parent
38d3be9f51
commit
3bc9d4420e
@ -157,6 +157,14 @@
|
|||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
* Tracing and Debugging:
|
* Tracing and Debugging:
|
||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
* `tf.train.Checkpoint`:
|
||||||
|
* Now accepts a `root` argument in the initialization, which generates a
|
||||||
|
checkpoint with a root object. This allows users to create a `Checkpoint`
|
||||||
|
object that is compatible with Keras `model.save_weights()` and
|
||||||
|
`model.load_weights`. The checkpoint is also compatible with the
|
||||||
|
checkpoint saved in the `variables/` folder in the SavedModel.
|
||||||
|
* When restoring, `save_path` can be a path to a SavedModel. The function
|
||||||
|
will automatically find the checkpoint in the SavedModel.
|
||||||
* Other:
|
* Other:
|
||||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||||
and "denylist" where possible. Please see
|
and "denylist" where possible. Please see
|
||||||
|
@ -262,6 +262,18 @@ def get_or_create_debug_dir(export_dir):
|
|||||||
return debug_dir
|
return debug_dir
|
||||||
|
|
||||||
|
|
||||||
|
def get_saved_model_pbtxt_path(export_dir):
|
||||||
|
return os.path.join(
|
||||||
|
compat.as_bytes(compat.path_to_str(export_dir)),
|
||||||
|
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
|
||||||
|
|
||||||
|
|
||||||
|
def get_saved_model_pb_path(export_dir):
|
||||||
|
return os.path.join(
|
||||||
|
compat.as_bytes(compat.path_to_str(export_dir)),
|
||||||
|
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
|
||||||
|
|
||||||
|
|
||||||
def get_debug_dir(export_dir):
|
def get_debug_dir(export_dir):
|
||||||
"""Returns path to the debug sub-directory in the SavedModel."""
|
"""Returns path to the debug sub-directory in the SavedModel."""
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
|
@ -146,6 +146,7 @@ py_library(
|
|||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
|
"//tensorflow/python/saved_model:utils",
|
||||||
"//tensorflow/python/training/saving:checkpoint_options",
|
"//tensorflow/python/training/saving:checkpoint_options",
|
||||||
"//tensorflow/python/training/saving:functional_saver",
|
"//tensorflow/python/training/saving:functional_saver",
|
||||||
"//tensorflow/python/training/saving:saveable_object_util",
|
"//tensorflow/python/training/saving:saveable_object_util",
|
||||||
@ -184,6 +185,7 @@ tf_py_test(
|
|||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
|
"//tensorflow/python/saved_model:save",
|
||||||
"//tensorflow/python/training/saving:checkpoint_options",
|
"//tensorflow/python/training/saving:checkpoint_options",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
|
@ -142,7 +142,7 @@ def _serialize_slot_variables(trackable_objects, node_ids, object_names):
|
|||||||
class ObjectGraphView(object):
|
class ObjectGraphView(object):
|
||||||
"""Gathers and serializes an object graph."""
|
"""Gathers and serializes an object graph."""
|
||||||
|
|
||||||
def __init__(self, root, saveables_cache=None):
|
def __init__(self, root, saveables_cache=None, attached_dependencies=None):
|
||||||
"""Configure the graph view.
|
"""Configure the graph view.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -151,16 +151,24 @@ class ObjectGraphView(object):
|
|||||||
saveables_cache: A dictionary mapping `Trackable` objects ->
|
saveables_cache: A dictionary mapping `Trackable` objects ->
|
||||||
attribute names -> SaveableObjects, used to avoid re-creating
|
attribute names -> SaveableObjects, used to avoid re-creating
|
||||||
SaveableObjects when graph building.
|
SaveableObjects when graph building.
|
||||||
|
attached_dependencies: Dependencies to attach to the root object. Used
|
||||||
|
when saving a Checkpoint with a defined root object.
|
||||||
"""
|
"""
|
||||||
self._root_ref = root
|
self._root_ref = root
|
||||||
self._saveables_cache = saveables_cache
|
self._saveables_cache = saveables_cache
|
||||||
|
self._attached_dependencies = attached_dependencies
|
||||||
|
|
||||||
def list_dependencies(self, obj):
|
def list_dependencies(self, obj):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
obj._maybe_initialize_trackable()
|
obj._maybe_initialize_trackable()
|
||||||
return obj._checkpoint_dependencies
|
dependencies = obj._checkpoint_dependencies
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
if obj is self.root and self._attached_dependencies:
|
||||||
|
dependencies = dependencies.copy()
|
||||||
|
dependencies.extend(self._attached_dependencies)
|
||||||
|
return dependencies
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def saveables_cache(self):
|
def saveables_cache(self):
|
||||||
"""Maps Trackable objects -> attribute names -> list(SaveableObjects).
|
"""Maps Trackable objects -> attribute names -> list(SaveableObjects).
|
||||||
@ -173,6 +181,19 @@ class ObjectGraphView(object):
|
|||||||
"""
|
"""
|
||||||
return self._saveables_cache
|
return self._saveables_cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attached_dependencies(self):
|
||||||
|
"""Returns list of dependencies that should be saved in the checkpoint.
|
||||||
|
|
||||||
|
These dependencies are not tracked by root, but are in the the checkpoint.
|
||||||
|
This is defined when the user creates a Checkpoint with both root and kwargs
|
||||||
|
set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of TrackableReferences.
|
||||||
|
"""
|
||||||
|
return self._attached_dependencies
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def root(self):
|
def root(self):
|
||||||
if isinstance(self._root_ref, weakref.ref):
|
if isinstance(self._root_ref, weakref.ref):
|
||||||
|
@ -40,7 +40,9 @@ from tensorflow.python.ops import gen_io_ops as io_ops
|
|||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.saved_model import utils_impl
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
from tensorflow.python.training import py_checkpoint_reader
|
from tensorflow.python.training import py_checkpoint_reader
|
||||||
from tensorflow.python.training import saver as v1_saver_lib
|
from tensorflow.python.training import saver as v1_saver_lib
|
||||||
@ -1325,6 +1327,30 @@ class TrackableSaver(object):
|
|||||||
options=options)
|
options=options)
|
||||||
base.CheckpointPosition(
|
base.CheckpointPosition(
|
||||||
checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
|
checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
|
||||||
|
|
||||||
|
# Attached dependencies are not attached to the root, so should be restored
|
||||||
|
# separately.
|
||||||
|
if self._graph_view.attached_dependencies:
|
||||||
|
for ref in self._graph_view.attached_dependencies:
|
||||||
|
if ref.name == "root":
|
||||||
|
# Root dependency is automatically added to attached dependencies --
|
||||||
|
# this can be ignored since it maps back to the root object.
|
||||||
|
continue
|
||||||
|
proto_id = None
|
||||||
|
# Find proto ID of attached dependency (if it is in the proto).
|
||||||
|
for proto_ref in object_graph_proto.nodes[0].children:
|
||||||
|
if proto_ref.local_name == ref.name:
|
||||||
|
proto_id = proto_ref.node_id
|
||||||
|
break
|
||||||
|
|
||||||
|
if proto_id in checkpoint.object_by_proto_id:
|
||||||
|
# Object has already been restored. This can happen when there's an
|
||||||
|
# indirect connection from the attached object to the root.
|
||||||
|
continue
|
||||||
|
|
||||||
|
base.CheckpointPosition(
|
||||||
|
checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref)
|
||||||
|
|
||||||
load_status = CheckpointLoadStatus(
|
load_status = CheckpointLoadStatus(
|
||||||
checkpoint,
|
checkpoint,
|
||||||
graph_view=self._graph_view,
|
graph_view=self._graph_view,
|
||||||
@ -1358,7 +1384,7 @@ def frozen_saver(root_trackable):
|
|||||||
return functional_saver.MultiDeviceSaver(named_saveable_objects)
|
return functional_saver.MultiDeviceSaver(named_saveable_objects)
|
||||||
|
|
||||||
|
|
||||||
def saver_with_op_caching(obj):
|
def saver_with_op_caching(obj, attached_dependencies=None):
|
||||||
"""A TrackableSaver with a SaveableObject cache when graph building."""
|
"""A TrackableSaver with a SaveableObject cache when graph building."""
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
saveables_cache = None
|
saveables_cache = None
|
||||||
@ -1366,7 +1392,19 @@ def saver_with_op_caching(obj):
|
|||||||
saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
|
saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
|
||||||
return TrackableSaver(
|
return TrackableSaver(
|
||||||
graph_view_lib.ObjectGraphView(
|
graph_view_lib.ObjectGraphView(
|
||||||
weakref.ref(obj), saveables_cache=saveables_cache))
|
weakref.ref(obj), saveables_cache=saveables_cache,
|
||||||
|
attached_dependencies=attached_dependencies))
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_trackable(obj):
|
||||||
|
if not isinstance(
|
||||||
|
obj, (base.Trackable, def_function.Function)):
|
||||||
|
raise ValueError(
|
||||||
|
"`Checkpoint` was expecting a trackable object (an object "
|
||||||
|
"derived from `TrackableBase`), got {}. If you believe this "
|
||||||
|
"object should be trackable (i.e. it is part of the "
|
||||||
|
"TensorFlow Python API and manages state), please open an issue."
|
||||||
|
.format(obj))
|
||||||
|
|
||||||
|
|
||||||
# Mentions graph building / Sessions. The v2 version is below.
|
# Mentions graph building / Sessions. The v2 version is below.
|
||||||
@ -1737,15 +1775,32 @@ class CheckpointV1(tracking.AutoTrackable):
|
|||||||
|
|
||||||
@tf_export("train.Checkpoint", v1=[])
|
@tf_export("train.Checkpoint", v1=[])
|
||||||
class Checkpoint(tracking.AutoTrackable):
|
class Checkpoint(tracking.AutoTrackable):
|
||||||
"""Groups trackable objects, saving and restoring them.
|
"""Manages saving/restoring trackable values to disk.
|
||||||
|
|
||||||
`Checkpoint`'s constructor accepts keyword arguments whose values are types
|
TensorFlow objects may contain trackable state, such as `tf.Variable`s,
|
||||||
that contain trackable state, such as `tf.keras.optimizers.Optimizer`
|
`tf.keras.optimizers.Optimizer` implementations, `tf.data.Dataset` iterators,
|
||||||
implementations, `tf.Variable`s, `tf.data.Dataset` iterators, `tf.keras.Layer`
|
`tf.keras.Layer` implementations, or `tf.keras.Model` implementations.
|
||||||
implementations, or `tf.keras.Model` implementations. It saves these values
|
These are called **trackable objects**.
|
||||||
with a checkpoint, and maintains a `save_counter` for numbering checkpoints.
|
|
||||||
|
|
||||||
Example usage:
|
A `Checkpoint` object can be constructed to save either a single or group of
|
||||||
|
trackable objects to a checkpoint file. It maintains a `save_counter` for
|
||||||
|
numbering checkpoints.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = tf.keras.Model(...)
|
||||||
|
checkpoint = tf.train.Checkpoint(model)
|
||||||
|
|
||||||
|
# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
|
||||||
|
# checkpoint.save is called, the save counter is increased.
|
||||||
|
save_path = checkpoint.save('/tmp/training_checkpoints')
|
||||||
|
|
||||||
|
# Restore the checkpointed values to the `model` object.
|
||||||
|
checkpoint.restore(save_path)
|
||||||
|
```
|
||||||
|
|
||||||
|
Example 2:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@ -1805,45 +1860,77 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
as a single checkpoint. This avoids copying all variables to one worker, but
|
as a single checkpoint. This avoids copying all variables to one worker, but
|
||||||
does require that all workers see a common filesystem.
|
does require that all workers see a common filesystem.
|
||||||
|
|
||||||
While `tf.keras.Model.save_weights` and `tf.train.Checkpoint.save` save in the
|
This function differs slightly from the Keras Model `save_weights` function.
|
||||||
same format, note that the root of the resulting checkpoint is the object the
|
`tf.keras.Model.save_weights` creates a checkpoint file with the name
|
||||||
save method is attached to. This means saving a `tf.keras.Model` using
|
specified in `filepath`, while `tf.train.Checkpoint` numbers the checkpoints,
|
||||||
`save_weights` and loading into a `tf.train.Checkpoint` with a `Model`
|
using `filepath` as the prefix for the checkpoint file names. Aside from this,
|
||||||
attached (or vice versa) will not match the `Model`'s variables. See the
|
`model.save_weights()` and `tf.train.Checkpoint(model).save()` are equivalent.
|
||||||
[guide to training
|
|
||||||
|
See the [guide to training
|
||||||
checkpoints](https://www.tensorflow.org/guide/checkpoint) for
|
checkpoints](https://www.tensorflow.org/guide/checkpoint) for
|
||||||
details. Prefer `tf.train.Checkpoint` over `tf.keras.Model.save_weights` for
|
details.
|
||||||
training checkpoints.
|
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
save_counter: Incremented when `save()` is called. Used to number
|
save_counter: Incremented when `save()` is called. Used to number
|
||||||
checkpoints.
|
checkpoints.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, root=None, **kwargs):
|
||||||
"""Group objects into a training checkpoint.
|
"""Creates a training checkpoint for a single or group of objects.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
root: The root object to checkpoint.
|
||||||
**kwargs: Keyword arguments are set as attributes of this object, and are
|
**kwargs: Keyword arguments are set as attributes of this object, and are
|
||||||
saved with the checkpoint. Values must be trackable objects.
|
saved with the checkpoint. Values must be trackable objects.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If objects in `kwargs` are not trackable.
|
ValueError: If `root` or the objects in `kwargs` are not trackable. A
|
||||||
|
`ValueError` is also raised if the `root` object tracks different
|
||||||
|
objects from the ones listed in attributes in kwargs (e.g.
|
||||||
|
`root.child = A` and `tf.train.Checkpoint(root, child=B)` are
|
||||||
|
incompatible).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super(Checkpoint, self).__init__()
|
super(Checkpoint, self).__init__()
|
||||||
for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
|
|
||||||
setattr(self, k, v)
|
saver_root = self
|
||||||
if not isinstance(
|
attached_dependencies = None
|
||||||
getattr(self, k), (base.Trackable, def_function.Function)):
|
|
||||||
raise ValueError(
|
|
||||||
("`Checkpoint` was expecting a trackable object (an object "
|
|
||||||
"derived from `TrackableBase`), got %s. If you believe this "
|
|
||||||
"object should be trackable (i.e. it is part of the "
|
|
||||||
"TensorFlow Python API and manages state), please open an issue.")
|
|
||||||
% (v,))
|
|
||||||
self._save_counter = None # Created lazily for restore-on-create.
|
self._save_counter = None # Created lazily for restore-on-create.
|
||||||
self._save_assign_op = None
|
self._save_assign_op = None
|
||||||
self._saver = saver_with_op_caching(self)
|
|
||||||
|
if root:
|
||||||
|
_assert_trackable(root)
|
||||||
|
saver_root = root
|
||||||
|
attached_dependencies = []
|
||||||
|
|
||||||
|
# All keyword arguments (including root itself) are set as children
|
||||||
|
# of root.
|
||||||
|
kwargs["root"] = root
|
||||||
|
root._maybe_initialize_trackable()
|
||||||
|
|
||||||
|
self._save_counter = root._lookup_dependency("save_counter")
|
||||||
|
self._root = root
|
||||||
|
|
||||||
|
for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
# Call getattr instead of directly using v because setattr converts
|
||||||
|
# v to a Trackable data structure when v is a list/dict/tuple.
|
||||||
|
converted_v = getattr(self, k)
|
||||||
|
_assert_trackable(converted_v)
|
||||||
|
|
||||||
|
if root:
|
||||||
|
# Make sure that root doesn't already have dependencies with these names
|
||||||
|
child = root._lookup_dependency(k)
|
||||||
|
if child is None:
|
||||||
|
attached_dependencies.append(base.TrackableReference(k, converted_v))
|
||||||
|
elif child != converted_v:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot create a Checkpoint with keyword argument {name} if "
|
||||||
|
"root.{name} already exists.".format(name=k))
|
||||||
|
|
||||||
|
self._saver = saver_with_op_caching(saver_root, attached_dependencies)
|
||||||
|
self._attached_dependencies = attached_dependencies
|
||||||
|
|
||||||
def _maybe_create_save_counter(self):
|
def _maybe_create_save_counter(self):
|
||||||
"""Create a save counter if it does not yet exist."""
|
"""Create a save counter if it does not yet exist."""
|
||||||
@ -1859,6 +1946,15 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
initializer=0,
|
initializer=0,
|
||||||
dtype=dtypes.int64,
|
dtype=dtypes.int64,
|
||||||
trainable=False))
|
trainable=False))
|
||||||
|
if self._attached_dependencies is not None:
|
||||||
|
self._attached_dependencies.append(
|
||||||
|
base.TrackableReference("save_counter", self._save_counter))
|
||||||
|
# When loading a checkpoint, the save counter is created after
|
||||||
|
# the checkpoint has been loaded, so it must be handled in a deferred
|
||||||
|
# manner.
|
||||||
|
restore = self.root._deferred_dependencies.get("save_counter") # pylint: disable=protected-access
|
||||||
|
if restore:
|
||||||
|
restore[0].restore(self._save_counter)
|
||||||
|
|
||||||
def write(self, file_prefix, options=None):
|
def write(self, file_prefix, options=None):
|
||||||
"""Writes a training checkpoint.
|
"""Writes a training checkpoint.
|
||||||
@ -2074,15 +2170,32 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
a matching Python object.
|
a matching Python object.
|
||||||
|
|
||||||
Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be
|
Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be
|
||||||
loaded
|
loaded using this method. Names are used to match variables. Re-encode
|
||||||
using this method. Names are used to match variables. Re-encode name-based
|
name-based checkpoints using `tf.train.Checkpoint.save` as soon as possible.
|
||||||
checkpoints using `tf.train.Checkpoint.save` as soon as possible.
|
|
||||||
|
**Loading from SavedModel checkpoints**
|
||||||
|
|
||||||
|
To load values from a SavedModel, just pass the SavedModel directory
|
||||||
|
to checkpoint.restore:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = tf.keras.Model(...)
|
||||||
|
tf.saved_model.save(model, path) # or model.save(path, save_format='tf')
|
||||||
|
|
||||||
|
checkpoint = tf.train.Checkpoint(model)
|
||||||
|
checkpoint.restore(path).expect_partial()
|
||||||
|
```
|
||||||
|
|
||||||
|
This example calls `expect_partial()` on the loaded status, since
|
||||||
|
SavedModels saved from Keras often generates extra keys in the checkpoint.
|
||||||
|
Otherwise, the program prints a lot of warnings about unused keys at exit
|
||||||
|
time.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
save_path: The path to the checkpoint, as returned by `save` or
|
save_path: The path to the checkpoint, as returned by `save` or
|
||||||
`tf.train.latest_checkpoint`. If the checkpoint was written by the
|
`tf.train.latest_checkpoint`. If the checkpoint was written by the
|
||||||
name-based `tf.compat.v1.train.Saver`, names are used to match
|
name-based `tf.compat.v1.train.Saver`, names are used to match
|
||||||
variables.
|
variables. This path may also be a SavedModel directory.
|
||||||
options: Optional `tf.train.CheckpointOptions` object.
|
options: Optional `tf.train.CheckpointOptions` object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -2121,8 +2234,25 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
restores. Warnings are otherwise printed for unused parts of the
|
restores. Warnings are otherwise printed for unused parts of the
|
||||||
checkpoint file or object when the `Checkpoint` object is deleted
|
checkpoint file or object when the `Checkpoint` object is deleted
|
||||||
(often at program shutdown).
|
(often at program shutdown).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: if the a checkpoint or SavedModel cannot be found at
|
||||||
|
`save_path`.
|
||||||
"""
|
"""
|
||||||
status = self.read(save_path, options=options)
|
orig_save_path = save_path
|
||||||
|
|
||||||
|
if save_path is not None and gfile.IsDirectory(save_path) and (
|
||||||
|
(gfile.Exists(utils_impl.get_saved_model_pb_path(save_path)) or
|
||||||
|
gfile.Exists(utils_impl.get_saved_model_pbtxt_path(save_path)))):
|
||||||
|
save_path = utils_impl.get_variables_path(save_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
status = self.read(save_path, options=options)
|
||||||
|
except errors_impl.NotFoundError:
|
||||||
|
raise errors_impl.NotFoundError(
|
||||||
|
None, None,
|
||||||
|
"Could not find checkpoint or SavedModel at {}."
|
||||||
|
.format(orig_save_path))
|
||||||
# Create the save counter now so it gets initialized with other variables
|
# Create the save counter now so it gets initialized with other variables
|
||||||
# when graph building. Creating it earlier would lead to errors when using,
|
# when graph building. Creating it earlier would lead to errors when using,
|
||||||
# say, train.Saver() to save the model before initializing it.
|
# say, train.Saver() to save the model before initializing it.
|
||||||
|
@ -26,6 +26,7 @@ from tensorflow.python.eager import context
|
|||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -37,6 +38,7 @@ from tensorflow.python.ops import variable_scope
|
|||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.saved_model import save as saved_model_save
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
from tensorflow.python.training import saver as saver_lib
|
from tensorflow.python.training import saver as saver_lib
|
||||||
from tensorflow.python.training.saving import checkpoint_options
|
from tensorflow.python.training.saving import checkpoint_options
|
||||||
@ -794,6 +796,101 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
|||||||
self.assertAllClose(self.evaluate(load_checkpoint.a), [0, 1])
|
self.assertAllClose(self.evaluate(load_checkpoint.a), [0, 1])
|
||||||
self.assertAllClose(self.evaluate(load_checkpoint.b), {"a": 2, "b": 3})
|
self.assertAllClose(self.evaluate(load_checkpoint.b), {"a": 2, "b": 3})
|
||||||
|
|
||||||
|
def _create_trackable(self):
|
||||||
|
class Model(tracking.AutoTrackable):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.v = variables_lib.Variable(2.)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.v * x
|
||||||
|
return Model()
|
||||||
|
|
||||||
|
def test_initialize_with_root_object(self):
|
||||||
|
model = self._create_trackable()
|
||||||
|
input_value = constant_op.constant([[3.]])
|
||||||
|
expected_output = self.evaluate(model(input_value))
|
||||||
|
model.deferred_variable = variables_lib.Variable(5.)
|
||||||
|
|
||||||
|
checkpoint = trackable_utils.Checkpoint(model)
|
||||||
|
checkpoint_directory = self.get_temp_dir()
|
||||||
|
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||||
|
save_path = checkpoint.save(checkpoint_prefix)
|
||||||
|
|
||||||
|
new_model = self._create_trackable()
|
||||||
|
load_checkpoint = trackable_utils.Checkpoint(new_model)
|
||||||
|
load_checkpoint.restore(save_path)
|
||||||
|
self.assertAllClose(expected_output, new_model(input_value))
|
||||||
|
|
||||||
|
new_model.deferred_variable = variables_lib.Variable(1.)
|
||||||
|
self.assertEqual(self.evaluate(new_model.deferred_variable), 5)
|
||||||
|
|
||||||
|
def test_initialize_with_root_object_and_kwargs(self):
|
||||||
|
model = self._create_trackable()
|
||||||
|
model.v.assign(3.)
|
||||||
|
separate_variable = variables_lib.Variable(5.)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "root.v already exists"):
|
||||||
|
trackable_utils.Checkpoint(model, v=separate_variable)
|
||||||
|
|
||||||
|
checkpoint = trackable_utils.Checkpoint(
|
||||||
|
model, separate_variable=separate_variable)
|
||||||
|
checkpoint_directory = self.get_temp_dir()
|
||||||
|
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||||
|
save_path = checkpoint.save(checkpoint_prefix)
|
||||||
|
|
||||||
|
# Case 1: Loading checkpoint with same configuration.
|
||||||
|
new_model = self._create_trackable()
|
||||||
|
separate_variable = variables_lib.Variable(1.)
|
||||||
|
load_checkpoint = trackable_utils.Checkpoint(
|
||||||
|
new_model, separate_variable=separate_variable)
|
||||||
|
load_checkpoint.restore(save_path).assert_consumed()
|
||||||
|
self.assertEqual(self.evaluate(new_model.v), 3)
|
||||||
|
self.assertEqual(self.evaluate(separate_variable), 5)
|
||||||
|
self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1)
|
||||||
|
|
||||||
|
# Case 2: Loading checkpoint where v and separate_variable are swapped:
|
||||||
|
# v is not attached to the root, while separate variable is attached to root
|
||||||
|
new_model = tracking.AutoTrackable()
|
||||||
|
new_model.separate_variable = variables_lib.Variable(200.)
|
||||||
|
v = variables_lib.Variable(100.)
|
||||||
|
load_checkpoint = trackable_utils.Checkpoint(new_model, v=v)
|
||||||
|
load_checkpoint.restore(save_path).assert_consumed()
|
||||||
|
self.assertEqual(self.evaluate(v), 3)
|
||||||
|
self.assertEqual(self.evaluate(new_model.separate_variable), 5)
|
||||||
|
self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1)
|
||||||
|
|
||||||
|
# Case 3: Loading checkpoint where no root object is specified
|
||||||
|
separate_variable = variables_lib.Variable(200.)
|
||||||
|
v = variables_lib.Variable(100.)
|
||||||
|
load_checkpoint = trackable_utils.Checkpoint(
|
||||||
|
v=v, separate_variable=separate_variable)
|
||||||
|
load_checkpoint.restore(save_path).assert_consumed()
|
||||||
|
self.assertEqual(self.evaluate(v), 3)
|
||||||
|
self.assertEqual(self.evaluate(new_model.separate_variable), 5)
|
||||||
|
self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1)
|
||||||
|
|
||||||
|
def test_checkpoint_saved_model_compatibility(self):
|
||||||
|
model = self._create_trackable()
|
||||||
|
input_value = constant_op.constant([[3.]])
|
||||||
|
expected_output = self.evaluate(model(input_value))
|
||||||
|
model.deferred_variable = variables_lib.Variable(5.)
|
||||||
|
saved_model_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||||
|
saved_model_save.save(model, saved_model_dir)
|
||||||
|
|
||||||
|
new_model = self._create_trackable()
|
||||||
|
load_checkpoint = trackable_utils.Checkpoint(new_model)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(errors_impl.NotFoundError,
|
||||||
|
"Could not find checkpoint or SavedModel"):
|
||||||
|
load_checkpoint.restore(saved_model_dir + "no").expect_partial()
|
||||||
|
|
||||||
|
load_checkpoint.restore(saved_model_dir).expect_partial()
|
||||||
|
self.assertAllClose(expected_output, new_model(input_value))
|
||||||
|
|
||||||
|
new_model.deferred_variable = variables_lib.Variable(1.)
|
||||||
|
self.assertEqual(self.evaluate(new_model.deferred_variable), 5)
|
||||||
|
|
||||||
|
|
||||||
class TemplateTests(parameterized.TestCase, test.TestCase):
|
class TemplateTests(parameterized.TestCase, test.TestCase):
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
|
argspec: "args=[\'self\', \'root\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "read"
|
name: "read"
|
||||||
|
Loading…
Reference in New Issue
Block a user