Add option to construct tf.train.Checkpoint with a root object.
PiperOrigin-RevId: 326701488 Change-Id: Ifc01382a513b8977793e68e801e2410cecface3a
This commit is contained in:
parent
38d3be9f51
commit
3bc9d4420e
@ -157,6 +157,14 @@
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Tracing and Debugging:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.train.Checkpoint`:
|
||||
* Now accepts a `root` argument in the initialization, which generates a
|
||||
checkpoint with a root object. This allows users to create a `Checkpoint`
|
||||
object that is compatible with Keras `model.save_weights()` and
|
||||
`model.load_weights`. The checkpoint is also compatible with the
|
||||
checkpoint saved in the `variables/` folder in the SavedModel.
|
||||
* When restoring, `save_path` can be a path to a SavedModel. The function
|
||||
will automatically find the checkpoint in the SavedModel.
|
||||
* Other:
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
and "denylist" where possible. Please see
|
||||
|
@ -262,6 +262,18 @@ def get_or_create_debug_dir(export_dir):
|
||||
return debug_dir
|
||||
|
||||
|
||||
def get_saved_model_pbtxt_path(export_dir):
|
||||
return os.path.join(
|
||||
compat.as_bytes(compat.path_to_str(export_dir)),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
|
||||
|
||||
|
||||
def get_saved_model_pb_path(export_dir):
|
||||
return os.path.join(
|
||||
compat.as_bytes(compat.path_to_str(export_dir)),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
|
||||
|
||||
|
||||
def get_debug_dir(export_dir):
|
||||
"""Returns path to the debug sub-directory in the SavedModel."""
|
||||
return os.path.join(
|
||||
|
@ -146,6 +146,7 @@ py_library(
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/saved_model:utils",
|
||||
"//tensorflow/python/training/saving:checkpoint_options",
|
||||
"//tensorflow/python/training/saving:functional_saver",
|
||||
"//tensorflow/python/training/saving:saveable_object_util",
|
||||
@ -184,6 +185,7 @@ tf_py_test(
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/saved_model:save",
|
||||
"//tensorflow/python/training/saving:checkpoint_options",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"@six_archive//:six",
|
||||
|
@ -142,7 +142,7 @@ def _serialize_slot_variables(trackable_objects, node_ids, object_names):
|
||||
class ObjectGraphView(object):
|
||||
"""Gathers and serializes an object graph."""
|
||||
|
||||
def __init__(self, root, saveables_cache=None):
|
||||
def __init__(self, root, saveables_cache=None, attached_dependencies=None):
|
||||
"""Configure the graph view.
|
||||
|
||||
Args:
|
||||
@ -151,16 +151,24 @@ class ObjectGraphView(object):
|
||||
saveables_cache: A dictionary mapping `Trackable` objects ->
|
||||
attribute names -> SaveableObjects, used to avoid re-creating
|
||||
SaveableObjects when graph building.
|
||||
attached_dependencies: Dependencies to attach to the root object. Used
|
||||
when saving a Checkpoint with a defined root object.
|
||||
"""
|
||||
self._root_ref = root
|
||||
self._saveables_cache = saveables_cache
|
||||
self._attached_dependencies = attached_dependencies
|
||||
|
||||
def list_dependencies(self, obj):
|
||||
# pylint: disable=protected-access
|
||||
obj._maybe_initialize_trackable()
|
||||
return obj._checkpoint_dependencies
|
||||
dependencies = obj._checkpoint_dependencies
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if obj is self.root and self._attached_dependencies:
|
||||
dependencies = dependencies.copy()
|
||||
dependencies.extend(self._attached_dependencies)
|
||||
return dependencies
|
||||
|
||||
@property
|
||||
def saveables_cache(self):
|
||||
"""Maps Trackable objects -> attribute names -> list(SaveableObjects).
|
||||
@ -173,6 +181,19 @@ class ObjectGraphView(object):
|
||||
"""
|
||||
return self._saveables_cache
|
||||
|
||||
@property
|
||||
def attached_dependencies(self):
|
||||
"""Returns list of dependencies that should be saved in the checkpoint.
|
||||
|
||||
These dependencies are not tracked by root, but are in the the checkpoint.
|
||||
This is defined when the user creates a Checkpoint with both root and kwargs
|
||||
set.
|
||||
|
||||
Returns:
|
||||
A list of TrackableReferences.
|
||||
"""
|
||||
return self._attached_dependencies
|
||||
|
||||
@property
|
||||
def root(self):
|
||||
if isinstance(self._root_ref, weakref.ref):
|
||||
|
@ -40,7 +40,9 @@ from tensorflow.python.ops import gen_io_ops as io_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import utils_impl
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import py_checkpoint_reader
|
||||
from tensorflow.python.training import saver as v1_saver_lib
|
||||
@ -1325,6 +1327,30 @@ class TrackableSaver(object):
|
||||
options=options)
|
||||
base.CheckpointPosition(
|
||||
checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
|
||||
|
||||
# Attached dependencies are not attached to the root, so should be restored
|
||||
# separately.
|
||||
if self._graph_view.attached_dependencies:
|
||||
for ref in self._graph_view.attached_dependencies:
|
||||
if ref.name == "root":
|
||||
# Root dependency is automatically added to attached dependencies --
|
||||
# this can be ignored since it maps back to the root object.
|
||||
continue
|
||||
proto_id = None
|
||||
# Find proto ID of attached dependency (if it is in the proto).
|
||||
for proto_ref in object_graph_proto.nodes[0].children:
|
||||
if proto_ref.local_name == ref.name:
|
||||
proto_id = proto_ref.node_id
|
||||
break
|
||||
|
||||
if proto_id in checkpoint.object_by_proto_id:
|
||||
# Object has already been restored. This can happen when there's an
|
||||
# indirect connection from the attached object to the root.
|
||||
continue
|
||||
|
||||
base.CheckpointPosition(
|
||||
checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref)
|
||||
|
||||
load_status = CheckpointLoadStatus(
|
||||
checkpoint,
|
||||
graph_view=self._graph_view,
|
||||
@ -1358,7 +1384,7 @@ def frozen_saver(root_trackable):
|
||||
return functional_saver.MultiDeviceSaver(named_saveable_objects)
|
||||
|
||||
|
||||
def saver_with_op_caching(obj):
|
||||
def saver_with_op_caching(obj, attached_dependencies=None):
|
||||
"""A TrackableSaver with a SaveableObject cache when graph building."""
|
||||
if context.executing_eagerly():
|
||||
saveables_cache = None
|
||||
@ -1366,7 +1392,19 @@ def saver_with_op_caching(obj):
|
||||
saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
|
||||
return TrackableSaver(
|
||||
graph_view_lib.ObjectGraphView(
|
||||
weakref.ref(obj), saveables_cache=saveables_cache))
|
||||
weakref.ref(obj), saveables_cache=saveables_cache,
|
||||
attached_dependencies=attached_dependencies))
|
||||
|
||||
|
||||
def _assert_trackable(obj):
|
||||
if not isinstance(
|
||||
obj, (base.Trackable, def_function.Function)):
|
||||
raise ValueError(
|
||||
"`Checkpoint` was expecting a trackable object (an object "
|
||||
"derived from `TrackableBase`), got {}. If you believe this "
|
||||
"object should be trackable (i.e. it is part of the "
|
||||
"TensorFlow Python API and manages state), please open an issue."
|
||||
.format(obj))
|
||||
|
||||
|
||||
# Mentions graph building / Sessions. The v2 version is below.
|
||||
@ -1737,15 +1775,32 @@ class CheckpointV1(tracking.AutoTrackable):
|
||||
|
||||
@tf_export("train.Checkpoint", v1=[])
|
||||
class Checkpoint(tracking.AutoTrackable):
|
||||
"""Groups trackable objects, saving and restoring them.
|
||||
"""Manages saving/restoring trackable values to disk.
|
||||
|
||||
`Checkpoint`'s constructor accepts keyword arguments whose values are types
|
||||
that contain trackable state, such as `tf.keras.optimizers.Optimizer`
|
||||
implementations, `tf.Variable`s, `tf.data.Dataset` iterators, `tf.keras.Layer`
|
||||
implementations, or `tf.keras.Model` implementations. It saves these values
|
||||
with a checkpoint, and maintains a `save_counter` for numbering checkpoints.
|
||||
TensorFlow objects may contain trackable state, such as `tf.Variable`s,
|
||||
`tf.keras.optimizers.Optimizer` implementations, `tf.data.Dataset` iterators,
|
||||
`tf.keras.Layer` implementations, or `tf.keras.Model` implementations.
|
||||
These are called **trackable objects**.
|
||||
|
||||
Example usage:
|
||||
A `Checkpoint` object can be constructed to save either a single or group of
|
||||
trackable objects to a checkpoint file. It maintains a `save_counter` for
|
||||
numbering checkpoints.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
model = tf.keras.Model(...)
|
||||
checkpoint = tf.train.Checkpoint(model)
|
||||
|
||||
# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
|
||||
# checkpoint.save is called, the save counter is increased.
|
||||
save_path = checkpoint.save('/tmp/training_checkpoints')
|
||||
|
||||
# Restore the checkpointed values to the `model` object.
|
||||
checkpoint.restore(save_path)
|
||||
```
|
||||
|
||||
Example 2:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
@ -1805,45 +1860,77 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
as a single checkpoint. This avoids copying all variables to one worker, but
|
||||
does require that all workers see a common filesystem.
|
||||
|
||||
While `tf.keras.Model.save_weights` and `tf.train.Checkpoint.save` save in the
|
||||
same format, note that the root of the resulting checkpoint is the object the
|
||||
save method is attached to. This means saving a `tf.keras.Model` using
|
||||
`save_weights` and loading into a `tf.train.Checkpoint` with a `Model`
|
||||
attached (or vice versa) will not match the `Model`'s variables. See the
|
||||
[guide to training
|
||||
This function differs slightly from the Keras Model `save_weights` function.
|
||||
`tf.keras.Model.save_weights` creates a checkpoint file with the name
|
||||
specified in `filepath`, while `tf.train.Checkpoint` numbers the checkpoints,
|
||||
using `filepath` as the prefix for the checkpoint file names. Aside from this,
|
||||
`model.save_weights()` and `tf.train.Checkpoint(model).save()` are equivalent.
|
||||
|
||||
See the [guide to training
|
||||
checkpoints](https://www.tensorflow.org/guide/checkpoint) for
|
||||
details. Prefer `tf.train.Checkpoint` over `tf.keras.Model.save_weights` for
|
||||
training checkpoints.
|
||||
details.
|
||||
|
||||
Attributes:
|
||||
save_counter: Incremented when `save()` is called. Used to number
|
||||
checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Group objects into a training checkpoint.
|
||||
def __init__(self, root=None, **kwargs):
|
||||
"""Creates a training checkpoint for a single or group of objects.
|
||||
|
||||
Args:
|
||||
root: The root object to checkpoint.
|
||||
**kwargs: Keyword arguments are set as attributes of this object, and are
|
||||
saved with the checkpoint. Values must be trackable objects.
|
||||
|
||||
Raises:
|
||||
ValueError: If objects in `kwargs` are not trackable.
|
||||
ValueError: If `root` or the objects in `kwargs` are not trackable. A
|
||||
`ValueError` is also raised if the `root` object tracks different
|
||||
objects from the ones listed in attributes in kwargs (e.g.
|
||||
`root.child = A` and `tf.train.Checkpoint(root, child=B)` are
|
||||
incompatible).
|
||||
|
||||
"""
|
||||
super(Checkpoint, self).__init__()
|
||||
for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
|
||||
setattr(self, k, v)
|
||||
if not isinstance(
|
||||
getattr(self, k), (base.Trackable, def_function.Function)):
|
||||
raise ValueError(
|
||||
("`Checkpoint` was expecting a trackable object (an object "
|
||||
"derived from `TrackableBase`), got %s. If you believe this "
|
||||
"object should be trackable (i.e. it is part of the "
|
||||
"TensorFlow Python API and manages state), please open an issue.")
|
||||
% (v,))
|
||||
|
||||
saver_root = self
|
||||
attached_dependencies = None
|
||||
self._save_counter = None # Created lazily for restore-on-create.
|
||||
self._save_assign_op = None
|
||||
self._saver = saver_with_op_caching(self)
|
||||
|
||||
if root:
|
||||
_assert_trackable(root)
|
||||
saver_root = root
|
||||
attached_dependencies = []
|
||||
|
||||
# All keyword arguments (including root itself) are set as children
|
||||
# of root.
|
||||
kwargs["root"] = root
|
||||
root._maybe_initialize_trackable()
|
||||
|
||||
self._save_counter = root._lookup_dependency("save_counter")
|
||||
self._root = root
|
||||
|
||||
for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
|
||||
setattr(self, k, v)
|
||||
|
||||
# Call getattr instead of directly using v because setattr converts
|
||||
# v to a Trackable data structure when v is a list/dict/tuple.
|
||||
converted_v = getattr(self, k)
|
||||
_assert_trackable(converted_v)
|
||||
|
||||
if root:
|
||||
# Make sure that root doesn't already have dependencies with these names
|
||||
child = root._lookup_dependency(k)
|
||||
if child is None:
|
||||
attached_dependencies.append(base.TrackableReference(k, converted_v))
|
||||
elif child != converted_v:
|
||||
raise ValueError(
|
||||
"Cannot create a Checkpoint with keyword argument {name} if "
|
||||
"root.{name} already exists.".format(name=k))
|
||||
|
||||
self._saver = saver_with_op_caching(saver_root, attached_dependencies)
|
||||
self._attached_dependencies = attached_dependencies
|
||||
|
||||
def _maybe_create_save_counter(self):
|
||||
"""Create a save counter if it does not yet exist."""
|
||||
@ -1859,6 +1946,15 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
initializer=0,
|
||||
dtype=dtypes.int64,
|
||||
trainable=False))
|
||||
if self._attached_dependencies is not None:
|
||||
self._attached_dependencies.append(
|
||||
base.TrackableReference("save_counter", self._save_counter))
|
||||
# When loading a checkpoint, the save counter is created after
|
||||
# the checkpoint has been loaded, so it must be handled in a deferred
|
||||
# manner.
|
||||
restore = self.root._deferred_dependencies.get("save_counter") # pylint: disable=protected-access
|
||||
if restore:
|
||||
restore[0].restore(self._save_counter)
|
||||
|
||||
def write(self, file_prefix, options=None):
|
||||
"""Writes a training checkpoint.
|
||||
@ -2074,15 +2170,32 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
a matching Python object.
|
||||
|
||||
Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be
|
||||
loaded
|
||||
using this method. Names are used to match variables. Re-encode name-based
|
||||
checkpoints using `tf.train.Checkpoint.save` as soon as possible.
|
||||
loaded using this method. Names are used to match variables. Re-encode
|
||||
name-based checkpoints using `tf.train.Checkpoint.save` as soon as possible.
|
||||
|
||||
**Loading from SavedModel checkpoints**
|
||||
|
||||
To load values from a SavedModel, just pass the SavedModel directory
|
||||
to checkpoint.restore:
|
||||
|
||||
```python
|
||||
model = tf.keras.Model(...)
|
||||
tf.saved_model.save(model, path) # or model.save(path, save_format='tf')
|
||||
|
||||
checkpoint = tf.train.Checkpoint(model)
|
||||
checkpoint.restore(path).expect_partial()
|
||||
```
|
||||
|
||||
This example calls `expect_partial()` on the loaded status, since
|
||||
SavedModels saved from Keras often generates extra keys in the checkpoint.
|
||||
Otherwise, the program prints a lot of warnings about unused keys at exit
|
||||
time.
|
||||
|
||||
Args:
|
||||
save_path: The path to the checkpoint, as returned by `save` or
|
||||
`tf.train.latest_checkpoint`. If the checkpoint was written by the
|
||||
name-based `tf.compat.v1.train.Saver`, names are used to match
|
||||
variables.
|
||||
variables. This path may also be a SavedModel directory.
|
||||
options: Optional `tf.train.CheckpointOptions` object.
|
||||
|
||||
Returns:
|
||||
@ -2121,8 +2234,25 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
restores. Warnings are otherwise printed for unused parts of the
|
||||
checkpoint file or object when the `Checkpoint` object is deleted
|
||||
(often at program shutdown).
|
||||
|
||||
Raises:
|
||||
NotFoundError: if the a checkpoint or SavedModel cannot be found at
|
||||
`save_path`.
|
||||
"""
|
||||
orig_save_path = save_path
|
||||
|
||||
if save_path is not None and gfile.IsDirectory(save_path) and (
|
||||
(gfile.Exists(utils_impl.get_saved_model_pb_path(save_path)) or
|
||||
gfile.Exists(utils_impl.get_saved_model_pbtxt_path(save_path)))):
|
||||
save_path = utils_impl.get_variables_path(save_path)
|
||||
|
||||
try:
|
||||
status = self.read(save_path, options=options)
|
||||
except errors_impl.NotFoundError:
|
||||
raise errors_impl.NotFoundError(
|
||||
None, None,
|
||||
"Could not find checkpoint or SavedModel at {}."
|
||||
.format(orig_save_path))
|
||||
# Create the save counter now so it gets initialized with other variables
|
||||
# when graph building. Creating it earlier would lead to errors when using,
|
||||
# say, train.Saver() to save the model before initializing it.
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -37,6 +38,7 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import save as saved_model_save
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training.saving import checkpoint_options
|
||||
@ -794,6 +796,101 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
self.assertAllClose(self.evaluate(load_checkpoint.a), [0, 1])
|
||||
self.assertAllClose(self.evaluate(load_checkpoint.b), {"a": 2, "b": 3})
|
||||
|
||||
def _create_trackable(self):
|
||||
class Model(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
self.v = variables_lib.Variable(2.)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.v * x
|
||||
return Model()
|
||||
|
||||
def test_initialize_with_root_object(self):
|
||||
model = self._create_trackable()
|
||||
input_value = constant_op.constant([[3.]])
|
||||
expected_output = self.evaluate(model(input_value))
|
||||
model.deferred_variable = variables_lib.Variable(5.)
|
||||
|
||||
checkpoint = trackable_utils.Checkpoint(model)
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
save_path = checkpoint.save(checkpoint_prefix)
|
||||
|
||||
new_model = self._create_trackable()
|
||||
load_checkpoint = trackable_utils.Checkpoint(new_model)
|
||||
load_checkpoint.restore(save_path)
|
||||
self.assertAllClose(expected_output, new_model(input_value))
|
||||
|
||||
new_model.deferred_variable = variables_lib.Variable(1.)
|
||||
self.assertEqual(self.evaluate(new_model.deferred_variable), 5)
|
||||
|
||||
def test_initialize_with_root_object_and_kwargs(self):
|
||||
model = self._create_trackable()
|
||||
model.v.assign(3.)
|
||||
separate_variable = variables_lib.Variable(5.)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "root.v already exists"):
|
||||
trackable_utils.Checkpoint(model, v=separate_variable)
|
||||
|
||||
checkpoint = trackable_utils.Checkpoint(
|
||||
model, separate_variable=separate_variable)
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
save_path = checkpoint.save(checkpoint_prefix)
|
||||
|
||||
# Case 1: Loading checkpoint with same configuration.
|
||||
new_model = self._create_trackable()
|
||||
separate_variable = variables_lib.Variable(1.)
|
||||
load_checkpoint = trackable_utils.Checkpoint(
|
||||
new_model, separate_variable=separate_variable)
|
||||
load_checkpoint.restore(save_path).assert_consumed()
|
||||
self.assertEqual(self.evaluate(new_model.v), 3)
|
||||
self.assertEqual(self.evaluate(separate_variable), 5)
|
||||
self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1)
|
||||
|
||||
# Case 2: Loading checkpoint where v and separate_variable are swapped:
|
||||
# v is not attached to the root, while separate variable is attached to root
|
||||
new_model = tracking.AutoTrackable()
|
||||
new_model.separate_variable = variables_lib.Variable(200.)
|
||||
v = variables_lib.Variable(100.)
|
||||
load_checkpoint = trackable_utils.Checkpoint(new_model, v=v)
|
||||
load_checkpoint.restore(save_path).assert_consumed()
|
||||
self.assertEqual(self.evaluate(v), 3)
|
||||
self.assertEqual(self.evaluate(new_model.separate_variable), 5)
|
||||
self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1)
|
||||
|
||||
# Case 3: Loading checkpoint where no root object is specified
|
||||
separate_variable = variables_lib.Variable(200.)
|
||||
v = variables_lib.Variable(100.)
|
||||
load_checkpoint = trackable_utils.Checkpoint(
|
||||
v=v, separate_variable=separate_variable)
|
||||
load_checkpoint.restore(save_path).assert_consumed()
|
||||
self.assertEqual(self.evaluate(v), 3)
|
||||
self.assertEqual(self.evaluate(new_model.separate_variable), 5)
|
||||
self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1)
|
||||
|
||||
def test_checkpoint_saved_model_compatibility(self):
|
||||
model = self._create_trackable()
|
||||
input_value = constant_op.constant([[3.]])
|
||||
expected_output = self.evaluate(model(input_value))
|
||||
model.deferred_variable = variables_lib.Variable(5.)
|
||||
saved_model_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
saved_model_save.save(model, saved_model_dir)
|
||||
|
||||
new_model = self._create_trackable()
|
||||
load_checkpoint = trackable_utils.Checkpoint(new_model)
|
||||
|
||||
with self.assertRaisesRegex(errors_impl.NotFoundError,
|
||||
"Could not find checkpoint or SavedModel"):
|
||||
load_checkpoint.restore(saved_model_dir + "no").expect_partial()
|
||||
|
||||
load_checkpoint.restore(saved_model_dir).expect_partial()
|
||||
self.assertAllClose(expected_output, new_model(input_value))
|
||||
|
||||
new_model.deferred_variable = variables_lib.Variable(1.)
|
||||
self.assertEqual(self.evaluate(new_model.deferred_variable), 5)
|
||||
|
||||
|
||||
class TemplateTests(parameterized.TestCase, test.TestCase):
|
||||
|
||||
|
@ -10,7 +10,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
argspec: "args=[\'self\', \'root\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "read"
|
||||
|
Loading…
Reference in New Issue
Block a user