Add tf.contrib.checkpoint.CheckpointManager for deleting old checkpoints

Removes a bit of boilerplate from training loops (making a prefix from a directory).

Also clarifies the recovery of checkpoint lists (like tf.train.Saver.recover_last_checkpoints, but automatic and more thorough). Adds a couple fields to the CheckpointState proto to support this.

Should live in contrib until I make it work well with tf.keras.Model.save_weights. When used together, save_weights needs to number its checkpoints. (There's a TODO for this.)

PiperOrigin-RevId: 208566198
This commit is contained in:
Allen Lavoie 2018-08-13 16:48:08 -07:00 committed by TensorFlower Gardener
parent b2dfe8a520
commit 7a81491366
7 changed files with 468 additions and 23 deletions

View File

@ -31,6 +31,9 @@ Checkpointable data structures:
@@List
@@Mapping
@@UniqueNameTracker
Checkpoint management:
@@CheckpointManager
"""
from __future__ import absolute_import
@ -41,6 +44,7 @@ from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
from tensorflow.python.training.checkpoint_management import CheckpointManager
from tensorflow.python.training.checkpointable.base import CheckpointableBase
from tensorflow.python.training.checkpointable.data_structures import List
from tensorflow.python.training.checkpointable.data_structures import Mapping

View File

@ -19,14 +19,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os.path
import re
import time
from google.protobuf import text_format
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.util.tf_export import tf_export
@ -51,7 +56,9 @@ def _GetCheckpointFilename(save_dir, latest_filename):
@tf_export("train.generate_checkpoint_state_proto")
def generate_checkpoint_state_proto(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None):
all_model_checkpoint_paths=None,
all_model_checkpoint_timestamps=None,
last_preserved_timestamp=None):
"""Generates a checkpoint state proto.
Args:
@ -61,11 +68,20 @@ def generate_checkpoint_state_proto(save_dir,
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
all_model_checkpoint_timestamps: A list of floats, indicating the number of
seconds since the Epoch when each checkpoint was generated.
last_preserved_timestamp: A float, indicating the number of seconds since
the Epoch when the last preserved checkpoint was written, e.g. due to a
`keep_checkpoint_every_n_hours` parameter (see
`tf.contrib.checkpoint.CheckpointManager` for an implementation).
Returns:
CheckpointState proto with model_checkpoint_path and
all_model_checkpoint_paths updated to either absolute paths or
relative paths to the current save_dir.
Raises:
ValueError: If `all_model_checkpoint_timestamps` was provided but its length
does not match `all_model_checkpoint_paths`.
"""
if all_model_checkpoint_paths is None:
all_model_checkpoint_paths = []
@ -76,6 +92,14 @@ def generate_checkpoint_state_proto(save_dir,
model_checkpoint_path)
all_model_checkpoint_paths.append(model_checkpoint_path)
if (all_model_checkpoint_timestamps
and (len(all_model_checkpoint_timestamps)
!= len(all_model_checkpoint_paths))):
raise ValueError(
("Checkpoint timestamps, if provided, must match checkpoint paths (got "
"paths %s and timestamps %s)")
% (all_model_checkpoint_paths, all_model_checkpoint_timestamps))
# Relative paths need to be rewritten to be relative to the "save_dir"
# if model_checkpoint_path already contains "save_dir".
if not os.path.isabs(save_dir):
@ -88,7 +112,9 @@ def generate_checkpoint_state_proto(save_dir,
coord_checkpoint_proto = CheckpointState(
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths)
all_model_checkpoint_paths=all_model_checkpoint_paths,
all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
last_preserved_timestamp=last_preserved_timestamp)
return coord_checkpoint_proto
@ -97,7 +123,9 @@ def generate_checkpoint_state_proto(save_dir,
def update_checkpoint_state(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
latest_filename=None):
latest_filename=None,
all_model_checkpoint_timestamps=None,
last_preserved_timestamp=None):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
@ -112,7 +140,13 @@ def update_checkpoint_state(save_dir,
are also saved in the CheckpointState proto.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
all_model_checkpoint_timestamps: Optional list of timestamps (floats,
seconds since the Epoch) indicating when the checkpoints in
`all_model_checkpoint_paths` were created.
last_preserved_timestamp: A float, indicating the number of seconds since
the Epoch when the last preserved checkpoint was written, e.g. due to a
`keep_checkpoint_every_n_hours` parameter (see
`tf.contrib.checkpoint.CheckpointManager` for an implementation).
Raises:
RuntimeError: If any of the model checkpoint paths conflict with the file
containing CheckpointSate.
@ -122,14 +156,18 @@ def update_checkpoint_state(save_dir,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths,
latest_filename=latest_filename,
save_relative_paths=False)
save_relative_paths=False,
all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
last_preserved_timestamp=last_preserved_timestamp)
def update_checkpoint_state_internal(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
latest_filename=None,
save_relative_paths=False):
save_relative_paths=False,
all_model_checkpoint_timestamps=None,
last_preserved_timestamp=None):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
@ -146,6 +184,13 @@ def update_checkpoint_state_internal(save_dir,
'checkpoint'.
save_relative_paths: If `True`, will write relative paths to the checkpoint
state file.
all_model_checkpoint_timestamps: Optional list of timestamps (floats,
seconds since the Epoch) indicating when the checkpoints in
`all_model_checkpoint_paths` were created.
last_preserved_timestamp: A float, indicating the number of seconds since
the Epoch when the last preserved checkpoint was written, e.g. due to a
`keep_checkpoint_every_n_hours` parameter (see
`tf.contrib.checkpoint.CheckpointManager` for an implementation).
Raises:
RuntimeError: If any of the model checkpoint paths conflict with the file
@ -168,12 +213,16 @@ def update_checkpoint_state_internal(save_dir,
ckpt = generate_checkpoint_state_proto(
save_dir,
rel_model_checkpoint_path,
all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
all_model_checkpoint_paths=rel_all_model_checkpoint_paths,
all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
last_preserved_timestamp=last_preserved_timestamp)
else:
ckpt = generate_checkpoint_state_proto(
save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths)
all_model_checkpoint_paths=all_model_checkpoint_paths,
all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
last_preserved_timestamp=last_preserved_timestamp)
if coord_checkpoint_filename == ckpt.model_checkpoint_path:
raise RuntimeError("Save path '%s' conflicts with path used for "
@ -404,3 +453,217 @@ def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
suffixed_filename = ".".join([basename, meta_graph_suffix])
return suffixed_filename
# TODO(allenl): Allow tf.keras.Model instances in the constructor directly?
class CheckpointManager(object):
"""Deletes old checkpoints.
Example usage:
```python
import tensorflow as tf
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.contrib.checkpoint.CheckpointManager(
checkpoint, directory="/tmp/model", max_to_keep=5)
status = checkpoint.restore(manager.latest_checkpoint)
while True:
# train
manager.save()
```
`CheckpointManager` preserves its own state across instantiations (see the
`__init__` documentation for details). Only one should be active in a
particular directory at a time.
"""
def __init__(self, checkpoint, directory,
max_to_keep, keep_checkpoint_every_n_hours=None):
"""Configure a `CheckpointManager` for use in `directory`.
If a `CheckpointManager` was previously used in `directory`, its
state will be restored. This includes the list of managed checkpoints and
the timestamp bookkeeping necessary to support
`keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager`
will be the same as the previous `CheckpointManager`, including cleaning up
existing checkpoints if appropriate.
Checkpoints are only considered for deletion just after a new checkpoint has
been added. At that point, `max_to_keep` checkpoints will remain in an
"active set". Once a checkpoint is preserved by
`keep_checkpoint_every_n_hours` it will not be deleted by this
`CheckpointManager` or any future `CheckpointManager` instantiated in
`directory` (regardless of the new setting of
`keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the
active set may be deleted by this `CheckpointManager` or a future
`CheckpointManager` instantiated in `directory` (subject to its
`max_to_keep` and `keep_checkpoint_every_n_hours` settings).
Args:
checkpoint: The `tf.train.Checkpoint` instance to save and manage
checkpoints for.
directory: The path to a directory in which to write checkpoints. A
special file named "checkpoint" is also written to this directory (in a
human-readable text format) which contains the state of the
`CheckpointManager`.
max_to_keep: An integer, the number of checkpoints to keep. Unless
preserved by `keep_checkpoint_every_n_hours`, checkpoints will be
deleted from the active set, oldest first, until only `max_to_keep`
checkpoints remain.
keep_checkpoint_every_n_hours: Upon removal from the active set, a
checkpoint will be preserved if it has been at least
`keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
default setting of `None` does not preserve any checkpoints in this way.
Raises:
ValueError: If `max_to_keep` is not a positive integer.
"""
self._checkpoint = checkpoint
self._save_counter_assign = None
if not max_to_keep or max_to_keep < 0:
raise ValueError(
"Expected a positive integer for `max_to_max_to_keep`, got %d."
% (max_to_keep,))
self._max_to_keep = max_to_keep
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
self._directory = directory
self._checkpoint_prefix = os.path.join(directory, "ckpt")
recovered_state = get_checkpoint_state(directory)
current_clock = time.time()
self._maybe_delete = collections.OrderedDict()
if recovered_state is None:
self._latest_checkpoint = None
self._last_preserved_timestamp = current_clock
else:
self._latest_checkpoint = recovered_state.model_checkpoint_path
self._last_preserved_timestamp = recovered_state.last_preserved_timestamp
if current_clock < self._last_preserved_timestamp:
# Time seems to have reversed itself. In addition to this warning, we'll
# min() saved checkpoint timestamps with the current time to ensure that
# old checkpoints don't get deleted accidentally.
logging.warning(
("time.time() returned a value %f seconds behind the last "
"preserved checkpoint timestamp.")
% (self._last_preserved_timestamp - current_clock,))
self._last_preserved_timestamp = current_clock
all_timestamps = recovered_state.all_model_checkpoint_timestamps
all_paths = recovered_state.all_model_checkpoint_paths
del recovered_state # Uses modified values from now on
if not all_timestamps:
all_timestamps = [self._last_preserved_timestamp] * len(all_paths)
for filename, timestamp in zip(all_paths, all_timestamps):
timestamp = min(timestamp, current_clock)
if timestamp > self._last_preserved_timestamp:
self._maybe_delete[filename] = timestamp
@property
def latest_checkpoint(self):
"""The prefix of the most recent checkpoint in `directory`.
Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is
the constructor argument to `CheckpointManager`.
Suitable for passing to `tf.train.Checkpoint.restore` to resume training.
Returns:
The checkpoint prefix. If there are no checkpoints, returns `None`.
"""
return self._latest_checkpoint
@property
def checkpoints(self):
"""A list of managed checkpoints.
Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not
show up in this list (to avoid ever-growing filename lists).
Returns:
A list of filenames, sorted from oldest to newest.
"""
return list(self._maybe_delete.keys())
def _sweep(self):
"""Deletes or preserves managed checkpoints."""
while len(self._maybe_delete) > self._max_to_keep:
filename, timestamp = self._maybe_delete.popitem(last=False)
# Even if we're keeping this checkpoint due to
# keep_checkpoint_every_n_hours, we won't reference it to avoid
# infinitely-growing CheckpointState protos.
if (self._keep_checkpoint_every_n_hours
and (timestamp - self._keep_checkpoint_every_n_hours * 3600.
>= self._last_preserved_timestamp)):
self._last_preserved_timestamp = timestamp
continue
remove_checkpoint(filename)
def _record_state(self):
"""Saves the `CheckpointManager`'s state in `directory`."""
filenames, timestamps = zip(*self._maybe_delete.items())
update_checkpoint_state_internal(
self._directory,
model_checkpoint_path=self.latest_checkpoint,
all_model_checkpoint_paths=filenames,
all_model_checkpoint_timestamps=timestamps,
last_preserved_timestamp=self._last_preserved_timestamp,
save_relative_paths=True)
@property
def _prefix(self):
"""A common prefix for all checkpoints saved with this manager.
For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`,
`prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be
numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each
checkpoint has several associated files
(e.g. `"/tmp/tf-model/ckpt-2.index"`).
Returns:
A string prefix.
"""
return self._checkpoint_prefix
def save(self, session=None):
"""Creates a new checkpoint and manages it.
Args:
session: The session to evaluate variables in. Ignored when executing
eagerly. If not provided when graph building, the default session is
used.
Returns:
The path to the new checkpoint. It is also recorded in the `checkpoints`
and `latest_checkpoint` properies.
"""
# Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
# slightly with a custom numbering option.
if context.executing_eagerly():
save_counter = self._checkpoint.save_counter
save_counter.assign_add(1)
checkpoint_number = save_counter.numpy()
else:
if session is None:
session = ops.get_default_session()
def _initializing_creator(next_creator, **kwargs):
"""Initialize the save counter if it has been newly created."""
v = next_creator(**kwargs)
session.run(v.initializer)
return v
with variable_scope.variable_creator_scope(_initializing_creator):
save_counter = self._checkpoint.save_counter
if self._save_counter_assign is None:
self._save_counter_assign = save_counter.assign_add(1, read_value=True)
checkpoint_number = session.run(self._save_counter_assign)
prefix = "%s-%d" % (self._prefix, checkpoint_number)
save_path = self._checkpoint.write(prefix)
timestamp = time.time()
# If this is an overwritten checkpoint we were previously tracking, delete
# and reinsert it to make sure it goes to the end of the queue.
if save_path in self._maybe_delete:
del self._maybe_delete[save_path]
self._maybe_delete[save_path] = timestamp
self._latest_checkpoint = save_path
self._sweep()
self._record_state()
return save_path

View File

@ -27,13 +27,16 @@ from google.protobuf import text_format
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.training.checkpointable import util
class LatestCheckpointWithRelativePaths(test.TestCase):
@ -312,5 +315,177 @@ class SaverUtilsTest(test.TestCase):
self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
class CheckpointManagerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testDeletion(self):
checkpoint = util.Checkpoint()
manager = checkpoint_management.CheckpointManager(
checkpoint, self.get_temp_dir(), max_to_keep=3)
first_path = manager.save()
second_path = manager.save()
third_path = manager.save()
fourth_path = manager.save()
self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
@test_util.run_in_graph_and_eager_modes
@test.mock.patch.object(checkpoint_management, "time")
def testSaveRestoreState(self, mock_time):
directory = self.get_temp_dir()
mock_time.time.return_value = 3.
checkpoint = util.Checkpoint()
first_manager = checkpoint_management.CheckpointManager(
checkpoint, directory, max_to_keep=2)
first_time = 10000.
first_name = os.path.join(directory, "ckpt-1")
mock_time.time.return_value = first_time
first_manager.save()
state = checkpoint_management.get_checkpoint_state(directory)
self.assertEqual([first_time], state.all_model_checkpoint_timestamps)
self.assertEqual(3., state.last_preserved_timestamp)
second_time = first_time + 3610.
second_name = os.path.join(directory, "ckpt-2")
mock_time.time.return_value = second_time
first_manager.save()
state = checkpoint_management.get_checkpoint_state(directory)
self.assertEqual([first_time, second_time],
state.all_model_checkpoint_timestamps)
self.assertEqual(3., state.last_preserved_timestamp)
self.assertEqual([first_name, second_name], first_manager.checkpoints)
self.assertEqual(second_name, first_manager.latest_checkpoint)
del first_manager
second_manager = checkpoint_management.CheckpointManager(
checkpoint, directory,
max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
self.assertEqual([first_name, second_name], second_manager.checkpoints)
self.assertEqual(second_name, second_manager.latest_checkpoint)
third_name = os.path.join(directory, "ckpt-3")
third_time = second_time + 3600. * 0.2
mock_time.time.return_value = third_time
second_manager.save()
self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
self.assertTrue(checkpoint_management.checkpoint_exists(second_name))
self.assertEqual([second_name, third_name],
second_manager.checkpoints)
state = checkpoint_management.get_checkpoint_state(directory)
self.assertEqual(first_time, state.last_preserved_timestamp)
fourth_time = third_time + 3600. * 0.5
mock_time.time.return_value = fourth_time
fourth_name = os.path.join(directory, "ckpt-4")
second_manager.save()
self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
self.assertEqual([third_name, fourth_name],
second_manager.checkpoints)
fifth_time = fourth_time + 3600. * 0.5
mock_time.time.return_value = fifth_time
fifth_name = os.path.join(directory, "ckpt-5")
second_manager.save()
self.assertEqual([fourth_name, fifth_name],
second_manager.checkpoints)
state = checkpoint_management.get_checkpoint_state(directory)
self.assertEqual(first_time, state.last_preserved_timestamp)
del second_manager
third_manager = checkpoint_management.CheckpointManager(
checkpoint, directory,
max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
self.assertEqual(fifth_name, third_manager.latest_checkpoint)
mock_time.time.return_value += 10.
third_manager.save()
sixth_name = os.path.join(directory, "ckpt-6")
state = checkpoint_management.get_checkpoint_state(directory)
self.assertEqual(fourth_time, state.last_preserved_timestamp)
self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name))
self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name))
self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name))
self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
self.assertFalse(checkpoint_management.checkpoint_exists(third_name))
self.assertEqual([fifth_name, sixth_name],
third_manager.checkpoints)
@test_util.run_in_graph_and_eager_modes
def testContinueFromUnmanaged(self):
directory = self.get_temp_dir()
prefix = os.path.join(directory, "unusual_prefix")
checkpoint = util.Checkpoint()
first_path = checkpoint.save(prefix)
second_path = checkpoint.save(prefix)
del checkpoint
checkpoint = util.Checkpoint()
manager = checkpoint_management.CheckpointManager(
checkpoint, directory, max_to_keep=2)
checkpoint.restore(manager.latest_checkpoint).run_restore_ops()
self.assertEqual(2, self.evaluate(checkpoint.save_counter))
third_path = manager.save()
self.assertEqual([third_path], manager.checkpoints)
fourth_path = manager.save()
self.assertEqual([third_path, fourth_path],
manager.checkpoints)
fifth_path = manager.save()
self.assertEqual([fourth_path, fifth_path],
manager.checkpoints)
self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
self.assertFalse(checkpoint_management.checkpoint_exists(third_path))
self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
@test_util.run_in_graph_and_eager_modes
@test.mock.patch.object(checkpoint_management, "time")
def testClockReset(self, mock_time):
directory = self.get_temp_dir()
mock_time.time.return_value = 10000.
checkpoint = util.Checkpoint()
first_manager = checkpoint_management.CheckpointManager(
checkpoint, directory, max_to_keep=1, keep_checkpoint_every_n_hours=1.)
first_path = first_manager.save()
mock_time.time.return_value += 3600.
second_path = first_manager.save()
mock_time.time.return_value += 3600.
third_path = first_manager.save()
self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
self.assertEqual([third_path], first_manager.checkpoints)
state = checkpoint_management.get_checkpoint_state(directory)
self.assertEqual(13600., state.last_preserved_timestamp)
# Set the clock back in time
mock_time.time.return_value = 5000.
del first_manager
with test.mock.patch.object(logging, "warning") as mock_log:
second_manager = checkpoint_management.CheckpointManager(
checkpoint, directory, max_to_keep=1)
self.assertRegexpMatches(
str(mock_log.call_args),
"behind the last preserved checkpoint timestamp")
# We should err on the side of keeping checkpoints around when we're not
# sure whether they were preserved or not due to clock funkiness.
self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
# We know about the existing checkpoints, but they'll never be deleted and
# so won't go in the CheckpointState proto on save.
self.assertEqual(third_path, second_manager.latest_checkpoint)
self.assertEqual([], second_manager.checkpoints)
mock_time.time.return_value += 10.
fourth_path = second_manager.save()
self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
self.assertEqual(fourth_path, second_manager.latest_checkpoint)
self.assertEqual([fourth_path], second_manager.checkpoints)
mock_time.time.return_value += 10.
fifth_path = second_manager.save()
self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
self.assertEqual([fifth_path], second_manager.checkpoints)
state = checkpoint_management.get_checkpoint_state(directory)
self.assertEqual(5000., state.last_preserved_timestamp)
self.assertEqual([5020.],
state.all_model_checkpoint_timestamps)
if __name__ == "__main__":
test.main()

View File

@ -4,8 +4,6 @@ package tensorflow;
option cc_enable_arenas = true;
// Protocol buffer representing the checkpoint state.
//
// TODO(touts): Add other attributes as needed.
message CheckpointState {
// Path to the most-recent model checkpoint.
string model_checkpoint_path = 1;
@ -15,4 +13,10 @@ message CheckpointState {
// Note that the value of model_checkpoint_path should be the last item in
// this list.
repeated string all_model_checkpoint_paths = 2;
// Unix timestamps corresponding to all_model_checkpoint_paths, indicating
// when each checkpoint was created.
repeated double all_model_checkpoint_timestamps = 3;
// Unix timestamp indicating the creation time for the last preserved
// checkpoint.
double last_preserved_timestamp = 4;
}

View File

@ -35,8 +35,8 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import saveable_object as saveable_object_lib
@ -227,10 +227,11 @@ def _default_getter(name, shape, dtype, initializer=None,
def initial_value():
return initializer(
shape_object.as_list(), dtype=dtype, partition_info=partition_info)
return resource_variable_ops.ResourceVariable(
return variables.Variable(
initial_value=initial_value,
name=name,
dtype=variable_dtype,
use_resource=True,
**kwargs
)
@ -1528,8 +1529,6 @@ class Checkpoint(tracking.Checkpointable):
self._maybe_create_save_counter()
return self._save_counter
# TODO(allenl): Update save's docstring with a pointer to
# tf.contrib.checkpoint.CheckpointManager once that's in.
def save(self, file_prefix, session=None):
"""Saves a training checkpoint and provides basic checkpoint management.
@ -1541,7 +1540,8 @@ class Checkpoint(tracking.Checkpointable):
sequentially numbering checkpoints using `save_counter` and updating the
metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
management, for example garbage collection and custom numbering, may be
provided by other utilities which also wrap `write`.
provided by other utilities which also wrap `write`
(`tf.contrib.checkpoint.CheckpointManager` for example).
Args:
file_prefix: A prefix to use for the checkpoint filenames

View File

@ -522,7 +522,6 @@ class CheckpointingTests(test.TestCase):
# Does create garbage when executing eagerly due to ops.Graph() creation.
num_training_steps = 10
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
for training_continuation in range(3):
with ops.Graph().as_default(), self.test_session(
graph=ops.get_default_graph()), test_util.device(use_gpu=True):
@ -531,9 +530,9 @@ class CheckpointingTests(test.TestCase):
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
checkpoint_path = checkpoint_management.latest_checkpoint(
checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
manager = checkpoint_management.CheckpointManager(
root, checkpoint_directory, max_to_keep=1)
status = root.restore(save_path=manager.latest_checkpoint)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
optimizer.minimize,
@ -544,7 +543,7 @@ class CheckpointingTests(test.TestCase):
status.initialize_or_restore()
for _ in range(num_training_steps):
train_fn()
root.save(file_prefix=checkpoint_prefix)
manager.save()
self.assertEqual((training_continuation + 1) * num_training_steps,
self.evaluate(root.global_step))
self.assertEqual(training_continuation + 1,

View File

@ -298,7 +298,7 @@ tf_module {
}
member_method {
name: "generate_checkpoint_state_proto"
argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'all_model_checkpoint_timestamps\', \'last_preserved_timestamp\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "get_checkpoint_mtimes"
@ -446,7 +446,7 @@ tf_module {
}
member_method {
name: "update_checkpoint_state"
argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'latest_filename\', \'all_model_checkpoint_timestamps\', \'last_preserved_timestamp\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "warm_start"