Gracefully handle missing checkpoints in recover_last_checkpoints

If some checkpoints present in CheckpointState are absent on disk, recover_last_checkpoints incorrectly initializes Saver internal state.
In this example:
(1) CheckpointState.all_model_checkpoint_paths = ['ckpt-1', 'ckpt-2', 'ckpt-3']
(2) Actual checkpoints on disk: ['ckpt-2', 'ckpt-3']
last_checkpoints gets incorrectly initialized to ['ckpt-1', 'ckpt-2']. This is because get_checkpoint_mtimes silently ignores any absent checkpoints and returns a list of length 2 corresponding to checkpoints on disk, which then gets zipped with (1). After the fix, last_checkpoints would be ['ckpt-2', 'ckpt-3'].

PiperOrigin-RevId: 245983586
This commit is contained in:
Goutham Bhat 2019-04-30 11:10:33 -07:00 committed by TensorFlower Gardener
parent 8ef1021608
commit a8cd934cde
3 changed files with 71 additions and 2 deletions

View File

@ -391,6 +391,10 @@ def get_checkpoint_mtimes(checkpoint_prefixes):
This is the recommended way to get the mtimes, since it takes into account This is the recommended way to get the mtimes, since it takes into account
the naming difference between V1 and V2 formats. the naming difference between V1 and V2 formats.
Note: If not all checkpoints exist, the length of the returned mtimes list
will be smaller than the length of `checkpoint_prefixes` list, so mapping
checkpoints to corresponding mtimes will not be possible.
Args: Args:
checkpoint_prefixes: a list of checkpoint paths, typically the results of checkpoint_prefixes: a list of checkpoint paths, typically the results of
`Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of

View File

@ -1061,8 +1061,12 @@ class Saver(object):
Args: Args:
checkpoint_paths: a list of checkpoint paths. checkpoint_paths: a list of checkpoint paths.
""" """
mtimes = checkpoint_management.get_checkpoint_mtimes(checkpoint_paths) checkpoints_with_mtimes = []
self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes))) for checkpoint_path in checkpoint_paths:
mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path])
if mtime:
checkpoints_with_mtimes.append((checkpoint_path, mtime[0]))
self.set_last_checkpoints_with_time(checkpoints_with_mtimes)
def save(self, def save(self,
sess, sess,

View File

@ -1475,6 +1475,67 @@ class MaxToKeepTest(test.TestCase):
gfile.Exists(checkpoint_management.meta_graph_filename(s1))) gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
class RecoverLastCheckpointsTest(test.TestCase):
def _get_test_dir(self, dirname):
test_dir = os.path.join(self.get_temp_dir(), dirname)
gfile.MakeDirs(test_dir)
return test_dir
def assertCheckpointState(self, model_checkpoint_path,
all_model_checkpoint_paths, save_dir):
checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir)
self.assertEqual(checkpoint_state.model_checkpoint_path,
model_checkpoint_path)
self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
all_model_checkpoint_paths)
def test_recover_last_checkpoints(self):
with context.eager_mode():
save_dir = self._get_test_dir("recover_last_checkpoints")
v = variable_scope.variable(10.0, name="v")
save = saver_module.Saver({"v": v}, max_to_keep=10)
self.evaluate(variables.global_variables_initializer())
self.assertEqual([], save.last_checkpoints)
s1 = save.save(None, os.path.join(save_dir, "ckpt-1"))
s2 = save.save(None, os.path.join(save_dir, "ckpt-2"))
s3 = save.save(None, os.path.join(save_dir, "ckpt-3"))
self.assertEqual([s1, s2, s3], save.last_checkpoints)
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertCheckpointState(
model_checkpoint_path=s3,
all_model_checkpoint_paths=[s1, s2, s3],
save_dir=save_dir)
# Create another saver and recover last checkpoints.
save2 = saver_module.Saver({"v": v}, max_to_keep=10)
self.assertEqual([], save2.last_checkpoints)
save2.recover_last_checkpoints([s1, s2, s3])
self.assertEqual([s1, s2, s3], save2.last_checkpoints)
# Remove a checkpoint and check that last checkpoints are
# restored correctly.
for fname in gfile.Glob("{}*".format(s1)):
gfile.Remove(fname)
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
# Create another saver and recover last checkpoints. The removed
# checkpoint would be correctly omitted.
save3 = saver_module.Saver({"v": v}, max_to_keep=10)
self.assertEqual([], save3.last_checkpoints)
save3.recover_last_checkpoints([s1, s2, s3])
self.assertEqual([s2, s3], save3.last_checkpoints)
s4 = save3.save(None, os.path.join(save_dir, "ckpt-4"))
self.assertCheckpointState(
model_checkpoint_path=s4,
all_model_checkpoint_paths=[s2, s3, s4],
save_dir=save_dir)
class KeepCheckpointEveryNHoursTest(test.TestCase): class KeepCheckpointEveryNHoursTest(test.TestCase):
def _get_test_dir(self, dirname): def _get_test_dir(self, dirname):