Rename Checkpointable -> Trackable and AutoCheckpointable -> AutoTrackable

No API changes in this CL. Just more refactoring for a future API change.

PiperOrigin-RevId: 234242335
This commit is contained in:
Allen Lavoie 2019-02-15 17:23:48 -08:00 committed by TensorFlower Gardener
parent 6655a2e6ea
commit bd36b48c55
504 changed files with 1695 additions and 1696 deletions

View File

@ -33,7 +33,7 @@ from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensem
from tensorflow.python.framework import ops
from tensorflow.python.ops import resources
from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.tracking import tracking
ops.NotDifferentiable("TreeEnsembleVariable")
ops.NotDifferentiable("TreeEnsembleSerialize")

View File

@ -33,7 +33,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import resources
from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.tracking import tracking
# Pattern to remove all non alpha numeric from a string.
_PATTERN = re.compile(r"[\W_]+")

View File

@ -26,7 +26,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import resources
from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.tracking import tracking
# Pattern to remove all non alpha numeric from a string.
_PATTERN = re.compile(r"[\W_]+")

View File

@ -27,7 +27,7 @@ Managing dependencies:
@@NoDependency
@@split_dependency
Checkpointable data structures:
Trackable data structures:
@@List
@@Mapping
@@UniqueNameTracker
@ -49,17 +49,16 @@ from tensorflow.contrib.checkpoint.python.python_state import NumpyState
from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
from tensorflow.core.protobuf.trackable_object_graph_pb2 import TrackableObjectGraph as CheckpointableObjectGraph
from tensorflow.python.training.checkpoint_management import CheckpointManager
from tensorflow.python.training.checkpointable.base import Checkpointable as CheckpointableBase
from tensorflow.python.training.checkpointable.data_structures import List
from tensorflow.python.training.checkpointable.data_structures import Mapping
from tensorflow.python.training.checkpointable.data_structures import NoDependency
from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable
from tensorflow.python.training.checkpointable.util import capture_dependencies
from tensorflow.python.training.checkpointable.util import list_objects
from tensorflow.python.training.checkpointable.util import object_metadata
from tensorflow.python.training.tracking.base import Trackable as CheckpointableBase
from tensorflow.python.training.tracking.data_structures import List
from tensorflow.python.training.tracking.data_structures import Mapping
from tensorflow.python.training.tracking.data_structures import NoDependency
from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable
from tensorflow.python.training.tracking.util import capture_dependencies
from tensorflow.python.training.tracking.util import list_objects
from tensorflow.python.training.tracking.util import object_metadata
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(module_name=__name__)

View File

@ -12,7 +12,7 @@ py_library(
":python_state",
":split_dependency",
":visualize",
"//tensorflow/python/training/checkpointable:data_structures",
"//tensorflow/python/training/tracking:data_structures",
],
)
@ -22,8 +22,8 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/checkpointable:data_structures",
"//tensorflow/python/training/tracking:base",
"//tensorflow/python/training/tracking:data_structures",
],
)
@ -36,8 +36,8 @@ tf_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/checkpointable:util",
"//tensorflow/python/training/tracking:base",
"//tensorflow/python/training/tracking:util",
],
)
@ -47,7 +47,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/tracking:base",
"//third_party/py/numpy",
"@six_archive//:six",
],
@ -64,7 +64,7 @@ tf_py_test(
"//tensorflow/python:session",
"//tensorflow/python:variables",
"//tensorflow/python/eager:test",
"//tensorflow/python/training/checkpointable:util",
"//tensorflow/python/training/tracking:util",
],
)
@ -76,7 +76,7 @@ py_library(
deps = [
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:training",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/tracking:base",
],
)
@ -89,8 +89,8 @@ tf_py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python/eager:test",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/checkpointable:util",
"//tensorflow/python/training/tracking:base",
"//tensorflow/python/training/tracking:util",
],
)
@ -101,8 +101,8 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/checkpointable:util",
"//tensorflow/python/training/tracking:base",
"//tensorflow/python/training/tracking:util",
],
)
@ -118,7 +118,7 @@ tf_py_test(
"//tensorflow/python/eager:test",
"//tensorflow/python/keras:engine",
"//tensorflow/python/keras:layers",
"//tensorflow/python/training/checkpointable:util",
"//tensorflow/python/training/tracking:util",
],
tags = ["nooss"], # b/124472244
)

View File

@ -1,4 +1,4 @@
"""Checkpointable data structures."""
"""Trackable data structures."""
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -17,12 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.training.checkpointable import base as checkpointable_lib
from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.tracking import base as trackable_lib
from tensorflow.python.training.tracking import data_structures
class UniqueNameTracker(data_structures.CheckpointableDataStructure):
"""Adds dependencies on checkpointable objects with name hints.
class UniqueNameTracker(data_structures.TrackableDataStructure):
"""Adds dependencies on trackable objects with name hints.
Useful for creating dependencies with locally unique names.
@ -43,30 +43,30 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure):
def __init__(self):
super(UniqueNameTracker, self).__init__()
self._maybe_initialize_checkpointable()
self._maybe_initialize_trackable()
self._name_counts = {}
@property
def _values(self):
return [dep.ref for dep in self._checkpoint_dependencies]
def track(self, checkpointable, base_name):
"""Add a dependency on `checkpointable`.
def track(self, trackable, base_name):
"""Add a dependency on `trackable`.
Args:
checkpointable: An object to add a checkpoint dependency on.
trackable: An object to add a checkpoint dependency on.
base_name: A name hint, which is uniquified to determine the dependency
name.
Returns:
`checkpointable`, for chaining.
`trackable`, for chaining.
Raises:
ValueError: If `checkpointable` is not a checkpointable object.
ValueError: If `trackable` is not a trackable object.
"""
if not isinstance(checkpointable, checkpointable_lib.Checkpointable):
if not isinstance(trackable, trackable_lib.Trackable):
raise ValueError(
("Expected a checkpointable value, got %s which does not inherit "
"from CheckpointableBase.") % (checkpointable,))
("Expected a trackable value, got %s which does not inherit "
"from tf.track.Trackable.") % (trackable,))
def _format_name(prefix, number):
if number > 0:
@ -80,5 +80,5 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure):
count += 1
candidate = _format_name(base_name, count)
self._name_counts[base_name] = count + 1
self._track_value(checkpointable, name=candidate)
return checkpointable
self._track_value(trackable, name=candidate)
return trackable

View File

@ -26,9 +26,9 @@ from tensorflow.python.keras import layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
class UniqueNameTrackerTests(test.TestCase):
@ -52,7 +52,7 @@ class UniqueNameTrackerTests(test.TestCase):
save_root = util.Checkpoint(slots=slots)
save_path = save_root.save(checkpoint_prefix)
restore_slots = tracking.AutoCheckpointable()
restore_slots = tracking.AutoTrackable()
restore_root = util.Checkpoint(
slots=restore_slots)
status = restore_root.restore(save_path)
@ -68,7 +68,7 @@ class UniqueNameTrackerTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testExample(self):
class SlotManager(tracking.AutoCheckpointable):
class SlotManager(tracking.AutoTrackable):
def __init__(self):
self.slotdeps = containers.UniqueNameTracker()

View File

@ -23,7 +23,7 @@ import six
import numpy
from tensorflow.python.training.checkpointable import base
from tensorflow.python.training.tracking import base
# pylint: disable=g-import-not-at-top
try:
@ -34,8 +34,8 @@ except ImportError:
# pylint: enable=g-import-not-at-top
class NumpyState(base.Checkpointable):
"""A checkpointable object whose NumPy array attributes are saved/restored.
class NumpyState(base.Trackable):
"""A trackable object whose NumPy array attributes are saved/restored.
Example usage:
@ -72,7 +72,7 @@ class NumpyState(base.Checkpointable):
"""Create placeholder NumPy arrays for to-be-restored attributes.
Typically `_lookup_dependency` is used to check by name whether a dependency
exists. We cheat slightly by creating a checkpointable object for `name` if
exists. We cheat slightly by creating a trackable object for `name` if
we don't already have one, giving us attribute re-creation behavior when
loading a checkpoint.
@ -85,7 +85,7 @@ class NumpyState(base.Checkpointable):
value = super(NumpyState, self)._lookup_dependency(name)
if value is None:
value = _NumpyWrapper(numpy.array([]))
new_reference = base.CheckpointableReference(name=name, ref=value)
new_reference = base.TrackableReference(name=name, ref=value)
self._unconditional_checkpoint_dependencies.append(new_reference)
self._unconditional_dependency_names[name] = value
super(NumpyState, self).__setattr__(name, value)
@ -101,7 +101,7 @@ class NumpyState(base.Checkpointable):
def __setattr__(self, name, value):
"""Automatically wrap NumPy arrays assigned to attributes."""
# TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making
# ndarrays checkpointable natively and using standard checkpointable list
# ndarrays trackable natively and using standard trackable list
# tracking.
if isinstance(value, (numpy.ndarray, numpy.generic)):
try:
@ -110,19 +110,19 @@ class NumpyState(base.Checkpointable):
return
except AttributeError:
value = _NumpyWrapper(value)
self._track_checkpointable(value, name=name, overwrite=True)
self._track_trackable(value, name=name, overwrite=True)
elif (name not in ("_setattr_tracking", "_update_uid")
and getattr(self, "_setattr_tracking", True)):
# Mixing restore()-created attributes with user-added checkpointable
# Mixing restore()-created attributes with user-added trackable
# objects is tricky, since we can't use the `_lookup_dependency` trick to
# re-create attributes (we might accidentally steal the restoration for
# another checkpointable object). For now `NumpyState` objects must be
# another trackable object). For now `NumpyState` objects must be
# leaf nodes. Theoretically we could add some extra arguments to
# `_lookup_dependency` to figure out whether we should create a NumPy
# array for the attribute or not.
raise NotImplementedError(
("Assigned %s to the %s property of %s, which is not a NumPy array. "
"Currently mixing NumPy arrays and other checkpointable objects is "
"Currently mixing NumPy arrays and other trackable objects is "
"not supported. File a feature request if this limitation bothers "
"you.")
% (value, name, self))
@ -130,7 +130,7 @@ class NumpyState(base.Checkpointable):
@six.add_metaclass(abc.ABCMeta)
class PythonStateWrapper(base.Checkpointable):
class PythonStateWrapper(base.Trackable):
"""Wraps a Python object for storage in an object-based checkpoint."""
@abc.abstractmethod

View File

@ -26,7 +26,7 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables
from tensorflow.python.training.checkpointable import util
from tensorflow.python.training.tracking import util
class NumpyStateTests(test.TestCase):

View File

@ -21,7 +21,7 @@ import functools
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
@ -43,7 +43,7 @@ class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
return self._restore_callback(tensor)
class _SplitDependency(checkpointable.Checkpointable):
class _SplitDependency(trackable.Trackable):
"""Looks like a regular variable while synchronizing save/restores."""
def __init__(self, save_buffer, restore_buffer, name, dtype, num_components,
@ -81,9 +81,9 @@ class _SplitDependency(checkpointable.Checkpointable):
return control_flow_ops.no_op()
def _gather_saveables_for_checkpoint(self):
"""Looks to Checkpointable like a regular variable."""
"""Looks to Trackable like a regular variable."""
return {
checkpointable.VARIABLE_VALUE_KEY:
trackable.VARIABLE_VALUE_KEY:
functools.partial(_CallbackSaveable,
dtype=self._dtype,
save_callback=self._save,
@ -117,7 +117,7 @@ def split_dependency(component_names, component_dtypes,
may return `None`).
Returns:
A dictionary mapping from names to Checkpointable objects. If one is
A dictionary mapping from names to Trackable objects. If one is
reachable from an object as a dependency, the others should be too; adding
dependencies on some but not all of the objects will result in errors.
"""

View File

@ -23,9 +23,9 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training.checkpointable import base
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
def _split_variable_closure(variable):
@ -44,7 +44,7 @@ def _combine_variable_closure(variable):
return _consume_restore_buffer_fn
class SaveTensorSlicesAsDeps(base.Checkpointable):
class SaveTensorSlicesAsDeps(base.Trackable):
def __init__(self):
self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.])
@ -56,17 +56,17 @@ class SaveTensorSlicesAsDeps(base.Checkpointable):
consume_restore_buffer_fn=_combine_variable_closure(
self.combined))
for name, dep in split_dependencies.items():
self._track_checkpointable(dep, name=name)
self._track_trackable(dep, name=name)
class HasRegularDeps(tracking.AutoCheckpointable):
class HasRegularDeps(tracking.AutoTrackable):
def __init__(self):
self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
self.second_half = resource_variable_ops.ResourceVariable([0., 0.])
class OnlyOneDep(tracking.AutoCheckpointable):
class OnlyOneDep(tracking.AutoTrackable):
def __init__(self):
self.first_half = resource_variable_ops.ResourceVariable([0., 0.])

View File

@ -18,8 +18,8 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import util as trackable_utils
def dot_graph_from_checkpoint(save_path):
@ -51,7 +51,7 @@ def dot_graph_from_checkpoint(save_path):
A graph in DOT format as a string.
"""
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
object_graph = checkpointable_utils.object_metadata(save_path)
object_graph = trackable_utils.object_metadata(save_path)
shape_map = reader.get_variable_to_shape_map()
dtype_map = reader.get_variable_to_dtype_map()
graph = 'digraph {\n'
@ -63,7 +63,7 @@ def dot_graph_from_checkpoint(save_path):
slot_ids.add(slot_reference.slot_variable_node_id)
for node_id, node in enumerate(object_graph.nodes):
if (len(node.attributes) == 1
and node.attributes[0].name == checkpointable.VARIABLE_VALUE_KEY):
and node.attributes[0].name == trackable.VARIABLE_VALUE_KEY):
if node_id in slot_ids:
color = 'orange'
tooltip_prefix = 'Slot variable'

View File

@ -28,7 +28,7 @@ from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import adam
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.training.tracking import util as trackable_utils
try:
import pydot # pylint: disable=g-import-not-at-top
@ -57,7 +57,7 @@ class DotGraphTests(test.TestCase):
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
optimizer_step = resource_variable_ops.ResourceVariable(12)
save_checkpoint = checkpointable_utils.Checkpoint(
save_checkpoint = trackable_utils.Checkpoint(
optimizer=optimizer, model=model, optimizer_step=optimizer_step)
optimizer.minimize(functools.partial(model, input_value))
checkpoint_directory = self.get_temp_dir()

View File

@ -72,7 +72,7 @@ tensorflow/python/tools
tensorflow/python/tools/api
tensorflow/python/tools/api/generator
tensorflow/python/training
tensorflow/python/training/checkpointable
tensorflow/python/training/tracking
tensorflow/python/user_ops
tensorflow/python/util
tensorflow/python/util/protobuf

View File

@ -58,7 +58,7 @@ from tensorflow.python.training import gradient_descent
from tensorflow.python.training import momentum
from tensorflow.python.training import rmsprop
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.training.tracking import util as trackable_utils
CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM
@ -709,7 +709,7 @@ class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase):
self._TestSaveRestoreHelper(CUDNN_RNN_RELU)
class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
class CudnnRNNTestSaveRestoreTrackable(test_util.TensorFlowTestCase):
def _VerifyCheckpoint(
self, checkpoint_path, compatible_cell_fn, cudnn_cell_fn,
@ -718,7 +718,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
with ops.device("gpu:0"):
cudnn_layer = cudnn_cell_fn()
cudnn_checkpoint = checkpointable_utils.Checkpoint(cell=cudnn_layer)
cudnn_checkpoint = trackable_utils.Checkpoint(cell=cudnn_layer)
status = cudnn_checkpoint.restore(checkpoint_path)
inputs = 3. * array_ops.ones([num_applications, num_layers, input_size],
dtype=dtypes.float32)
@ -726,7 +726,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
status.run_restore_ops()
second_save_path = cudnn_checkpoint.save(checkpoint_prefix)
restore_layer = compatible_cell_fn()
restore_layer_checkpoint = checkpointable_utils.Checkpoint(
restore_layer_checkpoint = trackable_utils.Checkpoint(
cell=restore_layer)
status = restore_layer_checkpoint.restore(second_save_path)
current_state = restore_layer.zero_state(1, dtypes.float32)
@ -742,7 +742,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
self.assertAllClose(self.evaluate(restore_layer_output),
self.evaluate(cudnn_output)[-1, -1:, ...])
def _CheckpointableSingleCellUnidirectionalTestTemplate(
def _TrackableSingleCellUnidirectionalTestTemplate(
self, single_cell_fn, cudnn_cell_fn):
# Single-layer cuDNN cells with object-based checkpointing should be
# checkpoint compatible with either single CudnnCompatible cells or
@ -759,7 +759,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
value = np.random.normal(size=variable.shape)
expected_values.append(value)
self.evaluate(variable.assign(value))
save_checkpoint = checkpointable_utils.Checkpoint(cell=save_cell_layer)
save_checkpoint = trackable_utils.Checkpoint(cell=save_cell_layer)
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
first_save_path = save_checkpoint.save(checkpoint_prefix)
@ -775,10 +775,10 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_in_graph_and_eager_modes
def testLSTMCheckpointableSingleLayer(self):
def testLSTMTrackableSingleLayer(self):
num_units = 2
direction = CUDNN_RNN_UNIDIRECTION
self._CheckpointableSingleCellUnidirectionalTestTemplate(
self._TrackableSingleCellUnidirectionalTestTemplate(
single_cell_fn=functools.partial(
cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units),
cudnn_cell_fn=functools.partial(
@ -788,19 +788,19 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@test_util.run_in_graph_and_eager_modes
def testGRUCheckpointableSingleLayer(self):
def testGRUTrackableSingleLayer(self):
num_units = 2
direction = CUDNN_RNN_UNIDIRECTION
with self.assertRaises(NotImplementedError):
# TODO(allenl): Implement object-based saving for GRUs and other cells.
self._CheckpointableSingleCellUnidirectionalTestTemplate(
self._TrackableSingleCellUnidirectionalTestTemplate(
single_cell_fn=functools.partial(
cudnn_rnn_ops.CudnnCompatibleGRUCell, num_units=num_units),
cudnn_cell_fn=functools.partial(
cudnn_rnn.CudnnGRU, num_layers=1, num_units=num_units,
direction=direction, name="awesome_gru"))
def _CheckpointableMultiLayerTestTemplate(
def _TrackableMultiLayerTestTemplate(
self, single_cell_fn, cudnn_cell_fn, num_layers):
def _MultiCellFn():
@ -819,7 +819,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
value = np.random.normal(size=variable.shape)
expected_values.append(value)
self.evaluate(variable.assign(value))
save_checkpoint = checkpointable_utils.Checkpoint(cell=save_layer)
save_checkpoint = trackable_utils.Checkpoint(cell=save_layer)
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
first_save_path = save_checkpoint.save(checkpoint_prefix)
@ -837,7 +837,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
num_units = 2
num_layers = 3
direction = CUDNN_RNN_UNIDIRECTION
self._CheckpointableMultiLayerTestTemplate(
self._TrackableMultiLayerTestTemplate(
single_cell_fn=functools.partial(
cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units),
cudnn_cell_fn=functools.partial(

View File

@ -518,8 +518,8 @@ class _CudnnRNN(base_layer.Layer):
direction=self.direction,
scope=vs.get_variable_scope(),
name="%s_saveable" % self.trainable_variables[0].name.split(":")[0])
self._saveable._add_checkpointable_dependencies( # pylint: disable=protected-access
checkpointable=self, dtype=self._plain_dtype)
self._saveable._add_trackable_dependencies( # pylint: disable=protected-access
trackable=self, dtype=self._plain_dtype)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)

View File

@ -33,7 +33,7 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import tracking as checkpointable_lib
from tensorflow.python.training.tracking import tracking as trackable_lib
CUDNN_RNN_UNIDIRECTION = "unidirectional"
CUDNN_RNN_BIDIRECTION = "bidirectional"
@ -737,13 +737,13 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
return state_ops.assign(
self._variables, opaque_params, validate_shape=False)
def _checkpointable_save(self, save_buffer):
def _trackable_save(self, save_buffer):
weights, biases = self.format_converter.opaque_to_tf_canonical(
self._variables)
for name, tensor in zip(self._param_names, weights + biases):
save_buffer[name] = array_ops.identity(tensor)
def _checkpointable_restore(self, restore_buffer):
def _trackable_restore(self, restore_buffer):
tensors = [
array_ops.identity(restore_buffer[name]) for name in self._param_names
]
@ -752,26 +752,26 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
restored_shapes=None # Unused
)
def _add_checkpointable_dependencies(self, checkpointable, dtype):
"""Add canonical weight dependencies to `checkpointable`.
def _add_trackable_dependencies(self, trackable, dtype):
"""Add canonical weight dependencies to `trackable`.
When saving or restoring, converts to or from the opaque buffer
format. Weights are saved and loaded in the configuration expected by
cuDNN-compatible cells.
Args:
checkpointable: An object inheriting from `CheckpointableBase` to add
trackable: An object inheriting from `Trackable` to add
dependencies too (typically the cuDNN `Layer`).
dtype: The dtype for the canonical parameter Tensors.
"""
split_dependencies = split_dependency.split_dependency(
component_names=self._param_names,
component_dtypes=(dtype,) * len(self._param_names),
fill_save_buffer_fn=self._checkpointable_save,
consume_restore_buffer_fn=self._checkpointable_restore)
self._checkpointable_track_params(checkpointable, split_dependencies)
fill_save_buffer_fn=self._trackable_save,
consume_restore_buffer_fn=self._trackable_restore)
self._trackable_track_params(trackable, split_dependencies)
def _checkpointable_track_params(self, checkpointable, params):
def _trackable_track_params(self, trackable, params):
"""Tracks parameters in a canonical configuration."""
return # NotImplementedError raised by the Layer.
@ -819,7 +819,7 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable):
tf_weights_names.append(prefix + "/kernel")
tf_bias_names.append(prefix + "/bias")
def _checkpointable_track_params(self, checkpointable, params):
def _trackable_track_params(self, trackable, params):
"""Track parameters for compatibility with CudnnCompatibleLSTMCell."""
biases = []
weights = []
@ -833,12 +833,12 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable):
# wrapping.
kernel, = weights # pylint: disable=unbalanced-tuple-unpacking
bias, = biases # pylint: disable=unbalanced-tuple-unpacking
checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access
checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access
trackable._track_trackable(kernel, name="kernel") # pylint: disable=protected-access
trackable._track_trackable(bias, name="bias") # pylint: disable=protected-access
assert len(biases) == len(weights)
for cell_index, (bias, kernel) in enumerate(zip(biases, weights)):
cell = checkpointable_lib.AutoCheckpointable()
checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access
cell = trackable_lib.AutoTrackable()
trackable._track_trackable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access
cell.bias = bias
cell.kernel = kernel

View File

@ -800,6 +800,6 @@ tf_xla_py_test(
":tpu_strategy",
"//tensorflow/compiler/tests:xla_test",
"//tensorflow/python/eager:test",
"//tensorflow/python/training/checkpointable:util",
"//tensorflow/python/training/tracking:util",
],
)

View File

@ -30,15 +30,15 @@ from tensorflow.python.platform import test
from tensorflow.python.training import adam as adam_v1
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util as trackable_utils
class NonLayerCheckpointable(tracking.AutoCheckpointable):
class NonLayerTrackable(tracking.AutoTrackable):
def __init__(self):
super(NonLayerCheckpointable, self).__init__()
self.a_variable = checkpointable_utils.add_variable(
super(NonLayerTrackable, self).__init__()
self.a_variable = trackable_utils.add_variable(
self, name="a_variable", shape=[])
@ -49,8 +49,8 @@ class Subclassed(training.Model):
super(Subclassed, self).__init__()
self._named_dense = core.Dense(1, use_bias=True)
self._second = core.Dense(1, use_bias=False)
# We can still track Checkpointables which aren't Layers.
self._non_layer = NonLayerCheckpointable()
# We can still track Trackables which aren't Layers.
self._non_layer = NonLayerTrackable()
def call(self, values):
ret = self._second(self._named_dense(values))
@ -76,7 +76,7 @@ class TrainingCheckpointTests(xla_test.XLATestCase):
with strategy.scope():
model = Subclassed()
optimizer = adam_v1.AdamOptimizer(0.001)
root = checkpointable_utils.Checkpoint(
root = trackable_utils.Checkpoint(
optimizer=optimizer, model=model,
optimizer_step=training_util.get_or_create_global_step())
root.restore(checkpoint_management.latest_checkpoint(

View File

@ -414,7 +414,7 @@ class TestDistributionStrategySaveLoadWeights(test.TestCase,
@combinations.generate(
keras_test_lib.all_strategy_combinations_minus_default())
def test_save_load_checkpointable(self, distribution):
def test_save_load_trackable(self, distribution):
# TODO(sourabhbajaj): Test fails with optimizer v2 without h5
with self.cached_session():
dataset = keras_test_lib.get_dataset(distribution)

View File

@ -144,7 +144,7 @@ py_library(
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/tracking:base",
],
)

View File

@ -37,7 +37,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.training.tracking import util as trackable_utils
class IteratorTest(test.TestCase):
@ -238,7 +238,7 @@ class IteratorTest(test.TestCase):
dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
dataset = dataset.map(math_ops.square).batch(2)
iterator = datasets.Iterator(dataset)
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
checkpoint = trackable_utils.Checkpoint(iterator=iterator)
self.assertAllEqual([1, 4], iterator.get_next().numpy())
save_path = checkpoint.save(checkpoint_prefix)
self.assertAllEqual([9, 16], iterator.get_next().numpy())
@ -257,7 +257,7 @@ class IteratorTest(test.TestCase):
dataset_2 = Dataset.range(10)
iterator_3 = datasets.Iterator(dataset_2)
checkpoint = checkpointable_utils.Checkpoint(
checkpoint = trackable_utils.Checkpoint(
iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
self.assertAllEqual([1, 4], iterator_1.get_next().numpy())
self.assertEqual(0, iterator_3.get_next().numpy())
@ -279,7 +279,7 @@ class IteratorTest(test.TestCase):
dataset = Dataset.range(3)
iterator = datasets.Iterator(dataset)
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
checkpoint = trackable_utils.Checkpoint(iterator=iterator)
self.assertEqual(0, iterator.get_next().numpy())
self.assertEqual(1, iterator.get_next().numpy())
save_path = checkpoint.save(checkpoint_prefix)
@ -293,7 +293,7 @@ class IteratorTest(test.TestCase):
dataset = Dataset.range(10)
for i in range(5):
iterator = datasets.Iterator(dataset)
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
checkpoint = trackable_utils.Checkpoint(iterator=iterator)
checkpoint.restore(checkpoint_management.latest_checkpoint(
checkpoint_directory))
for j in range(2):

View File

@ -37,7 +37,7 @@ from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.training.tracking import util as trackable_utils
# pylint: enable=g-bad-import-order
@ -421,7 +421,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
# 5. Verify that checkpoints exist and contains all the expected variables.
self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
object_graph = checkpointable_utils.object_metadata(
object_graph = trackable_utils.object_metadata(
checkpoint_management.latest_checkpoint(config.logdir))
ckpt_variable_names = set()
for node in object_graph.nodes:

View File

@ -32,12 +32,12 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
_to_replace = re.compile("[^A-Za-z0-9.]")
class Metric(checkpointable.Checkpointable):
class Metric(trackable.Trackable):
"""A metric holds state for aggregating statistics over an evaluation run.
Example use with eager execution:
@ -269,7 +269,7 @@ class Metric(checkpointable.Checkpointable):
else:
collections = [ops.GraphKeys.LOCAL_VARIABLES]
collections += [ops.GraphKeys.METRIC_VARIABLES]
# Variables are Checkpointable dependencies of Metrics regardless of the
# Variables are Trackable dependencies of Metrics regardless of the
# global/local distinction. Users can avoid saving variables by not adding a
# dependency on the Metric.
v = self._add_variable_with_custom_getter(
@ -282,7 +282,7 @@ class Metric(checkpointable.Checkpointable):
use_resource=True,
getter=variable_scope.get_variable,
# Raise duplicate variable exceptions from get_variable rather than
# Checkpointable.
# Trackable.
overwrite=True)
self._vars.append(v)
if context.executing_eagerly():

View File

@ -35,7 +35,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.training.tracking import util as trackable_utils
class MetricsTest(test.TestCase):
@ -314,7 +314,7 @@ class MetricsTest(test.TestCase):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
mean = metrics.Mean()
checkpoint = checkpointable_utils.Checkpoint(mean=mean)
checkpoint = trackable_utils.Checkpoint(mean=mean)
mean.build()
mean._built = True
self.evaluate(mean.init_variables())
@ -327,7 +327,7 @@ class MetricsTest(test.TestCase):
self.assertAllEqual(200., self.evaluate(mean.value()))
restore_mean = metrics.Mean()
restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean)
restore_checkpoint = trackable_utils.Checkpoint(mean=restore_mean)
status = restore_checkpoint.restore(save_path)
restore_update = restore_mean(300.)
status.assert_consumed().run_restore_ops()

View File

@ -31,7 +31,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.training.tracking import util as trackable_utils
# pylint: disable=not-callable
@ -65,7 +65,7 @@ class NetworkTest(test.TestCase):
def test_checkpointing_not_implemented(self):
checkpoint_directory = self.get_temp_dir()
checkpoint = checkpointable_utils.Checkpoint(net=MyNetwork())
checkpoint = trackable_utils.Checkpoint(net=MyNetwork())
with self.assertRaises(NotImplementedError):
checkpoint.save(checkpoint_directory)

View File

@ -30,7 +30,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
@ -129,8 +129,8 @@ class SharedVariable(resource_variable_ops.ResourceVariable):
if constraint is not None and not callable(constraint):
raise ValueError("The `constraint` argument must be a callable.")
if isinstance(initial_value, checkpointable.CheckpointInitialValue):
self._maybe_initialize_checkpointable()
if isinstance(initial_value, trackable.CheckpointInitialValue):
self._maybe_initialize_trackable()
self._update_uid = initial_value.checkpoint_position.restore_uid
initial_value = initial_value.wrapped_value

View File

@ -137,8 +137,8 @@ from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Vari
from tensorflow.python.ops.variable_scope import EagerVariableStore
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import template
from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable
from tensorflow.python.training.checkpointable.util import Checkpoint
from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable
from tensorflow.python.training.tracking.util import Checkpoint
from tensorflow.python.util.all_util import remove_undocumented
py_func = script_ops.eager_py_func

View File

@ -36,7 +36,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.util import deprecation
# TODO(rohanj): This should subclass Checkpointable and implement
# TODO(rohanj): This should subclass Trackable and implement
# _gather_saveables_for_checkpoint.
class ShardedMutableDenseHashTable(object):
"""A sharded version of MutableDenseHashTable.

View File

@ -1483,3 +1483,4 @@ class IdTableWithHashBucketsTest(test.TestCase):
if __name__ == "__main__":
test.main()

View File

@ -44,15 +44,15 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as core_saver
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import graph_view
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
class NonLayerCheckpointable(tracking.AutoCheckpointable):
class NonLayerTrackable(tracking.AutoTrackable):
def __init__(self):
super(NonLayerCheckpointable, self).__init__()
super(NonLayerTrackable, self).__init__()
self.a_variable = util.add_variable(
self, name="a_variable", shape=[])
@ -65,8 +65,8 @@ class MyModel(training.Model):
super(MyModel, self).__init__()
self._named_dense = core.Dense(1, use_bias=True)
self._second = core.Dense(1, use_bias=False)
# We can still track Checkpointables which aren't Layers.
self._non_layer = NonLayerCheckpointable()
# We can still track Trackables which aren't Layers.
self._non_layer = NonLayerTrackable()
def call(self, values):
ret = self._second(self._named_dense(values))
@ -101,7 +101,7 @@ class CheckpointingTests(test.TestCase):
other_model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
optimizer_step = training_util.get_or_create_global_step()
root_checkpointable = util.Checkpoint(
root_trackable = util.Checkpoint(
optimizer=optimizer, model=model, optimizer_step=optimizer_step)
if context.executing_eagerly():
optimizer.minimize(
@ -117,10 +117,10 @@ class CheckpointingTests(test.TestCase):
other_model(input_value),
global_step=optimizer_step)
self.evaluate(util.gather_initializers(
root_checkpointable))
root_trackable))
self.evaluate(train_op)
named_variables, serialized_graph, _ = graph_view.ObjectGraphView(
root_checkpointable).serialize_object_graph()
root_trackable).serialize_object_graph()
expected_checkpoint_names = (
# Created in the root node, so no prefix.
"optimizer_step",
@ -208,7 +208,7 @@ class CheckpointingTests(test.TestCase):
def testSaveRestore(self):
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
root_checkpointable = util.Checkpoint(
root_trackable = util.Checkpoint(
optimizer=optimizer, model=model)
input_value = constant_op.constant([[3.]])
if context.executing_eagerly():
@ -217,24 +217,24 @@ class CheckpointingTests(test.TestCase):
else:
train_op = optimizer.minimize(model(input_value))
# TODO(allenl): Make initialization more pleasant when graph building.
root_checkpointable.save_counter # pylint: disable=pointless-statement
root_trackable.save_counter # pylint: disable=pointless-statement
self.evaluate(util.gather_initializers(
root_checkpointable))
root_trackable))
self.evaluate(train_op)
prefix = os.path.join(self.get_temp_dir(), "ckpt")
self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.]))
m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m")
self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
save_path = root_checkpointable.save(file_prefix=prefix)
save_path = root_trackable.save(file_prefix=prefix)
self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.]))
self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3))
self.evaluate(state_ops.assign(root_trackable.save_counter, 3))
optimizer_variables = self.evaluate(optimizer.variables())
self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
# Immediate restoration
status = root_checkpointable.restore(save_path=save_path).assert_consumed()
status = root_trackable.restore(save_path=save_path).assert_consumed()
status.run_restore_ops()
self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1]))
self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter))
self.assertAllEqual(1, self.evaluate(root_trackable.save_counter))
self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
if not context.executing_eagerly():
return # Restore-on-create is only supported when executing eagerly
@ -542,11 +542,11 @@ class CheckpointingTests(test.TestCase):
first_session = session_lib.Session(graph=first_graph)
with first_graph.as_default(), first_session.as_default():
first_variable = resource_variable_ops.ResourceVariable([1.])
first_root_checkpointable = util.Checkpoint(
first_root_trackable = util.Checkpoint(
optimizer=optimizer, variable=first_variable)
train_op = optimizer.minimize(first_variable.read_value)
self.evaluate(util.gather_initializers(
first_root_checkpointable))
first_root_trackable))
self.evaluate(train_op)
self.evaluate(first_variable.assign([1.]))
self.evaluate(optimizer.get_slot(
@ -558,23 +558,23 @@ class CheckpointingTests(test.TestCase):
second_graph = ops.Graph()
with second_graph.as_default(), session_lib.Session(graph=second_graph):
second_variable = resource_variable_ops.ResourceVariable([1.])
second_root_checkpointable = util.Checkpoint(
second_root_trackable = util.Checkpoint(
optimizer=optimizer, variable=second_variable)
train_op = optimizer.minimize(second_variable.read_value)
second_root_checkpointable.restore(None).initialize_or_restore()
second_root_trackable.restore(None).initialize_or_restore()
self.evaluate(train_op)
self.evaluate(second_variable.assign([4.]))
self.evaluate(optimizer.get_slot(
var=second_variable, name="m").assign([5.]))
beta_1_power, _ = optimizer._get_beta_accumulators()
self.evaluate(beta_1_power.assign(6.))
save_path = second_root_checkpointable.save(checkpoint_prefix)
save_path = second_root_trackable.save(checkpoint_prefix)
self.evaluate(second_variable.assign([7.]))
self.evaluate(optimizer.get_slot(
var=second_variable, name="m").assign([8.]))
beta_1_power, _ = optimizer._get_beta_accumulators()
self.assertAllEqual(6., self.evaluate(beta_1_power))
status = second_root_checkpointable.restore(save_path)
status = second_root_trackable.restore(save_path)
status.assert_consumed().run_restore_ops()
self.assertAllEqual([4.], self.evaluate(second_variable))
self.assertAllEqual([5.], self.evaluate(optimizer.get_slot(
@ -594,7 +594,7 @@ class CheckpointingTests(test.TestCase):
class TemplateTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_checkpointable_save_restore(self):
def test_trackable_save_restore(self):
def _templated():
v = variable_scope.get_variable(
@ -641,13 +641,13 @@ class CheckpointCompatibilityTests(test.TestCase):
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
optimizer_step = training_util.get_or_create_global_step()
root_checkpointable = util.Checkpoint(
root_trackable = util.Checkpoint(
optimizer=optimizer, model=model, optimizer_step=optimizer_step)
train_op = optimizer.minimize(
functools.partial(model, input_value),
global_step=optimizer_step)
self.evaluate(util.gather_initializers(
root_checkpointable))
root_trackable))
self.evaluate(train_op)
# A regular variable, a slot variable, and a non-slot Optimizer variable
# with known values to check when loading.
@ -656,24 +656,24 @@ class CheckpointCompatibilityTests(test.TestCase):
var=model._named_dense.bias, name="m").assign([2.]))
beta_1_power, _ = optimizer._get_beta_accumulators()
self.evaluate(beta_1_power.assign(3.))
return root_checkpointable
return root_trackable
def _set_sentinels(self, root_checkpointable):
self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.]))
def _set_sentinels(self, root_trackable):
self.evaluate(root_trackable.model._named_dense.bias.assign([101.]))
self.evaluate(
root_checkpointable.optimizer.get_slot(
var=root_checkpointable.model._named_dense.bias, name="m")
root_trackable.optimizer.get_slot(
var=root_trackable.model._named_dense.bias, name="m")
.assign([102.]))
beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
beta_1_power, _ = root_trackable.optimizer._get_beta_accumulators()
self.evaluate(beta_1_power.assign(103.))
def _check_sentinels(self, root_checkpointable):
def _check_sentinels(self, root_trackable):
self.assertAllEqual(
[1.], self.evaluate(root_checkpointable.model._named_dense.bias))
[1.], self.evaluate(root_trackable.model._named_dense.bias))
self.assertAllEqual([2.], self.evaluate(
root_checkpointable.optimizer.get_slot(
var=root_checkpointable.model._named_dense.bias, name="m")))
beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
root_trackable.optimizer.get_slot(
var=root_trackable.model._named_dense.bias, name="m")))
beta_1_power, _ = root_trackable.optimizer._get_beta_accumulators()
self.assertAllEqual(3., self.evaluate(beta_1_power))
def _write_name_based_checkpoint(self):
@ -698,7 +698,7 @@ class CheckpointCompatibilityTests(test.TestCase):
self._set_sentinels(root)
with self.assertRaises(AssertionError):
self._check_sentinels(root)
object_saver = util.CheckpointableSaver(graph_view.ObjectGraphView(root))
object_saver = util.TrackableSaver(graph_view.ObjectGraphView(root))
self._set_sentinels(root)
status = object_saver.restore(save_path)
if context.executing_eagerly():

View File

@ -38,7 +38,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import optimizer as optimizer_v1
from tensorflow.python.training import slot_creator
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
@ -223,7 +223,7 @@ class _OptimizerV2State(object):
}
self._slots = {}
self._non_slot_dict = {}
# Extra state to help Optimizers implement Checkpointable. Holds information
# Extra state to help Optimizers implement Trackable. Holds information
# about variables which will be restored as soon as they're created.
self._deferred_dependencies = {} # Non-slot variables
self._deferred_slot_restorations = {} # Slot variables
@ -366,8 +366,8 @@ class _OptimizerV2State(object):
slot variable needs to be restored).
Args:
slot_variable_position: A `checkpointable._CheckpointPosition` object
indicating the slot variable `Checkpointable` object to be restored.
slot_variable_position: A `trackable._CheckpointPosition` object
indicating the slot variable `Trackable` object to be restored.
slot_name: The name of this `Optimizer`'s slot to restore into.
variable: The variable object this slot is being created for.
optional_op_name: Name to use when scoping the Variable that needs to be
@ -385,7 +385,7 @@ class _OptimizerV2State(object):
# (aside from double initialization), and makes variable creator scopes
# behave the same way they do when graph building.
and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
initializer = checkpointable.CheckpointInitialValue(
initializer = trackable.CheckpointInitialValue(
checkpoint_position=slot_variable_position)
slot_variable = self.create_slot(
var=variable,
@ -1259,10 +1259,10 @@ class OptimizerV2(optimizer_v1.Optimizer):
return self._per_graph_state.get(var._graph_key, None)
# --------------
# Overridden methods from Checkpointable.
# Overridden methods from Trackable.
# --------------
def _track_checkpointable(self, *args, **kwargs):
def _track_trackable(self, *args, **kwargs):
"""Optimizers may not track dependencies. Raises an error."""
raise NotImplementedError(
"Optimizers may not have dependencies. File a feature request if this "
@ -1270,7 +1270,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
@property
def _checkpoint_dependencies(self):
"""From Checkpointable. Gather graph-specific non-slot variables to save."""
"""From Trackable. Gather graph-specific non-slot variables to save."""
current_graph_non_slot_variables = []
state = self._get_per_graph_state()
if state is not None:
@ -1279,14 +1279,14 @@ class OptimizerV2(optimizer_v1.Optimizer):
# Avoid comparing variables
key=lambda item: item[0]):
current_graph_non_slot_variables.append(
checkpointable.CheckpointableReference(
trackable.TrackableReference(
name=name, ref=variable_object))
# Note: ignores super(); Optimizers may not have any dependencies outside of
# state objects.
return current_graph_non_slot_variables
def _lookup_dependency(self, name):
"""From Checkpointable. Find a non-slot variable in the current graph."""
"""From Trackable. Find a non-slot variable in the current graph."""
state = self._get_per_graph_state()
if state is None:
return None
@ -1295,10 +1295,10 @@ class OptimizerV2(optimizer_v1.Optimizer):
@property
def _deferred_dependencies(self):
"""Lets Checkpointable know where non-slot variables are created.
"""Lets Trackable know where non-slot variables are created.
If necessary, creates a new state object for the current default graph.
Checkpointable will then add entries to that state's deferred dependency
Trackable will then add entries to that state's deferred dependency
dictionary. The state object will check that dictionary when creating
non-slot variables, restoring their value if an entry is found.
@ -1311,14 +1311,14 @@ class OptimizerV2(optimizer_v1.Optimizer):
def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
variable):
"""Checkpointable: Restore a slot variable's value, possibly creating it.
"""Trackable: Restore a slot variable's value, possibly creating it.
Called when a variable which has an associated slot variable is created or
restored.
Args:
slot_variable_position: A `checkpointable._CheckpointPosition` object
indicating the slot variable `Checkpointable` object to be restored.
slot_variable_position: A `trackable._CheckpointPosition` object
indicating the slot variable `Trackable` object to be restored.
slot_name: The name of this `Optimizer`'s slot to restore into.
variable: The variable object this slot is being created for.
"""

View File

@ -35,7 +35,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import resources
from tensorflow.python.platform import resource_loader
from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.tracking import tracking
_model_ops = loader.load_op_library(

View File

@ -32,7 +32,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import resources
from tensorflow.python.platform import resource_loader
from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.tracking import tracking
_stats_ops = loader.load_op_library(

View File

@ -228,7 +228,7 @@ CORE_PROTO_SRCS = COMMON_PROTO_SRCS + ERROR_CODES_PROTO_SRCS
# ones with individual proto_library targets.
ADDITIONAL_CORE_PROTO_SRCS = [
"example/example_parser_configuration.proto",
"protobuf/checkpointable_object_graph.proto",
"protobuf/trackable_object_graph.proto",
"protobuf/control_flow.proto",
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# "protobuf/critical_section.proto",

View File

@ -8,10 +8,10 @@ package tensorflow;
// own variables, allowing for more robust checkpoint loading into modified
// programs.
message CheckpointableObjectGraph {
message CheckpointableObject {
message TrackableObjectGraph {
message TrackableObject {
message ObjectReference {
// An index into `CheckpointableObjectGraph.nodes`, indicating the object
// An index into `TrackableObjectGraph.nodes`, indicating the object
// being referenced.
int32 node_id = 1;
// A user-provided name for the edge.
@ -37,12 +37,12 @@ message CheckpointableObjectGraph {
}
message SlotVariableReference {
// An index into `CheckpointableObjectGraph.nodes`, indicating the
// An index into `TrackableObjectGraph.nodes`, indicating the
// variable object this slot was created for.
int32 original_variable_node_id = 1;
// The name of the slot (e.g. "m"/"v").
string slot_name = 2;
// An index into `CheckpointableObjectGraph.nodes`, indicating the
// An index into `TrackableObjectGraph.nodes`, indicating the
// `Object` with the value of the slot variable.
int32 slot_variable_node_id = 3;
}
@ -55,5 +55,5 @@ message CheckpointableObjectGraph {
repeated SlotVariableReference slot_variables = 3;
}
repeated CheckpointableObject nodes = 1;
repeated TrackableObject nodes = 1;
}

View File

@ -33,7 +33,7 @@ def main(argv):
del argv
root = tf.train.Checkpoint()
# Create a cell and attach to our checkpointable.
# Create a cell and attach to our trackable.
root.rnn_cell = tf.keras.layers.LSTMCell(units=10, recurrent_initializer=None)
# Wrap the rnn_cell.__call__ function and assign to next_state.

View File

@ -27,7 +27,7 @@ import tensorflow as tf
# TODO(vbardiovsky): remove these when symbols are public.
from tensorflow.python.ops import lookup_ops
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.tracking import tracking
FLAGS = flags.FLAGS

View File

@ -1815,7 +1815,7 @@ tf_gen_op_wrapper_private_py(
visibility = [
"//learning/brain/python/ops:__pkg__",
"//tensorflow/python/kernel_tests:__pkg__",
"//tensorflow/python/training/checkpointable:__pkg__",
"//tensorflow/python/training/tracking:__pkg__",
],
)
@ -3371,7 +3371,7 @@ py_library(
":util",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/tracking:base",
],
)
@ -3730,7 +3730,7 @@ py_library(
["training/**/*.py"],
exclude = [
"**/*test*",
"training/checkpointable/**/*.py",
"training/tracking/**/*.py",
"training/saving/**/*.py",
# The following targets have their own build rules (same name as the
# file):
@ -3791,8 +3791,8 @@ py_library(
"//tensorflow/python/eager:context",
"//tensorflow/python/keras/optimizer_v2:learning_rate_schedule",
"//tensorflow/python/ops/losses",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/checkpointable:util",
"//tensorflow/python/training/tracking:base",
"//tensorflow/python/training/tracking:util",
"//third_party/py/numpy",
"@six_archive//:six",
],
@ -3872,9 +3872,9 @@ py_library(
":variables",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/saving:saveable_object",
"//tensorflow/python/training/saving:saveable_object_util",
"//tensorflow/python/training/tracking:base",
"//third_party/py/numpy",
"@six_archive//:six",
],

View File

@ -291,7 +291,7 @@ tf_py_test(
":test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:util",
"//tensorflow/python/training/tracking:util",
"//tensorflow/python:checkpoint_management",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
@ -344,7 +344,7 @@ cuda_py_test(
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/util:structure",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:util",
"//tensorflow/python/training/tracking:util",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",

View File

@ -28,7 +28,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.training.tracking import util as trackable_utils
@test_util.run_all_in_graph_and_eager_modes
@ -43,7 +43,7 @@ class IteratorCheckpointingTest(test_base.DatasetTestBase):
) else dataset_ops.make_one_shot_iterator(dataset)
get_next = iterator.get_next if context.executing_eagerly(
) else functools.partial(self.evaluate, iterator.get_next())
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
checkpoint = trackable_utils.Checkpoint(iterator=iterator)
self.assertAllEqual([1, 4], get_next())
save_path = checkpoint.save(checkpoint_prefix)
self.assertAllEqual([9, 16], get_next())
@ -73,7 +73,7 @@ class IteratorCheckpointingTest(test_base.DatasetTestBase):
) else dataset_ops.make_one_shot_iterator(dataset_2)
get_next_3 = iterator_3.get_next if context.executing_eagerly(
) else functools.partial(self.evaluate, iterator_3.get_next())
checkpoint = checkpointable_utils.Checkpoint(
checkpoint = trackable_utils.Checkpoint(
iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
self.assertAllEqual([1, 4], get_next_1())
self.assertAllEqual(0, get_next_3())
@ -96,7 +96,7 @@ class IteratorCheckpointingTest(test_base.DatasetTestBase):
) else dataset_ops.make_one_shot_iterator(dataset)
get_next = iterator.get_next if context.executing_eagerly(
) else functools.partial(self.evaluate, iterator.get_next())
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
checkpoint = trackable_utils.Checkpoint(iterator=iterator)
self.assertAllEqual(0, get_next())
self.assertAllEqual(1, get_next())
save_path = checkpoint.save(checkpoint_prefix)
@ -115,7 +115,7 @@ class IteratorCheckpointingTest(test_base.DatasetTestBase):
iterator = iter(dataset) if context.executing_eagerly(
) else dataset_ops.make_initializable_iterator(dataset)
get_next = iterator.get_next
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
checkpoint = trackable_utils.Checkpoint(iterator=iterator)
for i in range(5):
checkpoint.restore(
checkpoint_management.latest_checkpoint(

View File

@ -74,7 +74,7 @@ py_library(
"//tensorflow/python/data/util:sparse",
"//tensorflow/python/data/util:structure",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/tracking:base",
],
)

View File

@ -31,8 +31,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.saver import BaseSaverBuilder
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util.tf_export import tf_export
@ -68,7 +68,7 @@ def _device_stack_is_empty():
@tf_export(v1=["data.Iterator"])
class Iterator(checkpointable.Checkpointable):
class Iterator(trackable.Trackable):
"""Represents the state of iterating through a `Dataset`."""
def __init__(self, iterator_resource, initializer, output_types,
@ -491,7 +491,7 @@ def _generate_shared_name(prefix):
return "{}{}".format(prefix, uid)
class EagerIterator(checkpointable.Checkpointable):
class EagerIterator(trackable.Trackable):
"""An iterator producing tf.Tensor objects from a tf.data.Dataset."""
def __init__(self, dataset):
@ -641,7 +641,7 @@ class EagerIterator(checkpointable.Checkpointable):
return {"ITERATOR": _saveable_factory}
# TODO(b/71645805): Expose checkpointable stateful objects from dataset
# TODO(b/71645805): Expose trackable stateful objects from dataset
# attributes(potential).
class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
"""SaveableObject for saving/restoring iterator state."""

View File

@ -460,7 +460,7 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/tracking:base",
"@six_archive//:six",
],
)

View File

@ -37,7 +37,7 @@ from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
@ -630,7 +630,7 @@ class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
class MirroredVariable(DistributedVariable, Mirrored,
checkpointable.Checkpointable):
trackable.Trackable):
"""Holds a map from device to variables whose values are kept in sync."""
def __init__(
@ -710,7 +710,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
return self.get()._as_graph_element()
def _gather_saveables_for_checkpoint(self):
"""Overrides CheckpointableBase method.
"""Overrides Trackable method.
This allows both name-based and object-based save and restore of
MirroredVariables.
@ -720,7 +720,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
"""
def _saveable_factory(name=self._common_name):
return _MirroredSaveable(self, self.primary, name)
return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
# Register a conversion function which reads the value of the variable,
@ -752,7 +752,7 @@ def _enclosing_tpu_context():
# tpu.replicate() because it assumes that you're in a device context where you
# can operate on a single version of the variable, but a tpu.replicate()
# operates on all variables and is replicated during a rewrite pass.
class TPUMirroredVariable(checkpointable.Checkpointable):
class TPUMirroredVariable(trackable.Trackable):
"""Holds a map from device to TPU variables whose values are kept in sync."""
def __init__(
@ -1085,7 +1085,7 @@ class TPUMirroredVariable(checkpointable.Checkpointable):
return self._read_variable_op()
def _gather_saveables_for_checkpoint(self):
"""Overrides CheckpointableBase method.
"""Overrides Trackable method.
This allows both name-based and object-based save and restore of
MirroredVariables.
@ -1095,7 +1095,7 @@ class TPUMirroredVariable(checkpointable.Checkpointable):
"""
def _saveable_factory(name=self._common_name):
return _MirroredSaveable(self, self.primary, name)
return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
@ -1205,7 +1205,7 @@ def _assert_replica_context(strategy):
class ReplicaLocalVariable(DistributedVariable, PerReplica,
checkpointable.Checkpointable):
trackable.Trackable):
"""Holds a map from device to variables whose values are reduced on save."""
def __init__(
@ -1256,7 +1256,7 @@ class ReplicaLocalVariable(DistributedVariable, PerReplica,
return self.get()._as_graph_element()
def _gather_saveables_for_checkpoint(self):
"""Overrides CheckpointableBase method.
"""Overrides Trackable method.
This allows both name-based and object-based save and restore of
ReplicaLocalVariables.
@ -1266,7 +1266,7 @@ class ReplicaLocalVariable(DistributedVariable, PerReplica,
"""
def _saveable_factory(name=self._common_name):
return _ReplicaLocalSaveable(self, name)
return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
# Register a conversion function for ReplicaLocalVariable which allows as_ref to
@ -1436,7 +1436,7 @@ def value_container(val):
# TODO(josh11b): Descend from Variable.
class AggregatingVariable(checkpointable.Checkpointable):
class AggregatingVariable(trackable.Trackable):
"""A wrapper around a variable that aggregates updates across replicas."""
def __init__(self, strategy, v, aggregation):
@ -1514,7 +1514,7 @@ class AggregatingVariable(checkpointable.Checkpointable):
# TODO(josh11b): Test saving & restoring.
def _gather_saveables_for_checkpoint(self):
return {checkpointable.VARIABLE_VALUE_KEY: self._v}
return {trackable.VARIABLE_VALUE_KEY: self._v}
# pylint: disable=multiple-statements
def __add__(self, o): return self._v + o

View File

@ -500,7 +500,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:while_v2", # TODO(b/118513001): Imported via control_flow_ops; remove.
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/tracking:base",
],
)
@ -555,7 +555,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:template",
"//tensorflow/python:variable_scope",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/tracking:base",
],
)

View File

@ -31,7 +31,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import tf_export
@ -113,8 +113,8 @@ class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable):
if constraint is not None and not callable(constraint):
raise ValueError("The `constraint` argument must be a callable.")
if isinstance(initial_value, checkpointable.CheckpointInitialValue):
self._maybe_initialize_checkpointable()
if isinstance(initial_value, trackable.CheckpointInitialValue):
self._maybe_initialize_trackable()
self._update_uid = initial_value.checkpoint_position.restore_uid
initial_value = initial_value.wrapped_value

View File

@ -62,7 +62,7 @@ class VariableHolder(object):
return self._fn(*args, **kwargs)
# TODO(allenl): make this checkpointable
# TODO(allenl): make this trackable
class WrappedFunction(function.ConcreteFunction):
"""Wraps a tf V1 piece of code in a function."""

View File

@ -141,11 +141,11 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine.base_layer import Layer
# TODO(b/118385027): Dependency on keras can be problematic if Keras moves out
# of the main repo.
from tensorflow.python.keras import utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@ -162,7 +162,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
@ -3228,7 +3228,7 @@ def _raise_shared_embedding_column_error():
'`DenseFeatures` or `LinearModel` instead.')
class SharedEmbeddingColumnCreator(tracking.AutoCheckpointable):
class SharedEmbeddingColumnCreator(tracking.AutoTrackable):
def __init__(self,
dimension,

View File

@ -170,7 +170,7 @@ py_library(
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/training/checkpointable:data_structures",
"//tensorflow/python/training/tracking:data_structures",
"//tensorflow/tools/docs:doc_controls",
"@six_archive//:six",
],

View File

@ -46,8 +46,8 @@ from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # p
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import layer_utils as checkpointable_layer_utils
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@ -57,7 +57,7 @@ from tensorflow.tools.docs import doc_controls
@keras_export('keras.layers.Layer')
class Layer(checkpointable.Checkpointable):
class Layer(trackable.Trackable):
"""Base layer class.
This is the class from which all layers inherit.
@ -110,7 +110,7 @@ class Layer(checkpointable.Checkpointable):
constraints on inputs that can be accepted by the layer.
"""
@checkpointable.no_automatic_dependency_tracking
@trackable.no_automatic_dependency_tracking
def __init__(self, trainable=True, name=None, dtype=None, dynamic=False,
**kwargs):
# These properties should be set by the user via keyword arguments.
@ -272,7 +272,7 @@ class Layer(checkpointable.Checkpointable):
marked as non-trainable. `trainable` defaults to `True` unless
`synchronization` is set to `ON_READ`.
constraint: constraint instance (callable).
partitioner: Partitioner to be passed to the `Checkpointable` API.
partitioner: Partitioner to be passed to the `Trackable` API.
use_resource: Whether to use `ResourceVariable`.
synchronization: Indicates when a distributed a variable will be
aggregated. Accepted values are constants defined in the class
@ -345,9 +345,9 @@ class Layer(checkpointable.Checkpointable):
name=name,
shape=shape,
# TODO(allenl): a `make_variable` equivalent should be added as a
# `Checkpointable` method.
# `Trackable` method.
getter=getter or base_layer_utils.make_variable,
# Manage errors in Layer rather than Checkpointable.
# Manage errors in Layer rather than Trackable.
overwrite=True,
initializer=initializer,
dtype=dtype,
@ -1629,7 +1629,7 @@ class Layer(checkpointable.Checkpointable):
# Append value to self._layers if relevant
if (isinstance(value, Layer) or
checkpointable_layer_utils.has_weights(value)):
trackable_layer_utils.has_weights(value)):
# Initialize `_layers` here in case `__init__` has not yet been called.
if not hasattr(self, '_layers'):
self._layers = []
@ -1666,7 +1666,7 @@ class Layer(checkpointable.Checkpointable):
return []
# This is a hack so that the is_layer (within
# training/checkpointable/layer_utils.py) check doesn't get the weights attr.
# training/trackable/layer_utils.py) check doesn't get the weights attr.
# TODO(b/110718070): Remove when fixed.
def _is_layer(self):
return True

View File

@ -76,7 +76,7 @@ def make_variable(name,
that has fewer constraints (`variable_scope.variable()`).
In the longer term, it seems like a similar "default variable creator" method
should exist in `CheckpointableBase` instead. When this happens, we can get
should exist in `Trackable` instead. When this happens, we can get
rid of this temporary solution.
TODO(fchollet): remove this method when no longer needed.

View File

@ -44,10 +44,10 @@ from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import layer_utils as checkpointable_layer_utils
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@ -152,7 +152,7 @@ class Network(base_layer.Layer):
# empty lists shouldn't cause issues; adding or removing them will not break
# checkpoints, but may cause "all Python objects matched" assertions to fail
# (in which case less strict assertions may be substituted if necessary).
@checkpointable.no_automatic_dependency_tracking
@trackable.no_automatic_dependency_tracking
def _base_init(self, name=None):
# The following are implemented as property functions:
# self.trainable_weights
@ -206,10 +206,10 @@ class Network(base_layer.Layer):
self._outbound_nodes = []
self._inbound_nodes = []
self._checkpointable_saver = (
checkpointable_utils.saver_with_op_caching(self))
self._trackable_saver = (
trackable_utils.saver_with_op_caching(self))
@checkpointable.no_automatic_dependency_tracking
@trackable.no_automatic_dependency_tracking
def _init_graph_network(self, inputs, outputs, name=None):
self._call_convention = (base_layer_utils
.CallConvention.EXPLICIT_INPUTS_ARGUMENT)
@ -309,7 +309,7 @@ class Network(base_layer.Layer):
for layer in self._output_layers:
self.output_names.append(layer.name)
@checkpointable.no_automatic_dependency_tracking
@trackable.no_automatic_dependency_tracking
def _init_subclassed_network(self, name=None, dynamic=False):
self._base_init(name=name)
self._is_graph_network = False
@ -370,20 +370,20 @@ class Network(base_layer.Layer):
return base_layer_utils.CallConvention.POSITIONAL_ARGUMENTS_ARE_INPUTS
def _track_layers(self, layers):
"""Add Checkpointable dependencies on a list of Layers."""
"""Add Trackable dependencies on a list of Layers."""
weight_layer_index = 0
for layer_index, layer in enumerate(layers):
if layer.weights:
# Keep a separate index for layers which have weights. This allows users
# to insert Layers without weights anywhere in the network without
# breaking checkpoints.
self._track_checkpointable(
self._track_trackable(
layer, name='layer_with_weights-%d' % weight_layer_index,
overwrite=True)
weight_layer_index += 1
# Even if it doesn't have weights, we should still track everything in
# case it has/will have Checkpointable dependencies.
self._track_checkpointable(
# case it has/will have Trackable dependencies.
self._track_trackable(
layer, name='layer-%d' % layer_index, overwrite=True)
def __setattr__(self, name, value):
@ -393,18 +393,18 @@ class Network(base_layer.Layer):
if all(
isinstance(v, (base_layer.Layer,
data_structures.CheckpointableDataStructure)) or
checkpointable_layer_utils.has_weights(v) for v in nest.flatten(value)):
data_structures.TrackableDataStructure)) or
trackable_layer_utils.has_weights(v) for v in nest.flatten(value)):
try:
self._is_graph_network
except AttributeError:
raise RuntimeError('It looks like you are subclassing `Model` and you '
'forgot to call `super(YourClass, self).__init__()`.'
' Always start with this line.')
# Keep track of checkpointable objects,
# Keep track of trackable objects,
# for the needs of `self.save/save_weights`.
value = data_structures.sticky_attribute_assignment(
checkpointable=self, value=value, name=name)
trackable=self, value=value, name=name)
super(Network, self).__setattr__(name, value)
# Keep track of metric instance created in subclassed model/layer.
@ -481,7 +481,7 @@ class Network(base_layer.Layer):
@property
def layers(self):
return checkpointable_layer_utils.filter_empty_layer_containers(
return trackable_layer_utils.filter_empty_layer_containers(
self._layers)
def get_layer(self, name=None, index=None):
@ -542,7 +542,7 @@ class Network(base_layer.Layer):
losses += layer.losses
return losses
@checkpointable.no_automatic_dependency_tracking
@trackable.no_automatic_dependency_tracking
def _clear_losses(self):
"""Used every step in eager to reset losses."""
self._eager_losses = []
@ -682,14 +682,14 @@ class Network(base_layer.Layer):
@property
def trainable_weights(self):
return checkpointable_layer_utils.gather_trainable_weights(
return trackable_layer_utils.gather_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._trainable_weights)
@property
def non_trainable_weights(self):
return checkpointable_layer_utils.gather_non_trainable_weights(
return trackable_layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._non_trainable_weights + self._trainable_weights)
@ -1397,7 +1397,7 @@ class Network(base_layer.Layer):
session = backend.get_session()
optimizer = getattr(self, 'optimizer', None)
if (optimizer
and not isinstance(optimizer, checkpointable.Checkpointable)):
and not isinstance(optimizer, trackable.Trackable)):
logging.warning(
('This model was compiled with a Keras optimizer (%s) but is being '
'saved in TensorFlow format with `save_weights`. The model\'s '
@ -1405,7 +1405,7 @@ class Network(base_layer.Layer):
'the TensorFlow format the optimizer\'s state will not be '
'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
% (optimizer,))
self._checkpointable_saver.save(filepath, session=session)
self._trackable_saver.save(filepath, session=session)
# Record this checkpoint so it's visible from tf.train.latest_checkpoint.
checkpoint_management.update_checkpoint_state_internal(
save_dir=os.path.dirname(filepath),
@ -1464,7 +1464,7 @@ class Network(base_layer.Layer):
# The checkpoint is not readable in TensorFlow format. Try HDF5.
save_format = 'h5'
if save_format == 'tf':
status = self._checkpointable_saver.restore(filepath)
status = self._trackable_saver.restore(filepath)
if by_name:
raise NotImplementedError(
'Weights may only be loaded based on topology into Models when '
@ -1474,7 +1474,7 @@ class Network(base_layer.Layer):
session = backend.get_session()
# Restore existing variables (if any) immediately, and set up a
# streaming restore for any variables created in the future.
checkpointable_utils.streaming_restore(status=status, session=session)
trackable_utils.streaming_restore(status=status, session=session)
status.assert_nontrivial_match()
return status
if h5py is None:

View File

@ -28,7 +28,7 @@ from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import keras_export
@ -93,7 +93,7 @@ class Sequential(training.Model):
```
"""
@checkpointable.no_automatic_dependency_tracking
@trackable.no_automatic_dependency_tracking
def __init__(self, layers=None, name=None):
super(Sequential, self).__init__(name=name)
self.supports_masking = True
@ -112,7 +112,7 @@ class Sequential(training.Model):
# Historically, `sequential.layers` only returns layers that were added
# via `add`, and omits the auto-generated `InputLayer` that comes at the
# bottom of the stack.
# `CheckpointableBase` manages the `_layers` attributes and does filtering
# `Trackable` manages the `_layers` attributes and does filtering
# over it.
layers = super(Sequential, self).layers
if layers and isinstance(layers[0], input_layer.InputLayer):
@ -123,7 +123,7 @@ class Sequential(training.Model):
def dynamic(self):
return any(layer.dynamic for layer in self.layers)
@checkpointable.no_automatic_dependency_tracking
@trackable.no_automatic_dependency_tracking
def add(self, layer):
"""Adds a layer instance on top of the layer stack.
@ -193,7 +193,7 @@ class Sequential(training.Model):
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
@checkpointable.no_automatic_dependency_tracking
@trackable.no_automatic_dependency_tracking
def pop(self):
"""Removes the last layer in the model.

View File

@ -48,7 +48,7 @@ from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
@ -141,7 +141,7 @@ class Model(Network):
return super(Model, self).get_weights()
return super(Model, self).get_weights()
@checkpointable.no_automatic_dependency_tracking
@trackable.no_automatic_dependency_tracking
def compile(self,
optimizer,
loss=None,
@ -245,9 +245,9 @@ class Model(Network):
self.optimizer = optimizer
# We've disabled automatic dependency tracking for this method, but do want
# to add a checkpoint dependency on the optimizer if it's checkpointable.
if isinstance(self.optimizer, checkpointable.Checkpointable):
self._track_checkpointable(
# to add a checkpoint dependency on the optimizer if it's trackable.
if isinstance(self.optimizer, trackable.Trackable):
self._track_trackable(
self.optimizer, name='optimizer', overwrite=True)
self.loss = loss
self._compile_metrics = metrics or []
@ -2663,7 +2663,7 @@ class Model(Network):
'However we received `validation_data=%s`' % validation_data)
return val_x, val_y, val_sample_weight
@checkpointable.no_automatic_dependency_tracking
@trackable.no_automatic_dependency_tracking
def _set_inputs(self, inputs, outputs=None, training=None):
"""Set model's input and output specs based on the input data received.

View File

@ -42,7 +42,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_cudnn_rnn_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
@ -78,7 +78,7 @@ class StackedRNNCells(Layer):
```
"""
@checkpointable.no_automatic_dependency_tracking
@trackable.no_automatic_dependency_tracking
def __init__(self, cells, **kwargs):
for cell in cells:
if not hasattr(cell, 'call'):
@ -443,7 +443,7 @@ class RNN(Layer):
```
"""
@checkpointable.no_automatic_dependency_tracking
@trackable.no_automatic_dependency_tracking
def __init__(self,
cell,
return_sequences=False,
@ -468,8 +468,8 @@ class RNN(Layer):
self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False)
super(RNN, self).__init__(**kwargs)
self.cell = cell
if isinstance(cell, checkpointable.Checkpointable):
self._track_checkpointable(self.cell, name='cell')
if isinstance(cell, trackable.Trackable):
self._track_trackable(self.cell, name='cell')
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards

View File

@ -40,7 +40,7 @@ from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import util as checkpointable_util
from tensorflow.python.training.tracking import util as trackable_util
from tensorflow.python.util import nest
# Used for nested input/output/state RNN test.
@ -715,7 +715,7 @@ class RNNTest(keras_parameterized.TestCase):
[tuple(o.as_list()) for o in output_shape],
expected_output_shape)
def test_checkpointable_dependencies(self):
def test_trackable_dependencies(self):
rnn = keras.layers.SimpleRNN
x = np.random.random((2, 2, 2))
y = np.random.random((2, 2))
@ -728,8 +728,8 @@ class RNNTest(keras_parameterized.TestCase):
model.fit(x, y, epochs=1, batch_size=1)
# check whether the model variables are present in the
# checkpointable list of objects
checkpointed_objects = set(checkpointable_util.list_objects(model))
# trackable list of objects
checkpointed_objects = set(trackable_util.list_objects(model))
for v in model.variables:
self.assertIn(v, checkpointed_objects)

View File

@ -29,7 +29,7 @@ from tensorflow.python.keras.layers.recurrent import _standardize_args
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
@ -46,7 +46,7 @@ class Wrapper(Layer):
layer: The layer to be wrapped.
"""
@checkpointable.no_automatic_dependency_tracking
@trackable.no_automatic_dependency_tracking
def __init__(self, layer, **kwargs):
assert isinstance(layer, Layer)
self.layer = layer
@ -170,7 +170,7 @@ class TimeDistributed(Wrapper):
'`Layer` instance. You passed: {input}'.format(input=layer))
super(TimeDistributed, self).__init__(layer, **kwargs)
self.supports_masking = True
self._track_checkpointable(layer, name='layer')
self._track_trackable(layer, name='layer')
def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None):
"""Finds non-specific dimensions in the static shapes.
@ -386,7 +386,7 @@ class Bidirectional(Wrapper):
```
"""
@checkpointable.no_automatic_dependency_tracking
@trackable.no_automatic_dependency_tracking
def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
if not isinstance(layer, Layer):
raise ValueError(
@ -419,8 +419,8 @@ class Bidirectional(Wrapper):
self._num_constants = None
super(Bidirectional, self).__init__(layer, **kwargs)
self.input_spec = layer.input_spec
self._track_checkpointable(self.forward_layer, name='forward_layer')
self._track_checkpointable(self.backward_layer, name='backward_layer')
self._track_trackable(self.forward_layer, name='forward_layer')
self._track_trackable(self.backward_layer, name='backward_layer')
@property
def trainable(self):

View File

@ -27,7 +27,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import util as checkpointable_util
from tensorflow.python.training.tracking import util as trackable_util
class _RNNCellWithConstants(keras.layers.Layer):
@ -88,8 +88,8 @@ class TimeDistributedTest(test.TestCase):
model.get_config()
# check whether the model variables are present in the
# checkpointable list of objects
checkpointed_objects = set(checkpointable_util.list_objects(model))
# trackable list of objects
checkpointed_objects = set(trackable_util.list_objects(model))
for v in model.variables:
self.assertIn(v, checkpointed_objects)
@ -303,8 +303,8 @@ class BidirectionalTest(test.TestCase):
model.fit(x, y, epochs=1, batch_size=1)
# check whether the model variables are present in the
# checkpointable list of objects
checkpointed_objects = set(checkpointable_util.list_objects(model))
# trackable list of objects
checkpointed_objects = set(trackable_util.list_objects(model))
for v in model.variables:
self.assertIn(v, checkpointed_objects)

View File

@ -37,7 +37,7 @@ from tensorflow.python.keras import testing_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.training.tracking import util as trackable_utils
@test_util.run_all_in_graph_and_eager_modes
@ -131,7 +131,7 @@ class KerasSumTest(test.TestCase):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
m = metrics.Sum()
checkpoint = checkpointable_utils.Checkpoint(sum=m)
checkpoint = trackable_utils.Checkpoint(sum=m)
self.evaluate(variables.variables_initializer(m.variables))
# update state
@ -149,7 +149,7 @@ class KerasSumTest(test.TestCase):
# restore to a different checkpoint sum object
restore_sum = metrics.Sum()
restore_checkpoint = checkpointable_utils.Checkpoint(sum=restore_sum)
restore_checkpoint = trackable_utils.Checkpoint(sum=restore_sum)
status = restore_checkpoint.restore(save_path)
restore_update = restore_sum(300.)
status.assert_consumed().run_restore_ops()
@ -267,7 +267,7 @@ class KerasMeanTest(test.TestCase):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
m = metrics.Mean()
checkpoint = checkpointable_utils.Checkpoint(mean=m)
checkpoint = trackable_utils.Checkpoint(mean=m)
self.evaluate(variables.variables_initializer(m.variables))
# update state
@ -285,7 +285,7 @@ class KerasMeanTest(test.TestCase):
# restore to a different checkpoint mean object
restore_mean = metrics.Mean()
restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean)
restore_checkpoint = trackable_utils.Checkpoint(mean=restore_mean)
status = restore_checkpoint.restore(save_path)
restore_update = restore_mean(300.)
status.assert_consumed().run_restore_ops()

View File

@ -35,7 +35,7 @@ from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.tracking import data_structures
try:
import h5py # pylint:disable=g-import-not-at-top

View File

@ -43,7 +43,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
@ -70,7 +70,7 @@ def _deduplicate_indexed_slices(values, indices):
@six.add_metaclass(abc.ABCMeta)
@keras_export("keras.optimizers.Optimizer")
class OptimizerV2(checkpointable.Checkpointable):
class OptimizerV2(trackable.Trackable):
"""Updated base class for optimizers.
This class defines the API to add Ops to train a model. You never use this
@ -244,9 +244,9 @@ class OptimizerV2(checkpointable.Checkpointable):
self._weights = []
self._iterations = None
# For implementing Checkpointable. Stores information about how to restore
# For implementing Trackable. Stores information about how to restore
# slot variables which have not yet been created
# (checkpointable._CheckpointPosition objects).
# (trackable._CheckpointPosition objects).
# {slot_name :
# {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
# ... }
@ -829,7 +829,7 @@ class OptimizerV2(checkpointable.Checkpointable):
return x.value()
# ---------------
# For implementing the checkpointable interface
# For implementing the trackable interface
# ---------------
def _restore_slot_variable(self, slot_name, variable, slot_variable):
@ -860,8 +860,8 @@ class OptimizerV2(checkpointable.Checkpointable):
slot variable needs to be restored).
Args:
slot_variable_position: A `checkpointable._CheckpointPosition` object
indicating the slot variable `Checkpointable` object to be restored.
slot_variable_position: A `trackable._CheckpointPosition` object
indicating the slot variable `Trackable` object to be restored.
slot_name: The name of this `Optimizer`'s slot to restore into.
variable: The variable object this slot is being created for.
"""
@ -879,7 +879,7 @@ class OptimizerV2(checkpointable.Checkpointable):
# (aside from double initialization), and makes variable creator scopes
# behave the same way they do when graph building.
and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
initializer = checkpointable.CheckpointInitialValue(
initializer = trackable.CheckpointInitialValue(
checkpoint_position=slot_variable_position)
slot_variable = self.add_slot(
var=variable,

View File

@ -40,7 +40,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util.tf_export import keras_export
@ -710,19 +710,19 @@ class Nadam(Optimizer):
return dict(list(base_config.items()) + list(config.items()))
class TFOptimizer(Optimizer, checkpointable.Checkpointable):
class TFOptimizer(Optimizer, trackable.Trackable):
"""Wrapper class for native TensorFlow optimizers.
"""
def __init__(self, optimizer, iterations=None): # pylint: disable=super-init-not-called
self.optimizer = optimizer
self._track_checkpointable(optimizer, name='optimizer')
self._track_trackable(optimizer, name='optimizer')
if iterations is None:
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
else:
self.iterations = iterations
self._track_checkpointable(self.iterations, name='global_step')
self._track_trackable(self.iterations, name='global_step')
def apply_gradients(self, grads):
self.optimizer.apply_gradients(grads, global_step=self.iterations)

View File

@ -40,7 +40,7 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import training as training_module
from tensorflow.python.training.checkpointable import util as checkpointable
from tensorflow.python.training.tracking import util as trackable
try:
import h5py # pylint:disable=g-import-not-at-top
@ -994,7 +994,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_incompatible_checkpoint(self):
save_path = checkpointable.Checkpoint().save(
save_path = trackable.Checkpoint().save(
os.path.join(self.get_temp_dir(), 'ckpt'))
m = keras.Model()
with self.assertRaisesRegexp(AssertionError, 'Nothing to load'):

View File

@ -37,7 +37,7 @@ from tensorflow.python.saved_model import model_utils
from tensorflow.python.saved_model import save as save_lib
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.checkpointable import graph_view
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
@ -300,10 +300,10 @@ def _export_mode(
# not counting optimizer objects. Optimizer objects are ignored because
# if the model has not trained, the slot variables will not have been
# created yet.
# TODO(b/113179535): Replace with checkpointable equivalence.
# TODO(b/113179535): Replace with trackable equivalence.
_assert_same_non_optimizer_objects(model, model_graph, clone, g)
# TODO(b/113178242): Use value transfer for checkpointable objects.
# TODO(b/113178242): Use value transfer for trackable objects.
clone.load_weights(checkpoint_path)
# Add graph and variables to SavedModel.
@ -361,7 +361,7 @@ def _create_signature_def_map(model, mode):
def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument
"""Asserts model and clone contain the same checkpointable objects."""
"""Asserts model and clone contain the same trackable objects."""
# TODO(fchollet, kathywu): make sure this works in eager mode.
return True

View File

@ -38,7 +38,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.training.checkpointable import util as checkpointable
from tensorflow.python.training.tracking import util as trackable
class HashTableTest(test.TestCase):
@ -1691,7 +1691,7 @@ class MutableHashTableOpTest(test.TestCase):
table = lookup_ops.MutableHashTable(
dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1)
checkpoint = trackable.Checkpoint(table=table, v0=v0, v1=v1)
self.evaluate([v0.initializer, v1.initializer])
# Check that the parameter nodes have been initialized.
@ -1716,7 +1716,7 @@ class MutableHashTableOpTest(test.TestCase):
constant_op.constant([12, 24], dtypes.int64)))
self.assertAllEqual(2, self.evaluate(table.size()))
checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1)
checkpoint = trackable.Checkpoint(table=table, v0=v0, v1=v1)
# Restore the saved values in the parameter nodes.
checkpoint.restore(save_path).run_restore_ops()
@ -2512,7 +2512,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
checkpoint=True,
initial_num_buckets=32)
save_checkpoint = checkpointable.Checkpoint(table=save_table)
save_checkpoint = trackable.Checkpoint(table=save_table)
self.assertAllEqual(0, self.evaluate(save_table.size()))
self.evaluate(save_table.insert(keys, values))
@ -2538,7 +2538,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual(2, self.evaluate(load_table.size()))
self.assertAllEqual(64, len(self.evaluate(load_table.export()[0])))
restore_checkpoint = checkpointable.Checkpoint(table=load_table)
restore_checkpoint = trackable.Checkpoint(table=load_table)
# Restore the saved values in the parameter nodes.
restore_checkpoint.restore(save_path).run_restore_ops()

View File

@ -49,7 +49,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.util import nest
@ -2804,7 +2804,7 @@ class RNNCellTest(test.TestCase, parameterized.TestCase):
wrapper(array_ops.ones([1, 1]),
state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
self.evaluate([v.initializer for v in cell.variables])
checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper)
checkpoint = trackable_utils.Checkpoint(wrapper=wrapper)
prefix = os.path.join(self.get_temp_dir(), "ckpt")
self.evaluate(cell._bias.assign([40.]))
save_path = checkpoint.save(prefix)

View File

@ -26,7 +26,7 @@ from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
@ -554,7 +554,7 @@ class Layer(base_layer.Layer):
def __setattr__(self, value, name):
# By-pass the automatic dependency tracking performed by the parent Layer.
super(checkpointable.Checkpointable, self).__setattr__(value, name)
super(trackable.Trackable, self).__setattr__(value, name)
def _add_elements_to_collection(elements, collection_list):

View File

@ -13,7 +13,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/training/checkpointable:tracking",
"//tensorflow/python/training/tracking",
"@six_archive//:six",
],
)

View File

@ -27,7 +27,7 @@ import six
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@ -149,7 +149,7 @@ def with_name_scope(unbound_method):
@tf_export("Module", "experimental.Module")
class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoCheckpointable)):
class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoTrackable)):
"""Base neural network module class.
A module is a named container for `tf.Variable`s, other `tf.Module`s and
@ -375,7 +375,7 @@ def camel_to_snake(value):
return _CAMEL_TO_SNAKE_R.sub(r"_\1", value).lower()
# AutoCheckpointable adds object attributes that users will not expect us to
# AutoTrackable adds object attributes that users will not expect us to
# include when flattening (these reference dependencies reachable via other
# object attributes).
AUTO_CHECKPOINTABLE_ATTRS = ("_unconditional_checkpoint_dependencies",

View File

@ -43,7 +43,7 @@ from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantil
# pylint: enable=unused-import
from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.tracking import tracking
class PruningMode(object):

View File

@ -38,10 +38,10 @@ from tensorflow.python.ops import string_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_lookup_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.training.checkpointable import base as checkpointable_base
from tensorflow.python.training.checkpointable import tracking as checkpointable
from tensorflow.python.training.saver import BaseSaverBuilder
# pylint: enable=wildcard-import
from tensorflow.python.training.tracking import base as trackable_base
from tensorflow.python.training.tracking import tracking as trackable
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@ -102,7 +102,7 @@ def _check_table_dtypes(table, key_dtype, value_dtype):
(table.value_dtype, value_dtype))
class LookupInterface(checkpointable.TrackableResource):
class LookupInterface(trackable.TrackableResource):
"""Represent a lookup table that persists across different steps."""
def __init__(self, key_dtype, value_dtype):
@ -165,8 +165,8 @@ class InitializableLookupTableBase(LookupInterface):
self._default_value = ops.convert_to_tensor(
default_value, dtype=self._value_dtype)
self._default_value.get_shape().merge_with(tensor_shape.scalar())
if isinstance(initializer, checkpointable_base.Checkpointable):
self._initializer = self._track_checkpointable(
if isinstance(initializer, trackable_base.Trackable):
self._initializer = self._track_trackable(
initializer, "_initializer")
self._resource_handle = self.create_resource()
self._init_op = self.initialize()
@ -314,7 +314,7 @@ class HashTable(InitializableLookupTableBase):
return exported_keys, exported_values
class TableInitializerBase(checkpointable_base.Checkpointable):
class TableInitializerBase(trackable_base.Trackable):
"""Base class for lookup table initializers."""
def __init__(self, key_dtype, value_dtype):
@ -543,8 +543,8 @@ class TextFileInitializer(TableInitializerBase):
self._vocab_size = vocab_size
self._delimiter = delimiter
self._name = name
self._filename = self._track_checkpointable(
checkpointable.TrackableAsset(filename),
self._filename = self._track_trackable(
trackable.TrackableAsset(filename),
"_filename")
super(TextFileInitializer, self).__init__(key_dtype, value_dtype)

View File

@ -43,7 +43,7 @@ from tensorflow.python.ops import variables
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_resource_variable_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated
@ -505,8 +505,8 @@ class ResourceVariable(variables.VariableV1):
if constraint is not None and not callable(constraint):
raise ValueError("The `constraint` argument must be a callable.")
if isinstance(initial_value, checkpointable.CheckpointInitialValue):
self._maybe_initialize_checkpointable()
if isinstance(initial_value, trackable.CheckpointInitialValue):
self._maybe_initialize_trackable()
self._update_uid = initial_value.checkpoint_position.restore_uid
initial_value = initial_value.wrapped_value
@ -1684,7 +1684,7 @@ def copy_to_graph_uninitialized(var):
constraint=var._constraint,
dtype=var.dtype,
name=var._shared_name)
new_variable._maybe_initialize_checkpointable()
new_variable._maybe_initialize_trackable()
# pylint: enable=protected-access
return new_variable

View File

@ -50,7 +50,7 @@ from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@ -1095,8 +1095,8 @@ class _RNNCellWrapperV1(RNNCell):
def __init__(self, cell):
super(_RNNCellWrapperV1, self).__init__()
self._cell = cell
if isinstance(cell, checkpointable.Checkpointable):
self._track_checkpointable(self._cell, name="cell")
if isinstance(cell, trackable.Trackable):
self._track_trackable(self._cell, name="cell")
def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
"""Calls the wrapped cell and performs the wrapping logic.
@ -1611,8 +1611,8 @@ class DeviceWrapper(RNNCell):
"""
super(DeviceWrapper, self).__init__()
self._cell = cell
if isinstance(cell, checkpointable.Checkpointable):
self._track_checkpointable(self._cell, name="cell")
if isinstance(cell, trackable.Trackable):
self._track_trackable(self._cell, name="cell")
self._device = device
@property
@ -1678,11 +1678,11 @@ class MultiRNNCell(RNNCell):
self._cells = cells
for cell_number, cell in enumerate(self._cells):
# Add Checkpointable dependencies on these cells so their variables get
# Add Trackable dependencies on these cells so their variables get
# saved with this object when using object-based saving.
if isinstance(cell, checkpointable.Checkpointable):
# TODO(allenl): Track down non-Checkpointable callers.
self._track_checkpointable(cell, name="cell-%d" % (cell_number,))
if isinstance(cell, trackable.Trackable):
# TODO(allenl): Track down non-Trackable callers.
self._track_trackable(cell, name="cell-%d" % (cell_number,))
self._state_is_tuple = state_is_tuple
if not state_is_tuple:
if any(nest.is_sequence(c.state_size) for c in self._cells):

View File

@ -27,7 +27,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_stateful_random_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.training.checkpointable import \
from tensorflow.python.training.tracking import \
tracking
from tensorflow.python.util.tf_export import tf_export
@ -144,7 +144,7 @@ def _shape_tensor(shape):
@tf_export("random.experimental.Generator")
class Generator(tracking.AutoCheckpointable):
class Generator(tracking.AutoTrackable):
"""Random-number generator.
It uses Variable to manage its internal state.

View File

@ -26,8 +26,8 @@ from tensorflow.python.eager import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import util as checkpointable_util
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import util as trackable_util
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.deprecation import deprecated
@ -232,7 +232,7 @@ def _skip_common_stack_elements(stacktrace, base_case):
return stacktrace[-1:]
class Template(checkpointable.Checkpointable):
class Template(trackable.Trackable):
"""Wrap a function to aid in variable sharing.
Templates are functions that create variables the first time they are called
@ -306,8 +306,8 @@ class Template(checkpointable.Checkpointable):
result = self._func(*args, **kwargs)
else:
# The first time we run, restore variables if necessary (via
# Checkpointable).
with checkpointable_util.capture_dependencies(template=self):
# Trackable).
with trackable_util.capture_dependencies(template=self):
result = self._func(*args, **kwargs)
if self._variables_created:
@ -577,8 +577,8 @@ class EagerTemplate(Template):
result = self._func(*args, **kwargs)
else:
# The first time we run, restore variables if necessary (via
# Checkpointable).
with checkpointable_util.capture_dependencies(template=self):
# Trackable).
with trackable_util.capture_dependencies(template=self):
result = self._func(*args, **kwargs)
if self._variables_created:

View File

@ -35,7 +35,7 @@ from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import compat
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.deprecation import deprecated
@ -204,7 +204,7 @@ class VariableMetaclass(type):
@tf_export("Variable", v1=[])
class Variable(six.with_metaclass(VariableMetaclass,
checkpointable.Checkpointable)):
trackable.Trackable)):
"""See the [Variables Guide](https://tensorflow.org/guide/variables).
A variable maintains state in the graph across calls to `run()`. You add a
@ -1018,8 +1018,8 @@ class Variable(six.with_metaclass(VariableMetaclass,
return self.shape
def _gather_saveables_for_checkpoint(self):
"""For implementing `Checkpointable`. This object is saveable on its own."""
return {checkpointable.VARIABLE_VALUE_KEY: self}
"""For implementing `Trackable`. This object is saveable on its own."""
return {trackable.VARIABLE_VALUE_KEY: self}
def to_proto(self, export_scope=None):
"""Converts a `Variable` to a `VariableDef` protocol buffer.
@ -1506,8 +1506,8 @@ class RefVariable(VariableV1):
# Store the graph key so optimizers know how to only retrieve variables from
# this graph.
self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
if isinstance(initial_value, checkpointable.CheckpointInitialValue):
self._maybe_initialize_checkpointable()
if isinstance(initial_value, trackable.CheckpointInitialValue):
self._maybe_initialize_trackable()
self._update_uid = initial_value.checkpoint_position.restore_uid
initial_value = initial_value.wrapped_value

View File

@ -275,7 +275,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:function",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/tracking:base",
],
)
@ -310,12 +310,12 @@ py_library(
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:function",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/checkpointable:graph_view",
"//tensorflow/python/training/checkpointable:object_identity",
"//tensorflow/python/training/checkpointable:tracking",
"//tensorflow/python/training/checkpointable:util",
"//tensorflow/python/training/saving:functional_saver",
"//tensorflow/python/training/tracking",
"//tensorflow/python/training/tracking:base",
"//tensorflow/python/training/tracking:graph_view",
"//tensorflow/python/training/tracking:object_identity",
"//tensorflow/python/training/tracking:util",
],
)
@ -356,10 +356,10 @@ py_library(
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/checkpointable:graph_view",
"//tensorflow/python/training/checkpointable:tracking",
"//tensorflow/python/training/checkpointable:util",
"//tensorflow/python/training/tracking",
"//tensorflow/python/training/tracking:base",
"//tensorflow/python/training/tracking:graph_view",
"//tensorflow/python/training/tracking:util",
],
)
@ -375,7 +375,7 @@ py_library(
"//tensorflow/python:constant_op",
"//tensorflow/python:saver",
"//tensorflow/python/eager:wrap_function",
"//tensorflow/python/training/checkpointable:tracking",
"//tensorflow/python/training/tracking",
],
)
@ -392,7 +392,7 @@ tf_py_test(
"//tensorflow/python:tensor_spec",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/training/checkpointable:tracking",
"//tensorflow/python/training/tracking:tracking",
],
)
@ -417,7 +417,7 @@ tf_py_test(
"//tensorflow/python:tensor_spec",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/training/checkpointable:tracking",
"//tensorflow/python/training/tracking:tracking",
"//tensorflow/python:variables",
],
)

View File

@ -49,7 +49,7 @@ def serialize_concrete_function(concrete_function, node_ids, coder):
"captures tensor %s which is unsupported or not reachable from root. "
"One reason could be that a stateful object or a variable that the "
"function depends on is not assigned to an attribute of the serialized "
"checkpointable object "
"trackable object "
"(see SaveTest.test_captures_unreachable_variable)."
% (concrete_function.name, capture))
concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Import a checkpointable object from a SavedModel."""
"""Import a trackable object from a SavedModel."""
from __future__ import absolute_import
from __future__ import division
@ -36,10 +36,10 @@ from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training.checkpointable import base
from tensorflow.python.training.checkpointable import graph_view
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
from tensorflow.python.util import compat
from tensorflow.python.util import nest
@ -149,16 +149,16 @@ class _Loader(object):
def _restore_checkpoint(self):
"""Load state from checkpoint into the deserialized objects."""
variables_path = saved_model_utils.get_variables_path(self._export_dir)
# TODO(andresp): Clean use of private methods of CheckpointableSaver.
# TODO(andresp): Clean use of private methods of TrackableSaver.
# pylint: disable=protected-access
saver = util.CheckpointableSaver(graph_view.ObjectGraphView(self.get(0)))
saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
saver._file_prefix_placeholder = constant_op.constant(variables_path)
load_status = saver.restore(variables_path)
load_status.assert_existing_objects_matched()
checkpoint = load_status._checkpoint
# When running in eager mode, the `restore` call above has already run and
# restored the state of checkpointables, call `position.restore_ops()` will
# restored the state of trackables, call `position.restore_ops()` will
# return an empty list as there is nothing left to do. In graph mode, that
# will return the list of ops that must run to restore the object on that
# position. We have to wire them in the initializers of the objects so that
@ -205,7 +205,7 @@ class _Loader(object):
# individually callable by adding a `__call__` method to the classes of
# the objects instances that have a `__call__` property.
class _UserObject(tracking.AutoCheckpointable):
class _UserObject(tracking.AutoTrackable):
pass
return _UserObject(), setattr
@ -282,7 +282,7 @@ def load(export_dir, tags=None):
print(f(x=tf.constant([[1.]])))
```
Objects exported with `tf.saved_model.save` additionally have checkpointable
Objects exported with `tf.saved_model.save` additionally have trackable
objects and functions assigned to attributes:
```python
@ -303,9 +303,9 @@ def load(export_dir, tags=None):
`tf.saved_model.load`.
Returns:
A checkpointable object with a `signatures` attribute mapping from signature
A trackable object with a `signatures` attribute mapping from signature
keys to functions. If the SavedModel was exported by `tf.saved_model.load`,
it also points to checkpointable objects and functions which were attached
it also points to trackable objects and functions which were attached
to the exported object.
Raises:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for checkpointable object SavedModel loading."""
"""Tests for trackable object SavedModel loading."""
from __future__ import absolute_import
from __future__ import division
@ -40,8 +40,8 @@ from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import monitored_session
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
from tensorflow.python.util import tf_inspect
@ -63,17 +63,17 @@ class LoadTest(test.TestCase, parameterized.TestCase):
return loaded
def test_structure_import(self, cycles):
root = tracking.AutoCheckpointable()
root.dep_one = tracking.AutoCheckpointable()
root.dep_two = tracking.AutoCheckpointable()
root.dep_two.dep = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.dep_one = tracking.AutoTrackable()
root.dep_two = tracking.AutoTrackable()
root.dep_two.dep = tracking.AutoTrackable()
root.dep_three = root.dep_two.dep
imported = self.cycle(root, cycles)
self.assertIs(imported.dep_three, imported.dep_two.dep)
self.assertIsNot(imported.dep_one, imported.dep_two)
def test_variables(self, cycles):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.v1 = variables.Variable(1., trainable=True)
root.v2 = variables.Variable(2., trainable=False)
imported = self.cycle(root, cycles)
@ -83,7 +83,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertFalse(imported.v2.trainable)
def test_capture_variables(self, cycles):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.weights = variables.Variable(2.)
root.f = def_function.function(
lambda x: root.weights * x,
@ -103,7 +103,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
file1 = self._make_asset("contents 1")
file2 = self._make_asset("contents 2")
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.asset1 = tracking.TrackableAsset(file1)
root.asset2 = tracking.TrackableAsset(file2)
@ -122,7 +122,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual("contents 2", f.read())
def test_capture_assets(self, cycles):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.vocab = tracking.TrackableAsset(self._make_asset("contents"))
root.f = def_function.function(
lambda: root.vocab.asset_path,
@ -135,7 +135,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual("contents", f.read())
def test_capture_assets_in_graph(self, cycles):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.vocab = tracking.TrackableAsset(self._make_asset("contents"))
root.f = def_function.function(
lambda: root.vocab.asset_path,
@ -159,7 +159,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def test_dedup_assets(self, cycles):
vocab = self._make_asset("contents")
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.asset1 = tracking.TrackableAsset(vocab)
root.asset2 = tracking.TrackableAsset(vocab)
imported = self.cycle(root, cycles)
@ -171,7 +171,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def func(x):
return 2 * x
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = func
# Add two traces.
@ -189,7 +189,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def func(x):
return 2 * x
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = func
imported = self.cycle(root, cycles)
@ -200,7 +200,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def func(x):
return 2 * x
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = func
imported = self.cycle(
@ -219,7 +219,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
lambda x: f(x) + 1.0,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.g = g
imported = self.cycle(root, cycles)
imported.g(constant_op.constant([1.0]))
@ -232,7 +232,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
else:
return 7
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = def_function.function(func)
self.assertEqual(20, root.f(constant_op.constant(10), True).numpy())
@ -252,7 +252,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
else:
return array_ops.zeros(shape=x.shape, dtype=dtypes.float32)
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = def_function.function(func)
self.assertAllEqual([0.0, 0.0, 0.0],
@ -286,17 +286,17 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def test_function_no_return(self, cycles):
class CheckpointableWithOneVariable(tracking.AutoCheckpointable):
class TrackableWithOneVariable(tracking.AutoTrackable):
def __init__(self, initial_value=0.0):
super(CheckpointableWithOneVariable, self).__init__()
super(TrackableWithOneVariable, self).__init__()
self.variable = variables.Variable(initial_value)
@def_function.function
def increase(self, by=1.0):
self.variable.assign_add(by)
obj = CheckpointableWithOneVariable(5.0)
obj = TrackableWithOneVariable(5.0)
obj.increase(constant_op.constant(10.0))
self.assertEqual(15.0, obj.variable.numpy())
@ -320,7 +320,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
else:
return 7
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = def_function.function(func)
x = constant_op.constant(10)
@ -352,7 +352,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2)
return [named_tuple, input2, {"x": 0.5}]
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = def_function.function(func)
result = root.f(constant_op.constant(2), constant_op.constant(3))
@ -382,7 +382,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
else:
return 7
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = def_function.function(func)
self.assertEqual(20, root.f(constant_op.constant(10), True).numpy())
@ -404,7 +404,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
else:
return 7
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = def_function.function(func)
x = constant_op.constant(10)
@ -419,10 +419,10 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual(7, imported.f(x, learning_rate=0.5, epochs=3).numpy())
def test_member_function(self, cycles):
class CheckpointableWithMember(tracking.AutoCheckpointable):
class TrackableWithMember(tracking.AutoTrackable):
def __init__(self):
super(CheckpointableWithMember, self).__init__()
super(TrackableWithMember, self).__init__()
self._some_value = 20
@def_function.function
@ -432,7 +432,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
else:
return 7 + self._some_value
root = CheckpointableWithMember()
root = TrackableWithMember()
self.assertEqual(20, root.f(constant_op.constant(10), True).numpy())
self.assertEqual(27, root.f(constant_op.constant(1)).numpy())
@ -444,7 +444,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual(27, imported.f(constant_op.constant(2)).numpy())
def test_side_effect_listing(self, cycles):
class M(tracking.AutoCheckpointable):
class M(tracking.AutoTrackable):
def __init__(self):
super(M, self).__init__()
@ -468,7 +468,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
lambda x: x*weight + bias,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.weight = weight
root.bias = bias
root.g = g
@ -508,7 +508,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def h(x):
return g(x) + bias,
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.weight = weight
root.bias = bias
root.g = h
@ -521,16 +521,16 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertAllClose(grad, [3.5, 2.0])
def test_callable(self, cycles):
class M1(tracking.AutoCheckpointable):
class M1(tracking.AutoTrackable):
@def_function.function(
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
def __call__(self, x):
return x
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.m1 = M1()
root.m2 = tracking.AutoCheckpointable()
root.m2 = tracking.AutoTrackable()
root.m2.__call__ = def_function.function(
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])(
lambda x: x*3.0)
@ -553,9 +553,9 @@ class LoadTest(test.TestCase, parameterized.TestCase):
func = def_function.function(
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])(
lambda x: x*3.0)
root = tracking.AutoCheckpointable()
root.__call__ = tracking.AutoCheckpointable()
root.__call__.__call__ = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.__call__ = tracking.AutoTrackable()
root.__call__.__call__ = tracking.AutoTrackable()
root.__call__.__call__.__call__ = func
imported = self.cycle(root, cycles)
@ -564,7 +564,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(imported(x).numpy(), 3.0)
def test_load_in_graph_mode(self, cycles):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.v1 = variables.Variable(1.)
root.v2 = variables.Variable(2.)
root.f = def_function.function(
@ -585,7 +585,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual(4.0, sess.run(output))
def test_load_in_func_graph(self, cycles):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.v1 = variables.Variable(1.)
root.v2 = variables.Variable(2.)
root.f = def_function.function(
@ -597,7 +597,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
path = tempfile.mkdtemp(prefix=self.get_temp_dir())
save.save(root, path)
closure = tracking.AutoCheckpointable()
closure = tracking.AutoTrackable()
@def_function.function
def func(x):
if not hasattr(closure, "model"):
@ -614,7 +614,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def func(x):
return 2 * x
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = func
self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
@ -650,7 +650,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
tensor_spec.TensorSpec([None], dtypes.int32), True)
func.get_concrete_function(tensor_spec.TensorSpec([None], dtypes.float32))
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = func
imported = self.cycle(root, cycles)
@ -674,7 +674,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def func(x):
return 2 * x
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = func.get_concrete_function()
self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
@ -695,7 +695,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def func(x):
return 2 * x
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = func.get_concrete_function()
self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
@ -711,7 +711,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def func(x):
return 2 * x
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = func.get_concrete_function(constant_op.constant([1]))
self.assertAllEqual([4], root.f(constant_op.constant([2])).numpy())
# TODO(andresp): Fix exporting of loaded concrete functions as signatures.
@ -724,7 +724,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
input_signature=[tensor_spec.TensorSpec([None], dtypes.float32)])
def func(x):
return x ** 2.
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = func.get_concrete_function()
def _compute_gradient(function):
@ -744,7 +744,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
@def_function.function
def func(x, y):
return x * (y + 1.)
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = func.get_concrete_function(
tensor_spec.TensorSpec([], dtypes.float32),
tensor_spec.TensorSpec([], dtypes.float32))
@ -761,7 +761,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def func(*args):
x, y = args
return x * (y + 1.)
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = func.get_concrete_function(
tensor_spec.TensorSpec([], dtypes.float32, name="x"),
tensor_spec.TensorSpec([], dtypes.float32, name="y"))
@ -782,7 +782,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
capture.assign_sub(1)
vsave = variables.Variable(1)
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = func.get_concrete_function(vsave)
root.capture = capture
self.assertEqual(1, vsave.numpy())
@ -805,7 +805,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def func(v):
return v + 1
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.func = func
root.concrete_func = func.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.int32))
@ -817,7 +817,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual(2, imported.concrete_func(one).numpy())
def test_dict(self, cycles):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.variables = dict(a=variables.Variable(1.))
root.variables["b"] = variables.Variable(2.)
root.variables["c"] = 1
@ -832,7 +832,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual(100., imported.funcs["conc"]().numpy())
def test_list(self, cycles):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.variables = [variables.Variable(1.)]
root.variables.append(1)
root.variables.append(variables.Variable(3.))
@ -843,7 +843,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual(3, len(imported.variables))
def test_functions_list(self, cycles):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
v1 = variables.Variable(1.)
root.losses = [def_function.function(lambda: math_ops.reduce_sum(v1 ** 2))]
root.variables = [v1]
@ -865,7 +865,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def test_captured_constant(self, cycles):
const = array_ops.zeros([100])
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = def_function.function(lambda: const + 1.)
root.g = def_function.function(lambda: const + 2.)
self.assertAllClose(array_ops.ones([100]), root.f())
@ -885,7 +885,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def test_functions_accessed_once(self, cycles):
class Exported(tracking.AutoCheckpointable):
class Exported(tracking.AutoTrackable):
def __init__(self):
self._counter = 0
@ -905,7 +905,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual(1, exported.make_func().numpy())
def test_overwritten_signatures_error(self, cycles):
exported = tracking.AutoCheckpointable()
exported = tracking.AutoTrackable()
exported.f = def_function.function(lambda: constant_op.constant(1.))
imported = self.cycle(
exported, cycles,
@ -917,7 +917,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def test_signature_loading(self, cycles):
class Exported(tracking.AutoCheckpointable):
class Exported(tracking.AutoTrackable):
def __init__(self):
self.v = variables.Variable(3.)
@ -961,7 +961,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
return def_function.function(input_signature=signature)(
lambda x: table.lookup(x)) # pylint: disable=unnecessary-lambda
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.table1 = table1
root.lookup1 = _make_lookup_function(table1)
root.table2 = table2
@ -999,7 +999,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
original_fullargspec = tf_inspect.getfullargspec(f)
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = def_function.function(f)
imported = self.cycle(root, cycles)
@ -1010,7 +1010,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
class SingleCycleTests(test.TestCase, parameterized.TestCase):
def test_load_with_tags(self):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
path = tempfile.mkdtemp(prefix=self.get_temp_dir())
save.save(root, path)
with self.assertRaises(ValueError):

View File

@ -27,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import signature_serialization
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.tracking import tracking
class _Initializer(tracking.TrackableResource):
@ -123,7 +123,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
self.restore_variables(wrapped, saver)
with wrapped.graph.as_default():
init_op = loader_impl.get_init_op(meta_graph_def)
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
if init_op is not None:
asset_feed_tensors = []
asset_paths = []
@ -141,6 +141,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
else:
root.asset_paths = []
signature_functions = self._extract_signatures(wrapped, meta_graph_def)
root.signatures = signature_serialization.create_signature_map(
signature_functions)
root.variables = list(wrapped.graph.variables)

View File

@ -31,7 +31,7 @@ class VersionedTypeRegistration(object):
Args:
object_factory: A callable which takes a SavedUserObject proto and returns
a checkpointable object. Dependencies are added later via `setter`.
a trackable object. Dependencies are added later via `setter`.
version: An integer, the producer version of this wrapper type. When
making incompatible changes to a wrapper, add a new
`VersionedTypeRegistration` with an incremented `version`. The most
@ -45,11 +45,11 @@ class VersionedTypeRegistration(object):
with this object. `min_consumer_version` should be set to the lowest
version number which can successfully load protos saved by this
object. If no matching registration is available on load, the object
will be revived with a generic checkpointable type.
will be revived with a generic trackable type.
`min_consumer_version` and `bad_consumers` are a blunt tool, and using
them will generally break forward compatibility: previous versions of
TensorFlow will revive newly saved objects as opaque checkpointable
TensorFlow will revive newly saved objects as opaque trackable
objects rather than wrapped objects. When updating wrappers, prefer
saving new information but preserving compatibility with previous
wrapper versions. They are, however, useful for ensuring that
@ -83,7 +83,7 @@ class VersionedTypeRegistration(object):
bad_consumers=self._bad_consumers))
def from_proto(self, proto):
"""Recreate a checkpointable object from a SavedUserObject proto."""
"""Recreate a trackable object from a SavedUserObject proto."""
return self._object_factory(proto)
def should_load(self, proto):
@ -111,7 +111,7 @@ def register_revived_type(identifier, predicate, versions):
Args:
identifier: A unique string identifying this class of objects.
predicate: A Boolean predicate for this registration. Takes a
checkpointable object as an argument. If True, `type_registration` may be
trackable object as an argument. If True, `type_registration` may be
used to save and restore the object.
versions: A list of `VersionedTypeRegistration` objects.
"""
@ -138,7 +138,7 @@ def register_revived_type(identifier, predicate, versions):
def serialize(obj):
"""Create a SavedUserObject from a checkpointable object."""
"""Create a SavedUserObject from a trackable object."""
for identifier in _TYPE_IDENTIFIERS:
predicate, versions = _REVIVED_TYPE_REGISTRY[identifier]
if predicate(obj):
@ -148,15 +148,15 @@ def serialize(obj):
def deserialize(proto):
"""Create a checkpointable object from a SavedUserObject proto.
"""Create a trackable object from a SavedUserObject proto.
Args:
proto: A SavedUserObject to deserialize.
Returns:
A tuple of (checkpointable, assignment_fn) where assignment_fn has the same
A tuple of (trackable, assignment_fn) where assignment_fn has the same
signature as setattr and should be used to add dependencies to
`checkpointable` when they are available.
`trackable` when they are available.
"""
_, type_registrations = _REVIVED_TYPE_REGISTRY.get(
proto.identifier, (None, None))

View File

@ -22,10 +22,10 @@ from tensorflow.core.framework import versions_pb2
from tensorflow.python.platform import test
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.tracking import tracking
class CustomTestClass(tracking.AutoCheckpointable):
class CustomTestClass(tracking.AutoTrackable):
def __init__(self, version):
self.version = version
@ -56,7 +56,7 @@ revived_types.register_revived_type(
class RegistrationMatchingTest(test.TestCase):
def test_save_typecheck(self):
self.assertIs(revived_types.serialize(tracking.AutoCheckpointable()), None)
self.assertIs(revived_types.serialize(tracking.AutoTrackable()), None)
def test_load_identifier_not_found(self):
nothing_matches = revived_types.deserialize(

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Exports a SavedModel from a Checkpointable Python object."""
"""Exports a SavedModel from a Trackable Python object."""
from __future__ import absolute_import
from __future__ import division
@ -46,12 +46,12 @@ from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import signature_serialization
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils_impl
from tensorflow.python.training.checkpointable import base
from tensorflow.python.training.checkpointable import graph_view
from tensorflow.python.training.checkpointable import object_identity
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
from tensorflow.python.training.saving import functional_saver
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import object_identity
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@ -97,13 +97,13 @@ class _AugmentedGraphView(graph_view.ObjectGraphView):
for name, dep in super(_AugmentedGraphView, self).list_dependencies(obj):
used_names.add(name)
if name in extra_dependencies:
yield base.CheckpointableReference(name, extra_dependencies[name])
yield base.TrackableReference(name, extra_dependencies[name])
else:
yield base.CheckpointableReference(name, dep)
yield base.TrackableReference(name, dep)
for name, dep in extra_dependencies.items():
if name in used_names:
continue
yield base.CheckpointableReference(name, dep)
yield base.TrackableReference(name, dep)
def list_functions(self, obj):
obj_functions = self._functions.get(obj, None)
@ -114,12 +114,12 @@ class _AugmentedGraphView(graph_view.ObjectGraphView):
class _SaveableView(object):
"""Provides a frozen view over a checkpointable root.
"""Provides a frozen view over a trackable root.
This class helps creating a single stable view over an object to save. The
saving code should access properties and functions via this class and not via
the original object as there are cases where an object construct their
checkpointable attributes and functions dynamically per call and will yield
trackable attributes and functions dynamically per call and will yield
different objects if invoked more than once.
Changes to the graph, for example adding objects, must happen in
@ -130,9 +130,9 @@ class _SaveableView(object):
def __init__(self, checkpoint_view):
self.checkpoint_view = checkpoint_view
checkpointable_objects, node_ids, slot_variables = (
trackable_objects, node_ids, slot_variables = (
self.checkpoint_view.objects_ids_and_slot_variables())
self.nodes = checkpointable_objects
self.nodes = trackable_objects
self.node_ids = node_ids
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
self.slot_variables = slot_variables
@ -544,7 +544,7 @@ def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions):
def _write_object_graph(saveable_view, export_dir, asset_file_def_index):
"""Save a SavedObjectGraph proto for `root`."""
# SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
# SavedObjectGraph is similar to the TrackableObjectGraph proto in the
# checkpoint. It will eventually go into the SavedModel.
proto = saved_object_graph_pb2.SavedObjectGraph()
saveable_view.fill_object_graph_proto(proto)
@ -603,7 +603,7 @@ def _write_object_proto(obj, proto, asset_file_def_index):
@tf_export("saved_model.save", v1=["saved_model.experimental.save"])
def save(obj, export_dir, signatures=None):
# pylint: disable=line-too-long
"""Exports the Checkpointable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md).
"""Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md).
Example usage:
@ -651,7 +651,7 @@ def save(obj, export_dir, signatures=None):
`.signatures` attribute. This is a reserved attribute: `tf.saved_model.save`
on an object with a custom `.signatures` attribute will raise an exception.
Since `tf.keras.Model` objects are also Checkpointable, this function can be
Since `tf.keras.Model` objects are also Trackable, this function can be
used to export Keras models. For example, exporting with a signature
specified:
@ -737,7 +737,7 @@ def save(obj, export_dir, signatures=None):
prior to the TensorFlow 2.0 release.
Args:
obj: A checkpointable object to export.
obj: A trackable object to export.
export_dir: A directory in which to write the SavedModel.
signatures: Optional, either a `tf.function` with an input signature
specified or the result of `f.get_concrete_function` on a
@ -750,7 +750,7 @@ def save(obj, export_dir, signatures=None):
`tf.saved_model.signature_constants` module.
Raises:
ValueError: If `obj` is not checkpointable.
ValueError: If `obj` is not trackable.
@compatibility(eager)
Not supported when graph building. From TensorFlow 1.x,
@ -771,9 +771,9 @@ def save(obj, export_dir, signatures=None):
"tf.enable_eager_execution() must run first when calling it from "
"TensorFlow 1.x.")
# pylint: enable=line-too-long
if not isinstance(obj, base.Checkpointable):
if not isinstance(obj, base.Trackable):
raise ValueError(
"Expected a Checkpointable object for export, got {}.".format(obj))
"Expected a Trackable object for export, got {}.".format(obj))
checkpoint_graph_view = _AugmentedGraphView(obj)
if signatures is None:
@ -799,7 +799,7 @@ def save(obj, export_dir, signatures=None):
# making a SavedModel proto and writing it directly.
saved_model = saved_model_pb2.SavedModel()
meta_graph_def = saved_model.meta_graphs.add()
object_saver = util.CheckpointableSaver(checkpoint_graph_view)
object_saver = util.TrackableSaver(checkpoint_graph_view)
asset_info, exported_graph = _fill_meta_graph_def(
meta_graph_def, saveable_view, signatures)
saved_model.saved_model_schema_version = (

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for checkpointable object SavedModel save."""
"""Tests for trackable object SavedModel save."""
from __future__ import absolute_import
from __future__ import division
@ -41,8 +41,8 @@ from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
from tensorflow.python.util import compat
@ -87,7 +87,7 @@ def _import_and_infer(
class SaveTest(test.TestCase):
def test_method_save_signature(self):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = def_function.function(
lambda x: 2. * x,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
@ -99,7 +99,7 @@ class SaveTest(test.TestCase):
_import_and_infer(save_dir, {"x": 1.}))
def test_method_save_concrete(self):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = def_function.function(
lambda z: {"out": 2. * z})
root.f(constant_op.constant(1.))
@ -115,7 +115,7 @@ class SaveTest(test.TestCase):
save_dir, {"z": 1.}, signature_key="non_default_key"))
def test_non_concrete_error(self):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = def_function.function(lambda x: 2. * x)
root.f(constant_op.constant(1.))
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
@ -124,7 +124,7 @@ class SaveTest(test.TestCase):
save.save(root, save_dir, root.f)
def test_captures_unreachable_variable(self):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
unreachable_variable = variables.Variable([5.0, 2.0])
root.reachable_variable = variables.Variable([1.0, 3.0])
@ -143,7 +143,7 @@ class SaveTest(test.TestCase):
save.save(root, save_dir)
def test_nested_inputs(self):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = def_function.function(
lambda x: 2. * x[0],
input_signature=([tensor_spec.TensorSpec(None, dtypes.float32),
@ -156,7 +156,7 @@ class SaveTest(test.TestCase):
root.f.get_concrete_function()
def test_nested_outputs(self):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x)))
root.f(constant_op.constant(1.))
to_save = root.f.get_concrete_function(constant_op.constant(1.))
@ -177,7 +177,7 @@ class SaveTest(test.TestCase):
save.save(root, save_dir, to_save)
def test_variable(self):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.v1 = variables.Variable(3.)
root.v2 = variables.Variable(2.)
root.f = def_function.function(
@ -214,7 +214,7 @@ class SaveTest(test.TestCase):
{"x": [[3., 4.]], "y": [2.]}))
def test_single_function_default_signature(self):
model = tracking.AutoCheckpointable()
model = tracking.AutoTrackable()
model.f = def_function.function(lambda: 3., input_signature=())
model.f()
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
@ -223,7 +223,7 @@ class SaveTest(test.TestCase):
_import_and_infer(save_dir, {}))
def test_single_function_no_signature(self):
model = tracking.AutoCheckpointable()
model = tracking.AutoTrackable()
model.f = def_function.function(lambda: 3.)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(model, save_dir)
@ -322,7 +322,7 @@ class AssetTests(test.TestCase):
f.write("alpha\nbeta\ngamma\n")
def test_asset_path_returned(self):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.path = tracking.TrackableAsset(self._vocab_path)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
root.get_asset = def_function.function(lambda: root.path.asset_path)
@ -362,7 +362,7 @@ class AssetTests(test.TestCase):
_import_and_infer(second_dir, {"keys": ["gamma", "beta"]}))
def test_unused_asset(self):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.f = def_function.function(
lambda x: 2. * x,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])

View File

@ -1,6 +1,6 @@
syntax = "proto3";
import "tensorflow/core/protobuf/checkpointable_object_graph.proto";
import "tensorflow/core/protobuf/trackable_object_graph.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
import "tensorflow/core/framework/versions.proto";
@ -14,9 +14,9 @@ package tensorflow;
// describes the directed graph of Python objects (or equivalent in other
// languages) that make up a model, with nodes[0] at the root.
// SavedObjectGraph shares some structure with CheckpointableObjectGraph, but
// SavedObjectGraph shares some structure with TrackableObjectGraph, but
// ObjectGraph belongs to the SavedModel and contains pointers to functions and
// type information, while CheckpointableObjectGraph lives in the checkpoint and
// type information, while TrackableObjectGraph lives in the checkpoint and
// contains pointers only to variable values.
// NOTE: This protocol buffer format is experimental and subject to change.
@ -38,10 +38,9 @@ message SavedObject {
// graph.
//
// Note: only valid if kind == "object".
repeated CheckpointableObjectGraph.CheckpointableObject.ObjectReference
children = 1;
repeated TrackableObjectGraph.TrackableObject.ObjectReference children = 1;
// Removed when forking from CheckpointableObjectGraph.
// Removed when forking from TrackableObjectGraph.
reserved "attributes";
reserved 2;
@ -50,7 +49,7 @@ message SavedObject {
// depend on the others directly.
//
// Note: only valid if kind == "object".
repeated CheckpointableObjectGraph.CheckpointableObject.SlotVariableReference
repeated TrackableObjectGraph.TrackableObject.SlotVariableReference
slot_variables = 3;
oneof kind {

View File

@ -26,7 +26,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training.checkpointable import base
from tensorflow.python.training.tracking import base
from tensorflow.python.util import compat
from tensorflow.python.util import nest
@ -170,7 +170,7 @@ def _normalize_outputs(outputs, function_name, signature_key):
# saved if they contain a _SignatureMap. A ".signatures" attribute containing
# any other type (e.g. a regular dict) will raise an exception asking the user
# to first "del obj.signatures" if they want it overwritten.
class _SignatureMap(collections.Mapping, base.Checkpointable):
class _SignatureMap(collections.Mapping, base.Trackable):
"""A collection of SavedModel signatures."""
def __init__(self):
@ -205,7 +205,7 @@ revived_types.register_revived_type(
"signature_map",
lambda obj: isinstance(obj, _SignatureMap),
versions=[revived_types.VersionedTypeRegistration(
# Standard dependencies are enough to reconstruct the checkpointable
# Standard dependencies are enough to reconstruct the trackable
# items in dictionaries, so we don't need to save any extra information.
object_factory=lambda proto: _SignatureMap(),
version=1,

View File

@ -38,7 +38,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.training.checkpointable import util
from tensorflow.python.training.tracking import util
class LatestCheckpointWithRelativePaths(test.TestCase):

View File

@ -41,8 +41,8 @@ from tensorflow.python.training import queue_runner
from tensorflow.python.training import saver as training_saver
from tensorflow.python.training import session_manager as sm
from tensorflow.python.training import session_run_hook
from tensorflow.python.training.checkpointable import graph_view
from tensorflow.python.training.checkpointable import util as checkpointable_util
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import util as trackable_util
from tensorflow.python.util import function_utils
from tensorflow.python.util.tf_export import tf_export
@ -228,7 +228,7 @@ class Scaffold(object):
if self._saver is None:
self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access
# pylint: enable=g-long-lambda
if isinstance(self._saver, checkpointable_util.Checkpoint):
if isinstance(self._saver, trackable_util.Checkpoint):
self._saver = training_saver.Saver(
var_list=graph_view.ObjectGraphView(
self._saver).frozen_saveable_objects(),

View File

@ -39,7 +39,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import slot_creator
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@ -214,10 +214,10 @@ def _get_processor(v):
@tf_export(v1=["train.Optimizer"])
class Optimizer(
# Optimizers inherit from CheckpointableBase rather than Checkpointable
# Optimizers inherit from Trackable rather than AutoTrackable
# since they do most of their dependency management themselves (slot
# variables are special-cased, and non-slot variables are keyed to graphs).
checkpointable.Checkpointable):
trackable.Trackable):
"""Base class for optimizers.
This class defines the API to add Ops to train a model. You never use this
@ -333,9 +333,9 @@ class Optimizer(
# ... }
self._slots = {}
self._non_slot_dict = {}
# For implementing Checkpointable. Stores information about how to restore
# For implementing Trackable. Stores information about how to restore
# slot variables which have not yet been created
# (checkpointable._CheckpointPosition objects).
# (trackable._CheckpointPosition objects).
# {slot_name :
# {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
# ... }
@ -796,7 +796,7 @@ class Optimizer(
key = (name, graph)
v = self._non_slot_dict.get(key, None)
if v is None:
self._maybe_initialize_checkpointable()
self._maybe_initialize_trackable()
distribution_strategy = distribute_ctx.get_strategy()
with distribution_strategy.extended.colocate_vars_with(colocate_with):
if eager:
@ -809,19 +809,19 @@ class Optimizer(
use_resource=resource_variable_ops.is_resource_variable(
colocate_with))
# Restore this variable by name if necessary, but don't add a
# Checkpointable dependency. Optimizers return the current graph's
# Trackable dependency. Optimizers return the current graph's
# non-slot variables from _checkpoint_dependencies explicitly rather
# than unconditionally adding dependencies (since there may be multiple
# non-slot variables with the same name in different graphs, trying to
# save all of them would result in errors).
self._handle_deferred_dependencies(name=name, checkpointable=v)
self._handle_deferred_dependencies(name=name, trackable=v)
self._non_slot_dict[key] = v
return v
@property
def _checkpoint_dependencies(self):
"""From Checkpointable. Gather graph-specific non-slot variables to save."""
"""From Trackable. Gather graph-specific non-slot variables to save."""
current_graph_non_slot_variables = []
current_graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
for (name, _), variable_object in sorted(self._non_slot_dict.items(),
@ -829,13 +829,13 @@ class Optimizer(
key=lambda item: item[0][0]):
if variable_object._graph_key == current_graph_key: # pylint: disable=protected-access
current_graph_non_slot_variables.append(
checkpointable.CheckpointableReference(
trackable.TrackableReference(
name=name, ref=variable_object))
return (super(Optimizer, self)._checkpoint_dependencies
+ current_graph_non_slot_variables)
def _lookup_dependency(self, name):
"""From Checkpointable. Find a non-slot variable in the current graph."""
"""From Trackable. Find a non-slot variable in the current graph."""
unconditional = super(Optimizer, self)._lookup_dependency(name)
if unconditional is not None:
return unconditional
@ -1140,7 +1140,7 @@ class Optimizer(
return named_slots[_var_key(var)]
# --------------
# For implementing the Checkpointable interface.
# For implementing the Trackable interface.
# --------------
def _restore_slot_variable(self, slot_name, variable, slot_variable):
@ -1171,8 +1171,8 @@ class Optimizer(
slot variable needs to be restored).
Args:
slot_variable_position: A `checkpointable._CheckpointPosition` object
indicating the slot variable `Checkpointable` object to be restored.
slot_variable_position: A `trackable._CheckpointPosition` object
indicating the slot variable `Trackable` object to be restored.
slot_name: The name of this `Optimizer`'s slot to restore into.
variable: The variable object this slot is being created for.
"""
@ -1190,7 +1190,7 @@ class Optimizer(
# (aside from double initialization), and makes variable creator scopes
# behave the same way they do when graph building.
and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
initializer = checkpointable.CheckpointInitialValue(
initializer = trackable.CheckpointInitialValue(
checkpoint_position=slot_variable_position)
slot_variable = self._get_or_make_slot(
var=variable,

View File

@ -17,7 +17,7 @@
"""Save and restore variables.
Symbols in this file are deprecated. See replacements in
tensorflow/python/training/checkpointable and tensorflow/python/training/saving.
tensorflow/python/training/trackable and tensorflow/python/training/saving.
"""
from __future__ import absolute_import
from __future__ import division
@ -29,10 +29,9 @@ import time
import uuid
import numpy as np
from tensorflow.core.protobuf import checkpointable_object_graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.core.protobuf import trackable_object_graph_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.eager import context
@ -51,9 +50,9 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@ -1605,9 +1604,9 @@ def object_graph_key_mapping(checkpoint_path):
"""
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
object_graph_string = reader.get_tensor(
checkpointable.OBJECT_GRAPH_PROTO_KEY)
trackable.OBJECT_GRAPH_PROTO_KEY)
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
trackable_object_graph_pb2.TrackableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string)
names_to_keys = {}
for node in object_graph_proto.nodes:

View File

@ -73,9 +73,9 @@ from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training import saver_test_utils
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base as checkpointable_base
from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.training.tracking import base as trackable_base
from tensorflow.python.training.tracking import tracking as trackable_tracking
from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.util import compat
@ -2775,15 +2775,15 @@ class ScopedGraphTest(test.TestCase):
self.assertEqual(2.0, self.evaluate(var_dict2["variable2:0"]))
class _OwnsAVariableSimple(checkpointable_base.Checkpointable):
"""A Checkpointable object which can be saved using a tf.train.Saver."""
class _OwnsAVariableSimple(trackable_base.Trackable):
"""A Trackable object which can be saved using a tf.train.Saver."""
def __init__(self):
self.non_dep_variable = variable_scope.get_variable(
name="non_dep_variable", initializer=6., use_resource=True)
def _gather_saveables_for_checkpoint(self):
return {checkpointable_base.VARIABLE_VALUE_KEY: self.non_dep_variable}
return {trackable_base.VARIABLE_VALUE_KEY: self.non_dep_variable}
# The Saver sorts by name before parsing, so we need a name property.
@property
@ -2808,8 +2808,8 @@ class _MirroringSaveable(
self._mirrored_variable.assign(tensor))
class _OwnsMirroredVariables(checkpointable_base.Checkpointable):
"""A Checkpointable object which returns a more complex SaveableObject."""
class _OwnsMirroredVariables(trackable_base.Trackable):
"""A Trackable object which returns a more complex SaveableObject."""
def __init__(self):
self.non_dep_variable = variable_scope.get_variable(
@ -2823,7 +2823,7 @@ class _OwnsMirroredVariables(checkpointable_base.Checkpointable):
primary_variable=self.non_dep_variable,
mirrored_variable=self.mirrored,
name=name)
return {checkpointable_base.VARIABLE_VALUE_KEY: _saveable_factory}
return {trackable_base.VARIABLE_VALUE_KEY: _saveable_factory}
# The Saver sorts by name before parsing, so we need a name property.
@property
@ -2831,11 +2831,11 @@ class _OwnsMirroredVariables(checkpointable_base.Checkpointable):
return self.non_dep_variable.name
class NonLayerCheckpointable(checkpointable_tracking.AutoCheckpointable):
class NonLayerTrackable(trackable_tracking.AutoTrackable):
def __init__(self):
super(NonLayerCheckpointable, self).__init__()
self.a_variable = checkpointable_utils.add_variable(
super(NonLayerTrackable, self).__init__()
self.a_variable = trackable_utils.add_variable(
self, name="a_variable", shape=[])
@ -2846,19 +2846,19 @@ class MyModel(training.Model):
super(MyModel, self).__init__()
self._named_dense = core.Dense(1, use_bias=True)
self._second = core.Dense(1, use_bias=False)
# We can still track Checkpointables which aren't Layers.
self._non_layer = NonLayerCheckpointable()
# We can still track Trackables which aren't Layers.
self._non_layer = NonLayerTrackable()
def call(self, values):
ret = self._second(self._named_dense(values))
return ret
class CheckpointableCompatibilityTests(test.TestCase):
class TrackableCompatibilityTests(test.TestCase):
# TODO(allenl): Track down python3 reference cycles in these tests.
@test_util.run_in_graph_and_eager_modes
def testNotSaveableButIsCheckpointable(self):
def testNotSaveableButIsTrackable(self):
v = _OwnsAVariableSimple()
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
@ -2923,13 +2923,13 @@ class CheckpointableCompatibilityTests(test.TestCase):
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
optimizer_step = training_util.get_or_create_global_step()
root_checkpointable = checkpointable_utils.Checkpoint(
root_trackable = trackable_utils.Checkpoint(
optimizer=optimizer, model=model, optimizer_step=optimizer_step)
train_op = optimizer.minimize(
functools.partial(model, input_value),
global_step=optimizer_step)
self.evaluate(checkpointable_utils.gather_initializers(
root_checkpointable))
self.evaluate(trackable_utils.gather_initializers(
root_trackable))
self.evaluate(train_op)
# A regular variable, a slot variable, and a non-slot Optimizer variable
# with known values to check when loading.
@ -2938,24 +2938,24 @@ class CheckpointableCompatibilityTests(test.TestCase):
var=model._named_dense.bias, name="m").assign([2.]))
beta1_power, _ = optimizer._get_beta_accumulators()
self.evaluate(beta1_power.assign(3.))
return root_checkpointable
return root_trackable
def _set_sentinels(self, root_checkpointable):
self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.]))
def _set_sentinels(self, root_trackable):
self.evaluate(root_trackable.model._named_dense.bias.assign([101.]))
self.evaluate(
root_checkpointable.optimizer.get_slot(
var=root_checkpointable.model._named_dense.bias, name="m")
root_trackable.optimizer.get_slot(
var=root_trackable.model._named_dense.bias, name="m")
.assign([102.]))
beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
beta1_power, _ = root_trackable.optimizer._get_beta_accumulators()
self.evaluate(beta1_power.assign(103.))
def _check_sentinels(self, root_checkpointable):
def _check_sentinels(self, root_trackable):
self.assertAllEqual(
[1.], self.evaluate(root_checkpointable.model._named_dense.bias))
[1.], self.evaluate(root_trackable.model._named_dense.bias))
self.assertAllEqual([2.], self.evaluate(
root_checkpointable.optimizer.get_slot(
var=root_checkpointable.model._named_dense.bias, name="m")))
beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
root_trackable.optimizer.get_slot(
var=root_trackable.model._named_dense.bias, name="m")))
beta1_power, _ = root_trackable.optimizer._get_beta_accumulators()
self.assertAllEqual(3., self.evaluate(beta1_power))
def testVariableNotFoundErrorRaised(self):
@ -3012,13 +3012,13 @@ class CheckpointableCompatibilityTests(test.TestCase):
save_graph = ops_lib.Graph()
with save_graph.as_default(), self.session(graph=save_graph) as sess:
root = self._initialized_model()
object_saver = checkpointable_utils.Checkpoint(root=root)
object_saver = trackable_utils.Checkpoint(root=root)
save_path = object_saver.save(file_prefix=checkpoint_prefix)
# An incompatible object-based checkpoint to check error messages
var = resource_variable_ops.ResourceVariable(1., name="a")
self.evaluate(var.initializer)
second_saver = checkpointable_utils.Checkpoint(v=var)
second_saver = trackable_utils.Checkpoint(v=var)
second_path = second_saver.save(file_prefix=os.path.join(
checkpoint_directory, "second"))
@ -3046,7 +3046,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
save_graph = ops_lib.Graph()
with save_graph.as_default(), self.session(graph=save_graph):
root = self._initialized_model()
object_saver = checkpointable_utils.Checkpoint(root=root)
object_saver = trackable_utils.Checkpoint(root=root)
save_path = object_saver.save(file_prefix=checkpoint_prefix)
with context.eager_mode():

View File

@ -49,7 +49,7 @@ py_library(
deps = [
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:variables",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/tracking:base",
"@six_archive//:six",
],
)

View File

@ -26,8 +26,8 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.tracking import base as trackable
# Op names which identify variable reads which should be saved.
@ -137,7 +137,7 @@ def saveable_objects_for_op(op, name):
if not isinstance(name, six.string_types):
raise TypeError(
"names_to_saveables must be a dict mapping string names to "
"checkpointable operations. Name is not a string: %s" % name)
"trackable operations. Name is not a string: %s" % name)
if isinstance(op, saveable_object.SaveableObject):
yield op
elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
@ -165,11 +165,11 @@ def saveable_objects_for_op(op, name):
yield ResourceVariableSaveable(
variable, variable._save_slice_info.spec, name)
# pylint: enable=protected-access
elif isinstance(op, checkpointable.Checkpointable) and not isinstance(
elif isinstance(op, trackable.Trackable) and not isinstance(
op, variables.Variable):
# pylint: disable=protected-access
for attr, factory in op._gather_saveables_for_checkpoint().items():
if attr == checkpointable.VARIABLE_VALUE_KEY:
if attr == trackable.VARIABLE_VALUE_KEY:
# Keep original name for classes masquerading as variables.
full_name = name
else:
@ -250,13 +250,13 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
names_to_saveables[name].append(var)
else:
names_to_saveables[name] = [var]
elif (isinstance(var, checkpointable.Checkpointable)
elif (isinstance(var, trackable.Trackable)
and not isinstance(var, variables.Variable)):
checkpointable_saveables = [
trackable_saveables = [
(factory() if callable(factory) else factory)
for factory in var._gather_saveables_for_checkpoint().values()]
names_to_saveables.update(
op_list_to_dict(checkpointable_saveables))
op_list_to_dict(trackable_saveables))
else:
# Variables (reference and resource) have an _in_graph_mode property
# indicating whether they were created in a graph building context. We
@ -326,7 +326,7 @@ def validate_and_slice_inputs(names_to_saveables):
Raises:
TypeError: If any of the keys are not strings or any of the
values are not one of Tensor or Variable or a checkpointable operation.
values are not one of Tensor or Variable or a trackable operation.
ValueError: If the same operation is given in more than one value
(this also applies to slices of SlicedVariables).
"""

View File

@ -44,18 +44,18 @@ OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
# A key indicating a variable's value in an object's checkpointed Tensors
# (Checkpointable._gather_saveables_for_checkpoint). If this is the only key and
# (Trackable._gather_saveables_for_checkpoint). If this is the only key and
# the object has no dependencies, then its value may be restored on object
# creation (avoiding double assignment when executing eagerly).
VARIABLE_VALUE_KEY = "VARIABLE_VALUE"
OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"
CheckpointableReference = collections.namedtuple(
"CheckpointableReference",
TrackableReference = collections.namedtuple(
"TrackableReference",
[
# The local name for this dependency.
"name",
# The Checkpointable object being referenced.
# The Trackable object being referenced.
"ref"
])
@ -195,26 +195,26 @@ class CheckpointPosition(object):
Args:
checkpoint: A _CheckpointRestoreCoordinator object.
proto_id: The index of this object in CheckpointableObjectGraph.nodes.
proto_id: The index of this object in TrackableObjectGraph.nodes.
"""
self._checkpoint = checkpoint
self._proto_id = proto_id
def restore(self, checkpointable):
"""Restore this value into `checkpointable`."""
def restore(self, trackable):
"""Restore this value into `trackable`."""
with ops.init_scope():
if self.bind_object(checkpointable):
if self.bind_object(trackable):
# This object's correspondence with a checkpointed object is new, so
# process deferred restorations for it and its dependencies.
restore_ops = checkpointable._restore_from_checkpoint_position(self) # pylint: disable=protected-access
restore_ops = trackable._restore_from_checkpoint_position(self) # pylint: disable=protected-access
if restore_ops:
self._checkpoint.new_restore_ops(restore_ops)
def bind_object(self, checkpointable):
def bind_object(self, trackable):
"""Set a checkpoint<->object correspondence and process slot variables.
Args:
checkpointable: The object to record a correspondence for.
trackable: The object to record a correspondence for.
Returns:
True if this is a new assignment, False if this object has already been
mapped to a checkpointed `Object` proto.
@ -222,13 +222,13 @@ class CheckpointPosition(object):
AssertionError: If another object is already bound to the `Object` proto.
"""
checkpoint = self.checkpoint
checkpoint.all_python_objects.add(checkpointable)
checkpoint.all_python_objects.add(trackable)
current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
if current_assignment is None:
checkpoint.object_by_proto_id[self._proto_id] = checkpointable
checkpoint.object_by_proto_id[self._proto_id] = trackable
for deferred_slot_restoration in (
checkpoint.deferred_slot_restorations.pop(self._proto_id, ())):
checkpointable._create_or_restore_slot_variable( # pylint: disable=protected-access
trackable._create_or_restore_slot_variable( # pylint: disable=protected-access
slot_variable_position=CheckpointPosition(
checkpoint=checkpoint,
proto_id=deferred_slot_restoration.slot_variable_id),
@ -244,7 +244,7 @@ class CheckpointPosition(object):
checkpoint.deferred_slot_restorations.setdefault(
slot_restoration.optimizer_id, []).append(
_DeferredSlotVariableRestoration(
original_variable=checkpointable,
original_variable=trackable,
slot_variable_id=slot_restoration.slot_variable_id,
slot_name=slot_restoration.slot_name))
else:
@ -252,7 +252,7 @@ class CheckpointPosition(object):
slot_variable_position=CheckpointPosition(
checkpoint=checkpoint,
proto_id=slot_restoration.slot_variable_id),
variable=checkpointable,
variable=trackable,
slot_name=slot_restoration.slot_name)
return True # New assignment
else:
@ -260,14 +260,14 @@ class CheckpointPosition(object):
# we don't need to do anything besides check that the mapping is
# consistent (if the dependency DAG is not a tree then there are
# multiple paths to the same object).
if current_assignment is not checkpointable:
if current_assignment is not trackable:
logging.warning(
("Inconsistent references when loading the checkpoint into this "
"object graph. Either the Checkpointable object references in the "
"object graph. Either the Trackable object references in the "
"Python program have changed in an incompatible way, or the "
"checkpoint was generated in an incompatible program.\n\nTwo "
"checkpoint references resolved to different objects (%s and %s).")
% (current_assignment, checkpointable))
% (current_assignment, trackable))
return False # Not a new assignment
def is_simple_variable(self):
@ -306,7 +306,7 @@ class CheckpointPosition(object):
def _gather_ops_or_named_saveables(self):
"""Looks up or creates SaveableObjects which don't have cached ops."""
saveables = self.checkpointable._gather_saveables_for_checkpoint() # pylint: disable=protected-access
saveables = self.trackable._gather_saveables_for_checkpoint() # pylint: disable=protected-access
# Name saveables based on the name this object had when it was checkpointed.
named_saveables = {}
python_saveables = []
@ -334,7 +334,7 @@ class CheckpointPosition(object):
# attribute, we can re-use it to avoid re-creating some ops when graph
# building.
saveable_list = saveables_cache.get(
self.checkpointable, {}).get(serialized_tensor.name, (None,))
self.trackable, {}).get(serialized_tensor.name, (None,))
if len(saveable_list) == 1:
# Almost every attribute will have exactly one SaveableObject.
saveable, = saveable_list
@ -348,7 +348,7 @@ class CheckpointPosition(object):
# the SaveableObject.
if serialized_tensor.checkpoint_key not in saveable.name:
saveable = None
del saveables_cache[self.checkpointable]
del saveables_cache[self.trackable]
break
if saveable is None:
# If there was no cached SaveableObject, we should check if the Python
@ -361,7 +361,7 @@ class CheckpointPosition(object):
# checkpoint was loaded.
if not serialized_tensor.optional_restore:
self._checkpoint.unused_attributes.setdefault(
self.checkpointable, []).append(serialized_tensor.name)
self.trackable, []).append(serialized_tensor.name)
continue
if callable(saveable_factory):
saveable = saveable_factory(name=serialized_tensor.checkpoint_key)
@ -369,7 +369,7 @@ class CheckpointPosition(object):
saveable = saveable_factory
if saveables_cache is not None:
saveables_cache.setdefault(
self.checkpointable, {})[serialized_tensor.name] = [saveable]
self.trackable, {})[serialized_tensor.name] = [saveable]
if isinstance(saveable, PythonStateSaveable):
python_saveables.append(saveable)
else:
@ -379,7 +379,7 @@ class CheckpointPosition(object):
def restore_ops(self):
"""Create or fetch restore ops for this object's attributes.
Requires that the `Checkpointable` Python object has been bound to an object
Requires that the `Trackable` Python object has been bound to an object
ID in the checkpoint.
Returns:
@ -398,7 +398,7 @@ class CheckpointPosition(object):
return self._checkpoint
@property
def checkpointable(self):
def trackable(self):
return self._checkpoint.object_by_proto_id[self._proto_id]
@property
@ -436,11 +436,11 @@ _SlotVariableRestoration = collections.namedtuple(
def no_automatic_dependency_tracking(method):
"""Disables automatic dependency tracking on attribute assignment.
Use to decorate any method of a Checkpointable object. Attribute assignment in
Use to decorate any method of a Trackable object. Attribute assignment in
that method will not add dependencies (also respected in Model). Harmless if
used in a class which does not do automatic dependency tracking (which means
it's safe to use in base classes which may have subclasses which also inherit
from Checkpointable).
from Trackable).
Args:
method: The method to decorate.
@ -461,37 +461,37 @@ def no_automatic_dependency_tracking(method):
target=method, decorator_func=_method_wrapper)
class Checkpointable(object):
"""Base class for `Checkpointable` objects without automatic dependencies.
class Trackable(object):
"""Base class for `Trackable` objects without automatic dependencies.
This class has no __setattr__ override for performance reasons. Dependencies
must be added explicitly. Unless attribute assignment is performance-critical,
use `AutoCheckpointable` instead. Use `Checkpointable` for `isinstance`
use `AutoTrackable` instead. Use `Trackable` for `isinstance`
checks.
"""
# Checkpointable does not do automatic dependency tracking, but uses the
# Trackable does not do automatic dependency tracking, but uses the
# no_automatic_dependency_tracking decorator so it can avoid adding
# dependencies if a subclass is Checkpointable / inherits from Model (both of
# dependencies if a subclass is Trackable / inherits from Model (both of
# which have __setattr__ overrides).
@no_automatic_dependency_tracking
def _maybe_initialize_checkpointable(self):
def _maybe_initialize_trackable(self):
"""Initialize dependency management.
Not __init__, since most objects will forget to call it.
"""
if hasattr(self, "_unconditional_checkpoint_dependencies"):
# __init__ already called. This check means that we don't need
# Checkpointable.__init__() in the constructor of every TensorFlow object.
# Trackable.__init__() in the constructor of every TensorFlow object.
return
# A list of CheckpointableReference objects. Some classes implementing
# `Checkpointable`, notably `Optimizer`s, may override the
# A list of TrackableReference objects. Some classes implementing
# `Trackable`, notably `Optimizer`s, may override the
# _checkpoint_dependencies property with conditional dependencies
# (e.g. based on the current graph when saving).
self._unconditional_checkpoint_dependencies = []
# Maps names -> Checkpointable objects
# Maps names -> Trackable objects
self._unconditional_dependency_names = {}
# Restorations for other Checkpointable objects on which this object may
# Restorations for other Trackable objects on which this object may
# eventually depend. Maps local name -> CheckpointPosition list. Optimizers
# tack on conditional dependencies, and so need separate management of
# deferred dependencies too.
@ -530,8 +530,8 @@ class Checkpointable(object):
May be overridden to include conditional dependencies.
Returns:
A list of `CheckpointableReference` objects indicating named
`Checkpointable` dependencies which should be saved along with this
A list of `TrackableReference` objects indicating named
`Trackable` dependencies which should be saved along with this
object.
"""
return self._unconditional_checkpoint_dependencies
@ -540,7 +540,7 @@ class Checkpointable(object):
def _deferred_dependencies(self):
"""A dictionary with deferred dependencies.
Stores restorations for other Checkpointable objects on which this object
Stores restorations for other Trackable objects on which this object
may eventually depend. May be overridden by sub-classes (e.g. Optimizers use
conditional dependencies based the current graph, and so need separate
management of deferred dependencies too).
@ -559,7 +559,7 @@ class Checkpointable(object):
Args:
name: The local name of the dependency.
Returns:
A `Checkpointable` object, or `None` if no dependency by this name was
A `Trackable` object, or `None` if no dependency by this name was
found.
"""
return self._unconditional_dependency_names.get(name, None)
@ -568,9 +568,9 @@ class Checkpointable(object):
self, name, shape=None, dtype=dtypes.float32,
initializer=None, getter=None, overwrite=False,
**kwargs_for_getter):
"""Restore-on-create for a variable be saved with this `Checkpointable`.
"""Restore-on-create for a variable be saved with this `Trackable`.
If the user has requested that this object or another `Checkpointable` which
If the user has requested that this object or another `Trackable` which
depends on this object be restored from a checkpoint (deferred loading
before variable object creation), `initializer` may be ignored and the value
from the checkpoint used instead.
@ -592,7 +592,7 @@ class Checkpointable(object):
Raises:
ValueError: If the variable name is not unique.
"""
self._maybe_initialize_checkpointable()
self._maybe_initialize_trackable()
with ops.init_scope():
if context.executing_eagerly():
# If this is a variable with a single Tensor stored in the checkpoint,
@ -608,11 +608,11 @@ class Checkpointable(object):
isinstance(initializer, CheckpointInitialValue)
and (initializer.restore_uid
> checkpoint_initializer.restore_uid))):
# If multiple Checkpointable objects are "creating" the same variable
# If multiple Trackable objects are "creating" the same variable
# via the magic of custom getters, the one with the highest restore UID
# (the one called last) has to make the final initializer. If another
# custom getter interrupts this process by overwriting the initializer,
# then we'll catch that when we call _track_checkpointable. So this is
# then we'll catch that when we call _track_trackable. So this is
# "best effort" to set the initializer with the highest restore UID.
initializer = checkpoint_initializer
shape = None
@ -624,12 +624,12 @@ class Checkpointable(object):
# assign again. It will add this variable to our dependencies, and if there
# is a non-trivial restoration queued, it will handle that. This also
# handles slot variables.
if not overwrite or isinstance(new_variable, Checkpointable):
return self._track_checkpointable(new_variable, name=name,
overwrite=overwrite)
if not overwrite or isinstance(new_variable, Trackable):
return self._track_trackable(new_variable, name=name,
overwrite=overwrite)
else:
# TODO(allenl): Some variable types are not yet supported. Remove this
# fallback once all get_variable() return types are Checkpointable.
# fallback once all get_variable() return types are Trackable.
return new_variable
def _preload_simple_restoration(self, name, shape):
@ -668,46 +668,46 @@ class Checkpointable(object):
return CheckpointInitialValue(
checkpoint_position=checkpoint_position, shape=shape)
def _track_checkpointable(self, checkpointable, name, overwrite=False):
"""Declare a dependency on another `Checkpointable` object.
def _track_trackable(self, trackable, name, overwrite=False):
"""Declare a dependency on another `Trackable` object.
Indicates that checkpoints for this object should include variables from
`checkpointable`.
`trackable`.
Variables in a checkpoint are mapped to `Checkpointable`s based on the names
Variables in a checkpoint are mapped to `Trackable`s based on the names
provided when the checkpoint was written. To avoid breaking existing
checkpoints when modifying a class, neither variable names nor dependency
names (the names passed to `_track_checkpointable`) may change.
names (the names passed to `_track_trackable`) may change.
Args:
checkpointable: A `Checkpointable` which this object depends on.
name: A local name for `checkpointable`, used for loading checkpoints into
trackable: A `Trackable` which this object depends on.
name: A local name for `trackable`, used for loading checkpoints into
the correct objects.
overwrite: Boolean, whether silently replacing dependencies is OK. Used
for __setattr__, where throwing an error on attribute reassignment would
be inappropriate.
Returns:
`checkpointable`, for convenience when declaring a dependency and
`trackable`, for convenience when declaring a dependency and
assigning to a member variable in one statement.
Raises:
TypeError: If `checkpointable` does not inherit from `Checkpointable`.
TypeError: If `trackable` does not inherit from `Trackable`.
ValueError: If another object is already tracked by this name.
"""
self._maybe_initialize_checkpointable()
if not isinstance(checkpointable, Checkpointable):
self._maybe_initialize_trackable()
if not isinstance(trackable, Trackable):
raise TypeError(
("Checkpointable._track_checkpointable() passed type %s, not a "
"Checkpointable.") % (type(checkpointable),))
new_reference = CheckpointableReference(name=name, ref=checkpointable)
("Trackable._track_trackable() passed type %s, not a "
"Trackable.") % (type(trackable),))
new_reference = TrackableReference(name=name, ref=trackable)
current_object = self._lookup_dependency(name)
if (current_object is not None
and current_object is not checkpointable):
and current_object is not trackable):
if not overwrite:
raise ValueError(
("Called Checkpointable._track_checkpointable() with name='%s', "
"but a Checkpointable with this name is already declared as a "
("Called Trackable._track_trackable() with name='%s', "
"but a Trackable with this name is already declared as a "
"dependency. Names must be unique (or overwrite=True).") % (name,))
# This is a weird thing to do, but we're not going to stop people from
# using __setattr__.
@ -718,20 +718,20 @@ class Checkpointable(object):
elif current_object is None:
self._unconditional_checkpoint_dependencies.append(new_reference)
self._handle_deferred_dependencies(
name=name, checkpointable=checkpointable)
self._unconditional_dependency_names[name] = checkpointable
return checkpointable
name=name, trackable=trackable)
self._unconditional_dependency_names[name] = trackable
return trackable
def _handle_deferred_dependencies(self, name, checkpointable):
"""Pop and load any deferred checkpoint restores into `checkpointable`.
def _handle_deferred_dependencies(self, name, trackable):
"""Pop and load any deferred checkpoint restores into `trackable`.
This method does not add a new dependency on `checkpointable`, but it does
This method does not add a new dependency on `trackable`, but it does
check if any outstanding/deferred dependencies have been queued waiting for
this dependency to be added (matched based on `name`). If so,
`checkpointable` and its dependencies are restored. The restorations are
`trackable` and its dependencies are restored. The restorations are
considered fulfilled and so are deleted.
`_track_checkpointable` is more appropriate for adding a
`_track_trackable` is more appropriate for adding a
normal/unconditional dependency, and includes handling for deferred
restorations. This method allows objects such as `Optimizer` to use the same
restoration logic while managing conditional dependencies themselves, by
@ -741,25 +741,25 @@ class Checkpointable(object):
Args:
name: The name of the dependency within this object (`self`), used to
match `checkpointable` with values saved in a checkpoint.
checkpointable: The Checkpointable object to restore (inheriting from
`Checkpointable`).
match `trackable` with values saved in a checkpoint.
trackable: The Trackable object to restore (inheriting from
`Trackable`).
"""
self._maybe_initialize_checkpointable()
checkpointable._maybe_initialize_checkpointable() # pylint: disable=protected-access
self._maybe_initialize_trackable()
trackable._maybe_initialize_trackable() # pylint: disable=protected-access
deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
for checkpoint_position in sorted(
deferred_dependencies_list,
key=lambda restore: restore.checkpoint.restore_uid,
reverse=True):
checkpoint_position.restore(checkpointable)
checkpoint_position.restore(trackable)
# Pass on any name-based restores queued in this object.
for name_based_restore in sorted(
self._name_based_restores,
key=lambda checkpoint: checkpoint.restore_uid,
reverse=True):
checkpointable._name_based_attribute_restore(name_based_restore) # pylint: disable=protected-access
trackable._name_based_attribute_restore(name_based_restore) # pylint: disable=protected-access
def _restore_from_checkpoint_position(self, checkpoint_position):
"""Restore this object and its dependencies (may be deferred)."""
@ -772,7 +772,7 @@ class Checkpointable(object):
while visit_queue:
current_position = visit_queue.popleft()
restore_ops.extend(nest.flatten(
current_position.checkpointable # pylint: disable=protected-access
current_position.trackable # pylint: disable=protected-access
._single_restoration_from_checkpoint_position(
checkpoint_position=current_position,
visit_queue=visit_queue)))
@ -781,7 +781,7 @@ class Checkpointable(object):
def _single_restoration_from_checkpoint_position(
self, checkpoint_position, visit_queue):
"""Restore this object, and either queue its dependencies or defer them."""
self._maybe_initialize_checkpointable()
self._maybe_initialize_trackable()
checkpoint = checkpoint_position.checkpoint
# If the UID of this restore is lower than our current update UID, we don't
# need to actually restore the object. However, we should pass the
@ -802,7 +802,7 @@ class Checkpointable(object):
self._deferred_dependencies.setdefault(child.local_name, []).append(
child_position)
else:
if child_position.bind_object(checkpointable=local_object):
if child_position.bind_object(trackable=local_object):
# This object's correspondence is new, so dependencies need to be
# visited. Delay doing it so that we get a breadth-first dependency
# resolution order (shallowest paths first). The caller is responsible
@ -818,7 +818,7 @@ class Checkpointable(object):
or variables easily converted to `SaveableObject`s (as in `tf.train.Saver`'s
`var_list` constructor argument).
`SaveableObjects` have a name set, which Checkpointable needs to generate
`SaveableObjects` have a name set, which Trackable needs to generate
itself. So rather than returning `SaveableObjects` directly, this method
should return a dictionary of callables which take `name` arguments and
return `SaveableObjects` with that name.
@ -861,10 +861,10 @@ class Checkpointable(object):
state_callback=_state_callback)}
def _list_functions_for_serialization(self):
"""Lists the functions of this checkpointable to serialize.
"""Lists the functions of this trackable to serialize.
Internal sub-classes can override this with specific logic. E.g.
`AutoCheckpointable` provides an implementation that returns the `attr`
`AutoTrackable` provides an implementation that returns the `attr`
that return functions.
Returns:

View File

@ -22,29 +22,29 @@ import os
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import base
from tensorflow.python.training.checkpointable import util
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import util
class InterfaceTests(test.TestCase):
def testOverwrite(self):
root = base.Checkpointable()
leaf = base.Checkpointable()
root._track_checkpointable(leaf, name="leaf")
root = base.Trackable()
leaf = base.Trackable()
root._track_trackable(leaf, name="leaf")
(current_name, current_dependency), = root._checkpoint_dependencies
self.assertIs(leaf, current_dependency)
self.assertEqual("leaf", current_name)
duplicate_name_dep = base.Checkpointable()
duplicate_name_dep = base.Trackable()
with self.assertRaises(ValueError):
root._track_checkpointable(duplicate_name_dep, name="leaf")
root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True)
root._track_trackable(duplicate_name_dep, name="leaf")
root._track_trackable(duplicate_name_dep, name="leaf", overwrite=True)
(current_name, current_dependency), = root._checkpoint_dependencies
self.assertIs(duplicate_name_dep, current_dependency)
self.assertEqual("leaf", current_name)
def testAddVariableOverwrite(self):
root = base.Checkpointable()
root = base.Trackable()
a = root._add_variable_with_custom_getter(
name="v", shape=[], getter=variable_scope.get_variable)
self.assertEqual([root, a], util.list_objects(root))
@ -61,15 +61,15 @@ class InterfaceTests(test.TestCase):
getter=variable_scope.get_variable)
def testAssertConsumedWithUnusedPythonState(self):
has_config = base.Checkpointable()
has_config = base.Trackable()
has_config.get_config = lambda: {}
saved = util.Checkpoint(obj=has_config)
save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt"))
restored = util.Checkpoint(obj=base.Checkpointable())
restored = util.Checkpoint(obj=base.Trackable())
restored.restore(save_path).assert_consumed()
def testAssertConsumedFailsWithUsedPythonState(self):
has_config = base.Checkpointable()
has_config = base.Trackable()
attributes = {
"foo_attr": functools.partial(
base.PythonStringStateSaveable,
@ -78,7 +78,7 @@ class InterfaceTests(test.TestCase):
has_config._gather_saveables_for_checkpoint = lambda: attributes
saved = util.Checkpoint(obj=has_config)
save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt"))
restored = util.Checkpoint(obj=base.Checkpointable())
restored = util.Checkpoint(obj=base.Trackable())
status = restored.restore(save_path)
with self.assertRaisesRegexp(AssertionError, "foo_attr"):
status.assert_consumed()

View File

@ -1,4 +1,4 @@
"""Checkpointable data structures."""
"""Trackable data structures."""
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -28,16 +28,16 @@ from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import revived_types
from tensorflow.python.training.checkpointable import base
from tensorflow.python.training.checkpointable import layer_utils
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import layer_utils
class NoDependency(object):
"""Allows attribute assignment to `Checkpointable` objects with no dependency.
"""Allows attribute assignment to `Trackable` objects with no dependency.
Example usage:
```python
obj = Checkpointable()
obj = Trackable()
obj.has_dependency = tf.Variable(0., name="dep")
obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
assert obj.no_dependency.name == "nodep:0"
@ -61,8 +61,8 @@ def _wrap_or_unwrap(value):
"""Wraps basic data structures, unwraps NoDependency objects."""
if isinstance(value, NoDependency):
return value.value
if isinstance(value, base.Checkpointable):
return value # Skip conversion for already checkpointable objects.
if isinstance(value, base.Trackable):
return value # Skip conversion for already trackable objects.
elif isinstance(value, dict):
return _DictWrapper(value)
elif isinstance(value, list):
@ -77,19 +77,19 @@ def _wrap_or_unwrap(value):
# come up with names. Dictionaries should look like lists.
def sticky_attribute_assignment(checkpointable, name, value):
def sticky_attribute_assignment(trackable, name, value):
"""Adds dependencies, generally called from __setattr__.
This behavior is shared between Checkpointable and Model.
This behavior is shared between Trackable and Model.
Respects NoDependency indicators, but otherwise makes checkpointable objects
Respects NoDependency indicators, but otherwise makes trackable objects
out of common data structures and tracks objects by their attribute names.
Args:
checkpointable: The object to add dependencies to (generally the one having
trackable: The object to add dependencies to (generally the one having
an attribute assigned).
name: The attribute name being assigned.
value: The value being assigned. Not necessarily a checkpointable object.
value: The value being assigned. Not necessarily a trackable object.
Returns:
The value which should be stored in the attribute (unwrapped from a
@ -102,18 +102,18 @@ def sticky_attribute_assignment(checkpointable, name, value):
value = _wrap_or_unwrap(value)
if not add_dependency:
return value
if isinstance(value, base.Checkpointable):
checkpointable._track_checkpointable( # pylint: disable=protected-access
if isinstance(value, base.Trackable):
trackable._track_trackable( # pylint: disable=protected-access
value, name=name,
# Allow the user to switch the Checkpointable which is tracked by this
# Allow the user to switch the Trackable which is tracked by this
# name, since assigning a new variable to an attribute has
# historically been fine (e.g. Adam did this).
overwrite=True)
return value
class CheckpointableDataStructure(base.Checkpointable):
"""Base class for data structures which contain checkpointable objects."""
class TrackableDataStructure(base.Trackable):
"""Base class for data structures which contain trackable objects."""
def __init__(self):
self.trainable = True
@ -122,14 +122,14 @@ class CheckpointableDataStructure(base.Checkpointable):
def _track_value(self, value, name):
"""Add a dependency on `value`."""
value = sticky_attribute_assignment(
checkpointable=self, value=value, name=name)
trackable=self, value=value, name=name)
if isinstance(value, variables.Variable):
self._extra_variables.append(value)
if not isinstance(value, base.Checkpointable):
if not isinstance(value, base.Trackable):
raise ValueError(
("Only checkpointable objects (such as Layers or Optimizers) may be "
("Only trackable objects (such as Layers or Optimizers) may be "
"stored in a List object. Got %s, which does not inherit from "
"Checkpointable.") % (value,))
"Trackable.") % (value,))
if hasattr(value, "_use_resource_variables"):
# In subclassed models, legacy layers (tf.layers) must always use
# resource variables.
@ -138,7 +138,7 @@ class CheckpointableDataStructure(base.Checkpointable):
@property
def _values(self):
"""An iterable/sequence which may contain checkpointable objects."""
"""An iterable/sequence which may contain trackable objects."""
raise NotImplementedError("Abstract method")
@property
@ -148,7 +148,7 @@ class CheckpointableDataStructure(base.Checkpointable):
# they're wrapping if out of sync.
collected = []
for obj in self._values:
if (isinstance(obj, CheckpointableDataStructure)
if (isinstance(obj, TrackableDataStructure)
or layer_utils.is_layer(obj)
or layer_utils.has_weights(obj)):
collected.append(obj)
@ -215,19 +215,19 @@ class CheckpointableDataStructure(base.Checkpointable):
return id(self)
def __eq__(self, other):
# Similar to Tensors, checkpointable data structures use object-identity
# Similar to Tensors, trackable data structures use object-identity
# equality to support set/dict membership.
return self is other
class List(CheckpointableDataStructure, collections.Sequence):
"""An append-only sequence type which is checkpointable.
class List(TrackableDataStructure, collections.Sequence):
"""An append-only sequence type which is trackable.
Maintains checkpoint dependencies on its contents (which must also be
checkpointable), and forwards any `Layer` metadata such as updates and losses.
trackable), and forwards any `Layer` metadata such as updates and losses.
Note that `List` is purely a container. It lets a `tf.keras.Model` or
other checkpointable object know about its contents, but does not call any
other trackable object know about its contents, but does not call any
`Layer` instances which are added to it. To indicate a sequence of `Layer`
instances which should be called sequentially, use `tf.keras.Sequential`.
@ -248,7 +248,7 @@ class List(CheckpointableDataStructure, collections.Sequence):
return aggregation
```
This kind of wrapping is necessary because `Checkpointable` objects do not
This kind of wrapping is necessary because `Trackable` objects do not
(yet) deeply inspect regular Python data structures, so for example assigning
a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a
checkpoint dependency and does not add the `Layer` instance's weights to its
@ -284,12 +284,12 @@ class List(CheckpointableDataStructure, collections.Sequence):
return self
def append(self, value):
"""Add a new checkpointable value."""
"""Add a new trackable value."""
value = self._track_value(value, self._name_element(len(self._storage)))
self._storage.append(value)
def extend(self, values):
"""Add a sequence of checkpointable values."""
"""Add a sequence of trackable values."""
for value in values:
self.append(value)
@ -350,7 +350,7 @@ class _ListWrapper(List, collections.MutableSequence,
occupied, meaning both elements get the same names at different times) and
refuses to save.
On assignment to an attribute of a Model or Checkpointable object, Python
On assignment to an attribute of a Model or Trackable object, Python
lists are replaced with _ListWrapper. Wrapping a list in a
`tf.contrib.checkpoint.NoDependency` object prevents this.
"""
@ -410,7 +410,7 @@ class _ListWrapper(List, collections.MutableSequence,
if self._non_append_mutation:
raise ValueError(
("Unable to save the object %s (a list wrapper constructed to track "
"checkpointable TensorFlow objects). A list element was replaced "
"trackable TensorFlow objects). A list element was replaced "
"(__setitem__, __setslice__), deleted (__delitem__, __delslice__), "
"or moved (sort). In order to support restoration on object "
"creation, tracking is exclusively for append-only data structures."
@ -420,7 +420,7 @@ class _ListWrapper(List, collections.MutableSequence,
if self._external_modification:
raise ValueError(
("Unable to save the object %s (a list wrapper constructed to track "
"checkpointable TensorFlow objects). The wrapped list was modified "
"trackable TensorFlow objects). The wrapped list was modified "
"outside the wrapper (its final value was %s, its value when a "
"checkpoint dependency was added was %s), which breaks restoration "
"on object creation.\n\nIf you don't need this list checkpointed, "
@ -449,7 +449,7 @@ class _ListWrapper(List, collections.MutableSequence,
value_now = self._storage[i] if i < len_now else None
value_before = storage_copy[i] if i < len_before else None
if isinstance(value_before, base.Checkpointable):
if isinstance(value_before, base.Trackable):
self._non_append_mutation = True
if value_now is not None and value_now != value_before:
@ -457,20 +457,20 @@ class _ListWrapper(List, collections.MutableSequence,
self._name_element(i))
else:
if isinstance(self._storage[key], base.Checkpointable):
if isinstance(self._storage[key], base.Trackable):
self._non_append_mutation = True
self._storage[key] = self._track_value(value, self._name_element(key))
self._update_snapshot()
def append(self, value):
"""Add a new checkpointable value."""
"""Add a new trackable value."""
self._check_external_modification()
super(_ListWrapper, self).append(value)
self._update_snapshot()
def extend(self, values):
"""Add a sequence of checkpointable values."""
"""Add a sequence of trackable values."""
self._check_external_modification()
super(_ListWrapper, self).extend(values)
self._update_snapshot()
@ -514,14 +514,14 @@ class _ListWrapper(List, collections.MutableSequence,
del self._storage[slice(i, j)]
def _track_value(self, value, name):
"""Allows storage of non-checkpointable objects."""
"""Allows storage of non-trackable objects."""
try:
value = super(_ListWrapper, self)._track_value(value=value, name=name)
except ValueError:
# Even if this value isn't checkpointable, we need to make sure
# Even if this value isn't trackable, we need to make sure
# NoDependency objects get unwrapped.
value = sticky_attribute_assignment(
checkpointable=self, value=value, name=name)
trackable=self, value=value, name=name)
return value
def __repr__(self):
@ -534,11 +534,11 @@ class _ListWrapper(List, collections.MutableSequence,
}
class Mapping(CheckpointableDataStructure, collections.Mapping):
"""An append-only checkpointable mapping data structure with string keys.
class Mapping(TrackableDataStructure, collections.Mapping):
"""An append-only trackable mapping data structure with string keys.
Maintains checkpoint dependencies on its contents (which must also be
checkpointable), named based on its keys.
trackable), named based on its keys.
Note that once a key has been added, it may not be deleted or replaced. If
names may not be unique, see `tf.contrib.checkpoint.UniqueNameTracker`.
@ -615,7 +615,7 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
# patching all of the "wrapped" dict's methods instead of creating a wrapper
# object is an option, but not a very attractive one (replacing methods without
# creating reference cycles is difficult, and then dicts would need to be
# special cased everywhere as being checkpointable).
# special cased everywhere as being trackable).
class _DictWrapper(Mapping, collections.MutableMapping):
"""Wraps built-in dicts to support restore-on-create for variables.
@ -671,7 +671,7 @@ class _DictWrapper(Mapping, collections.MutableMapping):
raise ValueError(
"Unable to save the object %s (a dictionary wrapper constructed "
"automatically on attribute assignment). The wrapped dictionary "
"contains a non-string key which maps to a checkpointable object or "
"contains a non-string key which maps to a trackable object or "
"mutable data structure.\n\nIf you don't need this dictionary "
"checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
"object; it will be automatically un-wrapped and subsequently "
@ -680,7 +680,7 @@ class _DictWrapper(Mapping, collections.MutableMapping):
raise ValueError(
"Unable to save the object %s (a dictionary wrapper constructed "
"automatically on attribute assignment). A key mapping to a "
"checkpointable object was overwritten or deleted, which would "
"trackable object was overwritten or deleted, which would "
"cause problems for restoration.\n\nIf you don't need this "
"dictionary checkpointed, wrap it in a "
"tf.contrib.checkpoint.NoDependency object; it will be automatically "
@ -721,7 +721,7 @@ class _DictWrapper(Mapping, collections.MutableMapping):
self._last_wrapped_dict_snapshot = dict(self)
def _track_value(self, value, name):
"""Allows storage of non-checkpointable objects."""
"""Allows storage of non-trackable objects."""
if isinstance(name, six.string_types):
string_key = True
else:
@ -731,15 +731,15 @@ class _DictWrapper(Mapping, collections.MutableMapping):
no_dependency = isinstance(value, NoDependency)
value = super(_DictWrapper, self)._track_value(value=value, name=name)
if not (string_key or no_dependency):
# A non-string key maps to a checkpointable value. This data structure
# A non-string key maps to a trackable value. This data structure
# is not saveable.
self._non_string_key = True
return value
except ValueError:
# Even if this value isn't checkpointable, we need to make sure
# Even if this value isn't trackable, we need to make sure
# NoDependency objects get unwrapped.
return sticky_attribute_assignment(
checkpointable=self, value=value, name=name)
trackable=self, value=value, name=name)
def _name_element(self, key):
"""Don't throw errors for non-string keys."""
@ -758,19 +758,19 @@ class _DictWrapper(Mapping, collections.MutableMapping):
else:
value = _wrap_or_unwrap(value)
existing_dependency = None
if not no_dep and isinstance(value, base.Checkpointable):
if not no_dep and isinstance(value, base.Trackable):
# Non-string keys are OK as long as we have no reason to add a
# dependency on the value (either because the value is not
# checkpointable, or because it was wrapped in a NoDependency object).
# trackable, or because it was wrapped in a NoDependency object).
self._non_string_key = True
current_value = self._storage.setdefault(key, value)
if current_value is not value:
if ((not no_dep and isinstance(value, base.Checkpointable))
if ((not no_dep and isinstance(value, base.Trackable))
# We don't want to just check that the existing object is
# checkpointable, since it may have been wrapped in a NoDependency
# trackable, since it may have been wrapped in a NoDependency
# object.
or existing_dependency is not None):
# A checkpointable object was replaced under the same key; this means
# A trackable object was replaced under the same key; this means
# that restoring would be error-prone, so we'll throw an exception on
# save.
self._non_append_mutation = True
@ -781,8 +781,8 @@ class _DictWrapper(Mapping, collections.MutableMapping):
def __delitem__(self, key):
self._check_external_modification()
existing_value = self[key]
if isinstance(existing_value, base.Checkpointable):
# Deleting tracked checkpointable values means restoring is problematic,
if isinstance(existing_value, base.Trackable):
# Deleting tracked trackable values means restoring is problematic,
# so we'll throw an exception on save.
self._non_append_mutation = True
del self._storage[key]
@ -812,10 +812,10 @@ def _is_function(x):
return isinstance(x, (def_function.Function, defun.ConcreteFunction))
revived_types.register_revived_type(
"checkpointable_dict_wrapper",
"trackable_dict_wrapper",
lambda obj: isinstance(obj, _DictWrapper),
versions=[revived_types.VersionedTypeRegistration(
# Standard dependencies are enough to reconstruct the checkpointable
# Standard dependencies are enough to reconstruct the trackable
# items in dictionaries, so we don't need to save any extra information.
object_factory=lambda proto: _DictWrapper({}),
version=1,
@ -832,7 +832,7 @@ def _set_list_item(list_object, index_string, value):
revived_types.register_revived_type(
"checkpointable_list_wrapper",
"trackable_list_wrapper",
lambda obj: isinstance(obj, _ListWrapper),
versions=[revived_types.VersionedTypeRegistration(
object_factory=lambda proto: _ListWrapper([]),

View File

@ -34,9 +34,9 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
class HasList(training.Model):
@ -145,12 +145,12 @@ class ListTests(test.TestCase):
model.l2.append(second_layer)
self.assertEqual([first_layer, second_layer], model.layers)
def testNotCheckpointable(self):
class NotCheckpointable(object):
def testNotTrackable(self):
class NotTrackable(object):
pass
with self.assertRaises(ValueError):
data_structures.List([NotCheckpointable()])
data_structures.List([NotTrackable()])
def testCallNotImplemented(self):
with self.assertRaisesRegexp(TypeError, "not callable"):
@ -287,8 +287,8 @@ class ListWrapperTest(test.TestCase):
def testListWrapperBasic(self):
# _ListWrapper, unlike List, compares like the built-in list type (since it
# is used to automatically replace lists).
a = tracking.AutoCheckpointable()
b = tracking.AutoCheckpointable()
a = tracking.AutoTrackable()
b = tracking.AutoTrackable()
self.assertEqual([a, a],
[a, a])
self.assertEqual(data_structures._ListWrapper([a, a]),
@ -321,7 +321,7 @@ class ListWrapperTest(test.TestCase):
self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
self.assertIsInstance(data_structures._ListWrapper([a]), list)
def testAcceptsNonCheckpointableContent(self):
def testAcceptsNonTrackableContent(self):
l = data_structures._ListWrapper([1, 2, 3])
self.assertEqual(l, [1, 2, 3])
@ -360,14 +360,14 @@ class ListWrapperTest(test.TestCase):
self.assertEqual(l, [1, 2, 4])
self.assertUnableToSave(l, "Unable to save .*__delslice__")
def testSetSlice_canSaveForNonCheckpointableItems(self):
def testSetSlice_canSaveForNonTrackableItems(self):
l = data_structures._ListWrapper([1, 2, 3, 4])
l[:] = 2, 8, 9, 0
self.assertEqual(l, [2, 8, 9, 0])
l._maybe_initialize_checkpointable() # pylint: disable=protected-access
l._maybe_initialize_trackable() # pylint: disable=protected-access
self.assertEqual(len(l._checkpoint_dependencies), 0) # pylint: disable=protected-access
def testSetSlice_cannotSaveIfCheckpointableModified(self):
def testSetSlice_cannotSaveIfTrackableModified(self):
v1 = resource_variable_ops.ResourceVariable(1.)
v2 = resource_variable_ops.ResourceVariable(1.)
l = data_structures._ListWrapper([1, 2, v1, v2])
@ -391,12 +391,12 @@ class ListWrapperTest(test.TestCase):
self.assertEqual(l, [1, 2, 3, 4])
# Regardless of being a no-op for the input list, we still refuse to save.
# This is intentional since otherwise we would end up with a hard to debug
# case for users (e.g. sometimes sort on a ListWrapper is checkpointable and
# case for users (e.g. sometimes sort on a ListWrapper is trackable and
# other times it is not).
self.assertUnableToSave(l, "Unable to save .*sort")
def assertUnableToSave(self, l, msg):
l._maybe_initialize_checkpointable() # pylint: disable=protected-access
l._maybe_initialize_trackable() # pylint: disable=protected-access
with self.assertRaisesRegexp(ValueError, msg):
return l._checkpoint_dependencies # pylint: disable=protected-access
@ -466,7 +466,7 @@ class MappingTests(test.TestCase):
def testLayerCollectionWithExternalMutation(self):
d = {}
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
root.wrapper = d
self.assertEqual([], root.wrapper.layers)
self.assertEqual([], root.wrapper.trainable_weights)
@ -484,7 +484,7 @@ class MappingTests(test.TestCase):
self.assertEqual(2, len(has_mappings))
self.assertNotIn(data_structures.Mapping(), has_mappings)
# In contrast to Mapping, dict wrappers are not hashable
a = tracking.AutoCheckpointable()
a = tracking.AutoTrackable()
a.d = {}
self.assertEqual({}, a.d)
self.assertFalse({} != a.d) # pylint: disable=g-explicit-bool-comparison
@ -493,7 +493,7 @@ class MappingTests(test.TestCase):
set([a.d])
def testDictWrapperBadKeys(self):
a = tracking.AutoCheckpointable()
a = tracking.AutoTrackable()
a.d = {}
a.d[1] = data_structures.List()
model = training.Model()
@ -503,7 +503,7 @@ class MappingTests(test.TestCase):
model.save_weights(save_path)
def testDictWrapperNoDependency(self):
a = tracking.AutoCheckpointable()
a = tracking.AutoTrackable()
a.d = data_structures.NoDependency({})
a.d[1] = [3]
self.assertEqual([a], util.list_objects(a))
@ -513,8 +513,8 @@ class MappingTests(test.TestCase):
model.save_weights(save_path)
model.load_weights(save_path)
def testNonStringKeyNotCheckpointableValue(self):
a = tracking.AutoCheckpointable()
def testNonStringKeyNotTrackableValue(self):
a = tracking.AutoTrackable()
a.d = {}
a.d["a"] = [3]
a.d[1] = data_structures.NoDependency([3])
@ -525,18 +525,18 @@ class MappingTests(test.TestCase):
model.save_weights(save_path)
model.load_weights(save_path)
def testNonAppendNotCheckpointable(self):
def testNonAppendNotTrackable(self):
# Non-append mutations (deleting or overwriting values) are OK when the
# values aren't tracked.
a = tracking.AutoCheckpointable()
a = tracking.AutoTrackable()
a.d = {}
a.d["a"] = [3]
a.d[1] = 3
a.d[1] = 2
self.assertEqual(2, a.d[1])
del a.d[1]
a.d[2] = data_structures.NoDependency(tracking.AutoCheckpointable())
second = tracking.AutoCheckpointable()
a.d[2] = data_structures.NoDependency(tracking.AutoTrackable())
second = tracking.AutoTrackable()
a.d[2] = data_structures.NoDependency(second)
self.assertIs(second, a.d[2])
self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
@ -598,7 +598,7 @@ class MappingTests(test.TestCase):
self.assertEqual({1: 3}, new_dict)
def testListShallowCopy(self):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
orig_list = [[1.]]
root.a = orig_list
copied = copy.copy(root.a)
@ -615,7 +615,7 @@ class MappingTests(test.TestCase):
util.list_objects(copy.copy(root.a))
def testListDeepCopy(self):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
orig_list = [[1.]]
root.a = orig_list
copied = copy.deepcopy(root.a)
@ -632,7 +632,7 @@ class MappingTests(test.TestCase):
util.list_objects(copy.deepcopy(root.a))
def testDictShallowCopy(self):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
orig_dict = {"a": [1.]}
root.a = orig_dict
copied = copy.copy(root.a)
@ -649,7 +649,7 @@ class MappingTests(test.TestCase):
util.list_objects(copy.copy(root.a))
def testDictDeepCopy(self):
root = tracking.AutoCheckpointable()
root = tracking.AutoTrackable()
orig_dict = {"a": [1.]}
root.a = orig_dict
copied = copy.deepcopy(root.a)
@ -665,9 +665,9 @@ class MappingTests(test.TestCase):
with self.assertRaises(ValueError):
util.list_objects(copy.deepcopy(root.a))
def testShallowCopyCheckpointable(self):
original = tracking.AutoCheckpointable()
original_sub = tracking.AutoCheckpointable()
def testShallowCopyTrackable(self):
original = tracking.AutoTrackable()
original_sub = tracking.AutoTrackable()
original.a = [[1.]]
original.b = {"a": original_sub}
shallow_copied = copy.copy(original)
@ -679,16 +679,16 @@ class MappingTests(test.TestCase):
self.assertIn(shallow_copied.b, shallow_deps)
self.assertIn(shallow_copied.b["a"], shallow_deps)
def testDeepCopyCheckpointable(self):
original = tracking.AutoCheckpointable()
original_sub = tracking.AutoCheckpointable()
def testDeepCopyTrackable(self):
original = tracking.AutoTrackable()
original_sub = tracking.AutoTrackable()
original.a = [[1.]]
original.b = {"a": original_sub}
deep_copied = copy.deepcopy(original)
self.assertIsNot(original, deep_copied)
self.assertIsNot(original_sub, deep_copied.b["a"])
self.assertEqual([[1.]], deep_copied.a)
self.assertIsInstance(deep_copied.b["a"], tracking.AutoCheckpointable)
self.assertIsInstance(deep_copied.b["a"], tracking.AutoTrackable)
deps = util.list_objects(deep_copied)
self.assertIn(deep_copied.a, deps)
self.assertIn(deep_copied.b, deps)

Some files were not shown because too many files have changed in this diff Show More