diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py index 32b9c023aae..e2dbde57a18 100644 --- a/tensorflow/python/training/checkpoint_management.py +++ b/tensorflow/python/training/checkpoint_management.py @@ -347,6 +347,30 @@ def latest_checkpoint(checkpoint_dir, latest_filename=None): return None +def checkpoint_exists_internal(checkpoint_prefix): + """Checks whether a V1 or V2 checkpoint exists with the specified prefix. + + This is an internal function to check if a checkpoint exists, + since it takes into account the naming difference between V1 and V2 formats. + + Args: + checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking + priority. Typically the result of `Saver.save()` or that of + `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or + V1/V2. + Returns: + A bool, true if a checkpoint referred to by `checkpoint_prefix` exists. + """ + pathname = _prefix_to_checkpoint_path(checkpoint_prefix, + saver_pb2.SaverDef.V2) + if file_io.get_matching_files(pathname): + return True + elif file_io.get_matching_files(checkpoint_prefix): + return True + else: + return False + + @deprecation.deprecated( date=None, instructions="Use standard file APIs to check for files with this prefix.") @@ -362,17 +386,11 @@ def checkpoint_exists(checkpoint_prefix): priority. Typically the result of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or V1/V2. + Returns: - A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists. + A bool, true if a checkpoint referred to by `checkpoint_prefix` exists. """ - pathname = _prefix_to_checkpoint_path(checkpoint_prefix, - saver_pb2.SaverDef.V2) - if file_io.get_matching_files(pathname): - return True - elif file_io.get_matching_files(checkpoint_prefix): - return True - else: - return False + return checkpoint_exists_internal(checkpoint_prefix) @deprecation.deprecated( diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 7b502bffa38..d65297fb30d 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1276,11 +1276,12 @@ class Saver(object): if save_path is None: raise ValueError("Can't load save_path when it is None.") - if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)): + checkpoint_prefix = compat.as_text(save_path) + if not checkpoint_management.checkpoint_exists_internal(checkpoint_prefix): raise ValueError("The passed save_path is not a valid checkpoint: " + - compat.as_text(save_path)) + checkpoint_prefix) - logging.info("Restoring parameters from %s", compat.as_text(save_path)) + logging.info("Restoring parameters from %s", checkpoint_prefix) try: if context.executing_eagerly(): self._build_eager(save_path, build_save=False, build_restore=True)