Gracefully handle missing checkpoints in recover_last_checkpoints
If some checkpoints present in CheckpointState are absent on disk, recover_last_checkpoints incorrectly initializes Saver internal state. In this example: (1) CheckpointState.all_model_checkpoint_paths = ['ckpt-1', 'ckpt-2', 'ckpt-3'] (2) Actual checkpoints on disk: ['ckpt-2', 'ckpt-3'] last_checkpoints gets incorrectly initialized to ['ckpt-1', 'ckpt-2']. This is because get_checkpoint_mtimes silently ignores any absent checkpoints and returns a list of length 2 corresponding to checkpoints on disk, which then gets zipped with (1). After the fix, last_checkpoints would be ['ckpt-2', 'ckpt-3']. PiperOrigin-RevId: 245983586
This commit is contained in:
parent
8ef1021608
commit
a8cd934cde
@ -391,6 +391,10 @@ def get_checkpoint_mtimes(checkpoint_prefixes):
|
|||||||
This is the recommended way to get the mtimes, since it takes into account
|
This is the recommended way to get the mtimes, since it takes into account
|
||||||
the naming difference between V1 and V2 formats.
|
the naming difference between V1 and V2 formats.
|
||||||
|
|
||||||
|
Note: If not all checkpoints exist, the length of the returned mtimes list
|
||||||
|
will be smaller than the length of `checkpoint_prefixes` list, so mapping
|
||||||
|
checkpoints to corresponding mtimes will not be possible.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpoint_prefixes: a list of checkpoint paths, typically the results of
|
checkpoint_prefixes: a list of checkpoint paths, typically the results of
|
||||||
`Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
|
`Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
|
||||||
|
@ -1061,8 +1061,12 @@ class Saver(object):
|
|||||||
Args:
|
Args:
|
||||||
checkpoint_paths: a list of checkpoint paths.
|
checkpoint_paths: a list of checkpoint paths.
|
||||||
"""
|
"""
|
||||||
mtimes = checkpoint_management.get_checkpoint_mtimes(checkpoint_paths)
|
checkpoints_with_mtimes = []
|
||||||
self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes)))
|
for checkpoint_path in checkpoint_paths:
|
||||||
|
mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path])
|
||||||
|
if mtime:
|
||||||
|
checkpoints_with_mtimes.append((checkpoint_path, mtime[0]))
|
||||||
|
self.set_last_checkpoints_with_time(checkpoints_with_mtimes)
|
||||||
|
|
||||||
def save(self,
|
def save(self,
|
||||||
sess,
|
sess,
|
||||||
|
@ -1475,6 +1475,67 @@ class MaxToKeepTest(test.TestCase):
|
|||||||
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
|
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
|
||||||
|
|
||||||
|
|
||||||
|
class RecoverLastCheckpointsTest(test.TestCase):
|
||||||
|
|
||||||
|
def _get_test_dir(self, dirname):
|
||||||
|
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||||
|
gfile.MakeDirs(test_dir)
|
||||||
|
return test_dir
|
||||||
|
|
||||||
|
def assertCheckpointState(self, model_checkpoint_path,
|
||||||
|
all_model_checkpoint_paths, save_dir):
|
||||||
|
checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir)
|
||||||
|
self.assertEqual(checkpoint_state.model_checkpoint_path,
|
||||||
|
model_checkpoint_path)
|
||||||
|
self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
|
||||||
|
all_model_checkpoint_paths)
|
||||||
|
|
||||||
|
def test_recover_last_checkpoints(self):
|
||||||
|
with context.eager_mode():
|
||||||
|
save_dir = self._get_test_dir("recover_last_checkpoints")
|
||||||
|
|
||||||
|
v = variable_scope.variable(10.0, name="v")
|
||||||
|
save = saver_module.Saver({"v": v}, max_to_keep=10)
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.assertEqual([], save.last_checkpoints)
|
||||||
|
|
||||||
|
s1 = save.save(None, os.path.join(save_dir, "ckpt-1"))
|
||||||
|
s2 = save.save(None, os.path.join(save_dir, "ckpt-2"))
|
||||||
|
s3 = save.save(None, os.path.join(save_dir, "ckpt-3"))
|
||||||
|
self.assertEqual([s1, s2, s3], save.last_checkpoints)
|
||||||
|
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||||
|
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||||
|
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
|
||||||
|
self.assertCheckpointState(
|
||||||
|
model_checkpoint_path=s3,
|
||||||
|
all_model_checkpoint_paths=[s1, s2, s3],
|
||||||
|
save_dir=save_dir)
|
||||||
|
|
||||||
|
# Create another saver and recover last checkpoints.
|
||||||
|
save2 = saver_module.Saver({"v": v}, max_to_keep=10)
|
||||||
|
self.assertEqual([], save2.last_checkpoints)
|
||||||
|
save2.recover_last_checkpoints([s1, s2, s3])
|
||||||
|
self.assertEqual([s1, s2, s3], save2.last_checkpoints)
|
||||||
|
|
||||||
|
# Remove a checkpoint and check that last checkpoints are
|
||||||
|
# restored correctly.
|
||||||
|
for fname in gfile.Glob("{}*".format(s1)):
|
||||||
|
gfile.Remove(fname)
|
||||||
|
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
|
||||||
|
|
||||||
|
# Create another saver and recover last checkpoints. The removed
|
||||||
|
# checkpoint would be correctly omitted.
|
||||||
|
save3 = saver_module.Saver({"v": v}, max_to_keep=10)
|
||||||
|
self.assertEqual([], save3.last_checkpoints)
|
||||||
|
save3.recover_last_checkpoints([s1, s2, s3])
|
||||||
|
self.assertEqual([s2, s3], save3.last_checkpoints)
|
||||||
|
s4 = save3.save(None, os.path.join(save_dir, "ckpt-4"))
|
||||||
|
self.assertCheckpointState(
|
||||||
|
model_checkpoint_path=s4,
|
||||||
|
all_model_checkpoint_paths=[s2, s3, s4],
|
||||||
|
save_dir=save_dir)
|
||||||
|
|
||||||
|
|
||||||
class KeepCheckpointEveryNHoursTest(test.TestCase):
|
class KeepCheckpointEveryNHoursTest(test.TestCase):
|
||||||
|
|
||||||
def _get_test_dir(self, dirname):
|
def _get_test_dir(self, dirname):
|
||||||
|
Loading…
Reference in New Issue
Block a user