Gracefully handle missing checkpoints in recover_last_checkpoints

If some checkpoints present in CheckpointState are absent on disk, recover_last_checkpoints incorrectly initializes Saver internal state.
In this example:
(1) CheckpointState.all_model_checkpoint_paths = ['ckpt-1', 'ckpt-2', 'ckpt-3']
(2) Actual checkpoints on disk: ['ckpt-2', 'ckpt-3']
last_checkpoints gets incorrectly initialized to ['ckpt-1', 'ckpt-2']. This is because get_checkpoint_mtimes silently ignores any absent checkpoints and returns a list of length 2 corresponding to checkpoints on disk, which then gets zipped with (1). After the fix, last_checkpoints would be ['ckpt-2', 'ckpt-3'].

PiperOrigin-RevId: 245983586
This commit is contained in:
Goutham Bhat 2019-04-30 11:10:33 -07:00 committed by TensorFlower Gardener
parent 8ef1021608
commit a8cd934cde
3 changed files with 71 additions and 2 deletions

View File

@ -391,6 +391,10 @@ def get_checkpoint_mtimes(checkpoint_prefixes):
This is the recommended way to get the mtimes, since it takes into account
the naming difference between V1 and V2 formats.
Note: If not all checkpoints exist, the length of the returned mtimes list
will be smaller than the length of `checkpoint_prefixes` list, so mapping
checkpoints to corresponding mtimes will not be possible.
Args:
checkpoint_prefixes: a list of checkpoint paths, typically the results of
`Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of

View File

@ -1061,8 +1061,12 @@ class Saver(object):
Args:
checkpoint_paths: a list of checkpoint paths.
"""
mtimes = checkpoint_management.get_checkpoint_mtimes(checkpoint_paths)
self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes)))
checkpoints_with_mtimes = []
for checkpoint_path in checkpoint_paths:
mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path])
if mtime:
checkpoints_with_mtimes.append((checkpoint_path, mtime[0]))
self.set_last_checkpoints_with_time(checkpoints_with_mtimes)
def save(self,
sess,

View File

@ -1475,6 +1475,67 @@ class MaxToKeepTest(test.TestCase):
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
class RecoverLastCheckpointsTest(test.TestCase):
def _get_test_dir(self, dirname):
test_dir = os.path.join(self.get_temp_dir(), dirname)
gfile.MakeDirs(test_dir)
return test_dir
def assertCheckpointState(self, model_checkpoint_path,
all_model_checkpoint_paths, save_dir):
checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir)
self.assertEqual(checkpoint_state.model_checkpoint_path,
model_checkpoint_path)
self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
all_model_checkpoint_paths)
def test_recover_last_checkpoints(self):
with context.eager_mode():
save_dir = self._get_test_dir("recover_last_checkpoints")
v = variable_scope.variable(10.0, name="v")
save = saver_module.Saver({"v": v}, max_to_keep=10)
self.evaluate(variables.global_variables_initializer())
self.assertEqual([], save.last_checkpoints)
s1 = save.save(None, os.path.join(save_dir, "ckpt-1"))
s2 = save.save(None, os.path.join(save_dir, "ckpt-2"))
s3 = save.save(None, os.path.join(save_dir, "ckpt-3"))
self.assertEqual([s1, s2, s3], save.last_checkpoints)
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertCheckpointState(
model_checkpoint_path=s3,
all_model_checkpoint_paths=[s1, s2, s3],
save_dir=save_dir)
# Create another saver and recover last checkpoints.
save2 = saver_module.Saver({"v": v}, max_to_keep=10)
self.assertEqual([], save2.last_checkpoints)
save2.recover_last_checkpoints([s1, s2, s3])
self.assertEqual([s1, s2, s3], save2.last_checkpoints)
# Remove a checkpoint and check that last checkpoints are
# restored correctly.
for fname in gfile.Glob("{}*".format(s1)):
gfile.Remove(fname)
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
# Create another saver and recover last checkpoints. The removed
# checkpoint would be correctly omitted.
save3 = saver_module.Saver({"v": v}, max_to_keep=10)
self.assertEqual([], save3.last_checkpoints)
save3.recover_last_checkpoints([s1, s2, s3])
self.assertEqual([s2, s3], save3.last_checkpoints)
s4 = save3.save(None, os.path.join(save_dir, "ckpt-4"))
self.assertCheckpointState(
model_checkpoint_path=s4,
all_model_checkpoint_paths=[s2, s3, s4],
save_dir=save_dir)
class KeepCheckpointEveryNHoursTest(test.TestCase):
def _get_test_dir(self, dirname):