Rename Checkpointable -> Trackable and AutoCheckpointable -> AutoTrackable
No API changes in this CL. Just more refactoring for a future API change. PiperOrigin-RevId: 234242335
This commit is contained in:
parent
6655a2e6ea
commit
bd36b48c55
@ -33,7 +33,7 @@ from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensem
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import resources
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
ops.NotDifferentiable("TreeEnsembleVariable")
|
||||
ops.NotDifferentiable("TreeEnsembleSerialize")
|
||||
|
@ -33,7 +33,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import resources
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
# Pattern to remove all non alpha numeric from a string.
|
||||
_PATTERN = re.compile(r"[\W_]+")
|
||||
|
@ -26,7 +26,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import resources
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
# Pattern to remove all non alpha numeric from a string.
|
||||
_PATTERN = re.compile(r"[\W_]+")
|
||||
|
@ -27,7 +27,7 @@ Managing dependencies:
|
||||
@@NoDependency
|
||||
@@split_dependency
|
||||
|
||||
Checkpointable data structures:
|
||||
Trackable data structures:
|
||||
@@List
|
||||
@@Mapping
|
||||
@@UniqueNameTracker
|
||||
@ -49,17 +49,16 @@ from tensorflow.contrib.checkpoint.python.python_state import NumpyState
|
||||
from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper
|
||||
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
|
||||
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
|
||||
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
|
||||
from tensorflow.core.protobuf.trackable_object_graph_pb2 import TrackableObjectGraph as CheckpointableObjectGraph
|
||||
from tensorflow.python.training.checkpoint_management import CheckpointManager
|
||||
from tensorflow.python.training.checkpointable.base import Checkpointable as CheckpointableBase
|
||||
from tensorflow.python.training.checkpointable.data_structures import List
|
||||
from tensorflow.python.training.checkpointable.data_structures import Mapping
|
||||
from tensorflow.python.training.checkpointable.data_structures import NoDependency
|
||||
from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable
|
||||
from tensorflow.python.training.checkpointable.util import capture_dependencies
|
||||
from tensorflow.python.training.checkpointable.util import list_objects
|
||||
from tensorflow.python.training.checkpointable.util import object_metadata
|
||||
|
||||
from tensorflow.python.training.tracking.base import Trackable as CheckpointableBase
|
||||
from tensorflow.python.training.tracking.data_structures import List
|
||||
from tensorflow.python.training.tracking.data_structures import Mapping
|
||||
from tensorflow.python.training.tracking.data_structures import NoDependency
|
||||
from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable
|
||||
from tensorflow.python.training.tracking.util import capture_dependencies
|
||||
from tensorflow.python.training.tracking.util import list_objects
|
||||
from tensorflow.python.training.tracking.util import object_metadata
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
remove_undocumented(module_name=__name__)
|
||||
|
@ -12,7 +12,7 @@ py_library(
|
||||
":python_state",
|
||||
":split_dependency",
|
||||
":visualize",
|
||||
"//tensorflow/python/training/checkpointable:data_structures",
|
||||
"//tensorflow/python/training/tracking:data_structures",
|
||||
],
|
||||
)
|
||||
|
||||
@ -22,8 +22,8 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/checkpointable:data_structures",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"//tensorflow/python/training/tracking:data_structures",
|
||||
],
|
||||
)
|
||||
|
||||
@ -36,8 +36,8 @@ tf_py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/checkpointable:util",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
],
|
||||
)
|
||||
|
||||
@ -47,7 +47,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
@ -64,7 +64,7 @@ tf_py_test(
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/training/checkpointable:util",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
],
|
||||
)
|
||||
|
||||
@ -76,7 +76,7 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
)
|
||||
|
||||
@ -89,8 +89,8 @@ tf_py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/checkpointable:util",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
],
|
||||
)
|
||||
|
||||
@ -101,8 +101,8 @@ py_library(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/checkpointable:util",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
],
|
||||
)
|
||||
|
||||
@ -118,7 +118,7 @@ tf_py_test(
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras:engine",
|
||||
"//tensorflow/python/keras:layers",
|
||||
"//tensorflow/python/training/checkpointable:util",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
],
|
||||
tags = ["nooss"], # b/124472244
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Checkpointable data structures."""
|
||||
"""Trackable data structures."""
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -17,12 +17,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable_lib
|
||||
from tensorflow.python.training.checkpointable import data_structures
|
||||
from tensorflow.python.training.tracking import base as trackable_lib
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
|
||||
|
||||
class UniqueNameTracker(data_structures.CheckpointableDataStructure):
|
||||
"""Adds dependencies on checkpointable objects with name hints.
|
||||
class UniqueNameTracker(data_structures.TrackableDataStructure):
|
||||
"""Adds dependencies on trackable objects with name hints.
|
||||
|
||||
Useful for creating dependencies with locally unique names.
|
||||
|
||||
@ -43,30 +43,30 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure):
|
||||
|
||||
def __init__(self):
|
||||
super(UniqueNameTracker, self).__init__()
|
||||
self._maybe_initialize_checkpointable()
|
||||
self._maybe_initialize_trackable()
|
||||
self._name_counts = {}
|
||||
|
||||
@property
|
||||
def _values(self):
|
||||
return [dep.ref for dep in self._checkpoint_dependencies]
|
||||
|
||||
def track(self, checkpointable, base_name):
|
||||
"""Add a dependency on `checkpointable`.
|
||||
def track(self, trackable, base_name):
|
||||
"""Add a dependency on `trackable`.
|
||||
|
||||
Args:
|
||||
checkpointable: An object to add a checkpoint dependency on.
|
||||
trackable: An object to add a checkpoint dependency on.
|
||||
base_name: A name hint, which is uniquified to determine the dependency
|
||||
name.
|
||||
Returns:
|
||||
`checkpointable`, for chaining.
|
||||
`trackable`, for chaining.
|
||||
Raises:
|
||||
ValueError: If `checkpointable` is not a checkpointable object.
|
||||
ValueError: If `trackable` is not a trackable object.
|
||||
"""
|
||||
|
||||
if not isinstance(checkpointable, checkpointable_lib.Checkpointable):
|
||||
if not isinstance(trackable, trackable_lib.Trackable):
|
||||
raise ValueError(
|
||||
("Expected a checkpointable value, got %s which does not inherit "
|
||||
"from CheckpointableBase.") % (checkpointable,))
|
||||
("Expected a trackable value, got %s which does not inherit "
|
||||
"from tf.track.Trackable.") % (trackable,))
|
||||
|
||||
def _format_name(prefix, number):
|
||||
if number > 0:
|
||||
@ -80,5 +80,5 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure):
|
||||
count += 1
|
||||
candidate = _format_name(base_name, count)
|
||||
self._name_counts[base_name] = count + 1
|
||||
self._track_value(checkpointable, name=candidate)
|
||||
return checkpointable
|
||||
self._track_value(trackable, name=candidate)
|
||||
return trackable
|
||||
|
@ -26,9 +26,9 @@ from tensorflow.python.keras import layers
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.checkpointable import data_structures
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.checkpointable import util
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
|
||||
|
||||
class UniqueNameTrackerTests(test.TestCase):
|
||||
@ -52,7 +52,7 @@ class UniqueNameTrackerTests(test.TestCase):
|
||||
save_root = util.Checkpoint(slots=slots)
|
||||
save_path = save_root.save(checkpoint_prefix)
|
||||
|
||||
restore_slots = tracking.AutoCheckpointable()
|
||||
restore_slots = tracking.AutoTrackable()
|
||||
restore_root = util.Checkpoint(
|
||||
slots=restore_slots)
|
||||
status = restore_root.restore(save_path)
|
||||
@ -68,7 +68,7 @@ class UniqueNameTrackerTests(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testExample(self):
|
||||
class SlotManager(tracking.AutoCheckpointable):
|
||||
class SlotManager(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
self.slotdeps = containers.UniqueNameTracker()
|
||||
|
@ -23,7 +23,7 @@ import six
|
||||
|
||||
import numpy
|
||||
|
||||
from tensorflow.python.training.checkpointable import base
|
||||
from tensorflow.python.training.tracking import base
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
@ -34,8 +34,8 @@ except ImportError:
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
class NumpyState(base.Checkpointable):
|
||||
"""A checkpointable object whose NumPy array attributes are saved/restored.
|
||||
class NumpyState(base.Trackable):
|
||||
"""A trackable object whose NumPy array attributes are saved/restored.
|
||||
|
||||
Example usage:
|
||||
|
||||
@ -72,7 +72,7 @@ class NumpyState(base.Checkpointable):
|
||||
"""Create placeholder NumPy arrays for to-be-restored attributes.
|
||||
|
||||
Typically `_lookup_dependency` is used to check by name whether a dependency
|
||||
exists. We cheat slightly by creating a checkpointable object for `name` if
|
||||
exists. We cheat slightly by creating a trackable object for `name` if
|
||||
we don't already have one, giving us attribute re-creation behavior when
|
||||
loading a checkpoint.
|
||||
|
||||
@ -85,7 +85,7 @@ class NumpyState(base.Checkpointable):
|
||||
value = super(NumpyState, self)._lookup_dependency(name)
|
||||
if value is None:
|
||||
value = _NumpyWrapper(numpy.array([]))
|
||||
new_reference = base.CheckpointableReference(name=name, ref=value)
|
||||
new_reference = base.TrackableReference(name=name, ref=value)
|
||||
self._unconditional_checkpoint_dependencies.append(new_reference)
|
||||
self._unconditional_dependency_names[name] = value
|
||||
super(NumpyState, self).__setattr__(name, value)
|
||||
@ -101,7 +101,7 @@ class NumpyState(base.Checkpointable):
|
||||
def __setattr__(self, name, value):
|
||||
"""Automatically wrap NumPy arrays assigned to attributes."""
|
||||
# TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making
|
||||
# ndarrays checkpointable natively and using standard checkpointable list
|
||||
# ndarrays trackable natively and using standard trackable list
|
||||
# tracking.
|
||||
if isinstance(value, (numpy.ndarray, numpy.generic)):
|
||||
try:
|
||||
@ -110,19 +110,19 @@ class NumpyState(base.Checkpointable):
|
||||
return
|
||||
except AttributeError:
|
||||
value = _NumpyWrapper(value)
|
||||
self._track_checkpointable(value, name=name, overwrite=True)
|
||||
self._track_trackable(value, name=name, overwrite=True)
|
||||
elif (name not in ("_setattr_tracking", "_update_uid")
|
||||
and getattr(self, "_setattr_tracking", True)):
|
||||
# Mixing restore()-created attributes with user-added checkpointable
|
||||
# Mixing restore()-created attributes with user-added trackable
|
||||
# objects is tricky, since we can't use the `_lookup_dependency` trick to
|
||||
# re-create attributes (we might accidentally steal the restoration for
|
||||
# another checkpointable object). For now `NumpyState` objects must be
|
||||
# another trackable object). For now `NumpyState` objects must be
|
||||
# leaf nodes. Theoretically we could add some extra arguments to
|
||||
# `_lookup_dependency` to figure out whether we should create a NumPy
|
||||
# array for the attribute or not.
|
||||
raise NotImplementedError(
|
||||
("Assigned %s to the %s property of %s, which is not a NumPy array. "
|
||||
"Currently mixing NumPy arrays and other checkpointable objects is "
|
||||
"Currently mixing NumPy arrays and other trackable objects is "
|
||||
"not supported. File a feature request if this limitation bothers "
|
||||
"you.")
|
||||
% (value, name, self))
|
||||
@ -130,7 +130,7 @@ class NumpyState(base.Checkpointable):
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class PythonStateWrapper(base.Checkpointable):
|
||||
class PythonStateWrapper(base.Trackable):
|
||||
"""Wraps a Python object for storage in an object-based checkpoint."""
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -26,7 +26,7 @@ from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.checkpointable import util
|
||||
from tensorflow.python.training.tracking import util
|
||||
|
||||
|
||||
class NumpyStateTests(test.TestCase):
|
||||
|
@ -21,7 +21,7 @@ import functools
|
||||
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
|
||||
|
||||
class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
|
||||
@ -43,7 +43,7 @@ class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
|
||||
return self._restore_callback(tensor)
|
||||
|
||||
|
||||
class _SplitDependency(checkpointable.Checkpointable):
|
||||
class _SplitDependency(trackable.Trackable):
|
||||
"""Looks like a regular variable while synchronizing save/restores."""
|
||||
|
||||
def __init__(self, save_buffer, restore_buffer, name, dtype, num_components,
|
||||
@ -81,9 +81,9 @@ class _SplitDependency(checkpointable.Checkpointable):
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
"""Looks to Checkpointable like a regular variable."""
|
||||
"""Looks to Trackable like a regular variable."""
|
||||
return {
|
||||
checkpointable.VARIABLE_VALUE_KEY:
|
||||
trackable.VARIABLE_VALUE_KEY:
|
||||
functools.partial(_CallbackSaveable,
|
||||
dtype=self._dtype,
|
||||
save_callback=self._save,
|
||||
@ -117,7 +117,7 @@ def split_dependency(component_names, component_dtypes,
|
||||
may return `None`).
|
||||
|
||||
Returns:
|
||||
A dictionary mapping from names to Checkpointable objects. If one is
|
||||
A dictionary mapping from names to Trackable objects. If one is
|
||||
reachable from an object as a dependency, the others should be too; adding
|
||||
dependencies on some but not all of the objects will result in errors.
|
||||
"""
|
||||
|
@ -23,9 +23,9 @@ from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.training.checkpointable import base
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.checkpointable import util
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
|
||||
|
||||
def _split_variable_closure(variable):
|
||||
@ -44,7 +44,7 @@ def _combine_variable_closure(variable):
|
||||
return _consume_restore_buffer_fn
|
||||
|
||||
|
||||
class SaveTensorSlicesAsDeps(base.Checkpointable):
|
||||
class SaveTensorSlicesAsDeps(base.Trackable):
|
||||
|
||||
def __init__(self):
|
||||
self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.])
|
||||
@ -56,17 +56,17 @@ class SaveTensorSlicesAsDeps(base.Checkpointable):
|
||||
consume_restore_buffer_fn=_combine_variable_closure(
|
||||
self.combined))
|
||||
for name, dep in split_dependencies.items():
|
||||
self._track_checkpointable(dep, name=name)
|
||||
self._track_trackable(dep, name=name)
|
||||
|
||||
|
||||
class HasRegularDeps(tracking.AutoCheckpointable):
|
||||
class HasRegularDeps(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
|
||||
self.second_half = resource_variable_ops.ResourceVariable([0., 0.])
|
||||
|
||||
|
||||
class OnlyOneDep(tracking.AutoCheckpointable):
|
||||
class OnlyOneDep(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
|
||||
|
@ -18,8 +18,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
def dot_graph_from_checkpoint(save_path):
|
||||
@ -51,7 +51,7 @@ def dot_graph_from_checkpoint(save_path):
|
||||
A graph in DOT format as a string.
|
||||
"""
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
|
||||
object_graph = checkpointable_utils.object_metadata(save_path)
|
||||
object_graph = trackable_utils.object_metadata(save_path)
|
||||
shape_map = reader.get_variable_to_shape_map()
|
||||
dtype_map = reader.get_variable_to_dtype_map()
|
||||
graph = 'digraph {\n'
|
||||
@ -63,7 +63,7 @@ def dot_graph_from_checkpoint(save_path):
|
||||
slot_ids.add(slot_reference.slot_variable_node_id)
|
||||
for node_id, node in enumerate(object_graph.nodes):
|
||||
if (len(node.attributes) == 1
|
||||
and node.attributes[0].name == checkpointable.VARIABLE_VALUE_KEY):
|
||||
and node.attributes[0].name == trackable.VARIABLE_VALUE_KEY):
|
||||
if node_id in slot_ids:
|
||||
color = 'orange'
|
||||
tooltip_prefix = 'Slot variable'
|
||||
|
@ -28,7 +28,7 @@ from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
try:
|
||||
import pydot # pylint: disable=g-import-not-at-top
|
||||
@ -57,7 +57,7 @@ class DotGraphTests(test.TestCase):
|
||||
model = MyModel()
|
||||
optimizer = adam.AdamOptimizer(0.001)
|
||||
optimizer_step = resource_variable_ops.ResourceVariable(12)
|
||||
save_checkpoint = checkpointable_utils.Checkpoint(
|
||||
save_checkpoint = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model, optimizer_step=optimizer_step)
|
||||
optimizer.minimize(functools.partial(model, input_value))
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
|
@ -72,7 +72,7 @@ tensorflow/python/tools
|
||||
tensorflow/python/tools/api
|
||||
tensorflow/python/tools/api/generator
|
||||
tensorflow/python/training
|
||||
tensorflow/python/training/checkpointable
|
||||
tensorflow/python/training/tracking
|
||||
tensorflow/python/user_ops
|
||||
tensorflow/python/util
|
||||
tensorflow/python/util/protobuf
|
||||
|
@ -58,7 +58,7 @@ from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import momentum
|
||||
from tensorflow.python.training import rmsprop
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM
|
||||
@ -709,7 +709,7 @@ class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase):
|
||||
self._TestSaveRestoreHelper(CUDNN_RNN_RELU)
|
||||
|
||||
|
||||
class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
|
||||
class CudnnRNNTestSaveRestoreTrackable(test_util.TensorFlowTestCase):
|
||||
|
||||
def _VerifyCheckpoint(
|
||||
self, checkpoint_path, compatible_cell_fn, cudnn_cell_fn,
|
||||
@ -718,7 +718,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
with ops.device("gpu:0"):
|
||||
cudnn_layer = cudnn_cell_fn()
|
||||
cudnn_checkpoint = checkpointable_utils.Checkpoint(cell=cudnn_layer)
|
||||
cudnn_checkpoint = trackable_utils.Checkpoint(cell=cudnn_layer)
|
||||
status = cudnn_checkpoint.restore(checkpoint_path)
|
||||
inputs = 3. * array_ops.ones([num_applications, num_layers, input_size],
|
||||
dtype=dtypes.float32)
|
||||
@ -726,7 +726,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
|
||||
status.run_restore_ops()
|
||||
second_save_path = cudnn_checkpoint.save(checkpoint_prefix)
|
||||
restore_layer = compatible_cell_fn()
|
||||
restore_layer_checkpoint = checkpointable_utils.Checkpoint(
|
||||
restore_layer_checkpoint = trackable_utils.Checkpoint(
|
||||
cell=restore_layer)
|
||||
status = restore_layer_checkpoint.restore(second_save_path)
|
||||
current_state = restore_layer.zero_state(1, dtypes.float32)
|
||||
@ -742,7 +742,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
|
||||
self.assertAllClose(self.evaluate(restore_layer_output),
|
||||
self.evaluate(cudnn_output)[-1, -1:, ...])
|
||||
|
||||
def _CheckpointableSingleCellUnidirectionalTestTemplate(
|
||||
def _TrackableSingleCellUnidirectionalTestTemplate(
|
||||
self, single_cell_fn, cudnn_cell_fn):
|
||||
# Single-layer cuDNN cells with object-based checkpointing should be
|
||||
# checkpoint compatible with either single CudnnCompatible cells or
|
||||
@ -759,7 +759,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
|
||||
value = np.random.normal(size=variable.shape)
|
||||
expected_values.append(value)
|
||||
self.evaluate(variable.assign(value))
|
||||
save_checkpoint = checkpointable_utils.Checkpoint(cell=save_cell_layer)
|
||||
save_checkpoint = trackable_utils.Checkpoint(cell=save_cell_layer)
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
first_save_path = save_checkpoint.save(checkpoint_prefix)
|
||||
@ -775,10 +775,10 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testLSTMCheckpointableSingleLayer(self):
|
||||
def testLSTMTrackableSingleLayer(self):
|
||||
num_units = 2
|
||||
direction = CUDNN_RNN_UNIDIRECTION
|
||||
self._CheckpointableSingleCellUnidirectionalTestTemplate(
|
||||
self._TrackableSingleCellUnidirectionalTestTemplate(
|
||||
single_cell_fn=functools.partial(
|
||||
cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units),
|
||||
cudnn_cell_fn=functools.partial(
|
||||
@ -788,19 +788,19 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testGRUCheckpointableSingleLayer(self):
|
||||
def testGRUTrackableSingleLayer(self):
|
||||
num_units = 2
|
||||
direction = CUDNN_RNN_UNIDIRECTION
|
||||
with self.assertRaises(NotImplementedError):
|
||||
# TODO(allenl): Implement object-based saving for GRUs and other cells.
|
||||
self._CheckpointableSingleCellUnidirectionalTestTemplate(
|
||||
self._TrackableSingleCellUnidirectionalTestTemplate(
|
||||
single_cell_fn=functools.partial(
|
||||
cudnn_rnn_ops.CudnnCompatibleGRUCell, num_units=num_units),
|
||||
cudnn_cell_fn=functools.partial(
|
||||
cudnn_rnn.CudnnGRU, num_layers=1, num_units=num_units,
|
||||
direction=direction, name="awesome_gru"))
|
||||
|
||||
def _CheckpointableMultiLayerTestTemplate(
|
||||
def _TrackableMultiLayerTestTemplate(
|
||||
self, single_cell_fn, cudnn_cell_fn, num_layers):
|
||||
|
||||
def _MultiCellFn():
|
||||
@ -819,7 +819,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
|
||||
value = np.random.normal(size=variable.shape)
|
||||
expected_values.append(value)
|
||||
self.evaluate(variable.assign(value))
|
||||
save_checkpoint = checkpointable_utils.Checkpoint(cell=save_layer)
|
||||
save_checkpoint = trackable_utils.Checkpoint(cell=save_layer)
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
first_save_path = save_checkpoint.save(checkpoint_prefix)
|
||||
@ -837,7 +837,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
|
||||
num_units = 2
|
||||
num_layers = 3
|
||||
direction = CUDNN_RNN_UNIDIRECTION
|
||||
self._CheckpointableMultiLayerTestTemplate(
|
||||
self._TrackableMultiLayerTestTemplate(
|
||||
single_cell_fn=functools.partial(
|
||||
cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units),
|
||||
cudnn_cell_fn=functools.partial(
|
||||
|
@ -518,8 +518,8 @@ class _CudnnRNN(base_layer.Layer):
|
||||
direction=self.direction,
|
||||
scope=vs.get_variable_scope(),
|
||||
name="%s_saveable" % self.trainable_variables[0].name.split(":")[0])
|
||||
self._saveable._add_checkpointable_dependencies( # pylint: disable=protected-access
|
||||
checkpointable=self, dtype=self._plain_dtype)
|
||||
self._saveable._add_trackable_dependencies( # pylint: disable=protected-access
|
||||
trackable=self, dtype=self._plain_dtype)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)
|
||||
|
||||
|
||||
|
@ -33,7 +33,7 @@ from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training.checkpointable import tracking as checkpointable_lib
|
||||
from tensorflow.python.training.tracking import tracking as trackable_lib
|
||||
|
||||
CUDNN_RNN_UNIDIRECTION = "unidirectional"
|
||||
CUDNN_RNN_BIDIRECTION = "bidirectional"
|
||||
@ -737,13 +737,13 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||
return state_ops.assign(
|
||||
self._variables, opaque_params, validate_shape=False)
|
||||
|
||||
def _checkpointable_save(self, save_buffer):
|
||||
def _trackable_save(self, save_buffer):
|
||||
weights, biases = self.format_converter.opaque_to_tf_canonical(
|
||||
self._variables)
|
||||
for name, tensor in zip(self._param_names, weights + biases):
|
||||
save_buffer[name] = array_ops.identity(tensor)
|
||||
|
||||
def _checkpointable_restore(self, restore_buffer):
|
||||
def _trackable_restore(self, restore_buffer):
|
||||
tensors = [
|
||||
array_ops.identity(restore_buffer[name]) for name in self._param_names
|
||||
]
|
||||
@ -752,26 +752,26 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||
restored_shapes=None # Unused
|
||||
)
|
||||
|
||||
def _add_checkpointable_dependencies(self, checkpointable, dtype):
|
||||
"""Add canonical weight dependencies to `checkpointable`.
|
||||
def _add_trackable_dependencies(self, trackable, dtype):
|
||||
"""Add canonical weight dependencies to `trackable`.
|
||||
|
||||
When saving or restoring, converts to or from the opaque buffer
|
||||
format. Weights are saved and loaded in the configuration expected by
|
||||
cuDNN-compatible cells.
|
||||
|
||||
Args:
|
||||
checkpointable: An object inheriting from `CheckpointableBase` to add
|
||||
trackable: An object inheriting from `Trackable` to add
|
||||
dependencies too (typically the cuDNN `Layer`).
|
||||
dtype: The dtype for the canonical parameter Tensors.
|
||||
"""
|
||||
split_dependencies = split_dependency.split_dependency(
|
||||
component_names=self._param_names,
|
||||
component_dtypes=(dtype,) * len(self._param_names),
|
||||
fill_save_buffer_fn=self._checkpointable_save,
|
||||
consume_restore_buffer_fn=self._checkpointable_restore)
|
||||
self._checkpointable_track_params(checkpointable, split_dependencies)
|
||||
fill_save_buffer_fn=self._trackable_save,
|
||||
consume_restore_buffer_fn=self._trackable_restore)
|
||||
self._trackable_track_params(trackable, split_dependencies)
|
||||
|
||||
def _checkpointable_track_params(self, checkpointable, params):
|
||||
def _trackable_track_params(self, trackable, params):
|
||||
"""Tracks parameters in a canonical configuration."""
|
||||
return # NotImplementedError raised by the Layer.
|
||||
|
||||
@ -819,7 +819,7 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable):
|
||||
tf_weights_names.append(prefix + "/kernel")
|
||||
tf_bias_names.append(prefix + "/bias")
|
||||
|
||||
def _checkpointable_track_params(self, checkpointable, params):
|
||||
def _trackable_track_params(self, trackable, params):
|
||||
"""Track parameters for compatibility with CudnnCompatibleLSTMCell."""
|
||||
biases = []
|
||||
weights = []
|
||||
@ -833,12 +833,12 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable):
|
||||
# wrapping.
|
||||
kernel, = weights # pylint: disable=unbalanced-tuple-unpacking
|
||||
bias, = biases # pylint: disable=unbalanced-tuple-unpacking
|
||||
checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access
|
||||
checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access
|
||||
trackable._track_trackable(kernel, name="kernel") # pylint: disable=protected-access
|
||||
trackable._track_trackable(bias, name="bias") # pylint: disable=protected-access
|
||||
assert len(biases) == len(weights)
|
||||
for cell_index, (bias, kernel) in enumerate(zip(biases, weights)):
|
||||
cell = checkpointable_lib.AutoCheckpointable()
|
||||
checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access
|
||||
cell = trackable_lib.AutoTrackable()
|
||||
trackable._track_trackable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access
|
||||
cell.bias = bias
|
||||
cell.kernel = kernel
|
||||
|
||||
|
@ -800,6 +800,6 @@ tf_xla_py_test(
|
||||
":tpu_strategy",
|
||||
"//tensorflow/compiler/tests:xla_test",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/training/checkpointable:util",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
],
|
||||
)
|
||||
|
@ -30,15 +30,15 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adam as adam_v1
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
class NonLayerCheckpointable(tracking.AutoCheckpointable):
|
||||
class NonLayerTrackable(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
super(NonLayerCheckpointable, self).__init__()
|
||||
self.a_variable = checkpointable_utils.add_variable(
|
||||
super(NonLayerTrackable, self).__init__()
|
||||
self.a_variable = trackable_utils.add_variable(
|
||||
self, name="a_variable", shape=[])
|
||||
|
||||
|
||||
@ -49,8 +49,8 @@ class Subclassed(training.Model):
|
||||
super(Subclassed, self).__init__()
|
||||
self._named_dense = core.Dense(1, use_bias=True)
|
||||
self._second = core.Dense(1, use_bias=False)
|
||||
# We can still track Checkpointables which aren't Layers.
|
||||
self._non_layer = NonLayerCheckpointable()
|
||||
# We can still track Trackables which aren't Layers.
|
||||
self._non_layer = NonLayerTrackable()
|
||||
|
||||
def call(self, values):
|
||||
ret = self._second(self._named_dense(values))
|
||||
@ -76,7 +76,7 @@ class TrainingCheckpointTests(xla_test.XLATestCase):
|
||||
with strategy.scope():
|
||||
model = Subclassed()
|
||||
optimizer = adam_v1.AdamOptimizer(0.001)
|
||||
root = checkpointable_utils.Checkpoint(
|
||||
root = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model,
|
||||
optimizer_step=training_util.get_or_create_global_step())
|
||||
root.restore(checkpoint_management.latest_checkpoint(
|
||||
|
@ -414,7 +414,7 @@ class TestDistributionStrategySaveLoadWeights(test.TestCase,
|
||||
|
||||
@combinations.generate(
|
||||
keras_test_lib.all_strategy_combinations_minus_default())
|
||||
def test_save_load_checkpointable(self, distribution):
|
||||
def test_save_load_trackable(self, distribution):
|
||||
# TODO(sourabhbajaj): Test fails with optimizer v2 without h5
|
||||
with self.cached_session():
|
||||
dataset = keras_test_lib.get_dataset(distribution)
|
||||
|
@ -144,7 +144,7 @@ py_library(
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -37,7 +37,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
class IteratorTest(test.TestCase):
|
||||
@ -238,7 +238,7 @@ class IteratorTest(test.TestCase):
|
||||
dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
|
||||
dataset = dataset.map(math_ops.square).batch(2)
|
||||
iterator = datasets.Iterator(dataset)
|
||||
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
|
||||
checkpoint = trackable_utils.Checkpoint(iterator=iterator)
|
||||
self.assertAllEqual([1, 4], iterator.get_next().numpy())
|
||||
save_path = checkpoint.save(checkpoint_prefix)
|
||||
self.assertAllEqual([9, 16], iterator.get_next().numpy())
|
||||
@ -257,7 +257,7 @@ class IteratorTest(test.TestCase):
|
||||
dataset_2 = Dataset.range(10)
|
||||
iterator_3 = datasets.Iterator(dataset_2)
|
||||
|
||||
checkpoint = checkpointable_utils.Checkpoint(
|
||||
checkpoint = trackable_utils.Checkpoint(
|
||||
iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
|
||||
self.assertAllEqual([1, 4], iterator_1.get_next().numpy())
|
||||
self.assertEqual(0, iterator_3.get_next().numpy())
|
||||
@ -279,7 +279,7 @@ class IteratorTest(test.TestCase):
|
||||
dataset = Dataset.range(3)
|
||||
iterator = datasets.Iterator(dataset)
|
||||
|
||||
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
|
||||
checkpoint = trackable_utils.Checkpoint(iterator=iterator)
|
||||
self.assertEqual(0, iterator.get_next().numpy())
|
||||
self.assertEqual(1, iterator.get_next().numpy())
|
||||
save_path = checkpoint.save(checkpoint_prefix)
|
||||
@ -293,7 +293,7 @@ class IteratorTest(test.TestCase):
|
||||
dataset = Dataset.range(10)
|
||||
for i in range(5):
|
||||
iterator = datasets.Iterator(dataset)
|
||||
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
|
||||
checkpoint = trackable_utils.Checkpoint(iterator=iterator)
|
||||
checkpoint.restore(checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory))
|
||||
for j in range(2):
|
||||
|
@ -37,7 +37,7 @@ from tensorflow.contrib.summary import summary_test_util
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
# pylint: enable=g-bad-import-order
|
||||
|
||||
|
||||
@ -421,7 +421,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# 5. Verify that checkpoints exist and contains all the expected variables.
|
||||
self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
|
||||
object_graph = checkpointable_utils.object_metadata(
|
||||
object_graph = trackable_utils.object_metadata(
|
||||
checkpoint_management.latest_checkpoint(config.logdir))
|
||||
ckpt_variable_names = set()
|
||||
for node in object_graph.nodes:
|
||||
|
@ -32,12 +32,12 @@ from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import summary_ops_v2 as summary_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
|
||||
_to_replace = re.compile("[^A-Za-z0-9.]")
|
||||
|
||||
|
||||
class Metric(checkpointable.Checkpointable):
|
||||
class Metric(trackable.Trackable):
|
||||
"""A metric holds state for aggregating statistics over an evaluation run.
|
||||
|
||||
Example use with eager execution:
|
||||
@ -269,7 +269,7 @@ class Metric(checkpointable.Checkpointable):
|
||||
else:
|
||||
collections = [ops.GraphKeys.LOCAL_VARIABLES]
|
||||
collections += [ops.GraphKeys.METRIC_VARIABLES]
|
||||
# Variables are Checkpointable dependencies of Metrics regardless of the
|
||||
# Variables are Trackable dependencies of Metrics regardless of the
|
||||
# global/local distinction. Users can avoid saving variables by not adding a
|
||||
# dependency on the Metric.
|
||||
v = self._add_variable_with_custom_getter(
|
||||
@ -282,7 +282,7 @@ class Metric(checkpointable.Checkpointable):
|
||||
use_resource=True,
|
||||
getter=variable_scope.get_variable,
|
||||
# Raise duplicate variable exceptions from get_variable rather than
|
||||
# Checkpointable.
|
||||
# Trackable.
|
||||
overwrite=True)
|
||||
self._vars.append(v)
|
||||
if context.executing_eagerly():
|
||||
|
@ -35,7 +35,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import summary_ops_v2 as summary_ops
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
class MetricsTest(test.TestCase):
|
||||
@ -314,7 +314,7 @@ class MetricsTest(test.TestCase):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
mean = metrics.Mean()
|
||||
checkpoint = checkpointable_utils.Checkpoint(mean=mean)
|
||||
checkpoint = trackable_utils.Checkpoint(mean=mean)
|
||||
mean.build()
|
||||
mean._built = True
|
||||
self.evaluate(mean.init_variables())
|
||||
@ -327,7 +327,7 @@ class MetricsTest(test.TestCase):
|
||||
self.assertAllEqual(200., self.evaluate(mean.value()))
|
||||
|
||||
restore_mean = metrics.Mean()
|
||||
restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean)
|
||||
restore_checkpoint = trackable_utils.Checkpoint(mean=restore_mean)
|
||||
status = restore_checkpoint.restore(save_path)
|
||||
restore_update = restore_mean(300.)
|
||||
status.assert_consumed().run_restore_ops()
|
||||
|
@ -31,7 +31,7 @@ from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
# pylint: disable=not-callable
|
||||
@ -65,7 +65,7 @@ class NetworkTest(test.TestCase):
|
||||
|
||||
def test_checkpointing_not_implemented(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint = checkpointable_utils.Checkpoint(net=MyNetwork())
|
||||
checkpoint = trackable_utils.Checkpoint(net=MyNetwork())
|
||||
with self.assertRaises(NotImplementedError):
|
||||
checkpoint.save(checkpoint_directory)
|
||||
|
||||
|
@ -30,7 +30,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
|
||||
|
||||
def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
|
||||
@ -129,8 +129,8 @@ class SharedVariable(resource_variable_ops.ResourceVariable):
|
||||
if constraint is not None and not callable(constraint):
|
||||
raise ValueError("The `constraint` argument must be a callable.")
|
||||
|
||||
if isinstance(initial_value, checkpointable.CheckpointInitialValue):
|
||||
self._maybe_initialize_checkpointable()
|
||||
if isinstance(initial_value, trackable.CheckpointInitialValue):
|
||||
self._maybe_initialize_trackable()
|
||||
self._update_uid = initial_value.checkpoint_position.restore_uid
|
||||
initial_value = initial_value.wrapped_value
|
||||
|
||||
|
@ -137,8 +137,8 @@ from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Vari
|
||||
from tensorflow.python.ops.variable_scope import EagerVariableStore
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import template
|
||||
from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable
|
||||
from tensorflow.python.training.checkpointable.util import Checkpoint
|
||||
from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable
|
||||
from tensorflow.python.training.tracking.util import Checkpoint
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
py_func = script_ops.eager_py_func
|
||||
|
@ -36,7 +36,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import deprecation
|
||||
|
||||
|
||||
# TODO(rohanj): This should subclass Checkpointable and implement
|
||||
# TODO(rohanj): This should subclass Trackable and implement
|
||||
# _gather_saveables_for_checkpoint.
|
||||
class ShardedMutableDenseHashTable(object):
|
||||
"""A sharded version of MutableDenseHashTable.
|
||||
|
@ -1483,3 +1483,4 @@ class IdTableWithHashBucketsTest(test.TestCase):
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
||||
|
@ -44,15 +44,15 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as core_saver
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.checkpointable import graph_view
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.checkpointable import util
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
|
||||
|
||||
class NonLayerCheckpointable(tracking.AutoCheckpointable):
|
||||
class NonLayerTrackable(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
super(NonLayerCheckpointable, self).__init__()
|
||||
super(NonLayerTrackable, self).__init__()
|
||||
self.a_variable = util.add_variable(
|
||||
self, name="a_variable", shape=[])
|
||||
|
||||
@ -65,8 +65,8 @@ class MyModel(training.Model):
|
||||
super(MyModel, self).__init__()
|
||||
self._named_dense = core.Dense(1, use_bias=True)
|
||||
self._second = core.Dense(1, use_bias=False)
|
||||
# We can still track Checkpointables which aren't Layers.
|
||||
self._non_layer = NonLayerCheckpointable()
|
||||
# We can still track Trackables which aren't Layers.
|
||||
self._non_layer = NonLayerTrackable()
|
||||
|
||||
def call(self, values):
|
||||
ret = self._second(self._named_dense(values))
|
||||
@ -101,7 +101,7 @@ class CheckpointingTests(test.TestCase):
|
||||
other_model = MyModel()
|
||||
optimizer = adam.AdamOptimizer(0.001)
|
||||
optimizer_step = training_util.get_or_create_global_step()
|
||||
root_checkpointable = util.Checkpoint(
|
||||
root_trackable = util.Checkpoint(
|
||||
optimizer=optimizer, model=model, optimizer_step=optimizer_step)
|
||||
if context.executing_eagerly():
|
||||
optimizer.minimize(
|
||||
@ -117,10 +117,10 @@ class CheckpointingTests(test.TestCase):
|
||||
other_model(input_value),
|
||||
global_step=optimizer_step)
|
||||
self.evaluate(util.gather_initializers(
|
||||
root_checkpointable))
|
||||
root_trackable))
|
||||
self.evaluate(train_op)
|
||||
named_variables, serialized_graph, _ = graph_view.ObjectGraphView(
|
||||
root_checkpointable).serialize_object_graph()
|
||||
root_trackable).serialize_object_graph()
|
||||
expected_checkpoint_names = (
|
||||
# Created in the root node, so no prefix.
|
||||
"optimizer_step",
|
||||
@ -208,7 +208,7 @@ class CheckpointingTests(test.TestCase):
|
||||
def testSaveRestore(self):
|
||||
model = MyModel()
|
||||
optimizer = adam.AdamOptimizer(0.001)
|
||||
root_checkpointable = util.Checkpoint(
|
||||
root_trackable = util.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
if context.executing_eagerly():
|
||||
@ -217,24 +217,24 @@ class CheckpointingTests(test.TestCase):
|
||||
else:
|
||||
train_op = optimizer.minimize(model(input_value))
|
||||
# TODO(allenl): Make initialization more pleasant when graph building.
|
||||
root_checkpointable.save_counter # pylint: disable=pointless-statement
|
||||
root_trackable.save_counter # pylint: disable=pointless-statement
|
||||
self.evaluate(util.gather_initializers(
|
||||
root_checkpointable))
|
||||
root_trackable))
|
||||
self.evaluate(train_op)
|
||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.]))
|
||||
m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m")
|
||||
self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
|
||||
save_path = root_checkpointable.save(file_prefix=prefix)
|
||||
save_path = root_trackable.save(file_prefix=prefix)
|
||||
self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.]))
|
||||
self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3))
|
||||
self.evaluate(state_ops.assign(root_trackable.save_counter, 3))
|
||||
optimizer_variables = self.evaluate(optimizer.variables())
|
||||
self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
|
||||
# Immediate restoration
|
||||
status = root_checkpointable.restore(save_path=save_path).assert_consumed()
|
||||
status = root_trackable.restore(save_path=save_path).assert_consumed()
|
||||
status.run_restore_ops()
|
||||
self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1]))
|
||||
self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter))
|
||||
self.assertAllEqual(1, self.evaluate(root_trackable.save_counter))
|
||||
self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
|
||||
if not context.executing_eagerly():
|
||||
return # Restore-on-create is only supported when executing eagerly
|
||||
@ -542,11 +542,11 @@ class CheckpointingTests(test.TestCase):
|
||||
first_session = session_lib.Session(graph=first_graph)
|
||||
with first_graph.as_default(), first_session.as_default():
|
||||
first_variable = resource_variable_ops.ResourceVariable([1.])
|
||||
first_root_checkpointable = util.Checkpoint(
|
||||
first_root_trackable = util.Checkpoint(
|
||||
optimizer=optimizer, variable=first_variable)
|
||||
train_op = optimizer.minimize(first_variable.read_value)
|
||||
self.evaluate(util.gather_initializers(
|
||||
first_root_checkpointable))
|
||||
first_root_trackable))
|
||||
self.evaluate(train_op)
|
||||
self.evaluate(first_variable.assign([1.]))
|
||||
self.evaluate(optimizer.get_slot(
|
||||
@ -558,23 +558,23 @@ class CheckpointingTests(test.TestCase):
|
||||
second_graph = ops.Graph()
|
||||
with second_graph.as_default(), session_lib.Session(graph=second_graph):
|
||||
second_variable = resource_variable_ops.ResourceVariable([1.])
|
||||
second_root_checkpointable = util.Checkpoint(
|
||||
second_root_trackable = util.Checkpoint(
|
||||
optimizer=optimizer, variable=second_variable)
|
||||
train_op = optimizer.minimize(second_variable.read_value)
|
||||
second_root_checkpointable.restore(None).initialize_or_restore()
|
||||
second_root_trackable.restore(None).initialize_or_restore()
|
||||
self.evaluate(train_op)
|
||||
self.evaluate(second_variable.assign([4.]))
|
||||
self.evaluate(optimizer.get_slot(
|
||||
var=second_variable, name="m").assign([5.]))
|
||||
beta_1_power, _ = optimizer._get_beta_accumulators()
|
||||
self.evaluate(beta_1_power.assign(6.))
|
||||
save_path = second_root_checkpointable.save(checkpoint_prefix)
|
||||
save_path = second_root_trackable.save(checkpoint_prefix)
|
||||
self.evaluate(second_variable.assign([7.]))
|
||||
self.evaluate(optimizer.get_slot(
|
||||
var=second_variable, name="m").assign([8.]))
|
||||
beta_1_power, _ = optimizer._get_beta_accumulators()
|
||||
self.assertAllEqual(6., self.evaluate(beta_1_power))
|
||||
status = second_root_checkpointable.restore(save_path)
|
||||
status = second_root_trackable.restore(save_path)
|
||||
status.assert_consumed().run_restore_ops()
|
||||
self.assertAllEqual([4.], self.evaluate(second_variable))
|
||||
self.assertAllEqual([5.], self.evaluate(optimizer.get_slot(
|
||||
@ -594,7 +594,7 @@ class CheckpointingTests(test.TestCase):
|
||||
class TemplateTests(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_checkpointable_save_restore(self):
|
||||
def test_trackable_save_restore(self):
|
||||
|
||||
def _templated():
|
||||
v = variable_scope.get_variable(
|
||||
@ -641,13 +641,13 @@ class CheckpointCompatibilityTests(test.TestCase):
|
||||
model = MyModel()
|
||||
optimizer = adam.AdamOptimizer(0.001)
|
||||
optimizer_step = training_util.get_or_create_global_step()
|
||||
root_checkpointable = util.Checkpoint(
|
||||
root_trackable = util.Checkpoint(
|
||||
optimizer=optimizer, model=model, optimizer_step=optimizer_step)
|
||||
train_op = optimizer.minimize(
|
||||
functools.partial(model, input_value),
|
||||
global_step=optimizer_step)
|
||||
self.evaluate(util.gather_initializers(
|
||||
root_checkpointable))
|
||||
root_trackable))
|
||||
self.evaluate(train_op)
|
||||
# A regular variable, a slot variable, and a non-slot Optimizer variable
|
||||
# with known values to check when loading.
|
||||
@ -656,24 +656,24 @@ class CheckpointCompatibilityTests(test.TestCase):
|
||||
var=model._named_dense.bias, name="m").assign([2.]))
|
||||
beta_1_power, _ = optimizer._get_beta_accumulators()
|
||||
self.evaluate(beta_1_power.assign(3.))
|
||||
return root_checkpointable
|
||||
return root_trackable
|
||||
|
||||
def _set_sentinels(self, root_checkpointable):
|
||||
self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.]))
|
||||
def _set_sentinels(self, root_trackable):
|
||||
self.evaluate(root_trackable.model._named_dense.bias.assign([101.]))
|
||||
self.evaluate(
|
||||
root_checkpointable.optimizer.get_slot(
|
||||
var=root_checkpointable.model._named_dense.bias, name="m")
|
||||
root_trackable.optimizer.get_slot(
|
||||
var=root_trackable.model._named_dense.bias, name="m")
|
||||
.assign([102.]))
|
||||
beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
|
||||
beta_1_power, _ = root_trackable.optimizer._get_beta_accumulators()
|
||||
self.evaluate(beta_1_power.assign(103.))
|
||||
|
||||
def _check_sentinels(self, root_checkpointable):
|
||||
def _check_sentinels(self, root_trackable):
|
||||
self.assertAllEqual(
|
||||
[1.], self.evaluate(root_checkpointable.model._named_dense.bias))
|
||||
[1.], self.evaluate(root_trackable.model._named_dense.bias))
|
||||
self.assertAllEqual([2.], self.evaluate(
|
||||
root_checkpointable.optimizer.get_slot(
|
||||
var=root_checkpointable.model._named_dense.bias, name="m")))
|
||||
beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
|
||||
root_trackable.optimizer.get_slot(
|
||||
var=root_trackable.model._named_dense.bias, name="m")))
|
||||
beta_1_power, _ = root_trackable.optimizer._get_beta_accumulators()
|
||||
self.assertAllEqual(3., self.evaluate(beta_1_power))
|
||||
|
||||
def _write_name_based_checkpoint(self):
|
||||
@ -698,7 +698,7 @@ class CheckpointCompatibilityTests(test.TestCase):
|
||||
self._set_sentinels(root)
|
||||
with self.assertRaises(AssertionError):
|
||||
self._check_sentinels(root)
|
||||
object_saver = util.CheckpointableSaver(graph_view.ObjectGraphView(root))
|
||||
object_saver = util.TrackableSaver(graph_view.ObjectGraphView(root))
|
||||
self._set_sentinels(root)
|
||||
status = object_saver.restore(save_path)
|
||||
if context.executing_eagerly():
|
||||
|
@ -38,7 +38,7 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import optimizer as optimizer_v1
|
||||
from tensorflow.python.training import slot_creator
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
@ -223,7 +223,7 @@ class _OptimizerV2State(object):
|
||||
}
|
||||
self._slots = {}
|
||||
self._non_slot_dict = {}
|
||||
# Extra state to help Optimizers implement Checkpointable. Holds information
|
||||
# Extra state to help Optimizers implement Trackable. Holds information
|
||||
# about variables which will be restored as soon as they're created.
|
||||
self._deferred_dependencies = {} # Non-slot variables
|
||||
self._deferred_slot_restorations = {} # Slot variables
|
||||
@ -366,8 +366,8 @@ class _OptimizerV2State(object):
|
||||
slot variable needs to be restored).
|
||||
|
||||
Args:
|
||||
slot_variable_position: A `checkpointable._CheckpointPosition` object
|
||||
indicating the slot variable `Checkpointable` object to be restored.
|
||||
slot_variable_position: A `trackable._CheckpointPosition` object
|
||||
indicating the slot variable `Trackable` object to be restored.
|
||||
slot_name: The name of this `Optimizer`'s slot to restore into.
|
||||
variable: The variable object this slot is being created for.
|
||||
optional_op_name: Name to use when scoping the Variable that needs to be
|
||||
@ -385,7 +385,7 @@ class _OptimizerV2State(object):
|
||||
# (aside from double initialization), and makes variable creator scopes
|
||||
# behave the same way they do when graph building.
|
||||
and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
|
||||
initializer = checkpointable.CheckpointInitialValue(
|
||||
initializer = trackable.CheckpointInitialValue(
|
||||
checkpoint_position=slot_variable_position)
|
||||
slot_variable = self.create_slot(
|
||||
var=variable,
|
||||
@ -1259,10 +1259,10 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
||||
return self._per_graph_state.get(var._graph_key, None)
|
||||
|
||||
# --------------
|
||||
# Overridden methods from Checkpointable.
|
||||
# Overridden methods from Trackable.
|
||||
# --------------
|
||||
|
||||
def _track_checkpointable(self, *args, **kwargs):
|
||||
def _track_trackable(self, *args, **kwargs):
|
||||
"""Optimizers may not track dependencies. Raises an error."""
|
||||
raise NotImplementedError(
|
||||
"Optimizers may not have dependencies. File a feature request if this "
|
||||
@ -1270,7 +1270,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
||||
|
||||
@property
|
||||
def _checkpoint_dependencies(self):
|
||||
"""From Checkpointable. Gather graph-specific non-slot variables to save."""
|
||||
"""From Trackable. Gather graph-specific non-slot variables to save."""
|
||||
current_graph_non_slot_variables = []
|
||||
state = self._get_per_graph_state()
|
||||
if state is not None:
|
||||
@ -1279,14 +1279,14 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
||||
# Avoid comparing variables
|
||||
key=lambda item: item[0]):
|
||||
current_graph_non_slot_variables.append(
|
||||
checkpointable.CheckpointableReference(
|
||||
trackable.TrackableReference(
|
||||
name=name, ref=variable_object))
|
||||
# Note: ignores super(); Optimizers may not have any dependencies outside of
|
||||
# state objects.
|
||||
return current_graph_non_slot_variables
|
||||
|
||||
def _lookup_dependency(self, name):
|
||||
"""From Checkpointable. Find a non-slot variable in the current graph."""
|
||||
"""From Trackable. Find a non-slot variable in the current graph."""
|
||||
state = self._get_per_graph_state()
|
||||
if state is None:
|
||||
return None
|
||||
@ -1295,10 +1295,10 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
||||
|
||||
@property
|
||||
def _deferred_dependencies(self):
|
||||
"""Lets Checkpointable know where non-slot variables are created.
|
||||
"""Lets Trackable know where non-slot variables are created.
|
||||
|
||||
If necessary, creates a new state object for the current default graph.
|
||||
Checkpointable will then add entries to that state's deferred dependency
|
||||
Trackable will then add entries to that state's deferred dependency
|
||||
dictionary. The state object will check that dictionary when creating
|
||||
non-slot variables, restoring their value if an entry is found.
|
||||
|
||||
@ -1311,14 +1311,14 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
||||
|
||||
def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
|
||||
variable):
|
||||
"""Checkpointable: Restore a slot variable's value, possibly creating it.
|
||||
"""Trackable: Restore a slot variable's value, possibly creating it.
|
||||
|
||||
Called when a variable which has an associated slot variable is created or
|
||||
restored.
|
||||
|
||||
Args:
|
||||
slot_variable_position: A `checkpointable._CheckpointPosition` object
|
||||
indicating the slot variable `Checkpointable` object to be restored.
|
||||
slot_variable_position: A `trackable._CheckpointPosition` object
|
||||
indicating the slot variable `Trackable` object to be restored.
|
||||
slot_name: The name of this `Optimizer`'s slot to restore into.
|
||||
variable: The variable object this slot is being created for.
|
||||
"""
|
||||
|
@ -35,7 +35,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import resources
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
|
||||
_model_ops = loader.load_op_library(
|
||||
|
@ -32,7 +32,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import resources
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
|
||||
_stats_ops = loader.load_op_library(
|
||||
|
@ -228,7 +228,7 @@ CORE_PROTO_SRCS = COMMON_PROTO_SRCS + ERROR_CODES_PROTO_SRCS
|
||||
# ones with individual proto_library targets.
|
||||
ADDITIONAL_CORE_PROTO_SRCS = [
|
||||
"example/example_parser_configuration.proto",
|
||||
"protobuf/checkpointable_object_graph.proto",
|
||||
"protobuf/trackable_object_graph.proto",
|
||||
"protobuf/control_flow.proto",
|
||||
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
|
||||
# "protobuf/critical_section.proto",
|
||||
|
@ -8,10 +8,10 @@ package tensorflow;
|
||||
// own variables, allowing for more robust checkpoint loading into modified
|
||||
// programs.
|
||||
|
||||
message CheckpointableObjectGraph {
|
||||
message CheckpointableObject {
|
||||
message TrackableObjectGraph {
|
||||
message TrackableObject {
|
||||
message ObjectReference {
|
||||
// An index into `CheckpointableObjectGraph.nodes`, indicating the object
|
||||
// An index into `TrackableObjectGraph.nodes`, indicating the object
|
||||
// being referenced.
|
||||
int32 node_id = 1;
|
||||
// A user-provided name for the edge.
|
||||
@ -37,12 +37,12 @@ message CheckpointableObjectGraph {
|
||||
}
|
||||
|
||||
message SlotVariableReference {
|
||||
// An index into `CheckpointableObjectGraph.nodes`, indicating the
|
||||
// An index into `TrackableObjectGraph.nodes`, indicating the
|
||||
// variable object this slot was created for.
|
||||
int32 original_variable_node_id = 1;
|
||||
// The name of the slot (e.g. "m"/"v").
|
||||
string slot_name = 2;
|
||||
// An index into `CheckpointableObjectGraph.nodes`, indicating the
|
||||
// An index into `TrackableObjectGraph.nodes`, indicating the
|
||||
// `Object` with the value of the slot variable.
|
||||
int32 slot_variable_node_id = 3;
|
||||
}
|
||||
@ -55,5 +55,5 @@ message CheckpointableObjectGraph {
|
||||
repeated SlotVariableReference slot_variables = 3;
|
||||
}
|
||||
|
||||
repeated CheckpointableObject nodes = 1;
|
||||
repeated TrackableObject nodes = 1;
|
||||
}
|
@ -33,7 +33,7 @@ def main(argv):
|
||||
del argv
|
||||
|
||||
root = tf.train.Checkpoint()
|
||||
# Create a cell and attach to our checkpointable.
|
||||
# Create a cell and attach to our trackable.
|
||||
root.rnn_cell = tf.keras.layers.LSTMCell(units=10, recurrent_initializer=None)
|
||||
|
||||
# Wrap the rnn_cell.__call__ function and assign to next_state.
|
||||
|
@ -27,7 +27,7 @@ import tensorflow as tf
|
||||
|
||||
# TODO(vbardiovsky): remove these when symbols are public.
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
@ -1815,7 +1815,7 @@ tf_gen_op_wrapper_private_py(
|
||||
visibility = [
|
||||
"//learning/brain/python/ops:__pkg__",
|
||||
"//tensorflow/python/kernel_tests:__pkg__",
|
||||
"//tensorflow/python/training/checkpointable:__pkg__",
|
||||
"//tensorflow/python/training/tracking:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
@ -3371,7 +3371,7 @@ py_library(
|
||||
":util",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
)
|
||||
|
||||
@ -3730,7 +3730,7 @@ py_library(
|
||||
["training/**/*.py"],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
"training/checkpointable/**/*.py",
|
||||
"training/tracking/**/*.py",
|
||||
"training/saving/**/*.py",
|
||||
# The following targets have their own build rules (same name as the
|
||||
# file):
|
||||
@ -3791,8 +3791,8 @@ py_library(
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/keras/optimizer_v2:learning_rate_schedule",
|
||||
"//tensorflow/python/ops/losses",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/checkpointable:util",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
@ -3872,9 +3872,9 @@ py_library(
|
||||
":variables",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/saving:saveable_object",
|
||||
"//tensorflow/python/training/saving:saveable_object_util",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
@ -291,7 +291,7 @@ tf_py_test(
|
||||
":test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/training/checkpointable:util",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
"//tensorflow/python:checkpoint_management",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
@ -344,7 +344,7 @@ cuda_py_test(
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/util:structure",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/training/checkpointable:util",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
|
@ -28,7 +28,7 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@ -43,7 +43,7 @@ class IteratorCheckpointingTest(test_base.DatasetTestBase):
|
||||
) else dataset_ops.make_one_shot_iterator(dataset)
|
||||
get_next = iterator.get_next if context.executing_eagerly(
|
||||
) else functools.partial(self.evaluate, iterator.get_next())
|
||||
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
|
||||
checkpoint = trackable_utils.Checkpoint(iterator=iterator)
|
||||
self.assertAllEqual([1, 4], get_next())
|
||||
save_path = checkpoint.save(checkpoint_prefix)
|
||||
self.assertAllEqual([9, 16], get_next())
|
||||
@ -73,7 +73,7 @@ class IteratorCheckpointingTest(test_base.DatasetTestBase):
|
||||
) else dataset_ops.make_one_shot_iterator(dataset_2)
|
||||
get_next_3 = iterator_3.get_next if context.executing_eagerly(
|
||||
) else functools.partial(self.evaluate, iterator_3.get_next())
|
||||
checkpoint = checkpointable_utils.Checkpoint(
|
||||
checkpoint = trackable_utils.Checkpoint(
|
||||
iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
|
||||
self.assertAllEqual([1, 4], get_next_1())
|
||||
self.assertAllEqual(0, get_next_3())
|
||||
@ -96,7 +96,7 @@ class IteratorCheckpointingTest(test_base.DatasetTestBase):
|
||||
) else dataset_ops.make_one_shot_iterator(dataset)
|
||||
get_next = iterator.get_next if context.executing_eagerly(
|
||||
) else functools.partial(self.evaluate, iterator.get_next())
|
||||
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
|
||||
checkpoint = trackable_utils.Checkpoint(iterator=iterator)
|
||||
self.assertAllEqual(0, get_next())
|
||||
self.assertAllEqual(1, get_next())
|
||||
save_path = checkpoint.save(checkpoint_prefix)
|
||||
@ -115,7 +115,7 @@ class IteratorCheckpointingTest(test_base.DatasetTestBase):
|
||||
iterator = iter(dataset) if context.executing_eagerly(
|
||||
) else dataset_ops.make_initializable_iterator(dataset)
|
||||
get_next = iterator.get_next
|
||||
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
|
||||
checkpoint = trackable_utils.Checkpoint(iterator=iterator)
|
||||
for i in range(5):
|
||||
checkpoint.restore(
|
||||
checkpoint_management.latest_checkpoint(
|
||||
|
@ -74,7 +74,7 @@ py_library(
|
||||
"//tensorflow/python/data/util:sparse",
|
||||
"//tensorflow/python/data/util:structure",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -31,8 +31,8 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.saver import BaseSaverBuilder
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -68,7 +68,7 @@ def _device_stack_is_empty():
|
||||
|
||||
|
||||
@tf_export(v1=["data.Iterator"])
|
||||
class Iterator(checkpointable.Checkpointable):
|
||||
class Iterator(trackable.Trackable):
|
||||
"""Represents the state of iterating through a `Dataset`."""
|
||||
|
||||
def __init__(self, iterator_resource, initializer, output_types,
|
||||
@ -491,7 +491,7 @@ def _generate_shared_name(prefix):
|
||||
return "{}{}".format(prefix, uid)
|
||||
|
||||
|
||||
class EagerIterator(checkpointable.Checkpointable):
|
||||
class EagerIterator(trackable.Trackable):
|
||||
"""An iterator producing tf.Tensor objects from a tf.data.Dataset."""
|
||||
|
||||
def __init__(self, dataset):
|
||||
@ -641,7 +641,7 @@ class EagerIterator(checkpointable.Checkpointable):
|
||||
return {"ITERATOR": _saveable_factory}
|
||||
|
||||
|
||||
# TODO(b/71645805): Expose checkpointable stateful objects from dataset
|
||||
# TODO(b/71645805): Expose trackable stateful objects from dataset
|
||||
# attributes(potential).
|
||||
class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
|
||||
"""SaveableObject for saving/restoring iterator state."""
|
||||
|
@ -460,7 +460,7 @@ py_library(
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -37,7 +37,7 @@ from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
@ -630,7 +630,7 @@ class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
|
||||
|
||||
|
||||
class MirroredVariable(DistributedVariable, Mirrored,
|
||||
checkpointable.Checkpointable):
|
||||
trackable.Trackable):
|
||||
"""Holds a map from device to variables whose values are kept in sync."""
|
||||
|
||||
def __init__(
|
||||
@ -710,7 +710,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
|
||||
return self.get()._as_graph_element()
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
"""Overrides CheckpointableBase method.
|
||||
"""Overrides Trackable method.
|
||||
|
||||
This allows both name-based and object-based save and restore of
|
||||
MirroredVariables.
|
||||
@ -720,7 +720,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
|
||||
"""
|
||||
def _saveable_factory(name=self._common_name):
|
||||
return _MirroredSaveable(self, self.primary, name)
|
||||
return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||
|
||||
|
||||
# Register a conversion function which reads the value of the variable,
|
||||
@ -752,7 +752,7 @@ def _enclosing_tpu_context():
|
||||
# tpu.replicate() because it assumes that you're in a device context where you
|
||||
# can operate on a single version of the variable, but a tpu.replicate()
|
||||
# operates on all variables and is replicated during a rewrite pass.
|
||||
class TPUMirroredVariable(checkpointable.Checkpointable):
|
||||
class TPUMirroredVariable(trackable.Trackable):
|
||||
"""Holds a map from device to TPU variables whose values are kept in sync."""
|
||||
|
||||
def __init__(
|
||||
@ -1085,7 +1085,7 @@ class TPUMirroredVariable(checkpointable.Checkpointable):
|
||||
return self._read_variable_op()
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
"""Overrides CheckpointableBase method.
|
||||
"""Overrides Trackable method.
|
||||
|
||||
This allows both name-based and object-based save and restore of
|
||||
MirroredVariables.
|
||||
@ -1095,7 +1095,7 @@ class TPUMirroredVariable(checkpointable.Checkpointable):
|
||||
"""
|
||||
def _saveable_factory(name=self._common_name):
|
||||
return _MirroredSaveable(self, self.primary, name)
|
||||
return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||
|
||||
def _should_act_as_resource_variable(self):
|
||||
"""Pass resource_variable_ops.is_resource_variable check."""
|
||||
@ -1205,7 +1205,7 @@ def _assert_replica_context(strategy):
|
||||
|
||||
|
||||
class ReplicaLocalVariable(DistributedVariable, PerReplica,
|
||||
checkpointable.Checkpointable):
|
||||
trackable.Trackable):
|
||||
"""Holds a map from device to variables whose values are reduced on save."""
|
||||
|
||||
def __init__(
|
||||
@ -1256,7 +1256,7 @@ class ReplicaLocalVariable(DistributedVariable, PerReplica,
|
||||
return self.get()._as_graph_element()
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
"""Overrides CheckpointableBase method.
|
||||
"""Overrides Trackable method.
|
||||
|
||||
This allows both name-based and object-based save and restore of
|
||||
ReplicaLocalVariables.
|
||||
@ -1266,7 +1266,7 @@ class ReplicaLocalVariable(DistributedVariable, PerReplica,
|
||||
"""
|
||||
def _saveable_factory(name=self._common_name):
|
||||
return _ReplicaLocalSaveable(self, name)
|
||||
return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||
|
||||
|
||||
# Register a conversion function for ReplicaLocalVariable which allows as_ref to
|
||||
@ -1436,7 +1436,7 @@ def value_container(val):
|
||||
|
||||
|
||||
# TODO(josh11b): Descend from Variable.
|
||||
class AggregatingVariable(checkpointable.Checkpointable):
|
||||
class AggregatingVariable(trackable.Trackable):
|
||||
"""A wrapper around a variable that aggregates updates across replicas."""
|
||||
|
||||
def __init__(self, strategy, v, aggregation):
|
||||
@ -1514,7 +1514,7 @@ class AggregatingVariable(checkpointable.Checkpointable):
|
||||
|
||||
# TODO(josh11b): Test saving & restoring.
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
return {checkpointable.VARIABLE_VALUE_KEY: self._v}
|
||||
return {trackable.VARIABLE_VALUE_KEY: self._v}
|
||||
|
||||
# pylint: disable=multiple-statements
|
||||
def __add__(self, o): return self._v + o
|
||||
|
@ -500,7 +500,7 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:while_v2", # TODO(b/118513001): Imported via control_flow_ops; remove.
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
)
|
||||
|
||||
@ -555,7 +555,7 @@ py_library(
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:template",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -31,7 +31,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -113,8 +113,8 @@ class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable):
|
||||
if constraint is not None and not callable(constraint):
|
||||
raise ValueError("The `constraint` argument must be a callable.")
|
||||
|
||||
if isinstance(initial_value, checkpointable.CheckpointInitialValue):
|
||||
self._maybe_initialize_checkpointable()
|
||||
if isinstance(initial_value, trackable.CheckpointInitialValue):
|
||||
self._maybe_initialize_trackable()
|
||||
self._update_uid = initial_value.checkpoint_position.restore_uid
|
||||
initial_value = initial_value.wrapped_value
|
||||
|
||||
|
@ -62,7 +62,7 @@ class VariableHolder(object):
|
||||
return self._fn(*args, **kwargs)
|
||||
|
||||
|
||||
# TODO(allenl): make this checkpointable
|
||||
# TODO(allenl): make this trackable
|
||||
class WrappedFunction(function.ConcreteFunction):
|
||||
"""Wraps a tf V1 piece of code in a function."""
|
||||
|
||||
|
@ -141,11 +141,11 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
# TODO(b/118385027): Dependency on keras can be problematic if Keras moves out
|
||||
# of the main repo.
|
||||
from tensorflow.python.keras import utils
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -162,7 +162,7 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_utils
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -3228,7 +3228,7 @@ def _raise_shared_embedding_column_error():
|
||||
'`DenseFeatures` or `LinearModel` instead.')
|
||||
|
||||
|
||||
class SharedEmbeddingColumnCreator(tracking.AutoCheckpointable):
|
||||
class SharedEmbeddingColumnCreator(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self,
|
||||
dimension,
|
||||
|
@ -170,7 +170,7 @@ py_library(
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:input_lib",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/training/checkpointable:data_structures",
|
||||
"//tensorflow/python/training/tracking:data_structures",
|
||||
"//tensorflow/tools/docs:doc_controls",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
@ -46,8 +46,8 @@ from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # p
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.checkpointable import layer_utils as checkpointable_layer_utils
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
||||
from tensorflow.python.util import function_utils
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
@ -57,7 +57,7 @@ from tensorflow.tools.docs import doc_controls
|
||||
|
||||
|
||||
@keras_export('keras.layers.Layer')
|
||||
class Layer(checkpointable.Checkpointable):
|
||||
class Layer(trackable.Trackable):
|
||||
"""Base layer class.
|
||||
|
||||
This is the class from which all layers inherit.
|
||||
@ -110,7 +110,7 @@ class Layer(checkpointable.Checkpointable):
|
||||
constraints on inputs that can be accepted by the layer.
|
||||
"""
|
||||
|
||||
@checkpointable.no_automatic_dependency_tracking
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, trainable=True, name=None, dtype=None, dynamic=False,
|
||||
**kwargs):
|
||||
# These properties should be set by the user via keyword arguments.
|
||||
@ -272,7 +272,7 @@ class Layer(checkpointable.Checkpointable):
|
||||
marked as non-trainable. `trainable` defaults to `True` unless
|
||||
`synchronization` is set to `ON_READ`.
|
||||
constraint: constraint instance (callable).
|
||||
partitioner: Partitioner to be passed to the `Checkpointable` API.
|
||||
partitioner: Partitioner to be passed to the `Trackable` API.
|
||||
use_resource: Whether to use `ResourceVariable`.
|
||||
synchronization: Indicates when a distributed a variable will be
|
||||
aggregated. Accepted values are constants defined in the class
|
||||
@ -345,9 +345,9 @@ class Layer(checkpointable.Checkpointable):
|
||||
name=name,
|
||||
shape=shape,
|
||||
# TODO(allenl): a `make_variable` equivalent should be added as a
|
||||
# `Checkpointable` method.
|
||||
# `Trackable` method.
|
||||
getter=getter or base_layer_utils.make_variable,
|
||||
# Manage errors in Layer rather than Checkpointable.
|
||||
# Manage errors in Layer rather than Trackable.
|
||||
overwrite=True,
|
||||
initializer=initializer,
|
||||
dtype=dtype,
|
||||
@ -1629,7 +1629,7 @@ class Layer(checkpointable.Checkpointable):
|
||||
|
||||
# Append value to self._layers if relevant
|
||||
if (isinstance(value, Layer) or
|
||||
checkpointable_layer_utils.has_weights(value)):
|
||||
trackable_layer_utils.has_weights(value)):
|
||||
# Initialize `_layers` here in case `__init__` has not yet been called.
|
||||
if not hasattr(self, '_layers'):
|
||||
self._layers = []
|
||||
@ -1666,7 +1666,7 @@ class Layer(checkpointable.Checkpointable):
|
||||
return []
|
||||
|
||||
# This is a hack so that the is_layer (within
|
||||
# training/checkpointable/layer_utils.py) check doesn't get the weights attr.
|
||||
# training/trackable/layer_utils.py) check doesn't get the weights attr.
|
||||
# TODO(b/110718070): Remove when fixed.
|
||||
def _is_layer(self):
|
||||
return True
|
||||
|
@ -76,7 +76,7 @@ def make_variable(name,
|
||||
that has fewer constraints (`variable_scope.variable()`).
|
||||
|
||||
In the longer term, it seems like a similar "default variable creator" method
|
||||
should exist in `CheckpointableBase` instead. When this happens, we can get
|
||||
should exist in `Trackable` instead. When this happens, we can get
|
||||
rid of this temporary solution.
|
||||
|
||||
TODO(fchollet): remove this method when no longer needed.
|
||||
|
@ -44,10 +44,10 @@ from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.checkpointable import data_structures
|
||||
from tensorflow.python.training.checkpointable import layer_utils as checkpointable_layer_utils
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
@ -152,7 +152,7 @@ class Network(base_layer.Layer):
|
||||
# empty lists shouldn't cause issues; adding or removing them will not break
|
||||
# checkpoints, but may cause "all Python objects matched" assertions to fail
|
||||
# (in which case less strict assertions may be substituted if necessary).
|
||||
@checkpointable.no_automatic_dependency_tracking
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _base_init(self, name=None):
|
||||
# The following are implemented as property functions:
|
||||
# self.trainable_weights
|
||||
@ -206,10 +206,10 @@ class Network(base_layer.Layer):
|
||||
self._outbound_nodes = []
|
||||
self._inbound_nodes = []
|
||||
|
||||
self._checkpointable_saver = (
|
||||
checkpointable_utils.saver_with_op_caching(self))
|
||||
self._trackable_saver = (
|
||||
trackable_utils.saver_with_op_caching(self))
|
||||
|
||||
@checkpointable.no_automatic_dependency_tracking
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _init_graph_network(self, inputs, outputs, name=None):
|
||||
self._call_convention = (base_layer_utils
|
||||
.CallConvention.EXPLICIT_INPUTS_ARGUMENT)
|
||||
@ -309,7 +309,7 @@ class Network(base_layer.Layer):
|
||||
for layer in self._output_layers:
|
||||
self.output_names.append(layer.name)
|
||||
|
||||
@checkpointable.no_automatic_dependency_tracking
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _init_subclassed_network(self, name=None, dynamic=False):
|
||||
self._base_init(name=name)
|
||||
self._is_graph_network = False
|
||||
@ -370,20 +370,20 @@ class Network(base_layer.Layer):
|
||||
return base_layer_utils.CallConvention.POSITIONAL_ARGUMENTS_ARE_INPUTS
|
||||
|
||||
def _track_layers(self, layers):
|
||||
"""Add Checkpointable dependencies on a list of Layers."""
|
||||
"""Add Trackable dependencies on a list of Layers."""
|
||||
weight_layer_index = 0
|
||||
for layer_index, layer in enumerate(layers):
|
||||
if layer.weights:
|
||||
# Keep a separate index for layers which have weights. This allows users
|
||||
# to insert Layers without weights anywhere in the network without
|
||||
# breaking checkpoints.
|
||||
self._track_checkpointable(
|
||||
self._track_trackable(
|
||||
layer, name='layer_with_weights-%d' % weight_layer_index,
|
||||
overwrite=True)
|
||||
weight_layer_index += 1
|
||||
# Even if it doesn't have weights, we should still track everything in
|
||||
# case it has/will have Checkpointable dependencies.
|
||||
self._track_checkpointable(
|
||||
# case it has/will have Trackable dependencies.
|
||||
self._track_trackable(
|
||||
layer, name='layer-%d' % layer_index, overwrite=True)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
@ -393,18 +393,18 @@ class Network(base_layer.Layer):
|
||||
|
||||
if all(
|
||||
isinstance(v, (base_layer.Layer,
|
||||
data_structures.CheckpointableDataStructure)) or
|
||||
checkpointable_layer_utils.has_weights(v) for v in nest.flatten(value)):
|
||||
data_structures.TrackableDataStructure)) or
|
||||
trackable_layer_utils.has_weights(v) for v in nest.flatten(value)):
|
||||
try:
|
||||
self._is_graph_network
|
||||
except AttributeError:
|
||||
raise RuntimeError('It looks like you are subclassing `Model` and you '
|
||||
'forgot to call `super(YourClass, self).__init__()`.'
|
||||
' Always start with this line.')
|
||||
# Keep track of checkpointable objects,
|
||||
# Keep track of trackable objects,
|
||||
# for the needs of `self.save/save_weights`.
|
||||
value = data_structures.sticky_attribute_assignment(
|
||||
checkpointable=self, value=value, name=name)
|
||||
trackable=self, value=value, name=name)
|
||||
super(Network, self).__setattr__(name, value)
|
||||
|
||||
# Keep track of metric instance created in subclassed model/layer.
|
||||
@ -481,7 +481,7 @@ class Network(base_layer.Layer):
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return checkpointable_layer_utils.filter_empty_layer_containers(
|
||||
return trackable_layer_utils.filter_empty_layer_containers(
|
||||
self._layers)
|
||||
|
||||
def get_layer(self, name=None, index=None):
|
||||
@ -542,7 +542,7 @@ class Network(base_layer.Layer):
|
||||
losses += layer.losses
|
||||
return losses
|
||||
|
||||
@checkpointable.no_automatic_dependency_tracking
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _clear_losses(self):
|
||||
"""Used every step in eager to reset losses."""
|
||||
self._eager_losses = []
|
||||
@ -682,14 +682,14 @@ class Network(base_layer.Layer):
|
||||
|
||||
@property
|
||||
def trainable_weights(self):
|
||||
return checkpointable_layer_utils.gather_trainable_weights(
|
||||
return trackable_layer_utils.gather_trainable_weights(
|
||||
trainable=self.trainable,
|
||||
sub_layers=self._layers,
|
||||
extra_variables=self._trainable_weights)
|
||||
|
||||
@property
|
||||
def non_trainable_weights(self):
|
||||
return checkpointable_layer_utils.gather_non_trainable_weights(
|
||||
return trackable_layer_utils.gather_non_trainable_weights(
|
||||
trainable=self.trainable,
|
||||
sub_layers=self._layers,
|
||||
extra_variables=self._non_trainable_weights + self._trainable_weights)
|
||||
@ -1397,7 +1397,7 @@ class Network(base_layer.Layer):
|
||||
session = backend.get_session()
|
||||
optimizer = getattr(self, 'optimizer', None)
|
||||
if (optimizer
|
||||
and not isinstance(optimizer, checkpointable.Checkpointable)):
|
||||
and not isinstance(optimizer, trackable.Trackable)):
|
||||
logging.warning(
|
||||
('This model was compiled with a Keras optimizer (%s) but is being '
|
||||
'saved in TensorFlow format with `save_weights`. The model\'s '
|
||||
@ -1405,7 +1405,7 @@ class Network(base_layer.Layer):
|
||||
'the TensorFlow format the optimizer\'s state will not be '
|
||||
'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
|
||||
% (optimizer,))
|
||||
self._checkpointable_saver.save(filepath, session=session)
|
||||
self._trackable_saver.save(filepath, session=session)
|
||||
# Record this checkpoint so it's visible from tf.train.latest_checkpoint.
|
||||
checkpoint_management.update_checkpoint_state_internal(
|
||||
save_dir=os.path.dirname(filepath),
|
||||
@ -1464,7 +1464,7 @@ class Network(base_layer.Layer):
|
||||
# The checkpoint is not readable in TensorFlow format. Try HDF5.
|
||||
save_format = 'h5'
|
||||
if save_format == 'tf':
|
||||
status = self._checkpointable_saver.restore(filepath)
|
||||
status = self._trackable_saver.restore(filepath)
|
||||
if by_name:
|
||||
raise NotImplementedError(
|
||||
'Weights may only be loaded based on topology into Models when '
|
||||
@ -1474,7 +1474,7 @@ class Network(base_layer.Layer):
|
||||
session = backend.get_session()
|
||||
# Restore existing variables (if any) immediately, and set up a
|
||||
# streaming restore for any variables created in the future.
|
||||
checkpointable_utils.streaming_restore(status=status, session=session)
|
||||
trackable_utils.streaming_restore(status=status, session=session)
|
||||
status.assert_nontrivial_match()
|
||||
return status
|
||||
if h5py is None:
|
||||
|
@ -28,7 +28,7 @@ from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -93,7 +93,7 @@ class Sequential(training.Model):
|
||||
```
|
||||
"""
|
||||
|
||||
@checkpointable.no_automatic_dependency_tracking
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, layers=None, name=None):
|
||||
super(Sequential, self).__init__(name=name)
|
||||
self.supports_masking = True
|
||||
@ -112,7 +112,7 @@ class Sequential(training.Model):
|
||||
# Historically, `sequential.layers` only returns layers that were added
|
||||
# via `add`, and omits the auto-generated `InputLayer` that comes at the
|
||||
# bottom of the stack.
|
||||
# `CheckpointableBase` manages the `_layers` attributes and does filtering
|
||||
# `Trackable` manages the `_layers` attributes and does filtering
|
||||
# over it.
|
||||
layers = super(Sequential, self).layers
|
||||
if layers and isinstance(layers[0], input_layer.InputLayer):
|
||||
@ -123,7 +123,7 @@ class Sequential(training.Model):
|
||||
def dynamic(self):
|
||||
return any(layer.dynamic for layer in self.layers)
|
||||
|
||||
@checkpointable.no_automatic_dependency_tracking
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def add(self, layer):
|
||||
"""Adds a layer instance on top of the layer stack.
|
||||
|
||||
@ -193,7 +193,7 @@ class Sequential(training.Model):
|
||||
|
||||
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
|
||||
|
||||
@checkpointable.no_automatic_dependency_tracking
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def pop(self):
|
||||
"""Removes the last layer in the model.
|
||||
|
||||
|
@ -48,7 +48,7 @@ from tensorflow.python.keras.utils.generic_utils import slice_arrays
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
@ -141,7 +141,7 @@ class Model(Network):
|
||||
return super(Model, self).get_weights()
|
||||
return super(Model, self).get_weights()
|
||||
|
||||
@checkpointable.no_automatic_dependency_tracking
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def compile(self,
|
||||
optimizer,
|
||||
loss=None,
|
||||
@ -245,9 +245,9 @@ class Model(Network):
|
||||
|
||||
self.optimizer = optimizer
|
||||
# We've disabled automatic dependency tracking for this method, but do want
|
||||
# to add a checkpoint dependency on the optimizer if it's checkpointable.
|
||||
if isinstance(self.optimizer, checkpointable.Checkpointable):
|
||||
self._track_checkpointable(
|
||||
# to add a checkpoint dependency on the optimizer if it's trackable.
|
||||
if isinstance(self.optimizer, trackable.Trackable):
|
||||
self._track_trackable(
|
||||
self.optimizer, name='optimizer', overwrite=True)
|
||||
self.loss = loss
|
||||
self._compile_metrics = metrics or []
|
||||
@ -2663,7 +2663,7 @@ class Model(Network):
|
||||
'However we received `validation_data=%s`' % validation_data)
|
||||
return val_x, val_y, val_sample_weight
|
||||
|
||||
@checkpointable.no_automatic_dependency_tracking
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _set_inputs(self, inputs, outputs=None, training=None):
|
||||
"""Set model's input and output specs based on the input data received.
|
||||
|
||||
|
@ -42,7 +42,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_cudnn_rnn_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
@ -78,7 +78,7 @@ class StackedRNNCells(Layer):
|
||||
```
|
||||
"""
|
||||
|
||||
@checkpointable.no_automatic_dependency_tracking
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, cells, **kwargs):
|
||||
for cell in cells:
|
||||
if not hasattr(cell, 'call'):
|
||||
@ -443,7 +443,7 @@ class RNN(Layer):
|
||||
```
|
||||
"""
|
||||
|
||||
@checkpointable.no_automatic_dependency_tracking
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self,
|
||||
cell,
|
||||
return_sequences=False,
|
||||
@ -468,8 +468,8 @@ class RNN(Layer):
|
||||
self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False)
|
||||
super(RNN, self).__init__(**kwargs)
|
||||
self.cell = cell
|
||||
if isinstance(cell, checkpointable.Checkpointable):
|
||||
self._track_checkpointable(self.cell, name='cell')
|
||||
if isinstance(cell, trackable.Trackable):
|
||||
self._track_trackable(self.cell, name='cell')
|
||||
self.return_sequences = return_sequences
|
||||
self.return_state = return_state
|
||||
self.go_backwards = go_backwards
|
||||
|
@ -40,7 +40,7 @@ from tensorflow.python.ops import special_math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_util
|
||||
from tensorflow.python.training.tracking import util as trackable_util
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
# Used for nested input/output/state RNN test.
|
||||
@ -715,7 +715,7 @@ class RNNTest(keras_parameterized.TestCase):
|
||||
[tuple(o.as_list()) for o in output_shape],
|
||||
expected_output_shape)
|
||||
|
||||
def test_checkpointable_dependencies(self):
|
||||
def test_trackable_dependencies(self):
|
||||
rnn = keras.layers.SimpleRNN
|
||||
x = np.random.random((2, 2, 2))
|
||||
y = np.random.random((2, 2))
|
||||
@ -728,8 +728,8 @@ class RNNTest(keras_parameterized.TestCase):
|
||||
model.fit(x, y, epochs=1, batch_size=1)
|
||||
|
||||
# check whether the model variables are present in the
|
||||
# checkpointable list of objects
|
||||
checkpointed_objects = set(checkpointable_util.list_objects(model))
|
||||
# trackable list of objects
|
||||
checkpointed_objects = set(trackable_util.list_objects(model))
|
||||
for v in model.variables:
|
||||
self.assertIn(v, checkpointed_objects)
|
||||
|
||||
|
@ -29,7 +29,7 @@ from tensorflow.python.keras.layers.recurrent import _standardize_args
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
@ -46,7 +46,7 @@ class Wrapper(Layer):
|
||||
layer: The layer to be wrapped.
|
||||
"""
|
||||
|
||||
@checkpointable.no_automatic_dependency_tracking
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, layer, **kwargs):
|
||||
assert isinstance(layer, Layer)
|
||||
self.layer = layer
|
||||
@ -170,7 +170,7 @@ class TimeDistributed(Wrapper):
|
||||
'`Layer` instance. You passed: {input}'.format(input=layer))
|
||||
super(TimeDistributed, self).__init__(layer, **kwargs)
|
||||
self.supports_masking = True
|
||||
self._track_checkpointable(layer, name='layer')
|
||||
self._track_trackable(layer, name='layer')
|
||||
|
||||
def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None):
|
||||
"""Finds non-specific dimensions in the static shapes.
|
||||
@ -386,7 +386,7 @@ class Bidirectional(Wrapper):
|
||||
```
|
||||
"""
|
||||
|
||||
@checkpointable.no_automatic_dependency_tracking
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
|
||||
if not isinstance(layer, Layer):
|
||||
raise ValueError(
|
||||
@ -419,8 +419,8 @@ class Bidirectional(Wrapper):
|
||||
self._num_constants = None
|
||||
super(Bidirectional, self).__init__(layer, **kwargs)
|
||||
self.input_spec = layer.input_spec
|
||||
self._track_checkpointable(self.forward_layer, name='forward_layer')
|
||||
self._track_checkpointable(self.backward_layer, name='backward_layer')
|
||||
self._track_trackable(self.forward_layer, name='forward_layer')
|
||||
self._track_trackable(self.backward_layer, name='backward_layer')
|
||||
|
||||
@property
|
||||
def trainable(self):
|
||||
|
@ -27,7 +27,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_util
|
||||
from tensorflow.python.training.tracking import util as trackable_util
|
||||
|
||||
|
||||
class _RNNCellWithConstants(keras.layers.Layer):
|
||||
@ -88,8 +88,8 @@ class TimeDistributedTest(test.TestCase):
|
||||
model.get_config()
|
||||
|
||||
# check whether the model variables are present in the
|
||||
# checkpointable list of objects
|
||||
checkpointed_objects = set(checkpointable_util.list_objects(model))
|
||||
# trackable list of objects
|
||||
checkpointed_objects = set(trackable_util.list_objects(model))
|
||||
for v in model.variables:
|
||||
self.assertIn(v, checkpointed_objects)
|
||||
|
||||
@ -303,8 +303,8 @@ class BidirectionalTest(test.TestCase):
|
||||
model.fit(x, y, epochs=1, batch_size=1)
|
||||
|
||||
# check whether the model variables are present in the
|
||||
# checkpointable list of objects
|
||||
checkpointed_objects = set(checkpointable_util.list_objects(model))
|
||||
# trackable list of objects
|
||||
checkpointed_objects = set(trackable_util.list_objects(model))
|
||||
for v in model.variables:
|
||||
self.assertIn(v, checkpointed_objects)
|
||||
|
||||
|
@ -37,7 +37,7 @@ from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@ -131,7 +131,7 @@ class KerasSumTest(test.TestCase):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
|
||||
m = metrics.Sum()
|
||||
checkpoint = checkpointable_utils.Checkpoint(sum=m)
|
||||
checkpoint = trackable_utils.Checkpoint(sum=m)
|
||||
self.evaluate(variables.variables_initializer(m.variables))
|
||||
|
||||
# update state
|
||||
@ -149,7 +149,7 @@ class KerasSumTest(test.TestCase):
|
||||
|
||||
# restore to a different checkpoint sum object
|
||||
restore_sum = metrics.Sum()
|
||||
restore_checkpoint = checkpointable_utils.Checkpoint(sum=restore_sum)
|
||||
restore_checkpoint = trackable_utils.Checkpoint(sum=restore_sum)
|
||||
status = restore_checkpoint.restore(save_path)
|
||||
restore_update = restore_sum(300.)
|
||||
status.assert_consumed().run_restore_ops()
|
||||
@ -267,7 +267,7 @@ class KerasMeanTest(test.TestCase):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
|
||||
m = metrics.Mean()
|
||||
checkpoint = checkpointable_utils.Checkpoint(mean=m)
|
||||
checkpoint = trackable_utils.Checkpoint(mean=m)
|
||||
self.evaluate(variables.variables_initializer(m.variables))
|
||||
|
||||
# update state
|
||||
@ -285,7 +285,7 @@ class KerasMeanTest(test.TestCase):
|
||||
|
||||
# restore to a different checkpoint mean object
|
||||
restore_mean = metrics.Mean()
|
||||
restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean)
|
||||
restore_checkpoint = trackable_utils.Checkpoint(mean=restore_mean)
|
||||
status = restore_checkpoint.restore(save_path)
|
||||
restore_update = restore_mean(300.)
|
||||
status.assert_consumed().run_restore_ops()
|
||||
|
@ -35,7 +35,7 @@ from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.checkpointable import data_structures
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
|
||||
try:
|
||||
import h5py # pylint:disable=g-import-not-at-top
|
||||
|
@ -43,7 +43,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
@ -70,7 +70,7 @@ def _deduplicate_indexed_slices(values, indices):
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
@keras_export("keras.optimizers.Optimizer")
|
||||
class OptimizerV2(checkpointable.Checkpointable):
|
||||
class OptimizerV2(trackable.Trackable):
|
||||
"""Updated base class for optimizers.
|
||||
|
||||
This class defines the API to add Ops to train a model. You never use this
|
||||
@ -244,9 +244,9 @@ class OptimizerV2(checkpointable.Checkpointable):
|
||||
self._weights = []
|
||||
self._iterations = None
|
||||
|
||||
# For implementing Checkpointable. Stores information about how to restore
|
||||
# For implementing Trackable. Stores information about how to restore
|
||||
# slot variables which have not yet been created
|
||||
# (checkpointable._CheckpointPosition objects).
|
||||
# (trackable._CheckpointPosition objects).
|
||||
# {slot_name :
|
||||
# {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
|
||||
# ... }
|
||||
@ -829,7 +829,7 @@ class OptimizerV2(checkpointable.Checkpointable):
|
||||
return x.value()
|
||||
|
||||
# ---------------
|
||||
# For implementing the checkpointable interface
|
||||
# For implementing the trackable interface
|
||||
# ---------------
|
||||
|
||||
def _restore_slot_variable(self, slot_name, variable, slot_variable):
|
||||
@ -860,8 +860,8 @@ class OptimizerV2(checkpointable.Checkpointable):
|
||||
slot variable needs to be restored).
|
||||
|
||||
Args:
|
||||
slot_variable_position: A `checkpointable._CheckpointPosition` object
|
||||
indicating the slot variable `Checkpointable` object to be restored.
|
||||
slot_variable_position: A `trackable._CheckpointPosition` object
|
||||
indicating the slot variable `Trackable` object to be restored.
|
||||
slot_name: The name of this `Optimizer`'s slot to restore into.
|
||||
variable: The variable object this slot is being created for.
|
||||
"""
|
||||
@ -879,7 +879,7 @@ class OptimizerV2(checkpointable.Checkpointable):
|
||||
# (aside from double initialization), and makes variable creator scopes
|
||||
# behave the same way they do when graph building.
|
||||
and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
|
||||
initializer = checkpointable.CheckpointInitialValue(
|
||||
initializer = trackable.CheckpointInitialValue(
|
||||
checkpoint_position=slot_variable_position)
|
||||
slot_variable = self.add_slot(
|
||||
var=variable,
|
||||
|
@ -40,7 +40,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.training import optimizer as tf_optimizer_module
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@ -710,19 +710,19 @@ class Nadam(Optimizer):
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class TFOptimizer(Optimizer, checkpointable.Checkpointable):
|
||||
class TFOptimizer(Optimizer, trackable.Trackable):
|
||||
"""Wrapper class for native TensorFlow optimizers.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, iterations=None): # pylint: disable=super-init-not-called
|
||||
self.optimizer = optimizer
|
||||
self._track_checkpointable(optimizer, name='optimizer')
|
||||
self._track_trackable(optimizer, name='optimizer')
|
||||
if iterations is None:
|
||||
with K.name_scope(self.__class__.__name__):
|
||||
self.iterations = K.variable(0, dtype='int64', name='iterations')
|
||||
else:
|
||||
self.iterations = iterations
|
||||
self._track_checkpointable(self.iterations, name='global_step')
|
||||
self._track_trackable(self.iterations, name='global_step')
|
||||
|
||||
def apply_gradients(self, grads):
|
||||
self.optimizer.apply_gradients(grads, global_step=self.iterations)
|
||||
|
@ -40,7 +40,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import training as training_module
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable
|
||||
from tensorflow.python.training.tracking import util as trackable
|
||||
|
||||
try:
|
||||
import h5py # pylint:disable=g-import-not-at-top
|
||||
@ -994,7 +994,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_incompatible_checkpoint(self):
|
||||
save_path = checkpointable.Checkpoint().save(
|
||||
save_path = trackable.Checkpoint().save(
|
||||
os.path.join(self.get_temp_dir(), 'ckpt'))
|
||||
m = keras.Model()
|
||||
with self.assertRaisesRegexp(AssertionError, 'Nothing to load'):
|
||||
|
@ -37,7 +37,7 @@ from tensorflow.python.saved_model import model_utils
|
||||
from tensorflow.python.saved_model import save as save_lib
|
||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training.checkpointable import graph_view
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -300,10 +300,10 @@ def _export_mode(
|
||||
# not counting optimizer objects. Optimizer objects are ignored because
|
||||
# if the model has not trained, the slot variables will not have been
|
||||
# created yet.
|
||||
# TODO(b/113179535): Replace with checkpointable equivalence.
|
||||
# TODO(b/113179535): Replace with trackable equivalence.
|
||||
_assert_same_non_optimizer_objects(model, model_graph, clone, g)
|
||||
|
||||
# TODO(b/113178242): Use value transfer for checkpointable objects.
|
||||
# TODO(b/113178242): Use value transfer for trackable objects.
|
||||
clone.load_weights(checkpoint_path)
|
||||
|
||||
# Add graph and variables to SavedModel.
|
||||
@ -361,7 +361,7 @@ def _create_signature_def_map(model, mode):
|
||||
|
||||
|
||||
def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument
|
||||
"""Asserts model and clone contain the same checkpointable objects."""
|
||||
"""Asserts model and clone contain the same trackable objects."""
|
||||
|
||||
# TODO(fchollet, kathywu): make sure this works in eager mode.
|
||||
return True
|
||||
|
@ -38,7 +38,7 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable
|
||||
from tensorflow.python.training.tracking import util as trackable
|
||||
|
||||
|
||||
class HashTableTest(test.TestCase):
|
||||
@ -1691,7 +1691,7 @@ class MutableHashTableOpTest(test.TestCase):
|
||||
table = lookup_ops.MutableHashTable(
|
||||
dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
|
||||
|
||||
checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1)
|
||||
checkpoint = trackable.Checkpoint(table=table, v0=v0, v1=v1)
|
||||
self.evaluate([v0.initializer, v1.initializer])
|
||||
|
||||
# Check that the parameter nodes have been initialized.
|
||||
@ -1716,7 +1716,7 @@ class MutableHashTableOpTest(test.TestCase):
|
||||
constant_op.constant([12, 24], dtypes.int64)))
|
||||
self.assertAllEqual(2, self.evaluate(table.size()))
|
||||
|
||||
checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1)
|
||||
checkpoint = trackable.Checkpoint(table=table, v0=v0, v1=v1)
|
||||
|
||||
# Restore the saved values in the parameter nodes.
|
||||
checkpoint.restore(save_path).run_restore_ops()
|
||||
@ -2512,7 +2512,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
|
||||
checkpoint=True,
|
||||
initial_num_buckets=32)
|
||||
|
||||
save_checkpoint = checkpointable.Checkpoint(table=save_table)
|
||||
save_checkpoint = trackable.Checkpoint(table=save_table)
|
||||
|
||||
self.assertAllEqual(0, self.evaluate(save_table.size()))
|
||||
self.evaluate(save_table.insert(keys, values))
|
||||
@ -2538,7 +2538,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
|
||||
self.assertAllEqual(2, self.evaluate(load_table.size()))
|
||||
self.assertAllEqual(64, len(self.evaluate(load_table.export()[0])))
|
||||
|
||||
restore_checkpoint = checkpointable.Checkpoint(table=load_table)
|
||||
restore_checkpoint = trackable.Checkpoint(table=load_table)
|
||||
|
||||
# Restore the saved values in the parameter nodes.
|
||||
restore_checkpoint.restore(save_path).run_restore_ops()
|
||||
|
@ -49,7 +49,7 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
@ -2804,7 +2804,7 @@ class RNNCellTest(test.TestCase, parameterized.TestCase):
|
||||
wrapper(array_ops.ones([1, 1]),
|
||||
state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
|
||||
self.evaluate([v.initializer for v in cell.variables])
|
||||
checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper)
|
||||
checkpoint = trackable_utils.Checkpoint(wrapper=wrapper)
|
||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
self.evaluate(cell._bias.assign([40.]))
|
||||
save_path = checkpoint.save(prefix)
|
||||
|
@ -26,7 +26,7 @@ from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import function_utils
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
@ -554,7 +554,7 @@ class Layer(base_layer.Layer):
|
||||
|
||||
def __setattr__(self, value, name):
|
||||
# By-pass the automatic dependency tracking performed by the parent Layer.
|
||||
super(checkpointable.Checkpointable, self).__setattr__(value, name)
|
||||
super(trackable.Trackable, self).__setattr__(value, name)
|
||||
|
||||
|
||||
def _add_elements_to_collection(elements, collection_list):
|
||||
|
@ -13,7 +13,7 @@ py_library(
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/training/checkpointable:tracking",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -27,7 +27,7 @@ import six
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
@ -149,7 +149,7 @@ def with_name_scope(unbound_method):
|
||||
|
||||
|
||||
@tf_export("Module", "experimental.Module")
|
||||
class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoCheckpointable)):
|
||||
class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoTrackable)):
|
||||
"""Base neural network module class.
|
||||
|
||||
A module is a named container for `tf.Variable`s, other `tf.Module`s and
|
||||
@ -375,7 +375,7 @@ def camel_to_snake(value):
|
||||
return _CAMEL_TO_SNAKE_R.sub(r"_\1", value).lower()
|
||||
|
||||
|
||||
# AutoCheckpointable adds object attributes that users will not expect us to
|
||||
# AutoTrackable adds object attributes that users will not expect us to
|
||||
# include when flattening (these reference dependencies reachable via other
|
||||
# object attributes).
|
||||
AUTO_CHECKPOINTABLE_ATTRS = ("_unconditional_checkpoint_dependencies",
|
||||
|
@ -43,7 +43,7 @@ from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantil
|
||||
# pylint: enable=unused-import
|
||||
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
|
||||
class PruningMode(object):
|
||||
|
@ -38,10 +38,10 @@ from tensorflow.python.ops import string_ops
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.gen_lookup_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable_base
|
||||
from tensorflow.python.training.checkpointable import tracking as checkpointable
|
||||
from tensorflow.python.training.saver import BaseSaverBuilder
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.training.tracking import base as trackable_base
|
||||
from tensorflow.python.training.tracking import tracking as trackable
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -102,7 +102,7 @@ def _check_table_dtypes(table, key_dtype, value_dtype):
|
||||
(table.value_dtype, value_dtype))
|
||||
|
||||
|
||||
class LookupInterface(checkpointable.TrackableResource):
|
||||
class LookupInterface(trackable.TrackableResource):
|
||||
"""Represent a lookup table that persists across different steps."""
|
||||
|
||||
def __init__(self, key_dtype, value_dtype):
|
||||
@ -165,8 +165,8 @@ class InitializableLookupTableBase(LookupInterface):
|
||||
self._default_value = ops.convert_to_tensor(
|
||||
default_value, dtype=self._value_dtype)
|
||||
self._default_value.get_shape().merge_with(tensor_shape.scalar())
|
||||
if isinstance(initializer, checkpointable_base.Checkpointable):
|
||||
self._initializer = self._track_checkpointable(
|
||||
if isinstance(initializer, trackable_base.Trackable):
|
||||
self._initializer = self._track_trackable(
|
||||
initializer, "_initializer")
|
||||
self._resource_handle = self.create_resource()
|
||||
self._init_op = self.initialize()
|
||||
@ -314,7 +314,7 @@ class HashTable(InitializableLookupTableBase):
|
||||
return exported_keys, exported_values
|
||||
|
||||
|
||||
class TableInitializerBase(checkpointable_base.Checkpointable):
|
||||
class TableInitializerBase(trackable_base.Trackable):
|
||||
"""Base class for lookup table initializers."""
|
||||
|
||||
def __init__(self, key_dtype, value_dtype):
|
||||
@ -543,8 +543,8 @@ class TextFileInitializer(TableInitializerBase):
|
||||
self._vocab_size = vocab_size
|
||||
self._delimiter = delimiter
|
||||
self._name = name
|
||||
self._filename = self._track_checkpointable(
|
||||
checkpointable.TrackableAsset(filename),
|
||||
self._filename = self._track_trackable(
|
||||
trackable.TrackableAsset(filename),
|
||||
"_filename")
|
||||
|
||||
super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
|
||||
|
@ -43,7 +43,7 @@ from tensorflow.python.ops import variables
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.gen_resource_variable_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
|
||||
@ -505,8 +505,8 @@ class ResourceVariable(variables.VariableV1):
|
||||
if constraint is not None and not callable(constraint):
|
||||
raise ValueError("The `constraint` argument must be a callable.")
|
||||
|
||||
if isinstance(initial_value, checkpointable.CheckpointInitialValue):
|
||||
self._maybe_initialize_checkpointable()
|
||||
if isinstance(initial_value, trackable.CheckpointInitialValue):
|
||||
self._maybe_initialize_trackable()
|
||||
self._update_uid = initial_value.checkpoint_position.restore_uid
|
||||
initial_value = initial_value.wrapped_value
|
||||
|
||||
@ -1684,7 +1684,7 @@ def copy_to_graph_uninitialized(var):
|
||||
constraint=var._constraint,
|
||||
dtype=var.dtype,
|
||||
name=var._shared_name)
|
||||
new_variable._maybe_initialize_checkpointable()
|
||||
new_variable._maybe_initialize_trackable()
|
||||
# pylint: enable=protected-access
|
||||
return new_variable
|
||||
|
||||
|
@ -50,7 +50,7 @@ from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -1095,8 +1095,8 @@ class _RNNCellWrapperV1(RNNCell):
|
||||
def __init__(self, cell):
|
||||
super(_RNNCellWrapperV1, self).__init__()
|
||||
self._cell = cell
|
||||
if isinstance(cell, checkpointable.Checkpointable):
|
||||
self._track_checkpointable(self._cell, name="cell")
|
||||
if isinstance(cell, trackable.Trackable):
|
||||
self._track_trackable(self._cell, name="cell")
|
||||
|
||||
def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
|
||||
"""Calls the wrapped cell and performs the wrapping logic.
|
||||
@ -1611,8 +1611,8 @@ class DeviceWrapper(RNNCell):
|
||||
"""
|
||||
super(DeviceWrapper, self).__init__()
|
||||
self._cell = cell
|
||||
if isinstance(cell, checkpointable.Checkpointable):
|
||||
self._track_checkpointable(self._cell, name="cell")
|
||||
if isinstance(cell, trackable.Trackable):
|
||||
self._track_trackable(self._cell, name="cell")
|
||||
self._device = device
|
||||
|
||||
@property
|
||||
@ -1678,11 +1678,11 @@ class MultiRNNCell(RNNCell):
|
||||
|
||||
self._cells = cells
|
||||
for cell_number, cell in enumerate(self._cells):
|
||||
# Add Checkpointable dependencies on these cells so their variables get
|
||||
# Add Trackable dependencies on these cells so their variables get
|
||||
# saved with this object when using object-based saving.
|
||||
if isinstance(cell, checkpointable.Checkpointable):
|
||||
# TODO(allenl): Track down non-Checkpointable callers.
|
||||
self._track_checkpointable(cell, name="cell-%d" % (cell_number,))
|
||||
if isinstance(cell, trackable.Trackable):
|
||||
# TODO(allenl): Track down non-Trackable callers.
|
||||
self._track_trackable(cell, name="cell-%d" % (cell_number,))
|
||||
self._state_is_tuple = state_is_tuple
|
||||
if not state_is_tuple:
|
||||
if any(nest.is_sequence(c.state_size) for c in self._cells):
|
||||
|
@ -27,7 +27,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_stateful_random_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.checkpointable import \
|
||||
from tensorflow.python.training.tracking import \
|
||||
tracking
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -144,7 +144,7 @@ def _shape_tensor(shape):
|
||||
|
||||
|
||||
@tf_export("random.experimental.Generator")
|
||||
class Generator(tracking.AutoCheckpointable):
|
||||
class Generator(tracking.AutoTrackable):
|
||||
"""Random-number generator.
|
||||
|
||||
It uses Variable to manage its internal state.
|
||||
|
@ -26,8 +26,8 @@ from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_util
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import util as trackable_util
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
@ -232,7 +232,7 @@ def _skip_common_stack_elements(stacktrace, base_case):
|
||||
return stacktrace[-1:]
|
||||
|
||||
|
||||
class Template(checkpointable.Checkpointable):
|
||||
class Template(trackable.Trackable):
|
||||
"""Wrap a function to aid in variable sharing.
|
||||
|
||||
Templates are functions that create variables the first time they are called
|
||||
@ -306,8 +306,8 @@ class Template(checkpointable.Checkpointable):
|
||||
result = self._func(*args, **kwargs)
|
||||
else:
|
||||
# The first time we run, restore variables if necessary (via
|
||||
# Checkpointable).
|
||||
with checkpointable_util.capture_dependencies(template=self):
|
||||
# Trackable).
|
||||
with trackable_util.capture_dependencies(template=self):
|
||||
result = self._func(*args, **kwargs)
|
||||
|
||||
if self._variables_created:
|
||||
@ -577,8 +577,8 @@ class EagerTemplate(Template):
|
||||
result = self._func(*args, **kwargs)
|
||||
else:
|
||||
# The first time we run, restore variables if necessary (via
|
||||
# Checkpointable).
|
||||
with checkpointable_util.capture_dependencies(template=self):
|
||||
# Trackable).
|
||||
with trackable_util.capture_dependencies(template=self):
|
||||
result = self._func(*args, **kwargs)
|
||||
|
||||
if self._variables_created:
|
||||
|
@ -35,7 +35,7 @@ from tensorflow.python.ops import gen_state_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import tf_should_use
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
@ -204,7 +204,7 @@ class VariableMetaclass(type):
|
||||
|
||||
@tf_export("Variable", v1=[])
|
||||
class Variable(six.with_metaclass(VariableMetaclass,
|
||||
checkpointable.Checkpointable)):
|
||||
trackable.Trackable)):
|
||||
"""See the [Variables Guide](https://tensorflow.org/guide/variables).
|
||||
|
||||
A variable maintains state in the graph across calls to `run()`. You add a
|
||||
@ -1018,8 +1018,8 @@ class Variable(six.with_metaclass(VariableMetaclass,
|
||||
return self.shape
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
"""For implementing `Checkpointable`. This object is saveable on its own."""
|
||||
return {checkpointable.VARIABLE_VALUE_KEY: self}
|
||||
"""For implementing `Trackable`. This object is saveable on its own."""
|
||||
return {trackable.VARIABLE_VALUE_KEY: self}
|
||||
|
||||
def to_proto(self, export_scope=None):
|
||||
"""Converts a `Variable` to a `VariableDef` protocol buffer.
|
||||
@ -1506,8 +1506,8 @@ class RefVariable(VariableV1):
|
||||
# Store the graph key so optimizers know how to only retrieve variables from
|
||||
# this graph.
|
||||
self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
|
||||
if isinstance(initial_value, checkpointable.CheckpointInitialValue):
|
||||
self._maybe_initialize_checkpointable()
|
||||
if isinstance(initial_value, trackable.CheckpointInitialValue):
|
||||
self._maybe_initialize_trackable()
|
||||
self._update_uid = initial_value.checkpoint_position.restore_uid
|
||||
initial_value = initial_value.wrapped_value
|
||||
|
||||
|
@ -275,7 +275,7 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
)
|
||||
|
||||
@ -310,12 +310,12 @@ py_library(
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/checkpointable:graph_view",
|
||||
"//tensorflow/python/training/checkpointable:object_identity",
|
||||
"//tensorflow/python/training/checkpointable:tracking",
|
||||
"//tensorflow/python/training/checkpointable:util",
|
||||
"//tensorflow/python/training/saving:functional_saver",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"//tensorflow/python/training/tracking:graph_view",
|
||||
"//tensorflow/python/training/tracking:object_identity",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
],
|
||||
)
|
||||
|
||||
@ -356,10 +356,10 @@ py_library(
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/checkpointable:graph_view",
|
||||
"//tensorflow/python/training/checkpointable:tracking",
|
||||
"//tensorflow/python/training/checkpointable:util",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"//tensorflow/python/training/tracking:graph_view",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
],
|
||||
)
|
||||
|
||||
@ -375,7 +375,7 @@ py_library(
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:saver",
|
||||
"//tensorflow/python/eager:wrap_function",
|
||||
"//tensorflow/python/training/checkpointable:tracking",
|
||||
"//tensorflow/python/training/tracking",
|
||||
],
|
||||
)
|
||||
|
||||
@ -392,7 +392,7 @@ tf_py_test(
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/training/checkpointable:tracking",
|
||||
"//tensorflow/python/training/tracking:tracking",
|
||||
],
|
||||
)
|
||||
|
||||
@ -417,7 +417,7 @@ tf_py_test(
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/training/checkpointable:tracking",
|
||||
"//tensorflow/python/training/tracking:tracking",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
@ -49,7 +49,7 @@ def serialize_concrete_function(concrete_function, node_ids, coder):
|
||||
"captures tensor %s which is unsupported or not reachable from root. "
|
||||
"One reason could be that a stateful object or a variable that the "
|
||||
"function depends on is not assigned to an attribute of the serialized "
|
||||
"checkpointable object "
|
||||
"trackable object "
|
||||
"(see SaveTest.test_captures_unreachable_variable)."
|
||||
% (concrete_function.name, capture))
|
||||
concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Import a checkpointable object from a SavedModel."""
|
||||
"""Import a trackable object from a SavedModel."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -36,10 +36,10 @@ from tensorflow.python.saved_model import nested_structure_coder
|
||||
from tensorflow.python.saved_model import revived_types
|
||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||
from tensorflow.python.training.checkpointable import base
|
||||
from tensorflow.python.training.checkpointable import graph_view
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.checkpointable import util
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
@ -149,16 +149,16 @@ class _Loader(object):
|
||||
def _restore_checkpoint(self):
|
||||
"""Load state from checkpoint into the deserialized objects."""
|
||||
variables_path = saved_model_utils.get_variables_path(self._export_dir)
|
||||
# TODO(andresp): Clean use of private methods of CheckpointableSaver.
|
||||
# TODO(andresp): Clean use of private methods of TrackableSaver.
|
||||
# pylint: disable=protected-access
|
||||
saver = util.CheckpointableSaver(graph_view.ObjectGraphView(self.get(0)))
|
||||
saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
|
||||
saver._file_prefix_placeholder = constant_op.constant(variables_path)
|
||||
load_status = saver.restore(variables_path)
|
||||
load_status.assert_existing_objects_matched()
|
||||
checkpoint = load_status._checkpoint
|
||||
|
||||
# When running in eager mode, the `restore` call above has already run and
|
||||
# restored the state of checkpointables, call `position.restore_ops()` will
|
||||
# restored the state of trackables, call `position.restore_ops()` will
|
||||
# return an empty list as there is nothing left to do. In graph mode, that
|
||||
# will return the list of ops that must run to restore the object on that
|
||||
# position. We have to wire them in the initializers of the objects so that
|
||||
@ -205,7 +205,7 @@ class _Loader(object):
|
||||
# individually callable by adding a `__call__` method to the classes of
|
||||
# the objects instances that have a `__call__` property.
|
||||
|
||||
class _UserObject(tracking.AutoCheckpointable):
|
||||
class _UserObject(tracking.AutoTrackable):
|
||||
pass
|
||||
|
||||
return _UserObject(), setattr
|
||||
@ -282,7 +282,7 @@ def load(export_dir, tags=None):
|
||||
print(f(x=tf.constant([[1.]])))
|
||||
```
|
||||
|
||||
Objects exported with `tf.saved_model.save` additionally have checkpointable
|
||||
Objects exported with `tf.saved_model.save` additionally have trackable
|
||||
objects and functions assigned to attributes:
|
||||
|
||||
```python
|
||||
@ -303,9 +303,9 @@ def load(export_dir, tags=None):
|
||||
`tf.saved_model.load`.
|
||||
|
||||
Returns:
|
||||
A checkpointable object with a `signatures` attribute mapping from signature
|
||||
A trackable object with a `signatures` attribute mapping from signature
|
||||
keys to functions. If the SavedModel was exported by `tf.saved_model.load`,
|
||||
it also points to checkpointable objects and functions which were attached
|
||||
it also points to trackable objects and functions which were attached
|
||||
to the exported object.
|
||||
|
||||
Raises:
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for checkpointable object SavedModel loading."""
|
||||
"""Tests for trackable object SavedModel loading."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -40,8 +40,8 @@ from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.checkpointable import util
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
@ -63,17 +63,17 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
return loaded
|
||||
|
||||
def test_structure_import(self, cycles):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root.dep_one = tracking.AutoCheckpointable()
|
||||
root.dep_two = tracking.AutoCheckpointable()
|
||||
root.dep_two.dep = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.dep_one = tracking.AutoTrackable()
|
||||
root.dep_two = tracking.AutoTrackable()
|
||||
root.dep_two.dep = tracking.AutoTrackable()
|
||||
root.dep_three = root.dep_two.dep
|
||||
imported = self.cycle(root, cycles)
|
||||
self.assertIs(imported.dep_three, imported.dep_two.dep)
|
||||
self.assertIsNot(imported.dep_one, imported.dep_two)
|
||||
|
||||
def test_variables(self, cycles):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.v1 = variables.Variable(1., trainable=True)
|
||||
root.v2 = variables.Variable(2., trainable=False)
|
||||
imported = self.cycle(root, cycles)
|
||||
@ -83,7 +83,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertFalse(imported.v2.trainable)
|
||||
|
||||
def test_capture_variables(self, cycles):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.weights = variables.Variable(2.)
|
||||
root.f = def_function.function(
|
||||
lambda x: root.weights * x,
|
||||
@ -103,7 +103,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
file1 = self._make_asset("contents 1")
|
||||
file2 = self._make_asset("contents 2")
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.asset1 = tracking.TrackableAsset(file1)
|
||||
root.asset2 = tracking.TrackableAsset(file2)
|
||||
|
||||
@ -122,7 +122,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual("contents 2", f.read())
|
||||
|
||||
def test_capture_assets(self, cycles):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.vocab = tracking.TrackableAsset(self._make_asset("contents"))
|
||||
root.f = def_function.function(
|
||||
lambda: root.vocab.asset_path,
|
||||
@ -135,7 +135,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual("contents", f.read())
|
||||
|
||||
def test_capture_assets_in_graph(self, cycles):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.vocab = tracking.TrackableAsset(self._make_asset("contents"))
|
||||
root.f = def_function.function(
|
||||
lambda: root.vocab.asset_path,
|
||||
@ -159,7 +159,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_dedup_assets(self, cycles):
|
||||
vocab = self._make_asset("contents")
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.asset1 = tracking.TrackableAsset(vocab)
|
||||
root.asset2 = tracking.TrackableAsset(vocab)
|
||||
imported = self.cycle(root, cycles)
|
||||
@ -171,7 +171,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
def func(x):
|
||||
return 2 * x
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func
|
||||
|
||||
# Add two traces.
|
||||
@ -189,7 +189,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
def func(x):
|
||||
return 2 * x
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func
|
||||
|
||||
imported = self.cycle(root, cycles)
|
||||
@ -200,7 +200,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
def func(x):
|
||||
return 2 * x
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func
|
||||
|
||||
imported = self.cycle(
|
||||
@ -219,7 +219,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
lambda x: f(x) + 1.0,
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.g = g
|
||||
imported = self.cycle(root, cycles)
|
||||
imported.g(constant_op.constant([1.0]))
|
||||
@ -232,7 +232,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
else:
|
||||
return 7
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(func)
|
||||
|
||||
self.assertEqual(20, root.f(constant_op.constant(10), True).numpy())
|
||||
@ -252,7 +252,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
else:
|
||||
return array_ops.zeros(shape=x.shape, dtype=dtypes.float32)
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(func)
|
||||
|
||||
self.assertAllEqual([0.0, 0.0, 0.0],
|
||||
@ -286,17 +286,17 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_function_no_return(self, cycles):
|
||||
|
||||
class CheckpointableWithOneVariable(tracking.AutoCheckpointable):
|
||||
class TrackableWithOneVariable(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self, initial_value=0.0):
|
||||
super(CheckpointableWithOneVariable, self).__init__()
|
||||
super(TrackableWithOneVariable, self).__init__()
|
||||
self.variable = variables.Variable(initial_value)
|
||||
|
||||
@def_function.function
|
||||
def increase(self, by=1.0):
|
||||
self.variable.assign_add(by)
|
||||
|
||||
obj = CheckpointableWithOneVariable(5.0)
|
||||
obj = TrackableWithOneVariable(5.0)
|
||||
|
||||
obj.increase(constant_op.constant(10.0))
|
||||
self.assertEqual(15.0, obj.variable.numpy())
|
||||
@ -320,7 +320,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
else:
|
||||
return 7
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(func)
|
||||
|
||||
x = constant_op.constant(10)
|
||||
@ -352,7 +352,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2)
|
||||
return [named_tuple, input2, {"x": 0.5}]
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(func)
|
||||
|
||||
result = root.f(constant_op.constant(2), constant_op.constant(3))
|
||||
@ -382,7 +382,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
else:
|
||||
return 7
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(func)
|
||||
|
||||
self.assertEqual(20, root.f(constant_op.constant(10), True).numpy())
|
||||
@ -404,7 +404,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
else:
|
||||
return 7
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(func)
|
||||
|
||||
x = constant_op.constant(10)
|
||||
@ -419,10 +419,10 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(7, imported.f(x, learning_rate=0.5, epochs=3).numpy())
|
||||
|
||||
def test_member_function(self, cycles):
|
||||
class CheckpointableWithMember(tracking.AutoCheckpointable):
|
||||
class TrackableWithMember(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
super(CheckpointableWithMember, self).__init__()
|
||||
super(TrackableWithMember, self).__init__()
|
||||
self._some_value = 20
|
||||
|
||||
@def_function.function
|
||||
@ -432,7 +432,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
else:
|
||||
return 7 + self._some_value
|
||||
|
||||
root = CheckpointableWithMember()
|
||||
root = TrackableWithMember()
|
||||
|
||||
self.assertEqual(20, root.f(constant_op.constant(10), True).numpy())
|
||||
self.assertEqual(27, root.f(constant_op.constant(1)).numpy())
|
||||
@ -444,7 +444,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(27, imported.f(constant_op.constant(2)).numpy())
|
||||
|
||||
def test_side_effect_listing(self, cycles):
|
||||
class M(tracking.AutoCheckpointable):
|
||||
class M(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
@ -468,7 +468,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
lambda x: x*weight + bias,
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.weight = weight
|
||||
root.bias = bias
|
||||
root.g = g
|
||||
@ -508,7 +508,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
def h(x):
|
||||
return g(x) + bias,
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.weight = weight
|
||||
root.bias = bias
|
||||
root.g = h
|
||||
@ -521,16 +521,16 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose(grad, [3.5, 2.0])
|
||||
|
||||
def test_callable(self, cycles):
|
||||
class M1(tracking.AutoCheckpointable):
|
||||
class M1(tracking.AutoTrackable):
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.m1 = M1()
|
||||
root.m2 = tracking.AutoCheckpointable()
|
||||
root.m2 = tracking.AutoTrackable()
|
||||
root.m2.__call__ = def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])(
|
||||
lambda x: x*3.0)
|
||||
@ -553,9 +553,9 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
func = def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])(
|
||||
lambda x: x*3.0)
|
||||
root = tracking.AutoCheckpointable()
|
||||
root.__call__ = tracking.AutoCheckpointable()
|
||||
root.__call__.__call__ = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.__call__ = tracking.AutoTrackable()
|
||||
root.__call__.__call__ = tracking.AutoTrackable()
|
||||
root.__call__.__call__.__call__ = func
|
||||
|
||||
imported = self.cycle(root, cycles)
|
||||
@ -564,7 +564,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(imported(x).numpy(), 3.0)
|
||||
|
||||
def test_load_in_graph_mode(self, cycles):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.v1 = variables.Variable(1.)
|
||||
root.v2 = variables.Variable(2.)
|
||||
root.f = def_function.function(
|
||||
@ -585,7 +585,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(4.0, sess.run(output))
|
||||
|
||||
def test_load_in_func_graph(self, cycles):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.v1 = variables.Variable(1.)
|
||||
root.v2 = variables.Variable(2.)
|
||||
root.f = def_function.function(
|
||||
@ -597,7 +597,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
path = tempfile.mkdtemp(prefix=self.get_temp_dir())
|
||||
save.save(root, path)
|
||||
|
||||
closure = tracking.AutoCheckpointable()
|
||||
closure = tracking.AutoTrackable()
|
||||
@def_function.function
|
||||
def func(x):
|
||||
if not hasattr(closure, "model"):
|
||||
@ -614,7 +614,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
def func(x):
|
||||
return 2 * x
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func
|
||||
|
||||
self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
|
||||
@ -650,7 +650,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
tensor_spec.TensorSpec([None], dtypes.int32), True)
|
||||
func.get_concrete_function(tensor_spec.TensorSpec([None], dtypes.float32))
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func
|
||||
|
||||
imported = self.cycle(root, cycles)
|
||||
@ -674,7 +674,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
def func(x):
|
||||
return 2 * x
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func.get_concrete_function()
|
||||
|
||||
self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
|
||||
@ -695,7 +695,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
def func(x):
|
||||
return 2 * x
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func.get_concrete_function()
|
||||
|
||||
self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
|
||||
@ -711,7 +711,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
def func(x):
|
||||
return 2 * x
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func.get_concrete_function(constant_op.constant([1]))
|
||||
self.assertAllEqual([4], root.f(constant_op.constant([2])).numpy())
|
||||
# TODO(andresp): Fix exporting of loaded concrete functions as signatures.
|
||||
@ -724,7 +724,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
input_signature=[tensor_spec.TensorSpec([None], dtypes.float32)])
|
||||
def func(x):
|
||||
return x ** 2.
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func.get_concrete_function()
|
||||
|
||||
def _compute_gradient(function):
|
||||
@ -744,7 +744,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
@def_function.function
|
||||
def func(x, y):
|
||||
return x * (y + 1.)
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func.get_concrete_function(
|
||||
tensor_spec.TensorSpec([], dtypes.float32),
|
||||
tensor_spec.TensorSpec([], dtypes.float32))
|
||||
@ -761,7 +761,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
def func(*args):
|
||||
x, y = args
|
||||
return x * (y + 1.)
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func.get_concrete_function(
|
||||
tensor_spec.TensorSpec([], dtypes.float32, name="x"),
|
||||
tensor_spec.TensorSpec([], dtypes.float32, name="y"))
|
||||
@ -782,7 +782,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
capture.assign_sub(1)
|
||||
|
||||
vsave = variables.Variable(1)
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = func.get_concrete_function(vsave)
|
||||
root.capture = capture
|
||||
self.assertEqual(1, vsave.numpy())
|
||||
@ -805,7 +805,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
def func(v):
|
||||
return v + 1
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.func = func
|
||||
root.concrete_func = func.get_concrete_function(
|
||||
tensor_spec.TensorSpec(None, dtypes.int32))
|
||||
@ -817,7 +817,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(2, imported.concrete_func(one).numpy())
|
||||
|
||||
def test_dict(self, cycles):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.variables = dict(a=variables.Variable(1.))
|
||||
root.variables["b"] = variables.Variable(2.)
|
||||
root.variables["c"] = 1
|
||||
@ -832,7 +832,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(100., imported.funcs["conc"]().numpy())
|
||||
|
||||
def test_list(self, cycles):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.variables = [variables.Variable(1.)]
|
||||
root.variables.append(1)
|
||||
root.variables.append(variables.Variable(3.))
|
||||
@ -843,7 +843,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(3, len(imported.variables))
|
||||
|
||||
def test_functions_list(self, cycles):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
v1 = variables.Variable(1.)
|
||||
root.losses = [def_function.function(lambda: math_ops.reduce_sum(v1 ** 2))]
|
||||
root.variables = [v1]
|
||||
@ -865,7 +865,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_captured_constant(self, cycles):
|
||||
const = array_ops.zeros([100])
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(lambda: const + 1.)
|
||||
root.g = def_function.function(lambda: const + 2.)
|
||||
self.assertAllClose(array_ops.ones([100]), root.f())
|
||||
@ -885,7 +885,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_functions_accessed_once(self, cycles):
|
||||
|
||||
class Exported(tracking.AutoCheckpointable):
|
||||
class Exported(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
self._counter = 0
|
||||
@ -905,7 +905,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(1, exported.make_func().numpy())
|
||||
|
||||
def test_overwritten_signatures_error(self, cycles):
|
||||
exported = tracking.AutoCheckpointable()
|
||||
exported = tracking.AutoTrackable()
|
||||
exported.f = def_function.function(lambda: constant_op.constant(1.))
|
||||
imported = self.cycle(
|
||||
exported, cycles,
|
||||
@ -917,7 +917,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_signature_loading(self, cycles):
|
||||
|
||||
class Exported(tracking.AutoCheckpointable):
|
||||
class Exported(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
self.v = variables.Variable(3.)
|
||||
@ -961,7 +961,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
return def_function.function(input_signature=signature)(
|
||||
lambda x: table.lookup(x)) # pylint: disable=unnecessary-lambda
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.table1 = table1
|
||||
root.lookup1 = _make_lookup_function(table1)
|
||||
root.table2 = table2
|
||||
@ -999,7 +999,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
original_fullargspec = tf_inspect.getfullargspec(f)
|
||||
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(f)
|
||||
imported = self.cycle(root, cycles)
|
||||
|
||||
@ -1010,7 +1010,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_load_with_tags(self):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
path = tempfile.mkdtemp(prefix=self.get_temp_dir())
|
||||
save.save(root, path)
|
||||
with self.assertRaises(ValueError):
|
||||
|
@ -27,7 +27,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import signature_serialization
|
||||
from tensorflow.python.training import saver as tf_saver
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
|
||||
class _Initializer(tracking.TrackableResource):
|
||||
@ -123,7 +123,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
self.restore_variables(wrapped, saver)
|
||||
with wrapped.graph.as_default():
|
||||
init_op = loader_impl.get_init_op(meta_graph_def)
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
if init_op is not None:
|
||||
asset_feed_tensors = []
|
||||
asset_paths = []
|
||||
@ -141,6 +141,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
else:
|
||||
root.asset_paths = []
|
||||
signature_functions = self._extract_signatures(wrapped, meta_graph_def)
|
||||
|
||||
root.signatures = signature_serialization.create_signature_map(
|
||||
signature_functions)
|
||||
root.variables = list(wrapped.graph.variables)
|
||||
|
@ -31,7 +31,7 @@ class VersionedTypeRegistration(object):
|
||||
|
||||
Args:
|
||||
object_factory: A callable which takes a SavedUserObject proto and returns
|
||||
a checkpointable object. Dependencies are added later via `setter`.
|
||||
a trackable object. Dependencies are added later via `setter`.
|
||||
version: An integer, the producer version of this wrapper type. When
|
||||
making incompatible changes to a wrapper, add a new
|
||||
`VersionedTypeRegistration` with an incremented `version`. The most
|
||||
@ -45,11 +45,11 @@ class VersionedTypeRegistration(object):
|
||||
with this object. `min_consumer_version` should be set to the lowest
|
||||
version number which can successfully load protos saved by this
|
||||
object. If no matching registration is available on load, the object
|
||||
will be revived with a generic checkpointable type.
|
||||
will be revived with a generic trackable type.
|
||||
|
||||
`min_consumer_version` and `bad_consumers` are a blunt tool, and using
|
||||
them will generally break forward compatibility: previous versions of
|
||||
TensorFlow will revive newly saved objects as opaque checkpointable
|
||||
TensorFlow will revive newly saved objects as opaque trackable
|
||||
objects rather than wrapped objects. When updating wrappers, prefer
|
||||
saving new information but preserving compatibility with previous
|
||||
wrapper versions. They are, however, useful for ensuring that
|
||||
@ -83,7 +83,7 @@ class VersionedTypeRegistration(object):
|
||||
bad_consumers=self._bad_consumers))
|
||||
|
||||
def from_proto(self, proto):
|
||||
"""Recreate a checkpointable object from a SavedUserObject proto."""
|
||||
"""Recreate a trackable object from a SavedUserObject proto."""
|
||||
return self._object_factory(proto)
|
||||
|
||||
def should_load(self, proto):
|
||||
@ -111,7 +111,7 @@ def register_revived_type(identifier, predicate, versions):
|
||||
Args:
|
||||
identifier: A unique string identifying this class of objects.
|
||||
predicate: A Boolean predicate for this registration. Takes a
|
||||
checkpointable object as an argument. If True, `type_registration` may be
|
||||
trackable object as an argument. If True, `type_registration` may be
|
||||
used to save and restore the object.
|
||||
versions: A list of `VersionedTypeRegistration` objects.
|
||||
"""
|
||||
@ -138,7 +138,7 @@ def register_revived_type(identifier, predicate, versions):
|
||||
|
||||
|
||||
def serialize(obj):
|
||||
"""Create a SavedUserObject from a checkpointable object."""
|
||||
"""Create a SavedUserObject from a trackable object."""
|
||||
for identifier in _TYPE_IDENTIFIERS:
|
||||
predicate, versions = _REVIVED_TYPE_REGISTRY[identifier]
|
||||
if predicate(obj):
|
||||
@ -148,15 +148,15 @@ def serialize(obj):
|
||||
|
||||
|
||||
def deserialize(proto):
|
||||
"""Create a checkpointable object from a SavedUserObject proto.
|
||||
"""Create a trackable object from a SavedUserObject proto.
|
||||
|
||||
Args:
|
||||
proto: A SavedUserObject to deserialize.
|
||||
|
||||
Returns:
|
||||
A tuple of (checkpointable, assignment_fn) where assignment_fn has the same
|
||||
A tuple of (trackable, assignment_fn) where assignment_fn has the same
|
||||
signature as setattr and should be used to add dependencies to
|
||||
`checkpointable` when they are available.
|
||||
`trackable` when they are available.
|
||||
"""
|
||||
_, type_registrations = _REVIVED_TYPE_REGISTRY.get(
|
||||
proto.identifier, (None, None))
|
||||
|
@ -22,10 +22,10 @@ from tensorflow.core.framework import versions_pb2
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import revived_types
|
||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
|
||||
class CustomTestClass(tracking.AutoCheckpointable):
|
||||
class CustomTestClass(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self, version):
|
||||
self.version = version
|
||||
@ -56,7 +56,7 @@ revived_types.register_revived_type(
|
||||
class RegistrationMatchingTest(test.TestCase):
|
||||
|
||||
def test_save_typecheck(self):
|
||||
self.assertIs(revived_types.serialize(tracking.AutoCheckpointable()), None)
|
||||
self.assertIs(revived_types.serialize(tracking.AutoTrackable()), None)
|
||||
|
||||
def test_load_identifier_not_found(self):
|
||||
nothing_matches = revived_types.deserialize(
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Exports a SavedModel from a Checkpointable Python object."""
|
||||
"""Exports a SavedModel from a Trackable Python object."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -46,12 +46,12 @@ from tensorflow.python.saved_model import signature_def_utils
|
||||
from tensorflow.python.saved_model import signature_serialization
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.saved_model import utils_impl
|
||||
from tensorflow.python.training.checkpointable import base
|
||||
from tensorflow.python.training.checkpointable import graph_view
|
||||
from tensorflow.python.training.checkpointable import object_identity
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.checkpointable import util
|
||||
from tensorflow.python.training.saving import functional_saver
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.training.tracking import object_identity
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -97,13 +97,13 @@ class _AugmentedGraphView(graph_view.ObjectGraphView):
|
||||
for name, dep in super(_AugmentedGraphView, self).list_dependencies(obj):
|
||||
used_names.add(name)
|
||||
if name in extra_dependencies:
|
||||
yield base.CheckpointableReference(name, extra_dependencies[name])
|
||||
yield base.TrackableReference(name, extra_dependencies[name])
|
||||
else:
|
||||
yield base.CheckpointableReference(name, dep)
|
||||
yield base.TrackableReference(name, dep)
|
||||
for name, dep in extra_dependencies.items():
|
||||
if name in used_names:
|
||||
continue
|
||||
yield base.CheckpointableReference(name, dep)
|
||||
yield base.TrackableReference(name, dep)
|
||||
|
||||
def list_functions(self, obj):
|
||||
obj_functions = self._functions.get(obj, None)
|
||||
@ -114,12 +114,12 @@ class _AugmentedGraphView(graph_view.ObjectGraphView):
|
||||
|
||||
|
||||
class _SaveableView(object):
|
||||
"""Provides a frozen view over a checkpointable root.
|
||||
"""Provides a frozen view over a trackable root.
|
||||
|
||||
This class helps creating a single stable view over an object to save. The
|
||||
saving code should access properties and functions via this class and not via
|
||||
the original object as there are cases where an object construct their
|
||||
checkpointable attributes and functions dynamically per call and will yield
|
||||
trackable attributes and functions dynamically per call and will yield
|
||||
different objects if invoked more than once.
|
||||
|
||||
Changes to the graph, for example adding objects, must happen in
|
||||
@ -130,9 +130,9 @@ class _SaveableView(object):
|
||||
|
||||
def __init__(self, checkpoint_view):
|
||||
self.checkpoint_view = checkpoint_view
|
||||
checkpointable_objects, node_ids, slot_variables = (
|
||||
trackable_objects, node_ids, slot_variables = (
|
||||
self.checkpoint_view.objects_ids_and_slot_variables())
|
||||
self.nodes = checkpointable_objects
|
||||
self.nodes = trackable_objects
|
||||
self.node_ids = node_ids
|
||||
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
|
||||
self.slot_variables = slot_variables
|
||||
@ -544,7 +544,7 @@ def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions):
|
||||
|
||||
def _write_object_graph(saveable_view, export_dir, asset_file_def_index):
|
||||
"""Save a SavedObjectGraph proto for `root`."""
|
||||
# SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
|
||||
# SavedObjectGraph is similar to the TrackableObjectGraph proto in the
|
||||
# checkpoint. It will eventually go into the SavedModel.
|
||||
proto = saved_object_graph_pb2.SavedObjectGraph()
|
||||
saveable_view.fill_object_graph_proto(proto)
|
||||
@ -603,7 +603,7 @@ def _write_object_proto(obj, proto, asset_file_def_index):
|
||||
@tf_export("saved_model.save", v1=["saved_model.experimental.save"])
|
||||
def save(obj, export_dir, signatures=None):
|
||||
# pylint: disable=line-too-long
|
||||
"""Exports the Checkpointable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md).
|
||||
"""Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md).
|
||||
|
||||
Example usage:
|
||||
|
||||
@ -651,7 +651,7 @@ def save(obj, export_dir, signatures=None):
|
||||
`.signatures` attribute. This is a reserved attribute: `tf.saved_model.save`
|
||||
on an object with a custom `.signatures` attribute will raise an exception.
|
||||
|
||||
Since `tf.keras.Model` objects are also Checkpointable, this function can be
|
||||
Since `tf.keras.Model` objects are also Trackable, this function can be
|
||||
used to export Keras models. For example, exporting with a signature
|
||||
specified:
|
||||
|
||||
@ -737,7 +737,7 @@ def save(obj, export_dir, signatures=None):
|
||||
prior to the TensorFlow 2.0 release.
|
||||
|
||||
Args:
|
||||
obj: A checkpointable object to export.
|
||||
obj: A trackable object to export.
|
||||
export_dir: A directory in which to write the SavedModel.
|
||||
signatures: Optional, either a `tf.function` with an input signature
|
||||
specified or the result of `f.get_concrete_function` on a
|
||||
@ -750,7 +750,7 @@ def save(obj, export_dir, signatures=None):
|
||||
`tf.saved_model.signature_constants` module.
|
||||
|
||||
Raises:
|
||||
ValueError: If `obj` is not checkpointable.
|
||||
ValueError: If `obj` is not trackable.
|
||||
|
||||
@compatibility(eager)
|
||||
Not supported when graph building. From TensorFlow 1.x,
|
||||
@ -771,9 +771,9 @@ def save(obj, export_dir, signatures=None):
|
||||
"tf.enable_eager_execution() must run first when calling it from "
|
||||
"TensorFlow 1.x.")
|
||||
# pylint: enable=line-too-long
|
||||
if not isinstance(obj, base.Checkpointable):
|
||||
if not isinstance(obj, base.Trackable):
|
||||
raise ValueError(
|
||||
"Expected a Checkpointable object for export, got {}.".format(obj))
|
||||
"Expected a Trackable object for export, got {}.".format(obj))
|
||||
|
||||
checkpoint_graph_view = _AugmentedGraphView(obj)
|
||||
if signatures is None:
|
||||
@ -799,7 +799,7 @@ def save(obj, export_dir, signatures=None):
|
||||
# making a SavedModel proto and writing it directly.
|
||||
saved_model = saved_model_pb2.SavedModel()
|
||||
meta_graph_def = saved_model.meta_graphs.add()
|
||||
object_saver = util.CheckpointableSaver(checkpoint_graph_view)
|
||||
object_saver = util.TrackableSaver(checkpoint_graph_view)
|
||||
asset_info, exported_graph = _fill_meta_graph_def(
|
||||
meta_graph_def, saveable_view, signatures)
|
||||
saved_model.saved_model_schema_version = (
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for checkpointable object SavedModel save."""
|
||||
"""Tests for trackable object SavedModel save."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -41,8 +41,8 @@ from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.checkpointable import util
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@ -87,7 +87,7 @@ def _import_and_infer(
|
||||
class SaveTest(test.TestCase):
|
||||
|
||||
def test_method_save_signature(self):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(
|
||||
lambda x: 2. * x,
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
@ -99,7 +99,7 @@ class SaveTest(test.TestCase):
|
||||
_import_and_infer(save_dir, {"x": 1.}))
|
||||
|
||||
def test_method_save_concrete(self):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(
|
||||
lambda z: {"out": 2. * z})
|
||||
root.f(constant_op.constant(1.))
|
||||
@ -115,7 +115,7 @@ class SaveTest(test.TestCase):
|
||||
save_dir, {"z": 1.}, signature_key="non_default_key"))
|
||||
|
||||
def test_non_concrete_error(self):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(lambda x: 2. * x)
|
||||
root.f(constant_op.constant(1.))
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
@ -124,7 +124,7 @@ class SaveTest(test.TestCase):
|
||||
save.save(root, save_dir, root.f)
|
||||
|
||||
def test_captures_unreachable_variable(self):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
unreachable_variable = variables.Variable([5.0, 2.0])
|
||||
root.reachable_variable = variables.Variable([1.0, 3.0])
|
||||
|
||||
@ -143,7 +143,7 @@ class SaveTest(test.TestCase):
|
||||
save.save(root, save_dir)
|
||||
|
||||
def test_nested_inputs(self):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(
|
||||
lambda x: 2. * x[0],
|
||||
input_signature=([tensor_spec.TensorSpec(None, dtypes.float32),
|
||||
@ -156,7 +156,7 @@ class SaveTest(test.TestCase):
|
||||
root.f.get_concrete_function()
|
||||
|
||||
def test_nested_outputs(self):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x)))
|
||||
root.f(constant_op.constant(1.))
|
||||
to_save = root.f.get_concrete_function(constant_op.constant(1.))
|
||||
@ -177,7 +177,7 @@ class SaveTest(test.TestCase):
|
||||
save.save(root, save_dir, to_save)
|
||||
|
||||
def test_variable(self):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.v1 = variables.Variable(3.)
|
||||
root.v2 = variables.Variable(2.)
|
||||
root.f = def_function.function(
|
||||
@ -214,7 +214,7 @@ class SaveTest(test.TestCase):
|
||||
{"x": [[3., 4.]], "y": [2.]}))
|
||||
|
||||
def test_single_function_default_signature(self):
|
||||
model = tracking.AutoCheckpointable()
|
||||
model = tracking.AutoTrackable()
|
||||
model.f = def_function.function(lambda: 3., input_signature=())
|
||||
model.f()
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
@ -223,7 +223,7 @@ class SaveTest(test.TestCase):
|
||||
_import_and_infer(save_dir, {}))
|
||||
|
||||
def test_single_function_no_signature(self):
|
||||
model = tracking.AutoCheckpointable()
|
||||
model = tracking.AutoTrackable()
|
||||
model.f = def_function.function(lambda: 3.)
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(model, save_dir)
|
||||
@ -322,7 +322,7 @@ class AssetTests(test.TestCase):
|
||||
f.write("alpha\nbeta\ngamma\n")
|
||||
|
||||
def test_asset_path_returned(self):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.path = tracking.TrackableAsset(self._vocab_path)
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
root.get_asset = def_function.function(lambda: root.path.asset_path)
|
||||
@ -362,7 +362,7 @@ class AssetTests(test.TestCase):
|
||||
_import_and_infer(second_dir, {"keys": ["gamma", "beta"]}))
|
||||
|
||||
def test_unused_asset(self):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(
|
||||
lambda x: 2. * x,
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
|
@ -1,6 +1,6 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "tensorflow/core/protobuf/checkpointable_object_graph.proto";
|
||||
import "tensorflow/core/protobuf/trackable_object_graph.proto";
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
import "tensorflow/core/framework/versions.proto";
|
||||
@ -14,9 +14,9 @@ package tensorflow;
|
||||
// describes the directed graph of Python objects (or equivalent in other
|
||||
// languages) that make up a model, with nodes[0] at the root.
|
||||
|
||||
// SavedObjectGraph shares some structure with CheckpointableObjectGraph, but
|
||||
// SavedObjectGraph shares some structure with TrackableObjectGraph, but
|
||||
// ObjectGraph belongs to the SavedModel and contains pointers to functions and
|
||||
// type information, while CheckpointableObjectGraph lives in the checkpoint and
|
||||
// type information, while TrackableObjectGraph lives in the checkpoint and
|
||||
// contains pointers only to variable values.
|
||||
|
||||
// NOTE: This protocol buffer format is experimental and subject to change.
|
||||
@ -38,10 +38,9 @@ message SavedObject {
|
||||
// graph.
|
||||
//
|
||||
// Note: only valid if kind == "object".
|
||||
repeated CheckpointableObjectGraph.CheckpointableObject.ObjectReference
|
||||
children = 1;
|
||||
repeated TrackableObjectGraph.TrackableObject.ObjectReference children = 1;
|
||||
|
||||
// Removed when forking from CheckpointableObjectGraph.
|
||||
// Removed when forking from TrackableObjectGraph.
|
||||
reserved "attributes";
|
||||
reserved 2;
|
||||
|
||||
@ -50,7 +49,7 @@ message SavedObject {
|
||||
// depend on the others directly.
|
||||
//
|
||||
// Note: only valid if kind == "object".
|
||||
repeated CheckpointableObjectGraph.CheckpointableObject.SlotVariableReference
|
||||
repeated TrackableObjectGraph.TrackableObject.SlotVariableReference
|
||||
slot_variables = 3;
|
||||
|
||||
oneof kind {
|
||||
|
@ -26,7 +26,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.saved_model import revived_types
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.training.checkpointable import base
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
@ -170,7 +170,7 @@ def _normalize_outputs(outputs, function_name, signature_key):
|
||||
# saved if they contain a _SignatureMap. A ".signatures" attribute containing
|
||||
# any other type (e.g. a regular dict) will raise an exception asking the user
|
||||
# to first "del obj.signatures" if they want it overwritten.
|
||||
class _SignatureMap(collections.Mapping, base.Checkpointable):
|
||||
class _SignatureMap(collections.Mapping, base.Trackable):
|
||||
"""A collection of SavedModel signatures."""
|
||||
|
||||
def __init__(self):
|
||||
@ -205,7 +205,7 @@ revived_types.register_revived_type(
|
||||
"signature_map",
|
||||
lambda obj: isinstance(obj, _SignatureMap),
|
||||
versions=[revived_types.VersionedTypeRegistration(
|
||||
# Standard dependencies are enough to reconstruct the checkpointable
|
||||
# Standard dependencies are enough to reconstruct the trackable
|
||||
# items in dictionaries, so we don't need to save any extra information.
|
||||
object_factory=lambda proto: _SignatureMap(),
|
||||
version=1,
|
||||
|
@ -38,7 +38,7 @@ from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_module
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
from tensorflow.python.training.checkpointable import util
|
||||
from tensorflow.python.training.tracking import util
|
||||
|
||||
|
||||
class LatestCheckpointWithRelativePaths(test.TestCase):
|
||||
|
@ -41,8 +41,8 @@ from tensorflow.python.training import queue_runner
|
||||
from tensorflow.python.training import saver as training_saver
|
||||
from tensorflow.python.training import session_manager as sm
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training.checkpointable import graph_view
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_util
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.training.tracking import util as trackable_util
|
||||
from tensorflow.python.util import function_utils
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -228,7 +228,7 @@ class Scaffold(object):
|
||||
if self._saver is None:
|
||||
self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access
|
||||
# pylint: enable=g-long-lambda
|
||||
if isinstance(self._saver, checkpointable_util.Checkpoint):
|
||||
if isinstance(self._saver, trackable_util.Checkpoint):
|
||||
self._saver = training_saver.Saver(
|
||||
var_list=graph_view.ObjectGraphView(
|
||||
self._saver).frozen_saveable_objects(),
|
||||
|
@ -39,7 +39,7 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import slot_creator
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -214,10 +214,10 @@ def _get_processor(v):
|
||||
|
||||
@tf_export(v1=["train.Optimizer"])
|
||||
class Optimizer(
|
||||
# Optimizers inherit from CheckpointableBase rather than Checkpointable
|
||||
# Optimizers inherit from Trackable rather than AutoTrackable
|
||||
# since they do most of their dependency management themselves (slot
|
||||
# variables are special-cased, and non-slot variables are keyed to graphs).
|
||||
checkpointable.Checkpointable):
|
||||
trackable.Trackable):
|
||||
"""Base class for optimizers.
|
||||
|
||||
This class defines the API to add Ops to train a model. You never use this
|
||||
@ -333,9 +333,9 @@ class Optimizer(
|
||||
# ... }
|
||||
self._slots = {}
|
||||
self._non_slot_dict = {}
|
||||
# For implementing Checkpointable. Stores information about how to restore
|
||||
# For implementing Trackable. Stores information about how to restore
|
||||
# slot variables which have not yet been created
|
||||
# (checkpointable._CheckpointPosition objects).
|
||||
# (trackable._CheckpointPosition objects).
|
||||
# {slot_name :
|
||||
# {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
|
||||
# ... }
|
||||
@ -796,7 +796,7 @@ class Optimizer(
|
||||
key = (name, graph)
|
||||
v = self._non_slot_dict.get(key, None)
|
||||
if v is None:
|
||||
self._maybe_initialize_checkpointable()
|
||||
self._maybe_initialize_trackable()
|
||||
distribution_strategy = distribute_ctx.get_strategy()
|
||||
with distribution_strategy.extended.colocate_vars_with(colocate_with):
|
||||
if eager:
|
||||
@ -809,19 +809,19 @@ class Optimizer(
|
||||
use_resource=resource_variable_ops.is_resource_variable(
|
||||
colocate_with))
|
||||
# Restore this variable by name if necessary, but don't add a
|
||||
# Checkpointable dependency. Optimizers return the current graph's
|
||||
# Trackable dependency. Optimizers return the current graph's
|
||||
# non-slot variables from _checkpoint_dependencies explicitly rather
|
||||
# than unconditionally adding dependencies (since there may be multiple
|
||||
# non-slot variables with the same name in different graphs, trying to
|
||||
# save all of them would result in errors).
|
||||
self._handle_deferred_dependencies(name=name, checkpointable=v)
|
||||
self._handle_deferred_dependencies(name=name, trackable=v)
|
||||
self._non_slot_dict[key] = v
|
||||
|
||||
return v
|
||||
|
||||
@property
|
||||
def _checkpoint_dependencies(self):
|
||||
"""From Checkpointable. Gather graph-specific non-slot variables to save."""
|
||||
"""From Trackable. Gather graph-specific non-slot variables to save."""
|
||||
current_graph_non_slot_variables = []
|
||||
current_graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
|
||||
for (name, _), variable_object in sorted(self._non_slot_dict.items(),
|
||||
@ -829,13 +829,13 @@ class Optimizer(
|
||||
key=lambda item: item[0][0]):
|
||||
if variable_object._graph_key == current_graph_key: # pylint: disable=protected-access
|
||||
current_graph_non_slot_variables.append(
|
||||
checkpointable.CheckpointableReference(
|
||||
trackable.TrackableReference(
|
||||
name=name, ref=variable_object))
|
||||
return (super(Optimizer, self)._checkpoint_dependencies
|
||||
+ current_graph_non_slot_variables)
|
||||
|
||||
def _lookup_dependency(self, name):
|
||||
"""From Checkpointable. Find a non-slot variable in the current graph."""
|
||||
"""From Trackable. Find a non-slot variable in the current graph."""
|
||||
unconditional = super(Optimizer, self)._lookup_dependency(name)
|
||||
if unconditional is not None:
|
||||
return unconditional
|
||||
@ -1140,7 +1140,7 @@ class Optimizer(
|
||||
return named_slots[_var_key(var)]
|
||||
|
||||
# --------------
|
||||
# For implementing the Checkpointable interface.
|
||||
# For implementing the Trackable interface.
|
||||
# --------------
|
||||
|
||||
def _restore_slot_variable(self, slot_name, variable, slot_variable):
|
||||
@ -1171,8 +1171,8 @@ class Optimizer(
|
||||
slot variable needs to be restored).
|
||||
|
||||
Args:
|
||||
slot_variable_position: A `checkpointable._CheckpointPosition` object
|
||||
indicating the slot variable `Checkpointable` object to be restored.
|
||||
slot_variable_position: A `trackable._CheckpointPosition` object
|
||||
indicating the slot variable `Trackable` object to be restored.
|
||||
slot_name: The name of this `Optimizer`'s slot to restore into.
|
||||
variable: The variable object this slot is being created for.
|
||||
"""
|
||||
@ -1190,7 +1190,7 @@ class Optimizer(
|
||||
# (aside from double initialization), and makes variable creator scopes
|
||||
# behave the same way they do when graph building.
|
||||
and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
|
||||
initializer = checkpointable.CheckpointInitialValue(
|
||||
initializer = trackable.CheckpointInitialValue(
|
||||
checkpoint_position=slot_variable_position)
|
||||
slot_variable = self._get_or_make_slot(
|
||||
var=variable,
|
||||
|
@ -17,7 +17,7 @@
|
||||
"""Save and restore variables.
|
||||
|
||||
Symbols in this file are deprecated. See replacements in
|
||||
tensorflow/python/training/checkpointable and tensorflow/python/training/saving.
|
||||
tensorflow/python/training/trackable and tensorflow/python/training/saving.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -29,10 +29,9 @@ import time
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import checkpointable_object_graph_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.core.protobuf import trackable_object_graph_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import context
|
||||
@ -51,9 +50,9 @@ from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.saving import saveable_object
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -1605,9 +1604,9 @@ def object_graph_key_mapping(checkpoint_path):
|
||||
"""
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
|
||||
object_graph_string = reader.get_tensor(
|
||||
checkpointable.OBJECT_GRAPH_PROTO_KEY)
|
||||
trackable.OBJECT_GRAPH_PROTO_KEY)
|
||||
object_graph_proto = (
|
||||
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
|
||||
trackable_object_graph_pb2.TrackableObjectGraph())
|
||||
object_graph_proto.ParseFromString(object_graph_string)
|
||||
names_to_keys = {}
|
||||
for node in object_graph_proto.nodes:
|
||||
|
@ -73,9 +73,9 @@ from tensorflow.python.training import queue_runner_impl
|
||||
from tensorflow.python.training import saver as saver_module
|
||||
from tensorflow.python.training import saver_test_utils
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable_base
|
||||
from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
from tensorflow.python.training.tracking import base as trackable_base
|
||||
from tensorflow.python.training.tracking import tracking as trackable_tracking
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@ -2775,15 +2775,15 @@ class ScopedGraphTest(test.TestCase):
|
||||
self.assertEqual(2.0, self.evaluate(var_dict2["variable2:0"]))
|
||||
|
||||
|
||||
class _OwnsAVariableSimple(checkpointable_base.Checkpointable):
|
||||
"""A Checkpointable object which can be saved using a tf.train.Saver."""
|
||||
class _OwnsAVariableSimple(trackable_base.Trackable):
|
||||
"""A Trackable object which can be saved using a tf.train.Saver."""
|
||||
|
||||
def __init__(self):
|
||||
self.non_dep_variable = variable_scope.get_variable(
|
||||
name="non_dep_variable", initializer=6., use_resource=True)
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
return {checkpointable_base.VARIABLE_VALUE_KEY: self.non_dep_variable}
|
||||
return {trackable_base.VARIABLE_VALUE_KEY: self.non_dep_variable}
|
||||
|
||||
# The Saver sorts by name before parsing, so we need a name property.
|
||||
@property
|
||||
@ -2808,8 +2808,8 @@ class _MirroringSaveable(
|
||||
self._mirrored_variable.assign(tensor))
|
||||
|
||||
|
||||
class _OwnsMirroredVariables(checkpointable_base.Checkpointable):
|
||||
"""A Checkpointable object which returns a more complex SaveableObject."""
|
||||
class _OwnsMirroredVariables(trackable_base.Trackable):
|
||||
"""A Trackable object which returns a more complex SaveableObject."""
|
||||
|
||||
def __init__(self):
|
||||
self.non_dep_variable = variable_scope.get_variable(
|
||||
@ -2823,7 +2823,7 @@ class _OwnsMirroredVariables(checkpointable_base.Checkpointable):
|
||||
primary_variable=self.non_dep_variable,
|
||||
mirrored_variable=self.mirrored,
|
||||
name=name)
|
||||
return {checkpointable_base.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||
return {trackable_base.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||
|
||||
# The Saver sorts by name before parsing, so we need a name property.
|
||||
@property
|
||||
@ -2831,11 +2831,11 @@ class _OwnsMirroredVariables(checkpointable_base.Checkpointable):
|
||||
return self.non_dep_variable.name
|
||||
|
||||
|
||||
class NonLayerCheckpointable(checkpointable_tracking.AutoCheckpointable):
|
||||
class NonLayerTrackable(trackable_tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
super(NonLayerCheckpointable, self).__init__()
|
||||
self.a_variable = checkpointable_utils.add_variable(
|
||||
super(NonLayerTrackable, self).__init__()
|
||||
self.a_variable = trackable_utils.add_variable(
|
||||
self, name="a_variable", shape=[])
|
||||
|
||||
|
||||
@ -2846,19 +2846,19 @@ class MyModel(training.Model):
|
||||
super(MyModel, self).__init__()
|
||||
self._named_dense = core.Dense(1, use_bias=True)
|
||||
self._second = core.Dense(1, use_bias=False)
|
||||
# We can still track Checkpointables which aren't Layers.
|
||||
self._non_layer = NonLayerCheckpointable()
|
||||
# We can still track Trackables which aren't Layers.
|
||||
self._non_layer = NonLayerTrackable()
|
||||
|
||||
def call(self, values):
|
||||
ret = self._second(self._named_dense(values))
|
||||
return ret
|
||||
|
||||
|
||||
class CheckpointableCompatibilityTests(test.TestCase):
|
||||
class TrackableCompatibilityTests(test.TestCase):
|
||||
|
||||
# TODO(allenl): Track down python3 reference cycles in these tests.
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testNotSaveableButIsCheckpointable(self):
|
||||
def testNotSaveableButIsTrackable(self):
|
||||
v = _OwnsAVariableSimple()
|
||||
test_dir = self.get_temp_dir()
|
||||
prefix = os.path.join(test_dir, "ckpt")
|
||||
@ -2923,13 +2923,13 @@ class CheckpointableCompatibilityTests(test.TestCase):
|
||||
model = MyModel()
|
||||
optimizer = adam.AdamOptimizer(0.001)
|
||||
optimizer_step = training_util.get_or_create_global_step()
|
||||
root_checkpointable = checkpointable_utils.Checkpoint(
|
||||
root_trackable = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model, optimizer_step=optimizer_step)
|
||||
train_op = optimizer.minimize(
|
||||
functools.partial(model, input_value),
|
||||
global_step=optimizer_step)
|
||||
self.evaluate(checkpointable_utils.gather_initializers(
|
||||
root_checkpointable))
|
||||
self.evaluate(trackable_utils.gather_initializers(
|
||||
root_trackable))
|
||||
self.evaluate(train_op)
|
||||
# A regular variable, a slot variable, and a non-slot Optimizer variable
|
||||
# with known values to check when loading.
|
||||
@ -2938,24 +2938,24 @@ class CheckpointableCompatibilityTests(test.TestCase):
|
||||
var=model._named_dense.bias, name="m").assign([2.]))
|
||||
beta1_power, _ = optimizer._get_beta_accumulators()
|
||||
self.evaluate(beta1_power.assign(3.))
|
||||
return root_checkpointable
|
||||
return root_trackable
|
||||
|
||||
def _set_sentinels(self, root_checkpointable):
|
||||
self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.]))
|
||||
def _set_sentinels(self, root_trackable):
|
||||
self.evaluate(root_trackable.model._named_dense.bias.assign([101.]))
|
||||
self.evaluate(
|
||||
root_checkpointable.optimizer.get_slot(
|
||||
var=root_checkpointable.model._named_dense.bias, name="m")
|
||||
root_trackable.optimizer.get_slot(
|
||||
var=root_trackable.model._named_dense.bias, name="m")
|
||||
.assign([102.]))
|
||||
beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
|
||||
beta1_power, _ = root_trackable.optimizer._get_beta_accumulators()
|
||||
self.evaluate(beta1_power.assign(103.))
|
||||
|
||||
def _check_sentinels(self, root_checkpointable):
|
||||
def _check_sentinels(self, root_trackable):
|
||||
self.assertAllEqual(
|
||||
[1.], self.evaluate(root_checkpointable.model._named_dense.bias))
|
||||
[1.], self.evaluate(root_trackable.model._named_dense.bias))
|
||||
self.assertAllEqual([2.], self.evaluate(
|
||||
root_checkpointable.optimizer.get_slot(
|
||||
var=root_checkpointable.model._named_dense.bias, name="m")))
|
||||
beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
|
||||
root_trackable.optimizer.get_slot(
|
||||
var=root_trackable.model._named_dense.bias, name="m")))
|
||||
beta1_power, _ = root_trackable.optimizer._get_beta_accumulators()
|
||||
self.assertAllEqual(3., self.evaluate(beta1_power))
|
||||
|
||||
def testVariableNotFoundErrorRaised(self):
|
||||
@ -3012,13 +3012,13 @@ class CheckpointableCompatibilityTests(test.TestCase):
|
||||
save_graph = ops_lib.Graph()
|
||||
with save_graph.as_default(), self.session(graph=save_graph) as sess:
|
||||
root = self._initialized_model()
|
||||
object_saver = checkpointable_utils.Checkpoint(root=root)
|
||||
object_saver = trackable_utils.Checkpoint(root=root)
|
||||
save_path = object_saver.save(file_prefix=checkpoint_prefix)
|
||||
|
||||
# An incompatible object-based checkpoint to check error messages
|
||||
var = resource_variable_ops.ResourceVariable(1., name="a")
|
||||
self.evaluate(var.initializer)
|
||||
second_saver = checkpointable_utils.Checkpoint(v=var)
|
||||
second_saver = trackable_utils.Checkpoint(v=var)
|
||||
second_path = second_saver.save(file_prefix=os.path.join(
|
||||
checkpoint_directory, "second"))
|
||||
|
||||
@ -3046,7 +3046,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
|
||||
save_graph = ops_lib.Graph()
|
||||
with save_graph.as_default(), self.session(graph=save_graph):
|
||||
root = self._initialized_model()
|
||||
object_saver = checkpointable_utils.Checkpoint(root=root)
|
||||
object_saver = trackable_utils.Checkpoint(root=root)
|
||||
save_path = object_saver.save(file_prefix=checkpoint_prefix)
|
||||
|
||||
with context.eager_mode():
|
||||
|
@ -49,7 +49,7 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -26,8 +26,8 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.saving import saveable_object
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
|
||||
|
||||
# Op names which identify variable reads which should be saved.
|
||||
@ -137,7 +137,7 @@ def saveable_objects_for_op(op, name):
|
||||
if not isinstance(name, six.string_types):
|
||||
raise TypeError(
|
||||
"names_to_saveables must be a dict mapping string names to "
|
||||
"checkpointable operations. Name is not a string: %s" % name)
|
||||
"trackable operations. Name is not a string: %s" % name)
|
||||
if isinstance(op, saveable_object.SaveableObject):
|
||||
yield op
|
||||
elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
|
||||
@ -165,11 +165,11 @@ def saveable_objects_for_op(op, name):
|
||||
yield ResourceVariableSaveable(
|
||||
variable, variable._save_slice_info.spec, name)
|
||||
# pylint: enable=protected-access
|
||||
elif isinstance(op, checkpointable.Checkpointable) and not isinstance(
|
||||
elif isinstance(op, trackable.Trackable) and not isinstance(
|
||||
op, variables.Variable):
|
||||
# pylint: disable=protected-access
|
||||
for attr, factory in op._gather_saveables_for_checkpoint().items():
|
||||
if attr == checkpointable.VARIABLE_VALUE_KEY:
|
||||
if attr == trackable.VARIABLE_VALUE_KEY:
|
||||
# Keep original name for classes masquerading as variables.
|
||||
full_name = name
|
||||
else:
|
||||
@ -250,13 +250,13 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
|
||||
names_to_saveables[name].append(var)
|
||||
else:
|
||||
names_to_saveables[name] = [var]
|
||||
elif (isinstance(var, checkpointable.Checkpointable)
|
||||
elif (isinstance(var, trackable.Trackable)
|
||||
and not isinstance(var, variables.Variable)):
|
||||
checkpointable_saveables = [
|
||||
trackable_saveables = [
|
||||
(factory() if callable(factory) else factory)
|
||||
for factory in var._gather_saveables_for_checkpoint().values()]
|
||||
names_to_saveables.update(
|
||||
op_list_to_dict(checkpointable_saveables))
|
||||
op_list_to_dict(trackable_saveables))
|
||||
else:
|
||||
# Variables (reference and resource) have an _in_graph_mode property
|
||||
# indicating whether they were created in a graph building context. We
|
||||
@ -326,7 +326,7 @@ def validate_and_slice_inputs(names_to_saveables):
|
||||
|
||||
Raises:
|
||||
TypeError: If any of the keys are not strings or any of the
|
||||
values are not one of Tensor or Variable or a checkpointable operation.
|
||||
values are not one of Tensor or Variable or a trackable operation.
|
||||
ValueError: If the same operation is given in more than one value
|
||||
(this also applies to slices of SlicedVariables).
|
||||
"""
|
||||
|
@ -44,18 +44,18 @@ OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
|
||||
|
||||
|
||||
# A key indicating a variable's value in an object's checkpointed Tensors
|
||||
# (Checkpointable._gather_saveables_for_checkpoint). If this is the only key and
|
||||
# (Trackable._gather_saveables_for_checkpoint). If this is the only key and
|
||||
# the object has no dependencies, then its value may be restored on object
|
||||
# creation (avoiding double assignment when executing eagerly).
|
||||
VARIABLE_VALUE_KEY = "VARIABLE_VALUE"
|
||||
OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"
|
||||
|
||||
CheckpointableReference = collections.namedtuple(
|
||||
"CheckpointableReference",
|
||||
TrackableReference = collections.namedtuple(
|
||||
"TrackableReference",
|
||||
[
|
||||
# The local name for this dependency.
|
||||
"name",
|
||||
# The Checkpointable object being referenced.
|
||||
# The Trackable object being referenced.
|
||||
"ref"
|
||||
])
|
||||
|
||||
@ -195,26 +195,26 @@ class CheckpointPosition(object):
|
||||
|
||||
Args:
|
||||
checkpoint: A _CheckpointRestoreCoordinator object.
|
||||
proto_id: The index of this object in CheckpointableObjectGraph.nodes.
|
||||
proto_id: The index of this object in TrackableObjectGraph.nodes.
|
||||
"""
|
||||
self._checkpoint = checkpoint
|
||||
self._proto_id = proto_id
|
||||
|
||||
def restore(self, checkpointable):
|
||||
"""Restore this value into `checkpointable`."""
|
||||
def restore(self, trackable):
|
||||
"""Restore this value into `trackable`."""
|
||||
with ops.init_scope():
|
||||
if self.bind_object(checkpointable):
|
||||
if self.bind_object(trackable):
|
||||
# This object's correspondence with a checkpointed object is new, so
|
||||
# process deferred restorations for it and its dependencies.
|
||||
restore_ops = checkpointable._restore_from_checkpoint_position(self) # pylint: disable=protected-access
|
||||
restore_ops = trackable._restore_from_checkpoint_position(self) # pylint: disable=protected-access
|
||||
if restore_ops:
|
||||
self._checkpoint.new_restore_ops(restore_ops)
|
||||
|
||||
def bind_object(self, checkpointable):
|
||||
def bind_object(self, trackable):
|
||||
"""Set a checkpoint<->object correspondence and process slot variables.
|
||||
|
||||
Args:
|
||||
checkpointable: The object to record a correspondence for.
|
||||
trackable: The object to record a correspondence for.
|
||||
Returns:
|
||||
True if this is a new assignment, False if this object has already been
|
||||
mapped to a checkpointed `Object` proto.
|
||||
@ -222,13 +222,13 @@ class CheckpointPosition(object):
|
||||
AssertionError: If another object is already bound to the `Object` proto.
|
||||
"""
|
||||
checkpoint = self.checkpoint
|
||||
checkpoint.all_python_objects.add(checkpointable)
|
||||
checkpoint.all_python_objects.add(trackable)
|
||||
current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
|
||||
if current_assignment is None:
|
||||
checkpoint.object_by_proto_id[self._proto_id] = checkpointable
|
||||
checkpoint.object_by_proto_id[self._proto_id] = trackable
|
||||
for deferred_slot_restoration in (
|
||||
checkpoint.deferred_slot_restorations.pop(self._proto_id, ())):
|
||||
checkpointable._create_or_restore_slot_variable( # pylint: disable=protected-access
|
||||
trackable._create_or_restore_slot_variable( # pylint: disable=protected-access
|
||||
slot_variable_position=CheckpointPosition(
|
||||
checkpoint=checkpoint,
|
||||
proto_id=deferred_slot_restoration.slot_variable_id),
|
||||
@ -244,7 +244,7 @@ class CheckpointPosition(object):
|
||||
checkpoint.deferred_slot_restorations.setdefault(
|
||||
slot_restoration.optimizer_id, []).append(
|
||||
_DeferredSlotVariableRestoration(
|
||||
original_variable=checkpointable,
|
||||
original_variable=trackable,
|
||||
slot_variable_id=slot_restoration.slot_variable_id,
|
||||
slot_name=slot_restoration.slot_name))
|
||||
else:
|
||||
@ -252,7 +252,7 @@ class CheckpointPosition(object):
|
||||
slot_variable_position=CheckpointPosition(
|
||||
checkpoint=checkpoint,
|
||||
proto_id=slot_restoration.slot_variable_id),
|
||||
variable=checkpointable,
|
||||
variable=trackable,
|
||||
slot_name=slot_restoration.slot_name)
|
||||
return True # New assignment
|
||||
else:
|
||||
@ -260,14 +260,14 @@ class CheckpointPosition(object):
|
||||
# we don't need to do anything besides check that the mapping is
|
||||
# consistent (if the dependency DAG is not a tree then there are
|
||||
# multiple paths to the same object).
|
||||
if current_assignment is not checkpointable:
|
||||
if current_assignment is not trackable:
|
||||
logging.warning(
|
||||
("Inconsistent references when loading the checkpoint into this "
|
||||
"object graph. Either the Checkpointable object references in the "
|
||||
"object graph. Either the Trackable object references in the "
|
||||
"Python program have changed in an incompatible way, or the "
|
||||
"checkpoint was generated in an incompatible program.\n\nTwo "
|
||||
"checkpoint references resolved to different objects (%s and %s).")
|
||||
% (current_assignment, checkpointable))
|
||||
% (current_assignment, trackable))
|
||||
return False # Not a new assignment
|
||||
|
||||
def is_simple_variable(self):
|
||||
@ -306,7 +306,7 @@ class CheckpointPosition(object):
|
||||
|
||||
def _gather_ops_or_named_saveables(self):
|
||||
"""Looks up or creates SaveableObjects which don't have cached ops."""
|
||||
saveables = self.checkpointable._gather_saveables_for_checkpoint() # pylint: disable=protected-access
|
||||
saveables = self.trackable._gather_saveables_for_checkpoint() # pylint: disable=protected-access
|
||||
# Name saveables based on the name this object had when it was checkpointed.
|
||||
named_saveables = {}
|
||||
python_saveables = []
|
||||
@ -334,7 +334,7 @@ class CheckpointPosition(object):
|
||||
# attribute, we can re-use it to avoid re-creating some ops when graph
|
||||
# building.
|
||||
saveable_list = saveables_cache.get(
|
||||
self.checkpointable, {}).get(serialized_tensor.name, (None,))
|
||||
self.trackable, {}).get(serialized_tensor.name, (None,))
|
||||
if len(saveable_list) == 1:
|
||||
# Almost every attribute will have exactly one SaveableObject.
|
||||
saveable, = saveable_list
|
||||
@ -348,7 +348,7 @@ class CheckpointPosition(object):
|
||||
# the SaveableObject.
|
||||
if serialized_tensor.checkpoint_key not in saveable.name:
|
||||
saveable = None
|
||||
del saveables_cache[self.checkpointable]
|
||||
del saveables_cache[self.trackable]
|
||||
break
|
||||
if saveable is None:
|
||||
# If there was no cached SaveableObject, we should check if the Python
|
||||
@ -361,7 +361,7 @@ class CheckpointPosition(object):
|
||||
# checkpoint was loaded.
|
||||
if not serialized_tensor.optional_restore:
|
||||
self._checkpoint.unused_attributes.setdefault(
|
||||
self.checkpointable, []).append(serialized_tensor.name)
|
||||
self.trackable, []).append(serialized_tensor.name)
|
||||
continue
|
||||
if callable(saveable_factory):
|
||||
saveable = saveable_factory(name=serialized_tensor.checkpoint_key)
|
||||
@ -369,7 +369,7 @@ class CheckpointPosition(object):
|
||||
saveable = saveable_factory
|
||||
if saveables_cache is not None:
|
||||
saveables_cache.setdefault(
|
||||
self.checkpointable, {})[serialized_tensor.name] = [saveable]
|
||||
self.trackable, {})[serialized_tensor.name] = [saveable]
|
||||
if isinstance(saveable, PythonStateSaveable):
|
||||
python_saveables.append(saveable)
|
||||
else:
|
||||
@ -379,7 +379,7 @@ class CheckpointPosition(object):
|
||||
def restore_ops(self):
|
||||
"""Create or fetch restore ops for this object's attributes.
|
||||
|
||||
Requires that the `Checkpointable` Python object has been bound to an object
|
||||
Requires that the `Trackable` Python object has been bound to an object
|
||||
ID in the checkpoint.
|
||||
|
||||
Returns:
|
||||
@ -398,7 +398,7 @@ class CheckpointPosition(object):
|
||||
return self._checkpoint
|
||||
|
||||
@property
|
||||
def checkpointable(self):
|
||||
def trackable(self):
|
||||
return self._checkpoint.object_by_proto_id[self._proto_id]
|
||||
|
||||
@property
|
||||
@ -436,11 +436,11 @@ _SlotVariableRestoration = collections.namedtuple(
|
||||
def no_automatic_dependency_tracking(method):
|
||||
"""Disables automatic dependency tracking on attribute assignment.
|
||||
|
||||
Use to decorate any method of a Checkpointable object. Attribute assignment in
|
||||
Use to decorate any method of a Trackable object. Attribute assignment in
|
||||
that method will not add dependencies (also respected in Model). Harmless if
|
||||
used in a class which does not do automatic dependency tracking (which means
|
||||
it's safe to use in base classes which may have subclasses which also inherit
|
||||
from Checkpointable).
|
||||
from Trackable).
|
||||
|
||||
Args:
|
||||
method: The method to decorate.
|
||||
@ -461,37 +461,37 @@ def no_automatic_dependency_tracking(method):
|
||||
target=method, decorator_func=_method_wrapper)
|
||||
|
||||
|
||||
class Checkpointable(object):
|
||||
"""Base class for `Checkpointable` objects without automatic dependencies.
|
||||
class Trackable(object):
|
||||
"""Base class for `Trackable` objects without automatic dependencies.
|
||||
|
||||
This class has no __setattr__ override for performance reasons. Dependencies
|
||||
must be added explicitly. Unless attribute assignment is performance-critical,
|
||||
use `AutoCheckpointable` instead. Use `Checkpointable` for `isinstance`
|
||||
use `AutoTrackable` instead. Use `Trackable` for `isinstance`
|
||||
checks.
|
||||
"""
|
||||
|
||||
# Checkpointable does not do automatic dependency tracking, but uses the
|
||||
# Trackable does not do automatic dependency tracking, but uses the
|
||||
# no_automatic_dependency_tracking decorator so it can avoid adding
|
||||
# dependencies if a subclass is Checkpointable / inherits from Model (both of
|
||||
# dependencies if a subclass is Trackable / inherits from Model (both of
|
||||
# which have __setattr__ overrides).
|
||||
@no_automatic_dependency_tracking
|
||||
def _maybe_initialize_checkpointable(self):
|
||||
def _maybe_initialize_trackable(self):
|
||||
"""Initialize dependency management.
|
||||
|
||||
Not __init__, since most objects will forget to call it.
|
||||
"""
|
||||
if hasattr(self, "_unconditional_checkpoint_dependencies"):
|
||||
# __init__ already called. This check means that we don't need
|
||||
# Checkpointable.__init__() in the constructor of every TensorFlow object.
|
||||
# Trackable.__init__() in the constructor of every TensorFlow object.
|
||||
return
|
||||
# A list of CheckpointableReference objects. Some classes implementing
|
||||
# `Checkpointable`, notably `Optimizer`s, may override the
|
||||
# A list of TrackableReference objects. Some classes implementing
|
||||
# `Trackable`, notably `Optimizer`s, may override the
|
||||
# _checkpoint_dependencies property with conditional dependencies
|
||||
# (e.g. based on the current graph when saving).
|
||||
self._unconditional_checkpoint_dependencies = []
|
||||
# Maps names -> Checkpointable objects
|
||||
# Maps names -> Trackable objects
|
||||
self._unconditional_dependency_names = {}
|
||||
# Restorations for other Checkpointable objects on which this object may
|
||||
# Restorations for other Trackable objects on which this object may
|
||||
# eventually depend. Maps local name -> CheckpointPosition list. Optimizers
|
||||
# tack on conditional dependencies, and so need separate management of
|
||||
# deferred dependencies too.
|
||||
@ -530,8 +530,8 @@ class Checkpointable(object):
|
||||
May be overridden to include conditional dependencies.
|
||||
|
||||
Returns:
|
||||
A list of `CheckpointableReference` objects indicating named
|
||||
`Checkpointable` dependencies which should be saved along with this
|
||||
A list of `TrackableReference` objects indicating named
|
||||
`Trackable` dependencies which should be saved along with this
|
||||
object.
|
||||
"""
|
||||
return self._unconditional_checkpoint_dependencies
|
||||
@ -540,7 +540,7 @@ class Checkpointable(object):
|
||||
def _deferred_dependencies(self):
|
||||
"""A dictionary with deferred dependencies.
|
||||
|
||||
Stores restorations for other Checkpointable objects on which this object
|
||||
Stores restorations for other Trackable objects on which this object
|
||||
may eventually depend. May be overridden by sub-classes (e.g. Optimizers use
|
||||
conditional dependencies based the current graph, and so need separate
|
||||
management of deferred dependencies too).
|
||||
@ -559,7 +559,7 @@ class Checkpointable(object):
|
||||
Args:
|
||||
name: The local name of the dependency.
|
||||
Returns:
|
||||
A `Checkpointable` object, or `None` if no dependency by this name was
|
||||
A `Trackable` object, or `None` if no dependency by this name was
|
||||
found.
|
||||
"""
|
||||
return self._unconditional_dependency_names.get(name, None)
|
||||
@ -568,9 +568,9 @@ class Checkpointable(object):
|
||||
self, name, shape=None, dtype=dtypes.float32,
|
||||
initializer=None, getter=None, overwrite=False,
|
||||
**kwargs_for_getter):
|
||||
"""Restore-on-create for a variable be saved with this `Checkpointable`.
|
||||
"""Restore-on-create for a variable be saved with this `Trackable`.
|
||||
|
||||
If the user has requested that this object or another `Checkpointable` which
|
||||
If the user has requested that this object or another `Trackable` which
|
||||
depends on this object be restored from a checkpoint (deferred loading
|
||||
before variable object creation), `initializer` may be ignored and the value
|
||||
from the checkpoint used instead.
|
||||
@ -592,7 +592,7 @@ class Checkpointable(object):
|
||||
Raises:
|
||||
ValueError: If the variable name is not unique.
|
||||
"""
|
||||
self._maybe_initialize_checkpointable()
|
||||
self._maybe_initialize_trackable()
|
||||
with ops.init_scope():
|
||||
if context.executing_eagerly():
|
||||
# If this is a variable with a single Tensor stored in the checkpoint,
|
||||
@ -608,11 +608,11 @@ class Checkpointable(object):
|
||||
isinstance(initializer, CheckpointInitialValue)
|
||||
and (initializer.restore_uid
|
||||
> checkpoint_initializer.restore_uid))):
|
||||
# If multiple Checkpointable objects are "creating" the same variable
|
||||
# If multiple Trackable objects are "creating" the same variable
|
||||
# via the magic of custom getters, the one with the highest restore UID
|
||||
# (the one called last) has to make the final initializer. If another
|
||||
# custom getter interrupts this process by overwriting the initializer,
|
||||
# then we'll catch that when we call _track_checkpointable. So this is
|
||||
# then we'll catch that when we call _track_trackable. So this is
|
||||
# "best effort" to set the initializer with the highest restore UID.
|
||||
initializer = checkpoint_initializer
|
||||
shape = None
|
||||
@ -624,12 +624,12 @@ class Checkpointable(object):
|
||||
# assign again. It will add this variable to our dependencies, and if there
|
||||
# is a non-trivial restoration queued, it will handle that. This also
|
||||
# handles slot variables.
|
||||
if not overwrite or isinstance(new_variable, Checkpointable):
|
||||
return self._track_checkpointable(new_variable, name=name,
|
||||
overwrite=overwrite)
|
||||
if not overwrite or isinstance(new_variable, Trackable):
|
||||
return self._track_trackable(new_variable, name=name,
|
||||
overwrite=overwrite)
|
||||
else:
|
||||
# TODO(allenl): Some variable types are not yet supported. Remove this
|
||||
# fallback once all get_variable() return types are Checkpointable.
|
||||
# fallback once all get_variable() return types are Trackable.
|
||||
return new_variable
|
||||
|
||||
def _preload_simple_restoration(self, name, shape):
|
||||
@ -668,46 +668,46 @@ class Checkpointable(object):
|
||||
return CheckpointInitialValue(
|
||||
checkpoint_position=checkpoint_position, shape=shape)
|
||||
|
||||
def _track_checkpointable(self, checkpointable, name, overwrite=False):
|
||||
"""Declare a dependency on another `Checkpointable` object.
|
||||
def _track_trackable(self, trackable, name, overwrite=False):
|
||||
"""Declare a dependency on another `Trackable` object.
|
||||
|
||||
Indicates that checkpoints for this object should include variables from
|
||||
`checkpointable`.
|
||||
`trackable`.
|
||||
|
||||
Variables in a checkpoint are mapped to `Checkpointable`s based on the names
|
||||
Variables in a checkpoint are mapped to `Trackable`s based on the names
|
||||
provided when the checkpoint was written. To avoid breaking existing
|
||||
checkpoints when modifying a class, neither variable names nor dependency
|
||||
names (the names passed to `_track_checkpointable`) may change.
|
||||
names (the names passed to `_track_trackable`) may change.
|
||||
|
||||
Args:
|
||||
checkpointable: A `Checkpointable` which this object depends on.
|
||||
name: A local name for `checkpointable`, used for loading checkpoints into
|
||||
trackable: A `Trackable` which this object depends on.
|
||||
name: A local name for `trackable`, used for loading checkpoints into
|
||||
the correct objects.
|
||||
overwrite: Boolean, whether silently replacing dependencies is OK. Used
|
||||
for __setattr__, where throwing an error on attribute reassignment would
|
||||
be inappropriate.
|
||||
|
||||
Returns:
|
||||
`checkpointable`, for convenience when declaring a dependency and
|
||||
`trackable`, for convenience when declaring a dependency and
|
||||
assigning to a member variable in one statement.
|
||||
|
||||
Raises:
|
||||
TypeError: If `checkpointable` does not inherit from `Checkpointable`.
|
||||
TypeError: If `trackable` does not inherit from `Trackable`.
|
||||
ValueError: If another object is already tracked by this name.
|
||||
"""
|
||||
self._maybe_initialize_checkpointable()
|
||||
if not isinstance(checkpointable, Checkpointable):
|
||||
self._maybe_initialize_trackable()
|
||||
if not isinstance(trackable, Trackable):
|
||||
raise TypeError(
|
||||
("Checkpointable._track_checkpointable() passed type %s, not a "
|
||||
"Checkpointable.") % (type(checkpointable),))
|
||||
new_reference = CheckpointableReference(name=name, ref=checkpointable)
|
||||
("Trackable._track_trackable() passed type %s, not a "
|
||||
"Trackable.") % (type(trackable),))
|
||||
new_reference = TrackableReference(name=name, ref=trackable)
|
||||
current_object = self._lookup_dependency(name)
|
||||
if (current_object is not None
|
||||
and current_object is not checkpointable):
|
||||
and current_object is not trackable):
|
||||
if not overwrite:
|
||||
raise ValueError(
|
||||
("Called Checkpointable._track_checkpointable() with name='%s', "
|
||||
"but a Checkpointable with this name is already declared as a "
|
||||
("Called Trackable._track_trackable() with name='%s', "
|
||||
"but a Trackable with this name is already declared as a "
|
||||
"dependency. Names must be unique (or overwrite=True).") % (name,))
|
||||
# This is a weird thing to do, but we're not going to stop people from
|
||||
# using __setattr__.
|
||||
@ -718,20 +718,20 @@ class Checkpointable(object):
|
||||
elif current_object is None:
|
||||
self._unconditional_checkpoint_dependencies.append(new_reference)
|
||||
self._handle_deferred_dependencies(
|
||||
name=name, checkpointable=checkpointable)
|
||||
self._unconditional_dependency_names[name] = checkpointable
|
||||
return checkpointable
|
||||
name=name, trackable=trackable)
|
||||
self._unconditional_dependency_names[name] = trackable
|
||||
return trackable
|
||||
|
||||
def _handle_deferred_dependencies(self, name, checkpointable):
|
||||
"""Pop and load any deferred checkpoint restores into `checkpointable`.
|
||||
def _handle_deferred_dependencies(self, name, trackable):
|
||||
"""Pop and load any deferred checkpoint restores into `trackable`.
|
||||
|
||||
This method does not add a new dependency on `checkpointable`, but it does
|
||||
This method does not add a new dependency on `trackable`, but it does
|
||||
check if any outstanding/deferred dependencies have been queued waiting for
|
||||
this dependency to be added (matched based on `name`). If so,
|
||||
`checkpointable` and its dependencies are restored. The restorations are
|
||||
`trackable` and its dependencies are restored. The restorations are
|
||||
considered fulfilled and so are deleted.
|
||||
|
||||
`_track_checkpointable` is more appropriate for adding a
|
||||
`_track_trackable` is more appropriate for adding a
|
||||
normal/unconditional dependency, and includes handling for deferred
|
||||
restorations. This method allows objects such as `Optimizer` to use the same
|
||||
restoration logic while managing conditional dependencies themselves, by
|
||||
@ -741,25 +741,25 @@ class Checkpointable(object):
|
||||
|
||||
Args:
|
||||
name: The name of the dependency within this object (`self`), used to
|
||||
match `checkpointable` with values saved in a checkpoint.
|
||||
checkpointable: The Checkpointable object to restore (inheriting from
|
||||
`Checkpointable`).
|
||||
match `trackable` with values saved in a checkpoint.
|
||||
trackable: The Trackable object to restore (inheriting from
|
||||
`Trackable`).
|
||||
"""
|
||||
self._maybe_initialize_checkpointable()
|
||||
checkpointable._maybe_initialize_checkpointable() # pylint: disable=protected-access
|
||||
self._maybe_initialize_trackable()
|
||||
trackable._maybe_initialize_trackable() # pylint: disable=protected-access
|
||||
deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
|
||||
for checkpoint_position in sorted(
|
||||
deferred_dependencies_list,
|
||||
key=lambda restore: restore.checkpoint.restore_uid,
|
||||
reverse=True):
|
||||
checkpoint_position.restore(checkpointable)
|
||||
checkpoint_position.restore(trackable)
|
||||
|
||||
# Pass on any name-based restores queued in this object.
|
||||
for name_based_restore in sorted(
|
||||
self._name_based_restores,
|
||||
key=lambda checkpoint: checkpoint.restore_uid,
|
||||
reverse=True):
|
||||
checkpointable._name_based_attribute_restore(name_based_restore) # pylint: disable=protected-access
|
||||
trackable._name_based_attribute_restore(name_based_restore) # pylint: disable=protected-access
|
||||
|
||||
def _restore_from_checkpoint_position(self, checkpoint_position):
|
||||
"""Restore this object and its dependencies (may be deferred)."""
|
||||
@ -772,7 +772,7 @@ class Checkpointable(object):
|
||||
while visit_queue:
|
||||
current_position = visit_queue.popleft()
|
||||
restore_ops.extend(nest.flatten(
|
||||
current_position.checkpointable # pylint: disable=protected-access
|
||||
current_position.trackable # pylint: disable=protected-access
|
||||
._single_restoration_from_checkpoint_position(
|
||||
checkpoint_position=current_position,
|
||||
visit_queue=visit_queue)))
|
||||
@ -781,7 +781,7 @@ class Checkpointable(object):
|
||||
def _single_restoration_from_checkpoint_position(
|
||||
self, checkpoint_position, visit_queue):
|
||||
"""Restore this object, and either queue its dependencies or defer them."""
|
||||
self._maybe_initialize_checkpointable()
|
||||
self._maybe_initialize_trackable()
|
||||
checkpoint = checkpoint_position.checkpoint
|
||||
# If the UID of this restore is lower than our current update UID, we don't
|
||||
# need to actually restore the object. However, we should pass the
|
||||
@ -802,7 +802,7 @@ class Checkpointable(object):
|
||||
self._deferred_dependencies.setdefault(child.local_name, []).append(
|
||||
child_position)
|
||||
else:
|
||||
if child_position.bind_object(checkpointable=local_object):
|
||||
if child_position.bind_object(trackable=local_object):
|
||||
# This object's correspondence is new, so dependencies need to be
|
||||
# visited. Delay doing it so that we get a breadth-first dependency
|
||||
# resolution order (shallowest paths first). The caller is responsible
|
||||
@ -818,7 +818,7 @@ class Checkpointable(object):
|
||||
or variables easily converted to `SaveableObject`s (as in `tf.train.Saver`'s
|
||||
`var_list` constructor argument).
|
||||
|
||||
`SaveableObjects` have a name set, which Checkpointable needs to generate
|
||||
`SaveableObjects` have a name set, which Trackable needs to generate
|
||||
itself. So rather than returning `SaveableObjects` directly, this method
|
||||
should return a dictionary of callables which take `name` arguments and
|
||||
return `SaveableObjects` with that name.
|
||||
@ -861,10 +861,10 @@ class Checkpointable(object):
|
||||
state_callback=_state_callback)}
|
||||
|
||||
def _list_functions_for_serialization(self):
|
||||
"""Lists the functions of this checkpointable to serialize.
|
||||
"""Lists the functions of this trackable to serialize.
|
||||
|
||||
Internal sub-classes can override this with specific logic. E.g.
|
||||
`AutoCheckpointable` provides an implementation that returns the `attr`
|
||||
`AutoTrackable` provides an implementation that returns the `attr`
|
||||
that return functions.
|
||||
|
||||
Returns:
|
@ -22,29 +22,29 @@ import os
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.checkpointable import base
|
||||
from tensorflow.python.training.checkpointable import util
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import util
|
||||
|
||||
|
||||
class InterfaceTests(test.TestCase):
|
||||
|
||||
def testOverwrite(self):
|
||||
root = base.Checkpointable()
|
||||
leaf = base.Checkpointable()
|
||||
root._track_checkpointable(leaf, name="leaf")
|
||||
root = base.Trackable()
|
||||
leaf = base.Trackable()
|
||||
root._track_trackable(leaf, name="leaf")
|
||||
(current_name, current_dependency), = root._checkpoint_dependencies
|
||||
self.assertIs(leaf, current_dependency)
|
||||
self.assertEqual("leaf", current_name)
|
||||
duplicate_name_dep = base.Checkpointable()
|
||||
duplicate_name_dep = base.Trackable()
|
||||
with self.assertRaises(ValueError):
|
||||
root._track_checkpointable(duplicate_name_dep, name="leaf")
|
||||
root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True)
|
||||
root._track_trackable(duplicate_name_dep, name="leaf")
|
||||
root._track_trackable(duplicate_name_dep, name="leaf", overwrite=True)
|
||||
(current_name, current_dependency), = root._checkpoint_dependencies
|
||||
self.assertIs(duplicate_name_dep, current_dependency)
|
||||
self.assertEqual("leaf", current_name)
|
||||
|
||||
def testAddVariableOverwrite(self):
|
||||
root = base.Checkpointable()
|
||||
root = base.Trackable()
|
||||
a = root._add_variable_with_custom_getter(
|
||||
name="v", shape=[], getter=variable_scope.get_variable)
|
||||
self.assertEqual([root, a], util.list_objects(root))
|
||||
@ -61,15 +61,15 @@ class InterfaceTests(test.TestCase):
|
||||
getter=variable_scope.get_variable)
|
||||
|
||||
def testAssertConsumedWithUnusedPythonState(self):
|
||||
has_config = base.Checkpointable()
|
||||
has_config = base.Trackable()
|
||||
has_config.get_config = lambda: {}
|
||||
saved = util.Checkpoint(obj=has_config)
|
||||
save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
||||
restored = util.Checkpoint(obj=base.Checkpointable())
|
||||
restored = util.Checkpoint(obj=base.Trackable())
|
||||
restored.restore(save_path).assert_consumed()
|
||||
|
||||
def testAssertConsumedFailsWithUsedPythonState(self):
|
||||
has_config = base.Checkpointable()
|
||||
has_config = base.Trackable()
|
||||
attributes = {
|
||||
"foo_attr": functools.partial(
|
||||
base.PythonStringStateSaveable,
|
||||
@ -78,7 +78,7 @@ class InterfaceTests(test.TestCase):
|
||||
has_config._gather_saveables_for_checkpoint = lambda: attributes
|
||||
saved = util.Checkpoint(obj=has_config)
|
||||
save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
||||
restored = util.Checkpoint(obj=base.Checkpointable())
|
||||
restored = util.Checkpoint(obj=base.Trackable())
|
||||
status = restored.restore(save_path)
|
||||
with self.assertRaisesRegexp(AssertionError, "foo_attr"):
|
||||
status.assert_consumed()
|
@ -1,4 +1,4 @@
|
||||
"""Checkpointable data structures."""
|
||||
"""Trackable data structures."""
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -28,16 +28,16 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.saved_model import revived_types
|
||||
from tensorflow.python.training.checkpointable import base
|
||||
from tensorflow.python.training.checkpointable import layer_utils
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import layer_utils
|
||||
|
||||
|
||||
class NoDependency(object):
|
||||
"""Allows attribute assignment to `Checkpointable` objects with no dependency.
|
||||
"""Allows attribute assignment to `Trackable` objects with no dependency.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
obj = Checkpointable()
|
||||
obj = Trackable()
|
||||
obj.has_dependency = tf.Variable(0., name="dep")
|
||||
obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
|
||||
assert obj.no_dependency.name == "nodep:0"
|
||||
@ -61,8 +61,8 @@ def _wrap_or_unwrap(value):
|
||||
"""Wraps basic data structures, unwraps NoDependency objects."""
|
||||
if isinstance(value, NoDependency):
|
||||
return value.value
|
||||
if isinstance(value, base.Checkpointable):
|
||||
return value # Skip conversion for already checkpointable objects.
|
||||
if isinstance(value, base.Trackable):
|
||||
return value # Skip conversion for already trackable objects.
|
||||
elif isinstance(value, dict):
|
||||
return _DictWrapper(value)
|
||||
elif isinstance(value, list):
|
||||
@ -77,19 +77,19 @@ def _wrap_or_unwrap(value):
|
||||
# come up with names. Dictionaries should look like lists.
|
||||
|
||||
|
||||
def sticky_attribute_assignment(checkpointable, name, value):
|
||||
def sticky_attribute_assignment(trackable, name, value):
|
||||
"""Adds dependencies, generally called from __setattr__.
|
||||
|
||||
This behavior is shared between Checkpointable and Model.
|
||||
This behavior is shared between Trackable and Model.
|
||||
|
||||
Respects NoDependency indicators, but otherwise makes checkpointable objects
|
||||
Respects NoDependency indicators, but otherwise makes trackable objects
|
||||
out of common data structures and tracks objects by their attribute names.
|
||||
|
||||
Args:
|
||||
checkpointable: The object to add dependencies to (generally the one having
|
||||
trackable: The object to add dependencies to (generally the one having
|
||||
an attribute assigned).
|
||||
name: The attribute name being assigned.
|
||||
value: The value being assigned. Not necessarily a checkpointable object.
|
||||
value: The value being assigned. Not necessarily a trackable object.
|
||||
|
||||
Returns:
|
||||
The value which should be stored in the attribute (unwrapped from a
|
||||
@ -102,18 +102,18 @@ def sticky_attribute_assignment(checkpointable, name, value):
|
||||
value = _wrap_or_unwrap(value)
|
||||
if not add_dependency:
|
||||
return value
|
||||
if isinstance(value, base.Checkpointable):
|
||||
checkpointable._track_checkpointable( # pylint: disable=protected-access
|
||||
if isinstance(value, base.Trackable):
|
||||
trackable._track_trackable( # pylint: disable=protected-access
|
||||
value, name=name,
|
||||
# Allow the user to switch the Checkpointable which is tracked by this
|
||||
# Allow the user to switch the Trackable which is tracked by this
|
||||
# name, since assigning a new variable to an attribute has
|
||||
# historically been fine (e.g. Adam did this).
|
||||
overwrite=True)
|
||||
return value
|
||||
|
||||
|
||||
class CheckpointableDataStructure(base.Checkpointable):
|
||||
"""Base class for data structures which contain checkpointable objects."""
|
||||
class TrackableDataStructure(base.Trackable):
|
||||
"""Base class for data structures which contain trackable objects."""
|
||||
|
||||
def __init__(self):
|
||||
self.trainable = True
|
||||
@ -122,14 +122,14 @@ class CheckpointableDataStructure(base.Checkpointable):
|
||||
def _track_value(self, value, name):
|
||||
"""Add a dependency on `value`."""
|
||||
value = sticky_attribute_assignment(
|
||||
checkpointable=self, value=value, name=name)
|
||||
trackable=self, value=value, name=name)
|
||||
if isinstance(value, variables.Variable):
|
||||
self._extra_variables.append(value)
|
||||
if not isinstance(value, base.Checkpointable):
|
||||
if not isinstance(value, base.Trackable):
|
||||
raise ValueError(
|
||||
("Only checkpointable objects (such as Layers or Optimizers) may be "
|
||||
("Only trackable objects (such as Layers or Optimizers) may be "
|
||||
"stored in a List object. Got %s, which does not inherit from "
|
||||
"Checkpointable.") % (value,))
|
||||
"Trackable.") % (value,))
|
||||
if hasattr(value, "_use_resource_variables"):
|
||||
# In subclassed models, legacy layers (tf.layers) must always use
|
||||
# resource variables.
|
||||
@ -138,7 +138,7 @@ class CheckpointableDataStructure(base.Checkpointable):
|
||||
|
||||
@property
|
||||
def _values(self):
|
||||
"""An iterable/sequence which may contain checkpointable objects."""
|
||||
"""An iterable/sequence which may contain trackable objects."""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
@property
|
||||
@ -148,7 +148,7 @@ class CheckpointableDataStructure(base.Checkpointable):
|
||||
# they're wrapping if out of sync.
|
||||
collected = []
|
||||
for obj in self._values:
|
||||
if (isinstance(obj, CheckpointableDataStructure)
|
||||
if (isinstance(obj, TrackableDataStructure)
|
||||
or layer_utils.is_layer(obj)
|
||||
or layer_utils.has_weights(obj)):
|
||||
collected.append(obj)
|
||||
@ -215,19 +215,19 @@ class CheckpointableDataStructure(base.Checkpointable):
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
# Similar to Tensors, checkpointable data structures use object-identity
|
||||
# Similar to Tensors, trackable data structures use object-identity
|
||||
# equality to support set/dict membership.
|
||||
return self is other
|
||||
|
||||
|
||||
class List(CheckpointableDataStructure, collections.Sequence):
|
||||
"""An append-only sequence type which is checkpointable.
|
||||
class List(TrackableDataStructure, collections.Sequence):
|
||||
"""An append-only sequence type which is trackable.
|
||||
|
||||
Maintains checkpoint dependencies on its contents (which must also be
|
||||
checkpointable), and forwards any `Layer` metadata such as updates and losses.
|
||||
trackable), and forwards any `Layer` metadata such as updates and losses.
|
||||
|
||||
Note that `List` is purely a container. It lets a `tf.keras.Model` or
|
||||
other checkpointable object know about its contents, but does not call any
|
||||
other trackable object know about its contents, but does not call any
|
||||
`Layer` instances which are added to it. To indicate a sequence of `Layer`
|
||||
instances which should be called sequentially, use `tf.keras.Sequential`.
|
||||
|
||||
@ -248,7 +248,7 @@ class List(CheckpointableDataStructure, collections.Sequence):
|
||||
return aggregation
|
||||
```
|
||||
|
||||
This kind of wrapping is necessary because `Checkpointable` objects do not
|
||||
This kind of wrapping is necessary because `Trackable` objects do not
|
||||
(yet) deeply inspect regular Python data structures, so for example assigning
|
||||
a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a
|
||||
checkpoint dependency and does not add the `Layer` instance's weights to its
|
||||
@ -284,12 +284,12 @@ class List(CheckpointableDataStructure, collections.Sequence):
|
||||
return self
|
||||
|
||||
def append(self, value):
|
||||
"""Add a new checkpointable value."""
|
||||
"""Add a new trackable value."""
|
||||
value = self._track_value(value, self._name_element(len(self._storage)))
|
||||
self._storage.append(value)
|
||||
|
||||
def extend(self, values):
|
||||
"""Add a sequence of checkpointable values."""
|
||||
"""Add a sequence of trackable values."""
|
||||
for value in values:
|
||||
self.append(value)
|
||||
|
||||
@ -350,7 +350,7 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
occupied, meaning both elements get the same names at different times) and
|
||||
refuses to save.
|
||||
|
||||
On assignment to an attribute of a Model or Checkpointable object, Python
|
||||
On assignment to an attribute of a Model or Trackable object, Python
|
||||
lists are replaced with _ListWrapper. Wrapping a list in a
|
||||
`tf.contrib.checkpoint.NoDependency` object prevents this.
|
||||
"""
|
||||
@ -410,7 +410,7 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
if self._non_append_mutation:
|
||||
raise ValueError(
|
||||
("Unable to save the object %s (a list wrapper constructed to track "
|
||||
"checkpointable TensorFlow objects). A list element was replaced "
|
||||
"trackable TensorFlow objects). A list element was replaced "
|
||||
"(__setitem__, __setslice__), deleted (__delitem__, __delslice__), "
|
||||
"or moved (sort). In order to support restoration on object "
|
||||
"creation, tracking is exclusively for append-only data structures."
|
||||
@ -420,7 +420,7 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
if self._external_modification:
|
||||
raise ValueError(
|
||||
("Unable to save the object %s (a list wrapper constructed to track "
|
||||
"checkpointable TensorFlow objects). The wrapped list was modified "
|
||||
"trackable TensorFlow objects). The wrapped list was modified "
|
||||
"outside the wrapper (its final value was %s, its value when a "
|
||||
"checkpoint dependency was added was %s), which breaks restoration "
|
||||
"on object creation.\n\nIf you don't need this list checkpointed, "
|
||||
@ -449,7 +449,7 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
value_now = self._storage[i] if i < len_now else None
|
||||
value_before = storage_copy[i] if i < len_before else None
|
||||
|
||||
if isinstance(value_before, base.Checkpointable):
|
||||
if isinstance(value_before, base.Trackable):
|
||||
self._non_append_mutation = True
|
||||
|
||||
if value_now is not None and value_now != value_before:
|
||||
@ -457,20 +457,20 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
self._name_element(i))
|
||||
|
||||
else:
|
||||
if isinstance(self._storage[key], base.Checkpointable):
|
||||
if isinstance(self._storage[key], base.Trackable):
|
||||
self._non_append_mutation = True
|
||||
self._storage[key] = self._track_value(value, self._name_element(key))
|
||||
|
||||
self._update_snapshot()
|
||||
|
||||
def append(self, value):
|
||||
"""Add a new checkpointable value."""
|
||||
"""Add a new trackable value."""
|
||||
self._check_external_modification()
|
||||
super(_ListWrapper, self).append(value)
|
||||
self._update_snapshot()
|
||||
|
||||
def extend(self, values):
|
||||
"""Add a sequence of checkpointable values."""
|
||||
"""Add a sequence of trackable values."""
|
||||
self._check_external_modification()
|
||||
super(_ListWrapper, self).extend(values)
|
||||
self._update_snapshot()
|
||||
@ -514,14 +514,14 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
del self._storage[slice(i, j)]
|
||||
|
||||
def _track_value(self, value, name):
|
||||
"""Allows storage of non-checkpointable objects."""
|
||||
"""Allows storage of non-trackable objects."""
|
||||
try:
|
||||
value = super(_ListWrapper, self)._track_value(value=value, name=name)
|
||||
except ValueError:
|
||||
# Even if this value isn't checkpointable, we need to make sure
|
||||
# Even if this value isn't trackable, we need to make sure
|
||||
# NoDependency objects get unwrapped.
|
||||
value = sticky_attribute_assignment(
|
||||
checkpointable=self, value=value, name=name)
|
||||
trackable=self, value=value, name=name)
|
||||
return value
|
||||
|
||||
def __repr__(self):
|
||||
@ -534,11 +534,11 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
}
|
||||
|
||||
|
||||
class Mapping(CheckpointableDataStructure, collections.Mapping):
|
||||
"""An append-only checkpointable mapping data structure with string keys.
|
||||
class Mapping(TrackableDataStructure, collections.Mapping):
|
||||
"""An append-only trackable mapping data structure with string keys.
|
||||
|
||||
Maintains checkpoint dependencies on its contents (which must also be
|
||||
checkpointable), named based on its keys.
|
||||
trackable), named based on its keys.
|
||||
|
||||
Note that once a key has been added, it may not be deleted or replaced. If
|
||||
names may not be unique, see `tf.contrib.checkpoint.UniqueNameTracker`.
|
||||
@ -615,7 +615,7 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
|
||||
# patching all of the "wrapped" dict's methods instead of creating a wrapper
|
||||
# object is an option, but not a very attractive one (replacing methods without
|
||||
# creating reference cycles is difficult, and then dicts would need to be
|
||||
# special cased everywhere as being checkpointable).
|
||||
# special cased everywhere as being trackable).
|
||||
class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
"""Wraps built-in dicts to support restore-on-create for variables.
|
||||
|
||||
@ -671,7 +671,7 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
raise ValueError(
|
||||
"Unable to save the object %s (a dictionary wrapper constructed "
|
||||
"automatically on attribute assignment). The wrapped dictionary "
|
||||
"contains a non-string key which maps to a checkpointable object or "
|
||||
"contains a non-string key which maps to a trackable object or "
|
||||
"mutable data structure.\n\nIf you don't need this dictionary "
|
||||
"checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
|
||||
"object; it will be automatically un-wrapped and subsequently "
|
||||
@ -680,7 +680,7 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
raise ValueError(
|
||||
"Unable to save the object %s (a dictionary wrapper constructed "
|
||||
"automatically on attribute assignment). A key mapping to a "
|
||||
"checkpointable object was overwritten or deleted, which would "
|
||||
"trackable object was overwritten or deleted, which would "
|
||||
"cause problems for restoration.\n\nIf you don't need this "
|
||||
"dictionary checkpointed, wrap it in a "
|
||||
"tf.contrib.checkpoint.NoDependency object; it will be automatically "
|
||||
@ -721,7 +721,7 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
self._last_wrapped_dict_snapshot = dict(self)
|
||||
|
||||
def _track_value(self, value, name):
|
||||
"""Allows storage of non-checkpointable objects."""
|
||||
"""Allows storage of non-trackable objects."""
|
||||
if isinstance(name, six.string_types):
|
||||
string_key = True
|
||||
else:
|
||||
@ -731,15 +731,15 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
no_dependency = isinstance(value, NoDependency)
|
||||
value = super(_DictWrapper, self)._track_value(value=value, name=name)
|
||||
if not (string_key or no_dependency):
|
||||
# A non-string key maps to a checkpointable value. This data structure
|
||||
# A non-string key maps to a trackable value. This data structure
|
||||
# is not saveable.
|
||||
self._non_string_key = True
|
||||
return value
|
||||
except ValueError:
|
||||
# Even if this value isn't checkpointable, we need to make sure
|
||||
# Even if this value isn't trackable, we need to make sure
|
||||
# NoDependency objects get unwrapped.
|
||||
return sticky_attribute_assignment(
|
||||
checkpointable=self, value=value, name=name)
|
||||
trackable=self, value=value, name=name)
|
||||
|
||||
def _name_element(self, key):
|
||||
"""Don't throw errors for non-string keys."""
|
||||
@ -758,19 +758,19 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
else:
|
||||
value = _wrap_or_unwrap(value)
|
||||
existing_dependency = None
|
||||
if not no_dep and isinstance(value, base.Checkpointable):
|
||||
if not no_dep and isinstance(value, base.Trackable):
|
||||
# Non-string keys are OK as long as we have no reason to add a
|
||||
# dependency on the value (either because the value is not
|
||||
# checkpointable, or because it was wrapped in a NoDependency object).
|
||||
# trackable, or because it was wrapped in a NoDependency object).
|
||||
self._non_string_key = True
|
||||
current_value = self._storage.setdefault(key, value)
|
||||
if current_value is not value:
|
||||
if ((not no_dep and isinstance(value, base.Checkpointable))
|
||||
if ((not no_dep and isinstance(value, base.Trackable))
|
||||
# We don't want to just check that the existing object is
|
||||
# checkpointable, since it may have been wrapped in a NoDependency
|
||||
# trackable, since it may have been wrapped in a NoDependency
|
||||
# object.
|
||||
or existing_dependency is not None):
|
||||
# A checkpointable object was replaced under the same key; this means
|
||||
# A trackable object was replaced under the same key; this means
|
||||
# that restoring would be error-prone, so we'll throw an exception on
|
||||
# save.
|
||||
self._non_append_mutation = True
|
||||
@ -781,8 +781,8 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
def __delitem__(self, key):
|
||||
self._check_external_modification()
|
||||
existing_value = self[key]
|
||||
if isinstance(existing_value, base.Checkpointable):
|
||||
# Deleting tracked checkpointable values means restoring is problematic,
|
||||
if isinstance(existing_value, base.Trackable):
|
||||
# Deleting tracked trackable values means restoring is problematic,
|
||||
# so we'll throw an exception on save.
|
||||
self._non_append_mutation = True
|
||||
del self._storage[key]
|
||||
@ -812,10 +812,10 @@ def _is_function(x):
|
||||
return isinstance(x, (def_function.Function, defun.ConcreteFunction))
|
||||
|
||||
revived_types.register_revived_type(
|
||||
"checkpointable_dict_wrapper",
|
||||
"trackable_dict_wrapper",
|
||||
lambda obj: isinstance(obj, _DictWrapper),
|
||||
versions=[revived_types.VersionedTypeRegistration(
|
||||
# Standard dependencies are enough to reconstruct the checkpointable
|
||||
# Standard dependencies are enough to reconstruct the trackable
|
||||
# items in dictionaries, so we don't need to save any extra information.
|
||||
object_factory=lambda proto: _DictWrapper({}),
|
||||
version=1,
|
||||
@ -832,7 +832,7 @@ def _set_list_item(list_object, index_string, value):
|
||||
|
||||
|
||||
revived_types.register_revived_type(
|
||||
"checkpointable_list_wrapper",
|
||||
"trackable_list_wrapper",
|
||||
lambda obj: isinstance(obj, _ListWrapper),
|
||||
versions=[revived_types.VersionedTypeRegistration(
|
||||
object_factory=lambda proto: _ListWrapper([]),
|
@ -34,9 +34,9 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.checkpointable import data_structures
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.checkpointable import util
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
|
||||
|
||||
class HasList(training.Model):
|
||||
@ -145,12 +145,12 @@ class ListTests(test.TestCase):
|
||||
model.l2.append(second_layer)
|
||||
self.assertEqual([first_layer, second_layer], model.layers)
|
||||
|
||||
def testNotCheckpointable(self):
|
||||
class NotCheckpointable(object):
|
||||
def testNotTrackable(self):
|
||||
class NotTrackable(object):
|
||||
pass
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
data_structures.List([NotCheckpointable()])
|
||||
data_structures.List([NotTrackable()])
|
||||
|
||||
def testCallNotImplemented(self):
|
||||
with self.assertRaisesRegexp(TypeError, "not callable"):
|
||||
@ -287,8 +287,8 @@ class ListWrapperTest(test.TestCase):
|
||||
def testListWrapperBasic(self):
|
||||
# _ListWrapper, unlike List, compares like the built-in list type (since it
|
||||
# is used to automatically replace lists).
|
||||
a = tracking.AutoCheckpointable()
|
||||
b = tracking.AutoCheckpointable()
|
||||
a = tracking.AutoTrackable()
|
||||
b = tracking.AutoTrackable()
|
||||
self.assertEqual([a, a],
|
||||
[a, a])
|
||||
self.assertEqual(data_structures._ListWrapper([a, a]),
|
||||
@ -321,7 +321,7 @@ class ListWrapperTest(test.TestCase):
|
||||
self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
|
||||
self.assertIsInstance(data_structures._ListWrapper([a]), list)
|
||||
|
||||
def testAcceptsNonCheckpointableContent(self):
|
||||
def testAcceptsNonTrackableContent(self):
|
||||
l = data_structures._ListWrapper([1, 2, 3])
|
||||
self.assertEqual(l, [1, 2, 3])
|
||||
|
||||
@ -360,14 +360,14 @@ class ListWrapperTest(test.TestCase):
|
||||
self.assertEqual(l, [1, 2, 4])
|
||||
self.assertUnableToSave(l, "Unable to save .*__delslice__")
|
||||
|
||||
def testSetSlice_canSaveForNonCheckpointableItems(self):
|
||||
def testSetSlice_canSaveForNonTrackableItems(self):
|
||||
l = data_structures._ListWrapper([1, 2, 3, 4])
|
||||
l[:] = 2, 8, 9, 0
|
||||
self.assertEqual(l, [2, 8, 9, 0])
|
||||
l._maybe_initialize_checkpointable() # pylint: disable=protected-access
|
||||
l._maybe_initialize_trackable() # pylint: disable=protected-access
|
||||
self.assertEqual(len(l._checkpoint_dependencies), 0) # pylint: disable=protected-access
|
||||
|
||||
def testSetSlice_cannotSaveIfCheckpointableModified(self):
|
||||
def testSetSlice_cannotSaveIfTrackableModified(self):
|
||||
v1 = resource_variable_ops.ResourceVariable(1.)
|
||||
v2 = resource_variable_ops.ResourceVariable(1.)
|
||||
l = data_structures._ListWrapper([1, 2, v1, v2])
|
||||
@ -391,12 +391,12 @@ class ListWrapperTest(test.TestCase):
|
||||
self.assertEqual(l, [1, 2, 3, 4])
|
||||
# Regardless of being a no-op for the input list, we still refuse to save.
|
||||
# This is intentional since otherwise we would end up with a hard to debug
|
||||
# case for users (e.g. sometimes sort on a ListWrapper is checkpointable and
|
||||
# case for users (e.g. sometimes sort on a ListWrapper is trackable and
|
||||
# other times it is not).
|
||||
self.assertUnableToSave(l, "Unable to save .*sort")
|
||||
|
||||
def assertUnableToSave(self, l, msg):
|
||||
l._maybe_initialize_checkpointable() # pylint: disable=protected-access
|
||||
l._maybe_initialize_trackable() # pylint: disable=protected-access
|
||||
with self.assertRaisesRegexp(ValueError, msg):
|
||||
return l._checkpoint_dependencies # pylint: disable=protected-access
|
||||
|
||||
@ -466,7 +466,7 @@ class MappingTests(test.TestCase):
|
||||
|
||||
def testLayerCollectionWithExternalMutation(self):
|
||||
d = {}
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
root.wrapper = d
|
||||
self.assertEqual([], root.wrapper.layers)
|
||||
self.assertEqual([], root.wrapper.trainable_weights)
|
||||
@ -484,7 +484,7 @@ class MappingTests(test.TestCase):
|
||||
self.assertEqual(2, len(has_mappings))
|
||||
self.assertNotIn(data_structures.Mapping(), has_mappings)
|
||||
# In contrast to Mapping, dict wrappers are not hashable
|
||||
a = tracking.AutoCheckpointable()
|
||||
a = tracking.AutoTrackable()
|
||||
a.d = {}
|
||||
self.assertEqual({}, a.d)
|
||||
self.assertFalse({} != a.d) # pylint: disable=g-explicit-bool-comparison
|
||||
@ -493,7 +493,7 @@ class MappingTests(test.TestCase):
|
||||
set([a.d])
|
||||
|
||||
def testDictWrapperBadKeys(self):
|
||||
a = tracking.AutoCheckpointable()
|
||||
a = tracking.AutoTrackable()
|
||||
a.d = {}
|
||||
a.d[1] = data_structures.List()
|
||||
model = training.Model()
|
||||
@ -503,7 +503,7 @@ class MappingTests(test.TestCase):
|
||||
model.save_weights(save_path)
|
||||
|
||||
def testDictWrapperNoDependency(self):
|
||||
a = tracking.AutoCheckpointable()
|
||||
a = tracking.AutoTrackable()
|
||||
a.d = data_structures.NoDependency({})
|
||||
a.d[1] = [3]
|
||||
self.assertEqual([a], util.list_objects(a))
|
||||
@ -513,8 +513,8 @@ class MappingTests(test.TestCase):
|
||||
model.save_weights(save_path)
|
||||
model.load_weights(save_path)
|
||||
|
||||
def testNonStringKeyNotCheckpointableValue(self):
|
||||
a = tracking.AutoCheckpointable()
|
||||
def testNonStringKeyNotTrackableValue(self):
|
||||
a = tracking.AutoTrackable()
|
||||
a.d = {}
|
||||
a.d["a"] = [3]
|
||||
a.d[1] = data_structures.NoDependency([3])
|
||||
@ -525,18 +525,18 @@ class MappingTests(test.TestCase):
|
||||
model.save_weights(save_path)
|
||||
model.load_weights(save_path)
|
||||
|
||||
def testNonAppendNotCheckpointable(self):
|
||||
def testNonAppendNotTrackable(self):
|
||||
# Non-append mutations (deleting or overwriting values) are OK when the
|
||||
# values aren't tracked.
|
||||
a = tracking.AutoCheckpointable()
|
||||
a = tracking.AutoTrackable()
|
||||
a.d = {}
|
||||
a.d["a"] = [3]
|
||||
a.d[1] = 3
|
||||
a.d[1] = 2
|
||||
self.assertEqual(2, a.d[1])
|
||||
del a.d[1]
|
||||
a.d[2] = data_structures.NoDependency(tracking.AutoCheckpointable())
|
||||
second = tracking.AutoCheckpointable()
|
||||
a.d[2] = data_structures.NoDependency(tracking.AutoTrackable())
|
||||
second = tracking.AutoTrackable()
|
||||
a.d[2] = data_structures.NoDependency(second)
|
||||
self.assertIs(second, a.d[2])
|
||||
self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
|
||||
@ -598,7 +598,7 @@ class MappingTests(test.TestCase):
|
||||
self.assertEqual({1: 3}, new_dict)
|
||||
|
||||
def testListShallowCopy(self):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
orig_list = [[1.]]
|
||||
root.a = orig_list
|
||||
copied = copy.copy(root.a)
|
||||
@ -615,7 +615,7 @@ class MappingTests(test.TestCase):
|
||||
util.list_objects(copy.copy(root.a))
|
||||
|
||||
def testListDeepCopy(self):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
orig_list = [[1.]]
|
||||
root.a = orig_list
|
||||
copied = copy.deepcopy(root.a)
|
||||
@ -632,7 +632,7 @@ class MappingTests(test.TestCase):
|
||||
util.list_objects(copy.deepcopy(root.a))
|
||||
|
||||
def testDictShallowCopy(self):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
orig_dict = {"a": [1.]}
|
||||
root.a = orig_dict
|
||||
copied = copy.copy(root.a)
|
||||
@ -649,7 +649,7 @@ class MappingTests(test.TestCase):
|
||||
util.list_objects(copy.copy(root.a))
|
||||
|
||||
def testDictDeepCopy(self):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root = tracking.AutoTrackable()
|
||||
orig_dict = {"a": [1.]}
|
||||
root.a = orig_dict
|
||||
copied = copy.deepcopy(root.a)
|
||||
@ -665,9 +665,9 @@ class MappingTests(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
util.list_objects(copy.deepcopy(root.a))
|
||||
|
||||
def testShallowCopyCheckpointable(self):
|
||||
original = tracking.AutoCheckpointable()
|
||||
original_sub = tracking.AutoCheckpointable()
|
||||
def testShallowCopyTrackable(self):
|
||||
original = tracking.AutoTrackable()
|
||||
original_sub = tracking.AutoTrackable()
|
||||
original.a = [[1.]]
|
||||
original.b = {"a": original_sub}
|
||||
shallow_copied = copy.copy(original)
|
||||
@ -679,16 +679,16 @@ class MappingTests(test.TestCase):
|
||||
self.assertIn(shallow_copied.b, shallow_deps)
|
||||
self.assertIn(shallow_copied.b["a"], shallow_deps)
|
||||
|
||||
def testDeepCopyCheckpointable(self):
|
||||
original = tracking.AutoCheckpointable()
|
||||
original_sub = tracking.AutoCheckpointable()
|
||||
def testDeepCopyTrackable(self):
|
||||
original = tracking.AutoTrackable()
|
||||
original_sub = tracking.AutoTrackable()
|
||||
original.a = [[1.]]
|
||||
original.b = {"a": original_sub}
|
||||
deep_copied = copy.deepcopy(original)
|
||||
self.assertIsNot(original, deep_copied)
|
||||
self.assertIsNot(original_sub, deep_copied.b["a"])
|
||||
self.assertEqual([[1.]], deep_copied.a)
|
||||
self.assertIsInstance(deep_copied.b["a"], tracking.AutoCheckpointable)
|
||||
self.assertIsInstance(deep_copied.b["a"], tracking.AutoTrackable)
|
||||
deps = util.list_objects(deep_copied)
|
||||
self.assertIn(deep_copied.a, deps)
|
||||
self.assertIn(deep_copied.b, deps)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user