Create an internal undeprecated function to check checkpoint exists.
This commit is contained in:
parent
638f250db8
commit
2d9851f9b0
@ -347,6 +347,30 @@ def latest_checkpoint(checkpoint_dir, latest_filename=None):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_exists_internal(checkpoint_prefix):
|
||||||
|
"""Checks whether a V1 or V2 checkpoint exists with the specified prefix.
|
||||||
|
|
||||||
|
This is an internal function to check if a checkpoint exists,
|
||||||
|
since it takes into account the naming difference between V1 and V2 formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
|
||||||
|
priority. Typically the result of `Saver.save()` or that of
|
||||||
|
`tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
|
||||||
|
V1/V2.
|
||||||
|
Returns:
|
||||||
|
A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
|
||||||
|
"""
|
||||||
|
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
|
||||||
|
saver_pb2.SaverDef.V2)
|
||||||
|
if file_io.get_matching_files(pathname):
|
||||||
|
return True
|
||||||
|
elif file_io.get_matching_files(checkpoint_prefix):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
date=None,
|
date=None,
|
||||||
instructions="Use standard file APIs to check for files with this prefix.")
|
instructions="Use standard file APIs to check for files with this prefix.")
|
||||||
@ -365,14 +389,7 @@ def checkpoint_exists(checkpoint_prefix):
|
|||||||
Returns:
|
Returns:
|
||||||
A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
|
A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
|
||||||
"""
|
"""
|
||||||
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
|
return checkpoint_exists_internal(checkpoint_exists)
|
||||||
saver_pb2.SaverDef.V2)
|
|
||||||
if file_io.get_matching_files(pathname):
|
|
||||||
return True
|
|
||||||
elif file_io.get_matching_files(checkpoint_prefix):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
|
@ -1276,7 +1276,7 @@ class Saver(object):
|
|||||||
if save_path is None:
|
if save_path is None:
|
||||||
raise ValueError("Can't load save_path when it is None.")
|
raise ValueError("Can't load save_path when it is None.")
|
||||||
|
|
||||||
if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)):
|
if not checkpoint_management.checkpoint_exists_internal(compat.as_text(save_path)):
|
||||||
raise ValueError("The passed save_path is not a valid checkpoint: " +
|
raise ValueError("The passed save_path is not a valid checkpoint: " +
|
||||||
compat.as_text(save_path))
|
compat.as_text(save_path))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user