Create an internal undeprecated function to check checkpoint exists.

This commit is contained in:
Abdullah Selek 2019-06-28 13:13:52 +01:00
parent 638f250db8
commit 2d9851f9b0
2 changed files with 26 additions and 9 deletions

View File

@ -347,6 +347,30 @@ def latest_checkpoint(checkpoint_dir, latest_filename=None):
return None return None
def checkpoint_exists_internal(checkpoint_prefix):
"""Checks whether a V1 or V2 checkpoint exists with the specified prefix.
This is an internal function to check if a checkpoint exists,
since it takes into account the naming difference between V1 and V2 formats.
Args:
checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
priority. Typically the result of `Saver.save()` or that of
`tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
V1/V2.
Returns:
A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
"""
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
saver_pb2.SaverDef.V2)
if file_io.get_matching_files(pathname):
return True
elif file_io.get_matching_files(checkpoint_prefix):
return True
else:
return False
@deprecation.deprecated( @deprecation.deprecated(
date=None, date=None,
instructions="Use standard file APIs to check for files with this prefix.") instructions="Use standard file APIs to check for files with this prefix.")
@ -365,14 +389,7 @@ def checkpoint_exists(checkpoint_prefix):
Returns: Returns:
A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists. A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
""" """
pathname = _prefix_to_checkpoint_path(checkpoint_prefix, return checkpoint_exists_internal(checkpoint_exists)
saver_pb2.SaverDef.V2)
if file_io.get_matching_files(pathname):
return True
elif file_io.get_matching_files(checkpoint_prefix):
return True
else:
return False
@deprecation.deprecated( @deprecation.deprecated(

View File

@ -1276,7 +1276,7 @@ class Saver(object):
if save_path is None: if save_path is None:
raise ValueError("Can't load save_path when it is None.") raise ValueError("Can't load save_path when it is None.")
if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)): if not checkpoint_management.checkpoint_exists_internal(compat.as_text(save_path)):
raise ValueError("The passed save_path is not a valid checkpoint: " + raise ValueError("The passed save_path is not a valid checkpoint: " +
compat.as_text(save_path)) compat.as_text(save_path))