From a8cd934cde19d36c1ff0c7812670150bf2e226c3 Mon Sep 17 00:00:00 2001 From: Goutham Bhat Date: Tue, 30 Apr 2019 11:10:33 -0700 Subject: [PATCH] Gracefully handle missing checkpoints in recover_last_checkpoints If some checkpoints present in CheckpointState are absent on disk, recover_last_checkpoints incorrectly initializes Saver internal state. In this example: (1) CheckpointState.all_model_checkpoint_paths = ['ckpt-1', 'ckpt-2', 'ckpt-3'] (2) Actual checkpoints on disk: ['ckpt-2', 'ckpt-3'] last_checkpoints gets incorrectly initialized to ['ckpt-1', 'ckpt-2']. This is because get_checkpoint_mtimes silently ignores any absent checkpoints and returns a list of length 2 corresponding to checkpoints on disk, which then gets zipped with (1). After the fix, last_checkpoints would be ['ckpt-2', 'ckpt-3']. PiperOrigin-RevId: 245983586 --- .../python/training/checkpoint_management.py | 4 ++ tensorflow/python/training/saver.py | 8 ++- tensorflow/python/training/saver_test.py | 61 +++++++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py index 131ecf71ba5..061ef0f71e4 100644 --- a/tensorflow/python/training/checkpoint_management.py +++ b/tensorflow/python/training/checkpoint_management.py @@ -391,6 +391,10 @@ def get_checkpoint_mtimes(checkpoint_prefixes): This is the recommended way to get the mtimes, since it takes into account the naming difference between V1 and V2 formats. + Note: If not all checkpoints exist, the length of the returned mtimes list + will be smaller than the length of `checkpoint_prefixes` list, so mapping + checkpoints to corresponding mtimes will not be possible. + Args: checkpoint_prefixes: a list of checkpoint paths, typically the results of `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index d4c4578d361..26eafaa06a8 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1061,8 +1061,12 @@ class Saver(object): Args: checkpoint_paths: a list of checkpoint paths. """ - mtimes = checkpoint_management.get_checkpoint_mtimes(checkpoint_paths) - self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes))) + checkpoints_with_mtimes = [] + for checkpoint_path in checkpoint_paths: + mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path]) + if mtime: + checkpoints_with_mtimes.append((checkpoint_path, mtime[0])) + self.set_last_checkpoints_with_time(checkpoints_with_mtimes) def save(self, sess, diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 89e64a5d3b1..99492bc5890 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -1475,6 +1475,67 @@ class MaxToKeepTest(test.TestCase): gfile.Exists(checkpoint_management.meta_graph_filename(s1))) +class RecoverLastCheckpointsTest(test.TestCase): + + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + + def assertCheckpointState(self, model_checkpoint_path, + all_model_checkpoint_paths, save_dir): + checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir) + self.assertEqual(checkpoint_state.model_checkpoint_path, + model_checkpoint_path) + self.assertEqual(checkpoint_state.all_model_checkpoint_paths, + all_model_checkpoint_paths) + + def test_recover_last_checkpoints(self): + with context.eager_mode(): + save_dir = self._get_test_dir("recover_last_checkpoints") + + v = variable_scope.variable(10.0, name="v") + save = saver_module.Saver({"v": v}, max_to_keep=10) + self.evaluate(variables.global_variables_initializer()) + self.assertEqual([], save.last_checkpoints) + + s1 = save.save(None, os.path.join(save_dir, "ckpt-1")) + s2 = save.save(None, os.path.join(save_dir, "ckpt-2")) + s3 = save.save(None, os.path.join(save_dir, "ckpt-3")) + self.assertEqual([s1, s2, s3], save.last_checkpoints) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s3)) + self.assertCheckpointState( + model_checkpoint_path=s3, + all_model_checkpoint_paths=[s1, s2, s3], + save_dir=save_dir) + + # Create another saver and recover last checkpoints. + save2 = saver_module.Saver({"v": v}, max_to_keep=10) + self.assertEqual([], save2.last_checkpoints) + save2.recover_last_checkpoints([s1, s2, s3]) + self.assertEqual([s1, s2, s3], save2.last_checkpoints) + + # Remove a checkpoint and check that last checkpoints are + # restored correctly. + for fname in gfile.Glob("{}*".format(s1)): + gfile.Remove(fname) + self.assertFalse(checkpoint_management.checkpoint_exists(s1)) + + # Create another saver and recover last checkpoints. The removed + # checkpoint would be correctly omitted. + save3 = saver_module.Saver({"v": v}, max_to_keep=10) + self.assertEqual([], save3.last_checkpoints) + save3.recover_last_checkpoints([s1, s2, s3]) + self.assertEqual([s2, s3], save3.last_checkpoints) + s4 = save3.save(None, os.path.join(save_dir, "ckpt-4")) + self.assertCheckpointState( + model_checkpoint_path=s4, + all_model_checkpoint_paths=[s2, s3, s4], + save_dir=save_dir) + + class KeepCheckpointEveryNHoursTest(test.TestCase): def _get_test_dir(self, dirname):