Factor out tf.train.remove_checkpoint utility function.
PiperOrigin-RevId: 200276735
This commit is contained in:
parent
abfdf45dcd
commit
52af244989
@ -1373,23 +1373,6 @@ class Saver(object):
|
||||
name, _ = p
|
||||
return name
|
||||
|
||||
def _MetaGraphFilename(self, checkpoint_filename, meta_graph_suffix="meta"):
|
||||
"""Returns the meta graph filename.
|
||||
|
||||
Args:
|
||||
checkpoint_filename: Name of the checkpoint file.
|
||||
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
|
||||
|
||||
Returns:
|
||||
MetaGraph file name.
|
||||
"""
|
||||
# If the checkpoint_filename is sharded, the checkpoint_filename could
|
||||
# be of format model.ckpt-step#-?????-of-shard#. For example,
|
||||
# model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
|
||||
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
|
||||
meta_graph_filename = ".".join([basename, meta_graph_suffix])
|
||||
return meta_graph_filename
|
||||
|
||||
def _RecordLastCheckpoint(self, latest_save_path):
|
||||
"""Manages the list of the latest checkpoints."""
|
||||
if not self.saver_def.max_to_keep:
|
||||
@ -1430,24 +1413,12 @@ class Saver(object):
|
||||
|
||||
# Otherwise delete the files.
|
||||
try:
|
||||
checkpoint_prefix = self._CheckpointFilename(p)
|
||||
self._delete_file_if_exists(
|
||||
self._MetaGraphFilename(checkpoint_prefix, meta_graph_suffix))
|
||||
if self.saver_def.version == saver_pb2.SaverDef.V2:
|
||||
# V2 has a metadata file and some data files.
|
||||
self._delete_file_if_exists(checkpoint_prefix + ".index")
|
||||
self._delete_file_if_exists(checkpoint_prefix +
|
||||
".data-?????-of-?????")
|
||||
else:
|
||||
# V1, Legacy. Exact match on the data file.
|
||||
self._delete_file_if_exists(checkpoint_prefix)
|
||||
remove_checkpoint(
|
||||
self._CheckpointFilename(p), self.saver_def.version,
|
||||
meta_graph_suffix)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
logging.warning("Ignoring: %s", str(e))
|
||||
|
||||
def _delete_file_if_exists(self, filespec):
|
||||
for pathname in file_io.get_matching_files(filespec):
|
||||
file_io.delete_file(pathname)
|
||||
|
||||
def as_saver_def(self):
|
||||
"""Generates a `SaverDef` representation of this saver.
|
||||
|
||||
@ -1669,7 +1640,7 @@ class Saver(object):
|
||||
raise exc
|
||||
|
||||
if write_meta_graph:
|
||||
meta_graph_filename = self._MetaGraphFilename(
|
||||
meta_graph_filename = _meta_graph_filename(
|
||||
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
|
||||
if not context.executing_eagerly():
|
||||
with sess.graph.as_default():
|
||||
@ -2121,6 +2092,55 @@ def get_checkpoint_mtimes(checkpoint_prefixes):
|
||||
return mtimes
|
||||
|
||||
|
||||
@tf_export("train.remove_checkpoint")
|
||||
def remove_checkpoint(checkpoint_prefix,
|
||||
checkpoint_format_version=saver_pb2.SaverDef.V2,
|
||||
meta_graph_suffix="meta"):
|
||||
"""Removes a checkpoint given by `checkpoint_prefix`.
|
||||
|
||||
Args:
|
||||
checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
|
||||
of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
|
||||
sharded/non-sharded or V1/V2.
|
||||
checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
|
||||
`SaverDef.V2`.
|
||||
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
|
||||
"""
|
||||
_delete_file_if_exists(
|
||||
_meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
|
||||
if checkpoint_format_version == saver_pb2.SaverDef.V2:
|
||||
# V2 has a metadata file and some data files.
|
||||
_delete_file_if_exists(checkpoint_prefix + ".index")
|
||||
_delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
|
||||
else:
|
||||
# V1, Legacy. Exact match on the data file.
|
||||
_delete_file_if_exists(checkpoint_prefix)
|
||||
|
||||
|
||||
def _delete_file_if_exists(filespec):
|
||||
"""Deletes files matching `filespec`."""
|
||||
for pathname in file_io.get_matching_files(filespec):
|
||||
file_io.delete_file(pathname)
|
||||
|
||||
|
||||
def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
|
||||
"""Returns the meta graph filename.
|
||||
|
||||
Args:
|
||||
checkpoint_filename: Name of the checkpoint file.
|
||||
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
|
||||
|
||||
Returns:
|
||||
MetaGraph file name.
|
||||
"""
|
||||
# If the checkpoint_filename is sharded, the checkpoint_filename could
|
||||
# be of format model.ckpt-step#-?????-of-shard#. For example,
|
||||
# model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
|
||||
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
|
||||
meta_graph_filename = ".".join([basename, meta_graph_suffix])
|
||||
return meta_graph_filename
|
||||
|
||||
|
||||
ops.register_proto_function(
|
||||
ops.GraphKeys.SAVERS,
|
||||
proto_type=saver_pb2.SaverDef,
|
||||
|
@ -809,7 +809,7 @@ class SaveRestoreShardedTest(test.TestCase):
|
||||
self.assertEqual(save_path + "-?????-of-00002", val)
|
||||
else:
|
||||
self.assertEqual(save_path, val)
|
||||
meta_graph_filename = save._MetaGraphFilename(val)
|
||||
meta_graph_filename = saver_module._meta_graph_filename(val)
|
||||
self.assertEqual(save_path + ".meta", meta_graph_filename)
|
||||
|
||||
if save._write_version is saver_pb2.SaverDef.V1:
|
||||
@ -1185,13 +1185,13 @@ class MaxToKeepTest(test.TestCase):
|
||||
self.assertEqual([s3, s2], save.last_checkpoints)
|
||||
self.assertFalse(saver_module.checkpoint_exists(s1))
|
||||
self.assertFalse(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s3))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s2,
|
||||
all_model_checkpoint_paths=[s3, s2],
|
||||
@ -1202,13 +1202,13 @@ class MaxToKeepTest(test.TestCase):
|
||||
self.assertEqual([s2, s1], save.last_checkpoints)
|
||||
self.assertFalse(saver_module.checkpoint_exists(s3))
|
||||
self.assertFalse(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s1,
|
||||
all_model_checkpoint_paths=[s2, s1],
|
||||
@ -1222,14 +1222,14 @@ class MaxToKeepTest(test.TestCase):
|
||||
# Created by the first helper.
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
|
||||
# Deleted by the first helper.
|
||||
self.assertFalse(saver_module.checkpoint_exists(s3))
|
||||
self.assertFalse(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s2,
|
||||
all_model_checkpoint_paths=[s3, s2],
|
||||
@ -1240,13 +1240,13 @@ class MaxToKeepTest(test.TestCase):
|
||||
self.assertEqual([s2, s1], save2.last_checkpoints)
|
||||
self.assertFalse(saver_module.checkpoint_exists(s3))
|
||||
self.assertFalse(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s1,
|
||||
all_model_checkpoint_paths=[s2, s1],
|
||||
@ -1260,14 +1260,14 @@ class MaxToKeepTest(test.TestCase):
|
||||
# Created by the first helper.
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
|
||||
# Deleted by the first helper.
|
||||
self.assertFalse(saver_module.checkpoint_exists(s3))
|
||||
self.assertFalse(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
|
||||
# Even though the file for s1 exists, this saver isn't aware of it, which
|
||||
# is why it doesn't end up in the checkpoint state.
|
||||
self.assertCheckpointState(
|
||||
@ -1280,13 +1280,13 @@ class MaxToKeepTest(test.TestCase):
|
||||
self.assertEqual([s2, s1], save3.last_checkpoints)
|
||||
self.assertFalse(saver_module.checkpoint_exists(s3))
|
||||
self.assertFalse(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s1,
|
||||
all_model_checkpoint_paths=[s2, s1],
|
||||
@ -1317,7 +1317,7 @@ class MaxToKeepTest(test.TestCase):
|
||||
else:
|
||||
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
|
||||
|
||||
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))
|
||||
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1)))
|
||||
|
||||
s2 = save.save(sess, os.path.join(save_dir, "s2"))
|
||||
self.assertEqual([s1, s2], save.last_checkpoints)
|
||||
@ -1325,27 +1325,27 @@ class MaxToKeepTest(test.TestCase):
|
||||
self.assertEqual(2, len(gfile.Glob(s1)))
|
||||
else:
|
||||
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
|
||||
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))
|
||||
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1)))
|
||||
if save._write_version is saver_pb2.SaverDef.V1:
|
||||
self.assertEqual(2, len(gfile.Glob(s2)))
|
||||
else:
|
||||
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
|
||||
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2)))
|
||||
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2)))
|
||||
|
||||
s3 = save.save(sess, os.path.join(save_dir, "s3"))
|
||||
self.assertEqual([s2, s3], save.last_checkpoints)
|
||||
self.assertEqual(0, len(gfile.Glob(s1 + "*")))
|
||||
self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1)))
|
||||
self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1)))
|
||||
if save._write_version is saver_pb2.SaverDef.V1:
|
||||
self.assertEqual(2, len(gfile.Glob(s2)))
|
||||
else:
|
||||
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
|
||||
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2)))
|
||||
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2)))
|
||||
if save._write_version is saver_pb2.SaverDef.V1:
|
||||
self.assertEqual(2, len(gfile.Glob(s3)))
|
||||
else:
|
||||
self.assertEqual(4, len(gfile.Glob(s3 + "*")))
|
||||
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3)))
|
||||
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s3)))
|
||||
|
||||
def testNoMaxToKeep(self):
|
||||
save_dir = self._get_test_dir("no_max_to_keep")
|
||||
@ -1385,7 +1385,7 @@ class MaxToKeepTest(test.TestCase):
|
||||
|
||||
s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False)
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1)))
|
||||
self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1)))
|
||||
|
||||
|
||||
class KeepCheckpointEveryNHoursTest(test.TestCase):
|
||||
@ -2621,6 +2621,20 @@ class SaverUtilsTest(test.TestCase):
|
||||
self.assertEqual(2, len(mtimes))
|
||||
self.assertTrue(mtimes[1] >= mtimes[0])
|
||||
|
||||
def testRemoveCheckpoint(self):
|
||||
for sharded in (False, True):
|
||||
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
|
||||
with self.test_session(graph=ops_lib.Graph()) as sess:
|
||||
unused_v = variables.Variable(1.0, name="v")
|
||||
variables.global_variables_initializer().run()
|
||||
saver = saver_module.Saver(sharded=sharded, write_version=version)
|
||||
|
||||
path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
|
||||
ckpt_prefix = saver.save(sess, path)
|
||||
self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
|
||||
saver_module.remove_checkpoint(ckpt_prefix, version)
|
||||
self.assertFalse(saver_module.checkpoint_exists(ckpt_prefix))
|
||||
|
||||
|
||||
class ScopedGraphTest(test.TestCase):
|
||||
|
||||
|
@ -400,6 +400,10 @@ tf_module {
|
||||
name: "range_input_producer"
|
||||
argspec: "args=[\'limit\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "remove_checkpoint"
|
||||
argspec: "args=[\'checkpoint_prefix\', \'checkpoint_format_version\', \'meta_graph_suffix\'], varargs=None, keywords=None, defaults=[\'2\', \'meta\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "replica_device_setter"
|
||||
argspec: "args=[\'ps_tasks\', \'ps_device\', \'worker_device\', \'merge_devices\', \'cluster\', \'ps_ops\', \'ps_strategy\'], varargs=None, keywords=None, defaults=[\'0\', \'/job:ps\', \'/job:worker\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user