Factor out tf.train.remove_checkpoint utility function.

PiperOrigin-RevId: 200276735
This commit is contained in:
Goutham Bhat 2018-06-12 14:05:03 -07:00 committed by TensorFlower Gardener
parent abfdf45dcd
commit 52af244989
3 changed files with 97 additions and 59 deletions

View File

@ -1373,23 +1373,6 @@ class Saver(object):
name, _ = p
return name
def _MetaGraphFilename(self, checkpoint_filename, meta_graph_suffix="meta"):
"""Returns the meta graph filename.
Args:
checkpoint_filename: Name of the checkpoint file.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
Returns:
MetaGraph file name.
"""
# If the checkpoint_filename is sharded, the checkpoint_filename could
# be of format model.ckpt-step#-?????-of-shard#. For example,
# model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
meta_graph_filename = ".".join([basename, meta_graph_suffix])
return meta_graph_filename
def _RecordLastCheckpoint(self, latest_save_path):
"""Manages the list of the latest checkpoints."""
if not self.saver_def.max_to_keep:
@ -1430,24 +1413,12 @@ class Saver(object):
# Otherwise delete the files.
try:
checkpoint_prefix = self._CheckpointFilename(p)
self._delete_file_if_exists(
self._MetaGraphFilename(checkpoint_prefix, meta_graph_suffix))
if self.saver_def.version == saver_pb2.SaverDef.V2:
# V2 has a metadata file and some data files.
self._delete_file_if_exists(checkpoint_prefix + ".index")
self._delete_file_if_exists(checkpoint_prefix +
".data-?????-of-?????")
else:
# V1, Legacy. Exact match on the data file.
self._delete_file_if_exists(checkpoint_prefix)
remove_checkpoint(
self._CheckpointFilename(p), self.saver_def.version,
meta_graph_suffix)
except Exception as e: # pylint: disable=broad-except
logging.warning("Ignoring: %s", str(e))
def _delete_file_if_exists(self, filespec):
for pathname in file_io.get_matching_files(filespec):
file_io.delete_file(pathname)
def as_saver_def(self):
"""Generates a `SaverDef` representation of this saver.
@ -1669,7 +1640,7 @@ class Saver(object):
raise exc
if write_meta_graph:
meta_graph_filename = self._MetaGraphFilename(
meta_graph_filename = _meta_graph_filename(
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
if not context.executing_eagerly():
with sess.graph.as_default():
@ -2121,6 +2092,55 @@ def get_checkpoint_mtimes(checkpoint_prefixes):
return mtimes
@tf_export("train.remove_checkpoint")
def remove_checkpoint(checkpoint_prefix,
checkpoint_format_version=saver_pb2.SaverDef.V2,
meta_graph_suffix="meta"):
"""Removes a checkpoint given by `checkpoint_prefix`.
Args:
checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
sharded/non-sharded or V1/V2.
checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
`SaverDef.V2`.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
"""
_delete_file_if_exists(
_meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
if checkpoint_format_version == saver_pb2.SaverDef.V2:
# V2 has a metadata file and some data files.
_delete_file_if_exists(checkpoint_prefix + ".index")
_delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
else:
# V1, Legacy. Exact match on the data file.
_delete_file_if_exists(checkpoint_prefix)
def _delete_file_if_exists(filespec):
"""Deletes files matching `filespec`."""
for pathname in file_io.get_matching_files(filespec):
file_io.delete_file(pathname)
def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
"""Returns the meta graph filename.
Args:
checkpoint_filename: Name of the checkpoint file.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
Returns:
MetaGraph file name.
"""
# If the checkpoint_filename is sharded, the checkpoint_filename could
# be of format model.ckpt-step#-?????-of-shard#. For example,
# model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
meta_graph_filename = ".".join([basename, meta_graph_suffix])
return meta_graph_filename
ops.register_proto_function(
ops.GraphKeys.SAVERS,
proto_type=saver_pb2.SaverDef,

View File

@ -809,7 +809,7 @@ class SaveRestoreShardedTest(test.TestCase):
self.assertEqual(save_path + "-?????-of-00002", val)
else:
self.assertEqual(save_path, val)
meta_graph_filename = save._MetaGraphFilename(val)
meta_graph_filename = saver_module._meta_graph_filename(val)
self.assertEqual(save_path + ".meta", meta_graph_filename)
if save._write_version is saver_pb2.SaverDef.V1:
@ -1185,13 +1185,13 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual([s3, s2], save.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s1))
self.assertFalse(
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
self.assertTrue(saver_module.checkpoint_exists(s3))
self.assertTrue(
saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(
saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
@ -1202,13 +1202,13 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual([s2, s1], save.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(
saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(
saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@ -1222,14 +1222,14 @@ class MaxToKeepTest(test.TestCase):
# Created by the first helper.
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
# Deleted by the first helper.
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(
saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(
saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
@ -1240,13 +1240,13 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual([s2, s1], save2.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(
saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(
saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@ -1260,14 +1260,14 @@ class MaxToKeepTest(test.TestCase):
# Created by the first helper.
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
# Deleted by the first helper.
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(
saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(
saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
# Even though the file for s1 exists, this saver isn't aware of it, which
# is why it doesn't end up in the checkpoint state.
self.assertCheckpointState(
@ -1280,13 +1280,13 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual([s2, s1], save3.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(
saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(
saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@ -1317,7 +1317,7 @@ class MaxToKeepTest(test.TestCase):
else:
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1)))
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
@ -1325,27 +1325,27 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual(2, len(gfile.Glob(s1)))
else:
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s2)))
else:
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2)))
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2)))
s3 = save.save(sess, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
self.assertEqual(0, len(gfile.Glob(s1 + "*")))
self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1)))
self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s2)))
else:
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2)))
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s3)))
else:
self.assertEqual(4, len(gfile.Glob(s3 + "*")))
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3)))
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s3)))
def testNoMaxToKeep(self):
save_dir = self._get_test_dir("no_max_to_keep")
@ -1385,7 +1385,7 @@ class MaxToKeepTest(test.TestCase):
s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False)
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1)))
self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1)))
class KeepCheckpointEveryNHoursTest(test.TestCase):
@ -2621,6 +2621,20 @@ class SaverUtilsTest(test.TestCase):
self.assertEqual(2, len(mtimes))
self.assertTrue(mtimes[1] >= mtimes[0])
def testRemoveCheckpoint(self):
for sharded in (False, True):
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
with self.test_session(graph=ops_lib.Graph()) as sess:
unused_v = variables.Variable(1.0, name="v")
variables.global_variables_initializer().run()
saver = saver_module.Saver(sharded=sharded, write_version=version)
path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
ckpt_prefix = saver.save(sess, path)
self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
saver_module.remove_checkpoint(ckpt_prefix, version)
self.assertFalse(saver_module.checkpoint_exists(ckpt_prefix))
class ScopedGraphTest(test.TestCase):

View File

@ -400,6 +400,10 @@ tf_module {
name: "range_input_producer"
argspec: "args=[\'limit\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], "
}
member_method {
name: "remove_checkpoint"
argspec: "args=[\'checkpoint_prefix\', \'checkpoint_format_version\', \'meta_graph_suffix\'], varargs=None, keywords=None, defaults=[\'2\', \'meta\'], "
}
member_method {
name: "replica_device_setter"
argspec: "args=[\'ps_tasks\', \'ps_device\', \'worker_device\', \'merge_devices\', \'cluster\', \'ps_ops\', \'ps_strategy\'], varargs=None, keywords=None, defaults=[\'0\', \'/job:ps\', \'/job:worker\', \'True\', \'None\', \'None\', \'None\'], "