Add tf.contrib.checkpoint.CheckpointManager for deleting old checkpoints
Removes a bit of boilerplate from training loops (making a prefix from a directory). Also clarifies the recovery of checkpoint lists (like tf.train.Saver.recover_last_checkpoints, but automatic and more thorough). Adds a couple fields to the CheckpointState proto to support this. Should live in contrib until I make it work well with tf.keras.Model.save_weights. When used together, save_weights needs to number its checkpoints. (There's a TODO for this.) PiperOrigin-RevId: 208566198
This commit is contained in:
parent
b2dfe8a520
commit
7a81491366
tensorflow
contrib/checkpoint
python/training
tools/api/golden/v1
@ -31,6 +31,9 @@ Checkpointable data structures:
|
||||
@@List
|
||||
@@Mapping
|
||||
@@UniqueNameTracker
|
||||
|
||||
Checkpoint management:
|
||||
@@CheckpointManager
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -41,6 +44,7 @@ from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
|
||||
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
|
||||
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
|
||||
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
|
||||
from tensorflow.python.training.checkpoint_management import CheckpointManager
|
||||
from tensorflow.python.training.checkpointable.base import CheckpointableBase
|
||||
from tensorflow.python.training.checkpointable.data_structures import List
|
||||
from tensorflow.python.training.checkpointable.data_structures import Mapping
|
||||
|
@ -19,14 +19,19 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os.path
|
||||
import re
|
||||
import time
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -51,7 +56,9 @@ def _GetCheckpointFilename(save_dir, latest_filename):
|
||||
@tf_export("train.generate_checkpoint_state_proto")
|
||||
def generate_checkpoint_state_proto(save_dir,
|
||||
model_checkpoint_path,
|
||||
all_model_checkpoint_paths=None):
|
||||
all_model_checkpoint_paths=None,
|
||||
all_model_checkpoint_timestamps=None,
|
||||
last_preserved_timestamp=None):
|
||||
"""Generates a checkpoint state proto.
|
||||
|
||||
Args:
|
||||
@ -61,11 +68,20 @@ def generate_checkpoint_state_proto(save_dir,
|
||||
checkpoints, sorted from oldest to newest. If this is a non-empty list,
|
||||
the last element must be equal to model_checkpoint_path. These paths
|
||||
are also saved in the CheckpointState proto.
|
||||
|
||||
all_model_checkpoint_timestamps: A list of floats, indicating the number of
|
||||
seconds since the Epoch when each checkpoint was generated.
|
||||
last_preserved_timestamp: A float, indicating the number of seconds since
|
||||
the Epoch when the last preserved checkpoint was written, e.g. due to a
|
||||
`keep_checkpoint_every_n_hours` parameter (see
|
||||
`tf.contrib.checkpoint.CheckpointManager` for an implementation).
|
||||
Returns:
|
||||
CheckpointState proto with model_checkpoint_path and
|
||||
all_model_checkpoint_paths updated to either absolute paths or
|
||||
relative paths to the current save_dir.
|
||||
|
||||
Raises:
|
||||
ValueError: If `all_model_checkpoint_timestamps` was provided but its length
|
||||
does not match `all_model_checkpoint_paths`.
|
||||
"""
|
||||
if all_model_checkpoint_paths is None:
|
||||
all_model_checkpoint_paths = []
|
||||
@ -76,6 +92,14 @@ def generate_checkpoint_state_proto(save_dir,
|
||||
model_checkpoint_path)
|
||||
all_model_checkpoint_paths.append(model_checkpoint_path)
|
||||
|
||||
if (all_model_checkpoint_timestamps
|
||||
and (len(all_model_checkpoint_timestamps)
|
||||
!= len(all_model_checkpoint_paths))):
|
||||
raise ValueError(
|
||||
("Checkpoint timestamps, if provided, must match checkpoint paths (got "
|
||||
"paths %s and timestamps %s)")
|
||||
% (all_model_checkpoint_paths, all_model_checkpoint_timestamps))
|
||||
|
||||
# Relative paths need to be rewritten to be relative to the "save_dir"
|
||||
# if model_checkpoint_path already contains "save_dir".
|
||||
if not os.path.isabs(save_dir):
|
||||
@ -88,7 +112,9 @@ def generate_checkpoint_state_proto(save_dir,
|
||||
|
||||
coord_checkpoint_proto = CheckpointState(
|
||||
model_checkpoint_path=model_checkpoint_path,
|
||||
all_model_checkpoint_paths=all_model_checkpoint_paths)
|
||||
all_model_checkpoint_paths=all_model_checkpoint_paths,
|
||||
all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
|
||||
last_preserved_timestamp=last_preserved_timestamp)
|
||||
|
||||
return coord_checkpoint_proto
|
||||
|
||||
@ -97,7 +123,9 @@ def generate_checkpoint_state_proto(save_dir,
|
||||
def update_checkpoint_state(save_dir,
|
||||
model_checkpoint_path,
|
||||
all_model_checkpoint_paths=None,
|
||||
latest_filename=None):
|
||||
latest_filename=None,
|
||||
all_model_checkpoint_timestamps=None,
|
||||
last_preserved_timestamp=None):
|
||||
"""Updates the content of the 'checkpoint' file.
|
||||
|
||||
This updates the checkpoint file containing a CheckpointState
|
||||
@ -112,7 +140,13 @@ def update_checkpoint_state(save_dir,
|
||||
are also saved in the CheckpointState proto.
|
||||
latest_filename: Optional name of the checkpoint file. Default to
|
||||
'checkpoint'.
|
||||
|
||||
all_model_checkpoint_timestamps: Optional list of timestamps (floats,
|
||||
seconds since the Epoch) indicating when the checkpoints in
|
||||
`all_model_checkpoint_paths` were created.
|
||||
last_preserved_timestamp: A float, indicating the number of seconds since
|
||||
the Epoch when the last preserved checkpoint was written, e.g. due to a
|
||||
`keep_checkpoint_every_n_hours` parameter (see
|
||||
`tf.contrib.checkpoint.CheckpointManager` for an implementation).
|
||||
Raises:
|
||||
RuntimeError: If any of the model checkpoint paths conflict with the file
|
||||
containing CheckpointSate.
|
||||
@ -122,14 +156,18 @@ def update_checkpoint_state(save_dir,
|
||||
model_checkpoint_path=model_checkpoint_path,
|
||||
all_model_checkpoint_paths=all_model_checkpoint_paths,
|
||||
latest_filename=latest_filename,
|
||||
save_relative_paths=False)
|
||||
save_relative_paths=False,
|
||||
all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
|
||||
last_preserved_timestamp=last_preserved_timestamp)
|
||||
|
||||
|
||||
def update_checkpoint_state_internal(save_dir,
|
||||
model_checkpoint_path,
|
||||
all_model_checkpoint_paths=None,
|
||||
latest_filename=None,
|
||||
save_relative_paths=False):
|
||||
save_relative_paths=False,
|
||||
all_model_checkpoint_timestamps=None,
|
||||
last_preserved_timestamp=None):
|
||||
"""Updates the content of the 'checkpoint' file.
|
||||
|
||||
This updates the checkpoint file containing a CheckpointState
|
||||
@ -146,6 +184,13 @@ def update_checkpoint_state_internal(save_dir,
|
||||
'checkpoint'.
|
||||
save_relative_paths: If `True`, will write relative paths to the checkpoint
|
||||
state file.
|
||||
all_model_checkpoint_timestamps: Optional list of timestamps (floats,
|
||||
seconds since the Epoch) indicating when the checkpoints in
|
||||
`all_model_checkpoint_paths` were created.
|
||||
last_preserved_timestamp: A float, indicating the number of seconds since
|
||||
the Epoch when the last preserved checkpoint was written, e.g. due to a
|
||||
`keep_checkpoint_every_n_hours` parameter (see
|
||||
`tf.contrib.checkpoint.CheckpointManager` for an implementation).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If any of the model checkpoint paths conflict with the file
|
||||
@ -168,12 +213,16 @@ def update_checkpoint_state_internal(save_dir,
|
||||
ckpt = generate_checkpoint_state_proto(
|
||||
save_dir,
|
||||
rel_model_checkpoint_path,
|
||||
all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
|
||||
all_model_checkpoint_paths=rel_all_model_checkpoint_paths,
|
||||
all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
|
||||
last_preserved_timestamp=last_preserved_timestamp)
|
||||
else:
|
||||
ckpt = generate_checkpoint_state_proto(
|
||||
save_dir,
|
||||
model_checkpoint_path,
|
||||
all_model_checkpoint_paths=all_model_checkpoint_paths)
|
||||
all_model_checkpoint_paths=all_model_checkpoint_paths,
|
||||
all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
|
||||
last_preserved_timestamp=last_preserved_timestamp)
|
||||
|
||||
if coord_checkpoint_filename == ckpt.model_checkpoint_path:
|
||||
raise RuntimeError("Save path '%s' conflicts with path used for "
|
||||
@ -404,3 +453,217 @@ def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
|
||||
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
|
||||
suffixed_filename = ".".join([basename, meta_graph_suffix])
|
||||
return suffixed_filename
|
||||
|
||||
|
||||
# TODO(allenl): Allow tf.keras.Model instances in the constructor directly?
|
||||
class CheckpointManager(object):
|
||||
"""Deletes old checkpoints.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
import tensorflow as tf
|
||||
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
|
||||
manager = tf.contrib.checkpoint.CheckpointManager(
|
||||
checkpoint, directory="/tmp/model", max_to_keep=5)
|
||||
status = checkpoint.restore(manager.latest_checkpoint)
|
||||
while True:
|
||||
# train
|
||||
manager.save()
|
||||
```
|
||||
|
||||
`CheckpointManager` preserves its own state across instantiations (see the
|
||||
`__init__` documentation for details). Only one should be active in a
|
||||
particular directory at a time.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint, directory,
|
||||
max_to_keep, keep_checkpoint_every_n_hours=None):
|
||||
"""Configure a `CheckpointManager` for use in `directory`.
|
||||
|
||||
If a `CheckpointManager` was previously used in `directory`, its
|
||||
state will be restored. This includes the list of managed checkpoints and
|
||||
the timestamp bookkeeping necessary to support
|
||||
`keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager`
|
||||
will be the same as the previous `CheckpointManager`, including cleaning up
|
||||
existing checkpoints if appropriate.
|
||||
|
||||
Checkpoints are only considered for deletion just after a new checkpoint has
|
||||
been added. At that point, `max_to_keep` checkpoints will remain in an
|
||||
"active set". Once a checkpoint is preserved by
|
||||
`keep_checkpoint_every_n_hours` it will not be deleted by this
|
||||
`CheckpointManager` or any future `CheckpointManager` instantiated in
|
||||
`directory` (regardless of the new setting of
|
||||
`keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the
|
||||
active set may be deleted by this `CheckpointManager` or a future
|
||||
`CheckpointManager` instantiated in `directory` (subject to its
|
||||
`max_to_keep` and `keep_checkpoint_every_n_hours` settings).
|
||||
|
||||
Args:
|
||||
checkpoint: The `tf.train.Checkpoint` instance to save and manage
|
||||
checkpoints for.
|
||||
directory: The path to a directory in which to write checkpoints. A
|
||||
special file named "checkpoint" is also written to this directory (in a
|
||||
human-readable text format) which contains the state of the
|
||||
`CheckpointManager`.
|
||||
max_to_keep: An integer, the number of checkpoints to keep. Unless
|
||||
preserved by `keep_checkpoint_every_n_hours`, checkpoints will be
|
||||
deleted from the active set, oldest first, until only `max_to_keep`
|
||||
checkpoints remain.
|
||||
keep_checkpoint_every_n_hours: Upon removal from the active set, a
|
||||
checkpoint will be preserved if it has been at least
|
||||
`keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
|
||||
default setting of `None` does not preserve any checkpoints in this way.
|
||||
|
||||
Raises:
|
||||
ValueError: If `max_to_keep` is not a positive integer.
|
||||
"""
|
||||
self._checkpoint = checkpoint
|
||||
self._save_counter_assign = None
|
||||
if not max_to_keep or max_to_keep < 0:
|
||||
raise ValueError(
|
||||
"Expected a positive integer for `max_to_max_to_keep`, got %d."
|
||||
% (max_to_keep,))
|
||||
self._max_to_keep = max_to_keep
|
||||
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
|
||||
self._directory = directory
|
||||
self._checkpoint_prefix = os.path.join(directory, "ckpt")
|
||||
recovered_state = get_checkpoint_state(directory)
|
||||
current_clock = time.time()
|
||||
self._maybe_delete = collections.OrderedDict()
|
||||
if recovered_state is None:
|
||||
self._latest_checkpoint = None
|
||||
self._last_preserved_timestamp = current_clock
|
||||
else:
|
||||
self._latest_checkpoint = recovered_state.model_checkpoint_path
|
||||
self._last_preserved_timestamp = recovered_state.last_preserved_timestamp
|
||||
if current_clock < self._last_preserved_timestamp:
|
||||
# Time seems to have reversed itself. In addition to this warning, we'll
|
||||
# min() saved checkpoint timestamps with the current time to ensure that
|
||||
# old checkpoints don't get deleted accidentally.
|
||||
logging.warning(
|
||||
("time.time() returned a value %f seconds behind the last "
|
||||
"preserved checkpoint timestamp.")
|
||||
% (self._last_preserved_timestamp - current_clock,))
|
||||
self._last_preserved_timestamp = current_clock
|
||||
all_timestamps = recovered_state.all_model_checkpoint_timestamps
|
||||
all_paths = recovered_state.all_model_checkpoint_paths
|
||||
del recovered_state # Uses modified values from now on
|
||||
if not all_timestamps:
|
||||
all_timestamps = [self._last_preserved_timestamp] * len(all_paths)
|
||||
|
||||
for filename, timestamp in zip(all_paths, all_timestamps):
|
||||
timestamp = min(timestamp, current_clock)
|
||||
if timestamp > self._last_preserved_timestamp:
|
||||
self._maybe_delete[filename] = timestamp
|
||||
|
||||
@property
|
||||
def latest_checkpoint(self):
|
||||
"""The prefix of the most recent checkpoint in `directory`.
|
||||
|
||||
Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is
|
||||
the constructor argument to `CheckpointManager`.
|
||||
|
||||
Suitable for passing to `tf.train.Checkpoint.restore` to resume training.
|
||||
|
||||
Returns:
|
||||
The checkpoint prefix. If there are no checkpoints, returns `None`.
|
||||
"""
|
||||
return self._latest_checkpoint
|
||||
|
||||
@property
|
||||
def checkpoints(self):
|
||||
"""A list of managed checkpoints.
|
||||
|
||||
Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not
|
||||
show up in this list (to avoid ever-growing filename lists).
|
||||
|
||||
Returns:
|
||||
A list of filenames, sorted from oldest to newest.
|
||||
"""
|
||||
return list(self._maybe_delete.keys())
|
||||
|
||||
def _sweep(self):
|
||||
"""Deletes or preserves managed checkpoints."""
|
||||
while len(self._maybe_delete) > self._max_to_keep:
|
||||
filename, timestamp = self._maybe_delete.popitem(last=False)
|
||||
# Even if we're keeping this checkpoint due to
|
||||
# keep_checkpoint_every_n_hours, we won't reference it to avoid
|
||||
# infinitely-growing CheckpointState protos.
|
||||
if (self._keep_checkpoint_every_n_hours
|
||||
and (timestamp - self._keep_checkpoint_every_n_hours * 3600.
|
||||
>= self._last_preserved_timestamp)):
|
||||
self._last_preserved_timestamp = timestamp
|
||||
continue
|
||||
remove_checkpoint(filename)
|
||||
|
||||
def _record_state(self):
|
||||
"""Saves the `CheckpointManager`'s state in `directory`."""
|
||||
filenames, timestamps = zip(*self._maybe_delete.items())
|
||||
update_checkpoint_state_internal(
|
||||
self._directory,
|
||||
model_checkpoint_path=self.latest_checkpoint,
|
||||
all_model_checkpoint_paths=filenames,
|
||||
all_model_checkpoint_timestamps=timestamps,
|
||||
last_preserved_timestamp=self._last_preserved_timestamp,
|
||||
save_relative_paths=True)
|
||||
|
||||
@property
|
||||
def _prefix(self):
|
||||
"""A common prefix for all checkpoints saved with this manager.
|
||||
|
||||
For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`,
|
||||
`prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be
|
||||
numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each
|
||||
checkpoint has several associated files
|
||||
(e.g. `"/tmp/tf-model/ckpt-2.index"`).
|
||||
|
||||
Returns:
|
||||
A string prefix.
|
||||
"""
|
||||
return self._checkpoint_prefix
|
||||
|
||||
def save(self, session=None):
|
||||
"""Creates a new checkpoint and manages it.
|
||||
|
||||
Args:
|
||||
session: The session to evaluate variables in. Ignored when executing
|
||||
eagerly. If not provided when graph building, the default session is
|
||||
used.
|
||||
|
||||
Returns:
|
||||
The path to the new checkpoint. It is also recorded in the `checkpoints`
|
||||
and `latest_checkpoint` properies.
|
||||
"""
|
||||
# Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
|
||||
# slightly with a custom numbering option.
|
||||
if context.executing_eagerly():
|
||||
save_counter = self._checkpoint.save_counter
|
||||
save_counter.assign_add(1)
|
||||
checkpoint_number = save_counter.numpy()
|
||||
else:
|
||||
if session is None:
|
||||
session = ops.get_default_session()
|
||||
|
||||
def _initializing_creator(next_creator, **kwargs):
|
||||
"""Initialize the save counter if it has been newly created."""
|
||||
v = next_creator(**kwargs)
|
||||
session.run(v.initializer)
|
||||
return v
|
||||
|
||||
with variable_scope.variable_creator_scope(_initializing_creator):
|
||||
save_counter = self._checkpoint.save_counter
|
||||
if self._save_counter_assign is None:
|
||||
self._save_counter_assign = save_counter.assign_add(1, read_value=True)
|
||||
checkpoint_number = session.run(self._save_counter_assign)
|
||||
prefix = "%s-%d" % (self._prefix, checkpoint_number)
|
||||
save_path = self._checkpoint.write(prefix)
|
||||
timestamp = time.time()
|
||||
# If this is an overwritten checkpoint we were previously tracking, delete
|
||||
# and reinsert it to make sure it goes to the end of the queue.
|
||||
if save_path in self._maybe_delete:
|
||||
del self._maybe_delete[save_path]
|
||||
self._maybe_delete[save_path] = timestamp
|
||||
self._latest_checkpoint = save_path
|
||||
self._sweep()
|
||||
self._record_state()
|
||||
return save_path
|
||||
|
@ -27,13 +27,16 @@ from google.protobuf import text_format
|
||||
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.framework import ops as ops_lib
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_module
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
from tensorflow.python.training.checkpointable import util
|
||||
|
||||
|
||||
class LatestCheckpointWithRelativePaths(test.TestCase):
|
||||
@ -312,5 +315,177 @@ class SaverUtilsTest(test.TestCase):
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
|
||||
|
||||
|
||||
class CheckpointManagerTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testDeletion(self):
|
||||
checkpoint = util.Checkpoint()
|
||||
manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint, self.get_temp_dir(), max_to_keep=3)
|
||||
first_path = manager.save()
|
||||
second_path = manager.save()
|
||||
third_path = manager.save()
|
||||
fourth_path = manager.save()
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@test.mock.patch.object(checkpoint_management, "time")
|
||||
def testSaveRestoreState(self, mock_time):
|
||||
directory = self.get_temp_dir()
|
||||
mock_time.time.return_value = 3.
|
||||
checkpoint = util.Checkpoint()
|
||||
first_manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint, directory, max_to_keep=2)
|
||||
first_time = 10000.
|
||||
first_name = os.path.join(directory, "ckpt-1")
|
||||
mock_time.time.return_value = first_time
|
||||
first_manager.save()
|
||||
state = checkpoint_management.get_checkpoint_state(directory)
|
||||
self.assertEqual([first_time], state.all_model_checkpoint_timestamps)
|
||||
self.assertEqual(3., state.last_preserved_timestamp)
|
||||
second_time = first_time + 3610.
|
||||
second_name = os.path.join(directory, "ckpt-2")
|
||||
mock_time.time.return_value = second_time
|
||||
first_manager.save()
|
||||
state = checkpoint_management.get_checkpoint_state(directory)
|
||||
self.assertEqual([first_time, second_time],
|
||||
state.all_model_checkpoint_timestamps)
|
||||
self.assertEqual(3., state.last_preserved_timestamp)
|
||||
self.assertEqual([first_name, second_name], first_manager.checkpoints)
|
||||
self.assertEqual(second_name, first_manager.latest_checkpoint)
|
||||
del first_manager
|
||||
|
||||
second_manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint, directory,
|
||||
max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
|
||||
self.assertEqual([first_name, second_name], second_manager.checkpoints)
|
||||
self.assertEqual(second_name, second_manager.latest_checkpoint)
|
||||
third_name = os.path.join(directory, "ckpt-3")
|
||||
third_time = second_time + 3600. * 0.2
|
||||
mock_time.time.return_value = third_time
|
||||
second_manager.save()
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(second_name))
|
||||
self.assertEqual([second_name, third_name],
|
||||
second_manager.checkpoints)
|
||||
state = checkpoint_management.get_checkpoint_state(directory)
|
||||
self.assertEqual(first_time, state.last_preserved_timestamp)
|
||||
fourth_time = third_time + 3600. * 0.5
|
||||
mock_time.time.return_value = fourth_time
|
||||
fourth_name = os.path.join(directory, "ckpt-4")
|
||||
second_manager.save()
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
|
||||
self.assertEqual([third_name, fourth_name],
|
||||
second_manager.checkpoints)
|
||||
fifth_time = fourth_time + 3600. * 0.5
|
||||
mock_time.time.return_value = fifth_time
|
||||
fifth_name = os.path.join(directory, "ckpt-5")
|
||||
second_manager.save()
|
||||
self.assertEqual([fourth_name, fifth_name],
|
||||
second_manager.checkpoints)
|
||||
state = checkpoint_management.get_checkpoint_state(directory)
|
||||
self.assertEqual(first_time, state.last_preserved_timestamp)
|
||||
del second_manager
|
||||
third_manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint, directory,
|
||||
max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
|
||||
self.assertEqual(fifth_name, third_manager.latest_checkpoint)
|
||||
mock_time.time.return_value += 10.
|
||||
third_manager.save()
|
||||
sixth_name = os.path.join(directory, "ckpt-6")
|
||||
state = checkpoint_management.get_checkpoint_state(directory)
|
||||
self.assertEqual(fourth_time, state.last_preserved_timestamp)
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(third_name))
|
||||
self.assertEqual([fifth_name, sixth_name],
|
||||
third_manager.checkpoints)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testContinueFromUnmanaged(self):
|
||||
directory = self.get_temp_dir()
|
||||
prefix = os.path.join(directory, "unusual_prefix")
|
||||
checkpoint = util.Checkpoint()
|
||||
first_path = checkpoint.save(prefix)
|
||||
second_path = checkpoint.save(prefix)
|
||||
del checkpoint
|
||||
checkpoint = util.Checkpoint()
|
||||
manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint, directory, max_to_keep=2)
|
||||
checkpoint.restore(manager.latest_checkpoint).run_restore_ops()
|
||||
self.assertEqual(2, self.evaluate(checkpoint.save_counter))
|
||||
third_path = manager.save()
|
||||
self.assertEqual([third_path], manager.checkpoints)
|
||||
fourth_path = manager.save()
|
||||
self.assertEqual([third_path, fourth_path],
|
||||
manager.checkpoints)
|
||||
fifth_path = manager.save()
|
||||
self.assertEqual([fourth_path, fifth_path],
|
||||
manager.checkpoints)
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(third_path))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@test.mock.patch.object(checkpoint_management, "time")
|
||||
def testClockReset(self, mock_time):
|
||||
directory = self.get_temp_dir()
|
||||
mock_time.time.return_value = 10000.
|
||||
checkpoint = util.Checkpoint()
|
||||
first_manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint, directory, max_to_keep=1, keep_checkpoint_every_n_hours=1.)
|
||||
first_path = first_manager.save()
|
||||
mock_time.time.return_value += 3600.
|
||||
second_path = first_manager.save()
|
||||
mock_time.time.return_value += 3600.
|
||||
third_path = first_manager.save()
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
|
||||
self.assertEqual([third_path], first_manager.checkpoints)
|
||||
state = checkpoint_management.get_checkpoint_state(directory)
|
||||
self.assertEqual(13600., state.last_preserved_timestamp)
|
||||
# Set the clock back in time
|
||||
mock_time.time.return_value = 5000.
|
||||
del first_manager
|
||||
with test.mock.patch.object(logging, "warning") as mock_log:
|
||||
second_manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint, directory, max_to_keep=1)
|
||||
self.assertRegexpMatches(
|
||||
str(mock_log.call_args),
|
||||
"behind the last preserved checkpoint timestamp")
|
||||
# We should err on the side of keeping checkpoints around when we're not
|
||||
# sure whether they were preserved or not due to clock funkiness.
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
|
||||
# We know about the existing checkpoints, but they'll never be deleted and
|
||||
# so won't go in the CheckpointState proto on save.
|
||||
self.assertEqual(third_path, second_manager.latest_checkpoint)
|
||||
self.assertEqual([], second_manager.checkpoints)
|
||||
mock_time.time.return_value += 10.
|
||||
fourth_path = second_manager.save()
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
|
||||
self.assertEqual(fourth_path, second_manager.latest_checkpoint)
|
||||
self.assertEqual([fourth_path], second_manager.checkpoints)
|
||||
mock_time.time.return_value += 10.
|
||||
fifth_path = second_manager.save()
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
|
||||
self.assertEqual([fifth_path], second_manager.checkpoints)
|
||||
state = checkpoint_management.get_checkpoint_state(directory)
|
||||
self.assertEqual(5000., state.last_preserved_timestamp)
|
||||
self.assertEqual([5020.],
|
||||
state.all_model_checkpoint_timestamps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -4,8 +4,6 @@ package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
// Protocol buffer representing the checkpoint state.
|
||||
//
|
||||
// TODO(touts): Add other attributes as needed.
|
||||
message CheckpointState {
|
||||
// Path to the most-recent model checkpoint.
|
||||
string model_checkpoint_path = 1;
|
||||
@ -15,4 +13,10 @@ message CheckpointState {
|
||||
// Note that the value of model_checkpoint_path should be the last item in
|
||||
// this list.
|
||||
repeated string all_model_checkpoint_paths = 2;
|
||||
// Unix timestamps corresponding to all_model_checkpoint_paths, indicating
|
||||
// when each checkpoint was created.
|
||||
repeated double all_model_checkpoint_timestamps = 3;
|
||||
// Unix timestamp indicating the creation time for the last preserved
|
||||
// checkpoint.
|
||||
double last_preserved_timestamp = 4;
|
||||
}
|
||||
|
@ -35,8 +35,8 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_io_ops as io_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import optimizer as optimizer_lib
|
||||
from tensorflow.python.training import saveable_object as saveable_object_lib
|
||||
@ -227,10 +227,11 @@ def _default_getter(name, shape, dtype, initializer=None,
|
||||
def initial_value():
|
||||
return initializer(
|
||||
shape_object.as_list(), dtype=dtype, partition_info=partition_info)
|
||||
return resource_variable_ops.ResourceVariable(
|
||||
return variables.Variable(
|
||||
initial_value=initial_value,
|
||||
name=name,
|
||||
dtype=variable_dtype,
|
||||
use_resource=True,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@ -1528,8 +1529,6 @@ class Checkpoint(tracking.Checkpointable):
|
||||
self._maybe_create_save_counter()
|
||||
return self._save_counter
|
||||
|
||||
# TODO(allenl): Update save's docstring with a pointer to
|
||||
# tf.contrib.checkpoint.CheckpointManager once that's in.
|
||||
def save(self, file_prefix, session=None):
|
||||
"""Saves a training checkpoint and provides basic checkpoint management.
|
||||
|
||||
@ -1541,7 +1540,8 @@ class Checkpoint(tracking.Checkpointable):
|
||||
sequentially numbering checkpoints using `save_counter` and updating the
|
||||
metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
|
||||
management, for example garbage collection and custom numbering, may be
|
||||
provided by other utilities which also wrap `write`.
|
||||
provided by other utilities which also wrap `write`
|
||||
(`tf.contrib.checkpoint.CheckpointManager` for example).
|
||||
|
||||
Args:
|
||||
file_prefix: A prefix to use for the checkpoint filenames
|
||||
|
@ -522,7 +522,6 @@ class CheckpointingTests(test.TestCase):
|
||||
# Does create garbage when executing eagerly due to ops.Graph() creation.
|
||||
num_training_steps = 10
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
for training_continuation in range(3):
|
||||
with ops.Graph().as_default(), self.test_session(
|
||||
graph=ops.get_default_graph()), test_util.device(use_gpu=True):
|
||||
@ -531,9 +530,9 @@ class CheckpointingTests(test.TestCase):
|
||||
root = checkpointable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model,
|
||||
global_step=training_util.get_or_create_global_step())
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
manager = checkpoint_management.CheckpointManager(
|
||||
root, checkpoint_directory, max_to_keep=1)
|
||||
status = root.restore(save_path=manager.latest_checkpoint)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
train_fn = functools.partial(
|
||||
optimizer.minimize,
|
||||
@ -544,7 +543,7 @@ class CheckpointingTests(test.TestCase):
|
||||
status.initialize_or_restore()
|
||||
for _ in range(num_training_steps):
|
||||
train_fn()
|
||||
root.save(file_prefix=checkpoint_prefix)
|
||||
manager.save()
|
||||
self.assertEqual((training_continuation + 1) * num_training_steps,
|
||||
self.evaluate(root.global_step))
|
||||
self.assertEqual(training_continuation + 1,
|
||||
|
@ -298,7 +298,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "generate_checkpoint_state_proto"
|
||||
argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'all_model_checkpoint_timestamps\', \'last_preserved_timestamp\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_checkpoint_mtimes"
|
||||
@ -446,7 +446,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "update_checkpoint_state"
|
||||
argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'latest_filename\', \'all_model_checkpoint_timestamps\', \'last_preserved_timestamp\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "warm_start"
|
||||
|
Loading…
Reference in New Issue
Block a user