diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index bd2d78b025e..b8f58a288c7 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1373,23 +1373,6 @@ class Saver(object): name, _ = p return name - def _MetaGraphFilename(self, checkpoint_filename, meta_graph_suffix="meta"): - """Returns the meta graph filename. - - Args: - checkpoint_filename: Name of the checkpoint file. - meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. - - Returns: - MetaGraph file name. - """ - # If the checkpoint_filename is sharded, the checkpoint_filename could - # be of format model.ckpt-step#-?????-of-shard#. For example, - # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. - basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) - meta_graph_filename = ".".join([basename, meta_graph_suffix]) - return meta_graph_filename - def _RecordLastCheckpoint(self, latest_save_path): """Manages the list of the latest checkpoints.""" if not self.saver_def.max_to_keep: @@ -1430,24 +1413,12 @@ class Saver(object): # Otherwise delete the files. try: - checkpoint_prefix = self._CheckpointFilename(p) - self._delete_file_if_exists( - self._MetaGraphFilename(checkpoint_prefix, meta_graph_suffix)) - if self.saver_def.version == saver_pb2.SaverDef.V2: - # V2 has a metadata file and some data files. - self._delete_file_if_exists(checkpoint_prefix + ".index") - self._delete_file_if_exists(checkpoint_prefix + - ".data-?????-of-?????") - else: - # V1, Legacy. Exact match on the data file. - self._delete_file_if_exists(checkpoint_prefix) + remove_checkpoint( + self._CheckpointFilename(p), self.saver_def.version, + meta_graph_suffix) except Exception as e: # pylint: disable=broad-except logging.warning("Ignoring: %s", str(e)) - def _delete_file_if_exists(self, filespec): - for pathname in file_io.get_matching_files(filespec): - file_io.delete_file(pathname) - def as_saver_def(self): """Generates a `SaverDef` representation of this saver. @@ -1669,7 +1640,7 @@ class Saver(object): raise exc if write_meta_graph: - meta_graph_filename = self._MetaGraphFilename( + meta_graph_filename = _meta_graph_filename( checkpoint_file, meta_graph_suffix=meta_graph_suffix) if not context.executing_eagerly(): with sess.graph.as_default(): @@ -2121,6 +2092,55 @@ def get_checkpoint_mtimes(checkpoint_prefixes): return mtimes +@tf_export("train.remove_checkpoint") +def remove_checkpoint(checkpoint_prefix, + checkpoint_format_version=saver_pb2.SaverDef.V2, + meta_graph_suffix="meta"): + """Removes a checkpoint given by `checkpoint_prefix`. + + Args: + checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result + of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of + sharded/non-sharded or V1/V2. + checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to + `SaverDef.V2`. + meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. + """ + _delete_file_if_exists( + _meta_graph_filename(checkpoint_prefix, meta_graph_suffix)) + if checkpoint_format_version == saver_pb2.SaverDef.V2: + # V2 has a metadata file and some data files. + _delete_file_if_exists(checkpoint_prefix + ".index") + _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????") + else: + # V1, Legacy. Exact match on the data file. + _delete_file_if_exists(checkpoint_prefix) + + +def _delete_file_if_exists(filespec): + """Deletes files matching `filespec`.""" + for pathname in file_io.get_matching_files(filespec): + file_io.delete_file(pathname) + + +def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): + """Returns the meta graph filename. + + Args: + checkpoint_filename: Name of the checkpoint file. + meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. + + Returns: + MetaGraph file name. + """ + # If the checkpoint_filename is sharded, the checkpoint_filename could + # be of format model.ckpt-step#-?????-of-shard#. For example, + # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. + basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) + meta_graph_filename = ".".join([basename, meta_graph_suffix]) + return meta_graph_filename + + ops.register_proto_function( ops.GraphKeys.SAVERS, proto_type=saver_pb2.SaverDef, diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index b228cb85d74..e3be7d868e5 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -809,7 +809,7 @@ class SaveRestoreShardedTest(test.TestCase): self.assertEqual(save_path + "-?????-of-00002", val) else: self.assertEqual(save_path, val) - meta_graph_filename = save._MetaGraphFilename(val) + meta_graph_filename = saver_module._meta_graph_filename(val) self.assertEqual(save_path + ".meta", meta_graph_filename) if save._write_version is saver_pb2.SaverDef.V1: @@ -1185,13 +1185,13 @@ class MaxToKeepTest(test.TestCase): self.assertEqual([s3, s2], save.last_checkpoints) self.assertFalse(saver_module.checkpoint_exists(s1)) self.assertFalse( - saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) self.assertTrue(saver_module.checkpoint_exists(s3)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) self.assertTrue(saver_module.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) self.assertCheckpointState( model_checkpoint_path=s2, all_model_checkpoint_paths=[s3, s2], @@ -1202,13 +1202,13 @@ class MaxToKeepTest(test.TestCase): self.assertEqual([s2, s1], save.last_checkpoints) self.assertFalse(saver_module.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) self.assertTrue(saver_module.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) self.assertTrue(saver_module.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], @@ -1222,14 +1222,14 @@ class MaxToKeepTest(test.TestCase): # Created by the first helper. self.assertTrue(saver_module.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) # Deleted by the first helper. self.assertFalse(saver_module.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) self.assertTrue(saver_module.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) self.assertCheckpointState( model_checkpoint_path=s2, all_model_checkpoint_paths=[s3, s2], @@ -1240,13 +1240,13 @@ class MaxToKeepTest(test.TestCase): self.assertEqual([s2, s1], save2.last_checkpoints) self.assertFalse(saver_module.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) self.assertTrue(saver_module.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) self.assertTrue(saver_module.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], @@ -1260,14 +1260,14 @@ class MaxToKeepTest(test.TestCase): # Created by the first helper. self.assertTrue(saver_module.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) # Deleted by the first helper. self.assertFalse(saver_module.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) self.assertTrue(saver_module.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) # Even though the file for s1 exists, this saver isn't aware of it, which # is why it doesn't end up in the checkpoint state. self.assertCheckpointState( @@ -1280,13 +1280,13 @@ class MaxToKeepTest(test.TestCase): self.assertEqual([s2, s1], save3.last_checkpoints) self.assertFalse(saver_module.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(save._MetaGraphFilename(s3))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) self.assertTrue(saver_module.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s2))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) self.assertTrue(saver_module.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) + saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], @@ -1317,7 +1317,7 @@ class MaxToKeepTest(test.TestCase): else: self.assertEqual(4, len(gfile.Glob(s1 + "*"))) - self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1))) + self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1))) s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s1, s2], save.last_checkpoints) @@ -1325,27 +1325,27 @@ class MaxToKeepTest(test.TestCase): self.assertEqual(2, len(gfile.Glob(s1))) else: self.assertEqual(4, len(gfile.Glob(s1 + "*"))) - self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1))) + self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1))) if save._write_version is saver_pb2.SaverDef.V1: self.assertEqual(2, len(gfile.Glob(s2))) else: self.assertEqual(4, len(gfile.Glob(s2 + "*"))) - self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2))) + self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2))) s3 = save.save(sess, os.path.join(save_dir, "s3")) self.assertEqual([s2, s3], save.last_checkpoints) self.assertEqual(0, len(gfile.Glob(s1 + "*"))) - self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1))) + self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1))) if save._write_version is saver_pb2.SaverDef.V1: self.assertEqual(2, len(gfile.Glob(s2))) else: self.assertEqual(4, len(gfile.Glob(s2 + "*"))) - self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2))) + self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2))) if save._write_version is saver_pb2.SaverDef.V1: self.assertEqual(2, len(gfile.Glob(s3))) else: self.assertEqual(4, len(gfile.Glob(s3 + "*"))) - self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3))) + self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s3))) def testNoMaxToKeep(self): save_dir = self._get_test_dir("no_max_to_keep") @@ -1385,7 +1385,7 @@ class MaxToKeepTest(test.TestCase): s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False) self.assertTrue(saver_module.checkpoint_exists(s1)) - self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1))) + self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1))) class KeepCheckpointEveryNHoursTest(test.TestCase): @@ -2621,6 +2621,20 @@ class SaverUtilsTest(test.TestCase): self.assertEqual(2, len(mtimes)) self.assertTrue(mtimes[1] >= mtimes[0]) + def testRemoveCheckpoint(self): + for sharded in (False, True): + for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): + with self.test_session(graph=ops_lib.Graph()) as sess: + unused_v = variables.Variable(1.0, name="v") + variables.global_variables_initializer().run() + saver = saver_module.Saver(sharded=sharded, write_version=version) + + path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) + ckpt_prefix = saver.save(sess, path) + self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix)) + saver_module.remove_checkpoint(ckpt_prefix, version) + self.assertFalse(saver_module.checkpoint_exists(ckpt_prefix)) + class ScopedGraphTest(test.TestCase): diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.pbtxt index 9fb18e77afd..5f45b3b1ad9 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.pbtxt @@ -400,6 +400,10 @@ tf_module { name: "range_input_producer" argspec: "args=[\'limit\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], " } + member_method { + name: "remove_checkpoint" + argspec: "args=[\'checkpoint_prefix\', \'checkpoint_format_version\', \'meta_graph_suffix\'], varargs=None, keywords=None, defaults=[\'2\', \'meta\'], " + } member_method { name: "replica_device_setter" argspec: "args=[\'ps_tasks\', \'ps_device\', \'worker_device\', \'merge_devices\', \'cluster\', \'ps_ops\', \'ps_strategy\'], varargs=None, keywords=None, defaults=[\'0\', \'/job:ps\', \'/job:worker\', \'True\', \'None\', \'None\', \'None\'], "