Split checkpoint management utility functions out of saver.py

Pure refactor, in preparation for adding a higher level checkpoint management utility. This utility will also need to work with the Checkpoint proto, and globbing it on to saver.py seems dirty.

PiperOrigin-RevId: 207179646
This commit is contained in:
Allen Lavoie 2018-08-02 15:47:43 -07:00 committed by TensorFlower Gardener
parent 6fbbad97e2
commit 1bf206bc82
35 changed files with 1011 additions and 817 deletions

View File

@ -28,6 +28,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
@ -55,7 +56,7 @@ class CheckpointInputPipelineHookTest(test.TestCase):
def _read_vars(self, model_dir):
"""Returns (global_step, latest_feature)."""
with ops.Graph().as_default() as g:
ckpt_path = saver_lib.latest_checkpoint(model_dir)
ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
meta_filename = ckpt_path + '.meta'
saver_lib.import_meta_graph(meta_filename)
saver = saver_lib.Saver()

View File

@ -32,6 +32,7 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import nest
@ -655,7 +656,7 @@ class DatasetSerializationTestBase(test.TestCase):
return os.path.join(self.get_temp_dir(), "iterator")
def _latest_ckpt(self):
return saver_lib.latest_checkpoint(self.get_temp_dir())
return checkpoint_management.latest_checkpoint(self.get_temp_dir())
def _save(self, sess, saver):
saver.save(sess, self._ckpt_path())

View File

@ -20,6 +20,7 @@ from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook
@ -206,7 +207,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
# Check if there is an existing checkpoint. If so, restore from it.
# pylint: disable=protected-access
latest_checkpoint_path = saver_lib.latest_checkpoint(
latest_checkpoint_path = checkpoint_management.latest_checkpoint(
self._checkpoint_saver_hook._checkpoint_dir,
latest_filename=self._latest_filename)
if latest_checkpoint_path:

View File

@ -37,7 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.training import saver
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@ -314,7 +314,8 @@ class IteratorTest(test.TestCase):
for i in range(5):
iterator = datasets.Iterator(dataset)
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
checkpoint.restore(saver.latest_checkpoint(checkpoint_directory))
checkpoint.restore(checkpoint_management.latest_checkpoint(
checkpoint_directory))
for j in range(2):
self.assertEqual(i * 2 + j, iterator.get_next().numpy())
checkpoint.save(file_prefix=checkpoint_prefix)

View File

@ -36,7 +36,7 @@ from third_party.examples.eager.spinn import spinn
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
from tensorflow.python.training import saver
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.checkpointable import util as checkpointable_utils
# pylint: enable=g-bad-import-order
@ -422,7 +422,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
# 5. Verify that checkpoints exist and contains all the expected variables.
self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
object_graph = checkpointable_utils.object_metadata(
saver.latest_checkpoint(config.logdir))
checkpoint_management.latest_checkpoint(config.logdir))
ckpt_variable_names = set()
for node in object_graph.nodes:
for attribute in node.attributes:

View File

@ -27,7 +27,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import training as train
__all__ = [
@ -40,7 +40,7 @@ __all__ = [
def _get_checkpoint_filename(filepattern):
"""Returns checkpoint filename given directory or specific filepattern."""
if gfile.IsDirectory(filepattern):
return saver.latest_checkpoint(filepattern)
return checkpoint_management.latest_checkpoint(filepattern)
return filepattern

View File

@ -72,6 +72,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary as core_summary
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import device_setter
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
@ -891,7 +892,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
# Check that model has been trained (if nothing has been set explicitly).
if not checkpoint_path:
latest_path = saver.latest_checkpoint(self._model_dir)
latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not latest_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)
@ -956,7 +957,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
as_iterable=True,
iterate_batches=False):
# Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)
@ -1364,7 +1365,7 @@ class Estimator(BaseEstimator):
if not checkpoint_path:
# Locate the latest checkpoint
checkpoint_path = saver.latest_checkpoint(self._model_dir)
checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)

View File

@ -41,7 +41,7 @@ from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import saver
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow.python.util import function_utils
@ -95,7 +95,7 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener):
# Load and cache the path of the most recent checkpoint to avoid duplicate
# searches on GCS.
logging.info("Checking for checkpoint in %s", self._model_dir)
latest_path = saver.latest_checkpoint(self._model_dir)
latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not latest_path:
logging.warning("Skipping evaluation and export since model has not been "
@ -516,7 +516,8 @@ class Experiment(object):
start = time.time()
error_msg = None
latest_path = saver.latest_checkpoint(self._estimator.model_dir)
latest_path = checkpoint_management.latest_checkpoint(
self._estimator.model_dir)
if not latest_path:
error_msg = ("Estimator is not fitted yet. "
"Will start an evaluation when a checkpoint is ready.")
@ -778,7 +779,8 @@ class Experiment(object):
saving_listeners=self._saving_listeners)
logging.info("Evaluating model now.")
latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir)
latest_checkpoint = checkpoint_management.latest_checkpoint(
self._estimator.model_dir)
eval_result = self._call_evaluate(
input_fn=self._eval_input_fn,
steps=self._eval_steps,

View File

@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
@ -124,7 +125,7 @@ class GraphActionsTest(test.TestCase):
# TODO(ptucker): Test number and contents of checkpoint files.
def _assert_ckpt(self, output_dir, expected=True):
ckpt_state = saver_lib.get_checkpoint_state(output_dir)
ckpt_state = checkpoint_management.get_checkpoint_state(output_dir)
if expected:
pattern = '%s/model.ckpt-.*' % output_dir
primary_ckpt_path = ckpt_state.model_checkpoint_path
@ -434,7 +435,7 @@ class GraphActionsTrainTest(test.TestCase):
# TODO(ptucker): Test number and contents of checkpoint files.
def _assert_ckpt(self, output_dir, expected=True):
ckpt_state = saver_lib.get_checkpoint_state(output_dir)
ckpt_state = checkpoint_management.get_checkpoint_state(output_dir)
if expected:
pattern = '%s/model.ckpt-.*' % output_dir
primary_ckpt_path = ckpt_state.model_checkpoint_path

View File

@ -51,7 +51,7 @@ from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as core_summary
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util import deprecation
@ -735,7 +735,8 @@ class ValidationMonitor(EveryN):
return False
self._last_checkpoint_check_time = current_time
# Check that we are not running evaluation on the same checkpoint.
latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
latest_path = checkpoint_management.latest_checkpoint(
self._estimator.model_dir)
if latest_path is None:
logging.debug("Skipping evaluation since model has not been saved yet "
"at step %d.", step)
@ -1059,7 +1060,8 @@ class ExportMonitor(EveryN):
def end(self, session=None):
super(ExportMonitor, self).end(session=session)
latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
latest_path = checkpoint_management.latest_checkpoint(
self._estimator.model_dir)
if latest_path is None:
logging.info("Skipping export at the end since model has not been saved "
"yet.")

View File

@ -39,9 +39,9 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
from tensorflow.python.training import training_util
@ -317,7 +317,7 @@ class MonitorsTest(test.TestCase):
self._run_monitor(monitor)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
@test.mock.patch.object(saver, 'latest_checkpoint')
@test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_no_ckpt(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@ -336,7 +336,7 @@ class MonitorsTest(test.TestCase):
mock_latest_checkpoint.assert_called_with(model_dir)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
@test.mock.patch.object(saver, 'latest_checkpoint')
@test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_no_early_stopping_rounds(self,
mock_latest_checkpoint,
mock_estimator_class):
@ -356,7 +356,7 @@ class MonitorsTest(test.TestCase):
self._assert_validation_monitor(monitor)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
@test.mock.patch.object(saver, 'latest_checkpoint')
@test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_invalid_metric(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@ -375,7 +375,7 @@ class MonitorsTest(test.TestCase):
self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
@test.mock.patch.object(saver, 'latest_checkpoint')
@test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@ -464,7 +464,7 @@ class MonitorsTest(test.TestCase):
monitor.epoch_end(epoch=0)
monitor.end()
@test.mock.patch.object(saver, 'latest_checkpoint')
@test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_with_core_estimator(self, mock_latest_checkpoint):
estimator = test.mock.Mock(spec=core_estimator.Estimator)
model_dir = 'model/dir'
@ -495,7 +495,7 @@ class MonitorsTest(test.TestCase):
expected_best_metrics={'loss': 42.0, 'auc': 0.5})
monitor.post_step(step=step, session=None)
@test.mock.patch.object(saver, 'latest_checkpoint')
@test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_fail_with_core_estimator_and_metrics(
self, mock_latest_checkpoint):
estimator = test.mock.Mock(spec=core_estimator.Estimator)

View File

@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training import training_util
@ -298,7 +299,8 @@ def _export_estimator(estimator,
# If checkpoint_path is specified, use the specified checkpoint path.
checkpoint_path = (checkpoint_path or
tf_saver.latest_checkpoint(estimator._model_dir))
checkpoint_management.latest_checkpoint(
estimator._model_dir))
with ops.Graph().as_default() as g:
training_util.create_global_step(g)

View File

@ -55,7 +55,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.summary import summary_iterator
from tensorflow.python.training import saver
from tensorflow.python.training import checkpoint_management
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated
@ -714,7 +714,8 @@ def make_best_model_export_strategy(
# as soon as contrib is cleaned up and we can thus be sure that
# estimator is a tf.estimator.Estimator and not a
# tf.contrib.learn.Estimator
checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
checkpoint_path = checkpoint_management.latest_checkpoint(
estimator.model_dir)
export_checkpoint_path, export_eval_result = best_model_selector.update(
checkpoint_path, eval_result)

View File

@ -41,6 +41,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as core_saver
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import tracking
@ -278,7 +279,8 @@ class CheckpointingTests(test.TestCase):
root = util.Checkpoint(
optimizer=optimizer, model=model,
optimizer_step=training_util.get_or_create_global_step())
root.restore(core_saver.latest_checkpoint(checkpoint_directory))
root.restore(checkpoint_management.latest_checkpoint(
checkpoint_directory))
for _ in range(num_training_steps):
# TODO(allenl): Use a Dataset and serialize/checkpoint it.
input_value = constant_op.constant([[3.]])
@ -306,7 +308,8 @@ class CheckpointingTests(test.TestCase):
train_op = optimizer.minimize(
model(input_value),
global_step=root.global_step)
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
checkpoint_path = checkpoint_management.latest_checkpoint(
checkpoint_directory)
with self.test_session(graph=ops.get_default_graph()) as session:
status = root.restore(save_path=checkpoint_path)
status.initialize_or_restore(session=session)
@ -339,7 +342,8 @@ class CheckpointingTests(test.TestCase):
root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
checkpoint_path = checkpoint_management.latest_checkpoint(
checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
@ -372,7 +376,8 @@ class CheckpointingTests(test.TestCase):
root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
checkpoint_path = checkpoint_management.latest_checkpoint(
checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
def train_fn():
@function.defun

View File

@ -22,8 +22,8 @@ from __future__ import print_function
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
from tensorflow.contrib.predictor import predictor
from tensorflow.python.framework import ops
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
class ContribEstimatorPredictor(predictor.Predictor):
@ -57,7 +57,8 @@ class ContribEstimatorPredictor(predictor.Predictor):
# pylint: disable=protected-access
model_fn_ops = estimator._get_predict_ops(input_fn_ops.features)
# pylint: enable=protected-access
checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
checkpoint_path = checkpoint_management.latest_checkpoint(
estimator.model_dir)
self._session = monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
config=config,

View File

@ -142,9 +142,9 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import evaluation
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
@ -189,7 +189,7 @@ def wait_for_new_checkpoint(checkpoint_dir,
logging.info('Waiting for new checkpoint at %s', checkpoint_dir)
stop_time = time.time() + timeout if timeout is not None else None
while True:
checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir)
checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir)
if checkpoint_path is None or checkpoint_path == last_checkpoint:
if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
return None

View File

@ -36,6 +36,7 @@ from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib
@ -421,7 +422,7 @@ class TrainTest(test.TestCase):
train_op = self.create_train_op()
model_variables = variables_lib2.global_variables()
model_path = saver_lib.latest_checkpoint(logdir1)
model_path = checkpoint_management.latest_checkpoint(logdir1)
assign_fn = variables_lib.assign_from_checkpoint_fn(
model_path, model_variables)

View File

@ -3216,6 +3216,7 @@ py_library(
"training/checkpointable/**/*.py",
# The following targets have their own build rules (same name as the
# file):
"training/checkpoint_management.py",
"training/saveable_object.py",
"training/saver.py",
"training/training_util.py",
@ -3223,8 +3224,10 @@ py_library(
),
srcs_version = "PY2AND3",
deps = [
"saver",
":array_ops",
":array_ops_gen",
":checkpoint_management",
":checkpoint_ops_gen",
":client",
":control_flow_ops",
@ -3236,25 +3239,20 @@ py_library(
":framework_ops",
":gradients",
":init_ops",
":distribute",
":io_ops",
":io_ops_gen",
":layers_base",
":lib",
":lookup_ops",
":math_ops",
":platform",
":protos_all_py",
":pywrap_tensorflow",
":random_ops",
":resource_variable_ops",
":resources",
"saver",
":saveable_object",
":sdca_ops",
":session",
":sparse_ops",
":sparse_tensor",
":state_ops",
":string_ops",
":summary",
":training_ops_gen",
":training_util",
@ -3264,6 +3262,7 @@ py_library(
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
# `layers` dependency only exists due to the use of a small utility.
@ -3280,12 +3279,26 @@ py_library(
srcs_version = "PY2AND3",
)
py_library(
name = "checkpoint_management",
srcs = ["training/checkpoint_management.py"],
deps = [
":errors",
":lib",
":platform",
":protos_all_py",
":util",
"//tensorflow/core:protos_all_py",
],
)
py_library(
name = "saver",
srcs = ["training/saver.py"],
srcs_version = "PY2AND3",
deps = [
":array_ops",
":checkpoint_management",
":constant_op",
":control_flow_ops",
":device",
@ -3294,9 +3307,7 @@ py_library(
":framework_ops",
":io_ops",
":io_ops_gen",
":lib",
":platform",
":protos_all_py",
":pywrap_tensorflow",
":resource_variable_ops",
":saveable_object",
@ -4423,6 +4434,42 @@ cuda_py_test(
tags = ["multi_gpu"],
)
cuda_py_test(
name = "checkpoint_management_test",
size = "small",
srcs = [
"training/checkpoint_management_test.py",
],
additional_deps = [
":array_ops",
":client_testlib",
":control_flow_ops",
":data_flow_ops",
":errors",
":gradients",
":math_ops",
":nn_grad",
":nn_ops",
":saver_test_utils",
":partitioned_variables",
":platform",
":platform_test",
":pywrap_tensorflow",
":random_ops",
":resource_variable_ops",
":sparse_ops",
":summary",
":training",
":util",
":variable_scope",
":variables",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/data/ops:dataset_ops",
],
)
py_test(
name = "saver_large_variable_test",
size = "medium",
@ -4489,6 +4536,7 @@ tf_py_test(
srcs = ["training/supervisor_test.py"],
additional_deps = [
":array_ops",
":checkpoint_management",
":client_testlib",
":errors",
":framework",
@ -4496,6 +4544,7 @@ tf_py_test(
":io_ops",
":parsing_ops",
":platform",
":saver",
":summary",
":training",
":variables",
@ -4609,10 +4658,13 @@ py_test(
tags = ["notsan"], # b/67945581
deps = [
":array_ops",
":checkpoint_management",
":client_testlib",
":control_flow_ops",
":errors",
":framework_for_generated_wrappers",
":resource_variable_ops",
":saver",
":session",
":state_ops",
":summary",

View File

@ -47,7 +47,7 @@ from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import server_lib
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.util import compat
@ -877,7 +877,7 @@ class IteratorCheckpointingTest(test.TestCase):
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
for i in range(5):
with self.test_session() as sess:
checkpoint.restore(saver.latest_checkpoint(
checkpoint.restore(checkpoint_management.latest_checkpoint(
checkpoint_directory)).initialize_or_restore(sess)
for j in range(2):
self.assertEqual(i * 2 + j, sess.run(get_next))

View File

@ -53,6 +53,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
from tensorflow.python.summary import summary
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import device_setter
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import evaluation
@ -268,7 +269,7 @@ class Estimator(object):
found.
"""
with context.graph_mode():
return saver.latest_checkpoint(self.model_dir)
return checkpoint_management.latest_checkpoint(self.model_dir)
def train(self,
input_fn,
@ -417,7 +418,7 @@ class Estimator(object):
# Check that model has been trained (if nothing has been set explicitly).
if not checkpoint_path:
latest_path = saver.latest_checkpoint(self._model_dir)
latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not latest_path:
logging.info('Could not find trained model in model_dir: {}, running '
'initialization to evaluate.'.format(self._model_dir))
@ -504,7 +505,8 @@ class Estimator(object):
hooks = _check_hooks_type(hooks)
# Check that model has been trained.
if not checkpoint_path:
checkpoint_path = saver.latest_checkpoint(self._model_dir)
checkpoint_path = checkpoint_management.latest_checkpoint(
self._model_dir)
if not checkpoint_path:
logging.info('Could not find trained model in model_dir: {}, running '
'initialization to predict.'.format(self._model_dir))
@ -769,7 +771,8 @@ class Estimator(object):
with context.graph_mode():
if not checkpoint_path:
# Locate the latest checkpoint
checkpoint_path = saver.latest_checkpoint(self._model_dir)
checkpoint_path = checkpoint_management.latest_checkpoint(
self._model_dir)
if not checkpoint_path:
raise ValueError("Couldn't find trained model at %s." % self._model_dir)
@ -1626,7 +1629,7 @@ def _combine_distributed_scaffold(grouped_scaffold, distribution):
def _check_checkpoint_available(model_dir):
latest_path = saver.latest_checkpoint(model_dir)
latest_path = checkpoint_management.latest_checkpoint(model_dir)
if not latest_path:
raise ValueError(
'Could not find trained model in model_dir: {}.'.format(model_dir))

View File

@ -69,6 +69,7 @@ from tensorflow.python.summary import summary
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import checkpoint_state_pb2
from tensorflow.python.training import saver
from tensorflow.python.training import saver_test_utils
@ -1548,7 +1549,8 @@ class EstimatorPredictTest(test.TestCase):
next(
est.predict(
dummy_input_fn,
checkpoint_path=saver.latest_checkpoint('fakedir')))
checkpoint_path=
checkpoint_management.latest_checkpoint('fakedir')))
def test_tensor_predictions(self):

View File

@ -42,6 +42,7 @@ from tensorflow.python.ops import metrics as metrics_module
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
@ -442,7 +443,7 @@ def _save_first_checkpoint(keras_model, custom_objects, config):
# save checkpoint into subdirectory to allow warm start
keras_model_dir = os.path.join(config.model_dir, 'keras')
# Load weights and save to checkpoint if there is no checkpoint
latest_path = saver_lib.latest_checkpoint(keras_model_dir)
latest_path = checkpoint_management.latest_checkpoint(keras_model_dir)
if not latest_path:
keras_weights = None
if _any_weight_initialized(keras_model):

View File

@ -55,6 +55,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
@ -78,7 +79,7 @@ def freeze_graph_with_def_protos(input_graph_def,
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if (not input_saved_model_dir and
not saver_lib.checkpoint_exists(input_checkpoint)):
not checkpoint_management.checkpoint_exists(input_checkpoint)):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1

View File

@ -0,0 +1,406 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
"""Save and restore variables."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import re
from google.protobuf import text_format
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.framework import errors
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.util.tf_export import tf_export
def _GetCheckpointFilename(save_dir, latest_filename):
"""Returns a filename for storing the CheckpointState.
Args:
save_dir: The directory for saving and restoring checkpoints.
latest_filename: Name of the file in 'save_dir' that is used
to store the CheckpointState.
Returns:
The path of the file that contains the CheckpointState proto.
"""
if latest_filename is None:
latest_filename = "checkpoint"
return os.path.join(save_dir, latest_filename)
@tf_export("train.generate_checkpoint_state_proto")
def generate_checkpoint_state_proto(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None):
"""Generates a checkpoint state proto.
Args:
save_dir: Directory where the model was saved.
model_checkpoint_path: The checkpoint file.
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
Returns:
CheckpointState proto with model_checkpoint_path and
all_model_checkpoint_paths updated to either absolute paths or
relative paths to the current save_dir.
"""
if all_model_checkpoint_paths is None:
all_model_checkpoint_paths = []
if (not all_model_checkpoint_paths or
all_model_checkpoint_paths[-1] != model_checkpoint_path):
logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
model_checkpoint_path)
all_model_checkpoint_paths.append(model_checkpoint_path)
# Relative paths need to be rewritten to be relative to the "save_dir"
# if model_checkpoint_path already contains "save_dir".
if not os.path.isabs(save_dir):
if not os.path.isabs(model_checkpoint_path):
model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
for i in range(len(all_model_checkpoint_paths)):
p = all_model_checkpoint_paths[i]
if not os.path.isabs(p):
all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
coord_checkpoint_proto = CheckpointState(
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths)
return coord_checkpoint_proto
@tf_export("train.update_checkpoint_state")
def update_checkpoint_state(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
latest_filename=None):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
proto.
Args:
save_dir: Directory where the model was saved.
model_checkpoint_path: The checkpoint file.
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
Raises:
RuntimeError: If any of the model checkpoint paths conflict with the file
containing CheckpointSate.
"""
update_checkpoint_state_internal(
save_dir=save_dir,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths,
latest_filename=latest_filename,
save_relative_paths=False)
def update_checkpoint_state_internal(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
latest_filename=None,
save_relative_paths=False):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
proto.
Args:
save_dir: Directory where the model was saved.
model_checkpoint_path: The checkpoint file.
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
save_relative_paths: If `True`, will write relative paths to the checkpoint
state file.
Raises:
RuntimeError: If any of the model checkpoint paths conflict with the file
containing CheckpointSate.
"""
# Writes the "checkpoint" file for the coordinator for later restoration.
coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
if save_relative_paths:
if os.path.isabs(model_checkpoint_path):
rel_model_checkpoint_path = os.path.relpath(
model_checkpoint_path, save_dir)
else:
rel_model_checkpoint_path = model_checkpoint_path
rel_all_model_checkpoint_paths = []
for p in all_model_checkpoint_paths:
if os.path.isabs(p):
rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
else:
rel_all_model_checkpoint_paths.append(p)
ckpt = generate_checkpoint_state_proto(
save_dir,
rel_model_checkpoint_path,
all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
else:
ckpt = generate_checkpoint_state_proto(
save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths)
if coord_checkpoint_filename == ckpt.model_checkpoint_path:
raise RuntimeError("Save path '%s' conflicts with path used for "
"checkpoint state. Please use a different save path." %
model_checkpoint_path)
# Preventing potential read/write race condition by *atomically* writing to a
# file.
file_io.atomic_write_string_to_file(coord_checkpoint_filename,
text_format.MessageToString(ckpt))
@tf_export("train.get_checkpoint_state")
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
"""Returns CheckpointState proto from the "checkpoint" file.
If the "checkpoint" file contains a valid CheckpointState
proto, returns it.
Args:
checkpoint_dir: The directory of checkpoints.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
Returns:
A CheckpointState if the state was available, None
otherwise.
Raises:
ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
"""
ckpt = None
coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
latest_filename)
f = None
try:
# Check that the file exists before opening it to avoid
# many lines of errors from colossus in the logs.
if file_io.file_exists(coord_checkpoint_filename):
file_content = file_io.read_file_to_string(
coord_checkpoint_filename)
ckpt = CheckpointState()
text_format.Merge(file_content, ckpt)
if not ckpt.model_checkpoint_path:
raise ValueError("Invalid checkpoint state loaded from "
+ checkpoint_dir)
# For relative model_checkpoint_path and all_model_checkpoint_paths,
# prepend checkpoint_dir.
if not os.path.isabs(ckpt.model_checkpoint_path):
ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
ckpt.model_checkpoint_path)
for i in range(len(ckpt.all_model_checkpoint_paths)):
p = ckpt.all_model_checkpoint_paths[i]
if not os.path.isabs(p):
ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
except errors.OpError as e:
# It's ok if the file cannot be read
logging.warning("%s: %s", type(e).__name__, e)
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None
except text_format.ParseError as e:
logging.warning("%s: %s", type(e).__name__, e)
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None
finally:
if f:
f.close()
return ckpt
def _prefix_to_checkpoint_path(prefix, format_version):
"""Returns the pathname of a checkpoint file, given the checkpoint prefix.
For V1 checkpoint, simply returns the prefix itself (the data file). For V2,
returns the pathname to the index file.
Args:
prefix: a string, the prefix of a checkpoint.
format_version: the checkpoint format version that corresponds to the
prefix.
Returns:
The pathname of a checkpoint file, taking into account the checkpoint
format version.
"""
if format_version == saver_pb2.SaverDef.V2:
return prefix + ".index" # The index file identifies a checkpoint.
return prefix # Just the data file.
@tf_export("train.latest_checkpoint")
def latest_checkpoint(checkpoint_dir, latest_filename=None):
"""Finds the filename of latest saved checkpoint file.
Args:
checkpoint_dir: Directory where the variables were saved.
latest_filename: Optional name for the protocol buffer file that
contains the list of most recent checkpoint filenames.
See the corresponding argument to `Saver.save()`.
Returns:
The full path to the latest checkpoint or `None` if no checkpoint was found.
"""
# Pick the latest checkpoint based on checkpoint state.
ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
if ckpt and ckpt.model_checkpoint_path:
# Look for either a V2 path or a V1 path, with priority for V2.
v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
saver_pb2.SaverDef.V2)
v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
saver_pb2.SaverDef.V1)
if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
v1_path):
return ckpt.model_checkpoint_path
else:
logging.error("Couldn't match files for checkpoint %s",
ckpt.model_checkpoint_path)
return None
@tf_export("train.checkpoint_exists")
def checkpoint_exists(checkpoint_prefix):
"""Checks whether a V1 or V2 checkpoint exists with the specified prefix.
This is the recommended way to check if a checkpoint exists, since it takes
into account the naming difference between V1 and V2 formats.
Args:
checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
priority. Typically the result of `Saver.save()` or that of
`tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
V1/V2.
Returns:
A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
"""
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
saver_pb2.SaverDef.V2)
if file_io.get_matching_files(pathname):
return True
elif file_io.get_matching_files(checkpoint_prefix):
return True
else:
return False
@tf_export("train.get_checkpoint_mtimes")
def get_checkpoint_mtimes(checkpoint_prefixes):
"""Returns the mtimes (modification timestamps) of the checkpoints.
Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files
exist, collect their mtime. Both V2 and V1 checkpoints are considered, in
that priority.
This is the recommended way to get the mtimes, since it takes into account
the naming difference between V1 and V2 formats.
Args:
checkpoint_prefixes: a list of checkpoint paths, typically the results of
`Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
sharded/non-sharded or V1/V2.
Returns:
A list of mtimes (in microseconds) of the found checkpoints.
"""
mtimes = []
def match_maybe_append(pathname):
fnames = file_io.get_matching_files(pathname)
if fnames:
mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
return True
return False
for checkpoint_prefix in checkpoint_prefixes:
# Tries V2's metadata file first.
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
saver_pb2.SaverDef.V2)
if match_maybe_append(pathname):
continue
# Otherwise, tries V1, where the prefix is the complete pathname.
match_maybe_append(checkpoint_prefix)
return mtimes
@tf_export("train.remove_checkpoint")
def remove_checkpoint(checkpoint_prefix,
checkpoint_format_version=saver_pb2.SaverDef.V2,
meta_graph_suffix="meta"):
"""Removes a checkpoint given by `checkpoint_prefix`.
Args:
checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
sharded/non-sharded or V1/V2.
checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
`SaverDef.V2`.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
"""
_delete_file_if_exists(
meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
if checkpoint_format_version == saver_pb2.SaverDef.V2:
# V2 has a metadata file and some data files.
_delete_file_if_exists(checkpoint_prefix + ".index")
_delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
else:
# V1, Legacy. Exact match on the data file.
_delete_file_if_exists(checkpoint_prefix)
def _delete_file_if_exists(filespec):
"""Deletes files matching `filespec`."""
for pathname in file_io.get_matching_files(filespec):
file_io.delete_file(pathname)
def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
"""Returns the meta graph filename.
Args:
checkpoint_filename: Name of the checkpoint file.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
Returns:
MetaGraph file name.
"""
# If the checkpoint_filename is sharded, the checkpoint_filename could
# be of format model.ckpt-step#-?????-of-shard#. For example,
# model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
suffixed_filename = ".".join([basename, meta_graph_suffix])
return suffixed_filename

View File

@ -0,0 +1,316 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for tensorflow.python.training.saver.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import os
import shutil
import tempfile
from google.protobuf import text_format
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
class LatestCheckpointWithRelativePaths(test.TestCase):
@staticmethod
@contextlib.contextmanager
def tempWorkingDir(temppath):
cwd = os.getcwd()
os.chdir(temppath)
try:
yield
finally:
os.chdir(cwd)
@staticmethod
@contextlib.contextmanager
def tempDir():
tempdir = tempfile.mkdtemp()
try:
yield tempdir
finally:
shutil.rmtree(tempdir)
def testNameCollision(self):
# Make sure we have a clean directory to work in.
with self.tempDir() as tempdir:
# Jump to that directory until this test is done.
with self.tempWorkingDir(tempdir):
# Save training snapshots to a relative path.
traindir = "train/"
os.mkdir(traindir)
# Collides with the default name of the checkpoint state file.
filepath = os.path.join(traindir, "checkpoint")
with self.test_session() as sess:
unused_a = variables.Variable(0.0) # So that Saver saves something.
variables.global_variables_initializer().run()
# Should fail.
saver = saver_module.Saver(sharded=False)
with self.assertRaisesRegexp(ValueError, "collides with"):
saver.save(sess, filepath)
# Succeeds: the file will be named "checkpoint-<step>".
saver.save(sess, filepath, global_step=1)
self.assertIsNotNone(
checkpoint_management.latest_checkpoint(traindir))
# Succeeds: the file will be named "checkpoint-<i>-of-<n>".
saver = saver_module.Saver(sharded=True)
saver.save(sess, filepath)
self.assertIsNotNone(
checkpoint_management.latest_checkpoint(traindir))
# Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
saver = saver_module.Saver(sharded=True)
saver.save(sess, filepath, global_step=1)
self.assertIsNotNone(
checkpoint_management.latest_checkpoint(traindir))
def testRelativePath(self):
# Make sure we have a clean directory to work in.
with self.tempDir() as tempdir:
# Jump to that directory until this test is done.
with self.tempWorkingDir(tempdir):
# Save training snapshots to a relative path.
traindir = "train/"
os.mkdir(traindir)
filename = "snapshot"
filepath = os.path.join(traindir, filename)
with self.test_session() as sess:
# Build a simple graph.
v0 = variables.Variable(0.0)
inc = v0.assign_add(1.0)
save = saver_module.Saver({"v0": v0})
# Record a short training history.
variables.global_variables_initializer().run()
save.save(sess, filepath, global_step=0)
inc.eval()
save.save(sess, filepath, global_step=1)
inc.eval()
save.save(sess, filepath, global_step=2)
with self.test_session() as sess:
# Build a new graph with different initialization.
v0 = variables.Variable(-1.0)
# Create a new saver.
save = saver_module.Saver({"v0": v0})
variables.global_variables_initializer().run()
# Get the most recent checkpoint name from the training history file.
name = checkpoint_management.latest_checkpoint(traindir)
self.assertIsNotNone(name)
# Restore "v0" from that checkpoint.
save.restore(sess, name)
self.assertEqual(v0.eval(), 2.0)
class CheckpointStateTest(test.TestCase):
def _get_test_dir(self, dirname):
test_dir = os.path.join(self.get_temp_dir(), dirname)
gfile.MakeDirs(test_dir)
return test_dir
def testAbsPath(self):
save_dir = self._get_test_dir("abs_paths")
abs_path = os.path.join(save_dir, "model-0")
ckpt = checkpoint_management.generate_checkpoint_state_proto(
save_dir, abs_path)
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
def testRelPath(self):
train_dir = "train"
model = os.path.join(train_dir, "model-0")
# model_checkpoint_path should have no "train" directory part.
new_rel_path = "model-0"
ckpt = checkpoint_management.generate_checkpoint_state_proto(
train_dir, model)
self.assertEqual(ckpt.model_checkpoint_path, new_rel_path)
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
def testAllModelCheckpointPaths(self):
save_dir = self._get_test_dir("all_models_test")
abs_path = os.path.join(save_dir, "model-0")
for paths in [None, [], ["model-2"]]:
ckpt = checkpoint_management.generate_checkpoint_state_proto(
save_dir, abs_path, all_model_checkpoint_paths=paths)
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
self.assertEqual(
len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1)
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
def testUpdateCheckpointState(self):
save_dir = self._get_test_dir("update_checkpoint_state")
os.chdir(save_dir)
# Make a temporary train directory.
train_dir = "train"
os.mkdir(train_dir)
abs_path = os.path.join(save_dir, "model-0")
rel_path = os.path.join("train", "model-2")
checkpoint_management.update_checkpoint_state(
train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])
ckpt = checkpoint_management.get_checkpoint_state(train_dir)
self.assertEqual(ckpt.model_checkpoint_path, rel_path)
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)
self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
def testUpdateCheckpointStateSaveRelativePaths(self):
save_dir = self._get_test_dir("update_checkpoint_state")
os.chdir(save_dir)
abs_path2 = os.path.join(save_dir, "model-2")
rel_path2 = "model-2"
abs_path0 = os.path.join(save_dir, "model-0")
rel_path0 = "model-0"
checkpoint_management.update_checkpoint_state_internal(
save_dir=save_dir,
model_checkpoint_path=abs_path2,
all_model_checkpoint_paths=[rel_path0, abs_path2],
save_relative_paths=True)
# File should contain relative paths.
file_content = file_io.read_file_to_string(
os.path.join(save_dir, "checkpoint"))
ckpt = CheckpointState()
text_format.Merge(file_content, ckpt)
self.assertEqual(ckpt.model_checkpoint_path, rel_path2)
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)
self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)
# get_checkpoint_state should return absolute paths.
ckpt = checkpoint_management.get_checkpoint_state(save_dir)
self.assertEqual(ckpt.model_checkpoint_path, abs_path2)
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)
self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
def testCheckPointStateFailsWhenIncomplete(self):
save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
os.chdir(save_dir)
ckpt_path = os.path.join(save_dir, "checkpoint")
ckpt_file = open(ckpt_path, "w")
ckpt_file.write("")
ckpt_file.close()
with self.assertRaises(ValueError):
checkpoint_management.get_checkpoint_state(save_dir)
def testCheckPointCompletesRelativePaths(self):
save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
os.chdir(save_dir)
ckpt_path = os.path.join(save_dir, "checkpoint")
ckpt_file = open(ckpt_path, "w")
ckpt_file.write("""
model_checkpoint_path: "./model.ckpt-687529"
all_model_checkpoint_paths: "./model.ckpt-687500"
all_model_checkpoint_paths: "./model.ckpt-687529"
""")
ckpt_file.close()
ckpt = checkpoint_management.get_checkpoint_state(save_dir)
self.assertEqual(ckpt.model_checkpoint_path,
os.path.join(save_dir, "./model.ckpt-687529"))
self.assertEqual(ckpt.all_model_checkpoint_paths[0],
os.path.join(save_dir, "./model.ckpt-687500"))
self.assertEqual(ckpt.all_model_checkpoint_paths[1],
os.path.join(save_dir, "./model.ckpt-687529"))
class SaverUtilsTest(test.TestCase):
def setUp(self):
self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test")
gfile.MakeDirs(self._base_dir)
def tearDown(self):
gfile.DeleteRecursively(self._base_dir)
def testCheckpointExists(self):
for sharded in (False, True):
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
with self.test_session(graph=ops_lib.Graph()) as sess:
unused_v = variables.Variable(1.0, name="v")
variables.global_variables_initializer().run()
saver = saver_module.Saver(sharded=sharded, write_version=version)
path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
self.assertFalse(
checkpoint_management.checkpoint_exists(path)) # Not saved yet.
ckpt_prefix = saver.save(sess, path)
self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
ckpt_prefix = checkpoint_management.latest_checkpoint(self._base_dir)
self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
def testGetCheckpointMtimes(self):
prefixes = []
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
with self.test_session(graph=ops_lib.Graph()) as sess:
unused_v = variables.Variable(1.0, name="v")
variables.global_variables_initializer().run()
saver = saver_module.Saver(write_version=version)
prefixes.append(
saver.save(sess, os.path.join(self._base_dir, str(version))))
mtimes = checkpoint_management.get_checkpoint_mtimes(prefixes)
self.assertEqual(2, len(mtimes))
self.assertTrue(mtimes[1] >= mtimes[0])
def testRemoveCheckpoint(self):
for sharded in (False, True):
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
with self.test_session(graph=ops_lib.Graph()) as sess:
unused_v = variables.Variable(1.0, name="v")
variables.global_variables_initializer().run()
saver = saver_module.Saver(sharded=sharded, write_version=version)
path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
ckpt_prefix = saver.save(sess, path)
self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
checkpoint_management.remove_checkpoint(ckpt_prefix, version)
self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
if __name__ == "__main__":
test.main()

View File

@ -28,6 +28,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import tf_export
@ -277,7 +278,7 @@ def _init_from_checkpoint(_, ckpt_dir_or_file, assignment_map):
def _get_checkpoint_filename(ckpt_dir_or_file):
"""Returns checkpoint filename given directory or specific checkpoint file."""
if gfile.IsDirectory(ckpt_dir_or_file):
return saver.latest_checkpoint(ckpt_dir_or_file)
return checkpoint_management.latest_checkpoint(ckpt_dir_or_file)
return ckpt_dir_or_file

View File

@ -124,14 +124,18 @@ py_test(
],
deps = [
":base",
":tracking",
":util",
"//tensorflow/python:checkpoint_management",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:init_ops",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:saver",
"//tensorflow/python:session",
"//tensorflow/python:state_ops",
"//tensorflow/python:template",

View File

@ -42,6 +42,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import adam
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base
@ -467,7 +468,8 @@ class CheckpointingTests(test.TestCase):
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model,
optimizer_step=training_util.get_or_create_global_step())
root.restore(saver_lib.latest_checkpoint(checkpoint_directory))
root.restore(checkpoint_management.latest_checkpoint(
checkpoint_directory))
for _ in range(num_training_steps):
# TODO(allenl): Use a Dataset and serialize/checkpoint it.
input_value = constant_op.constant([[3.]])
@ -495,7 +497,8 @@ class CheckpointingTests(test.TestCase):
train_op = optimizer.minimize(
model(input_value),
global_step=root.global_step)
checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
checkpoint_path = checkpoint_management.latest_checkpoint(
checkpoint_directory)
with self.test_session(graph=ops.get_default_graph()) as session:
status = root.restore(save_path=checkpoint_path)
status.initialize_or_restore(session=session)
@ -528,7 +531,8 @@ class CheckpointingTests(test.TestCase):
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
checkpoint_path = checkpoint_management.latest_checkpoint(
checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
@ -561,7 +565,8 @@ class CheckpointingTests(test.TestCase):
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
checkpoint_path = checkpoint_management.latest_checkpoint(
checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
def train_fn():
@function.defun
@ -1180,7 +1185,8 @@ class CheckpointingTests(test.TestCase):
optimizer_checkpoint = checkpointable_utils.Checkpoint(
optimizer=optimizer)
checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
checkpoint_path = checkpoint_management.latest_checkpoint(
checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(

View File

@ -44,6 +44,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import coordinator
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib
@ -1364,8 +1365,8 @@ class MonitoredSessionTest(test.TestCase):
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold,
checkpoint_filename_with_path=saver_lib.latest_checkpoint(
logdir))) as session:
checkpoint_filename_with_path=
checkpoint_management.latest_checkpoint(logdir))) as session:
self.assertEqual(2, session.run(gstep))
def test_retry_initialization_on_aborted_error(self):

View File

@ -21,15 +21,12 @@ from __future__ import print_function
import collections
import os.path
import re
import time
import uuid
import numpy as np
import six
from google.protobuf import text_format
from tensorflow.core.protobuf import checkpointable_object_graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saver_pb2
@ -41,7 +38,6 @@ from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops
@ -52,14 +48,19 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saveable_object
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
# TODO(allenl): Remove these aliases once all users are migrated off.
get_checkpoint_state = checkpoint_management.get_checkpoint_state
update_checkpoint_state = checkpoint_management.update_checkpoint_state
# Op names which identify variable reads which should be saved.
_VARIABLE_OPS = set(["Variable",
"VariableV2",
@ -858,218 +859,6 @@ def _get_saver_or_default():
return saver
def _GetCheckpointFilename(save_dir, latest_filename):
"""Returns a filename for storing the CheckpointState.
Args:
save_dir: The directory for saving and restoring checkpoints.
latest_filename: Name of the file in 'save_dir' that is used
to store the CheckpointState.
Returns:
The path of the file that contains the CheckpointState proto.
"""
if latest_filename is None:
latest_filename = "checkpoint"
return os.path.join(save_dir, latest_filename)
@tf_export("train.generate_checkpoint_state_proto")
def generate_checkpoint_state_proto(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None):
"""Generates a checkpoint state proto.
Args:
save_dir: Directory where the model was saved.
model_checkpoint_path: The checkpoint file.
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
Returns:
CheckpointState proto with model_checkpoint_path and
all_model_checkpoint_paths updated to either absolute paths or
relative paths to the current save_dir.
"""
if all_model_checkpoint_paths is None:
all_model_checkpoint_paths = []
if (not all_model_checkpoint_paths or
all_model_checkpoint_paths[-1] != model_checkpoint_path):
logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
model_checkpoint_path)
all_model_checkpoint_paths.append(model_checkpoint_path)
# Relative paths need to be rewritten to be relative to the "save_dir"
# if model_checkpoint_path already contains "save_dir".
if not os.path.isabs(save_dir):
if not os.path.isabs(model_checkpoint_path):
model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
for i in range(len(all_model_checkpoint_paths)):
p = all_model_checkpoint_paths[i]
if not os.path.isabs(p):
all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
coord_checkpoint_proto = CheckpointState(
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths)
return coord_checkpoint_proto
@tf_export("train.update_checkpoint_state")
def update_checkpoint_state(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
latest_filename=None):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
proto.
Args:
save_dir: Directory where the model was saved.
model_checkpoint_path: The checkpoint file.
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
Raises:
RuntimeError: If any of the model checkpoint paths conflict with the file
containing CheckpointSate.
"""
_update_checkpoint_state(
save_dir=save_dir,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths,
latest_filename=latest_filename,
save_relative_paths=False)
def _update_checkpoint_state(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
latest_filename=None,
save_relative_paths=False):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
proto.
Args:
save_dir: Directory where the model was saved.
model_checkpoint_path: The checkpoint file.
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
save_relative_paths: If `True`, will write relative paths to the checkpoint
state file.
Raises:
RuntimeError: If any of the model checkpoint paths conflict with the file
containing CheckpointSate.
"""
# Writes the "checkpoint" file for the coordinator for later restoration.
coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
if save_relative_paths:
if os.path.isabs(model_checkpoint_path):
rel_model_checkpoint_path = os.path.relpath(
model_checkpoint_path, save_dir)
else:
rel_model_checkpoint_path = model_checkpoint_path
rel_all_model_checkpoint_paths = []
for p in all_model_checkpoint_paths:
if os.path.isabs(p):
rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
else:
rel_all_model_checkpoint_paths.append(p)
ckpt = generate_checkpoint_state_proto(
save_dir,
rel_model_checkpoint_path,
all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
else:
ckpt = generate_checkpoint_state_proto(
save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths)
if coord_checkpoint_filename == ckpt.model_checkpoint_path:
raise RuntimeError("Save path '%s' conflicts with path used for "
"checkpoint state. Please use a different save path." %
model_checkpoint_path)
# Preventing potential read/write race condition by *atomically* writing to a
# file.
file_io.atomic_write_string_to_file(coord_checkpoint_filename,
text_format.MessageToString(ckpt))
@tf_export("train.get_checkpoint_state")
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
"""Returns CheckpointState proto from the "checkpoint" file.
If the "checkpoint" file contains a valid CheckpointState
proto, returns it.
Args:
checkpoint_dir: The directory of checkpoints.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
Returns:
A CheckpointState if the state was available, None
otherwise.
Raises:
ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
"""
ckpt = None
coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
latest_filename)
f = None
try:
# Check that the file exists before opening it to avoid
# many lines of errors from colossus in the logs.
if file_io.file_exists(coord_checkpoint_filename):
file_content = file_io.read_file_to_string(
coord_checkpoint_filename)
ckpt = CheckpointState()
text_format.Merge(file_content, ckpt)
if not ckpt.model_checkpoint_path:
raise ValueError("Invalid checkpoint state loaded from "
+ checkpoint_dir)
# For relative model_checkpoint_path and all_model_checkpoint_paths,
# prepend checkpoint_dir.
if not os.path.isabs(ckpt.model_checkpoint_path):
ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
ckpt.model_checkpoint_path)
for i in range(len(ckpt.all_model_checkpoint_paths)):
p = ckpt.all_model_checkpoint_paths[i]
if not os.path.isabs(p):
ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
except errors.OpError as e:
# It's ok if the file cannot be read
logging.warning("%s: %s", type(e).__name__, e)
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None
except text_format.ParseError as e:
logging.warning("%s: %s", type(e).__name__, e)
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None
finally:
if f:
f.close()
return ckpt
@tf_export("train.Saver")
class Saver(object):
"""Saves and restores variables.
@ -1412,7 +1201,7 @@ class Saver(object):
# Otherwise delete the files.
try:
remove_checkpoint(
checkpoint_management.remove_checkpoint(
self._CheckpointFilename(p), self.saver_def.version,
meta_graph_suffix)
except Exception as e: # pylint: disable=broad-except
@ -1518,7 +1307,7 @@ class Saver(object):
Args:
checkpoint_paths: a list of checkpoint paths.
"""
mtimes = get_checkpoint_mtimes(checkpoint_paths)
mtimes = checkpoint_management.get_checkpoint_mtimes(checkpoint_paths)
self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes)))
def save(self,
@ -1624,7 +1413,7 @@ class Saver(object):
model_checkpoint_path = compat.as_str(model_checkpoint_path)
if write_state:
self._RecordLastCheckpoint(model_checkpoint_path)
_update_checkpoint_state(
checkpoint_management.update_checkpoint_state_internal(
save_dir=save_path_parent,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=self.last_checkpoints,
@ -1639,7 +1428,7 @@ class Saver(object):
raise exc
if write_meta_graph:
meta_graph_filename = _meta_graph_filename(
meta_graph_filename = checkpoint_management.meta_graph_filename(
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
if not context.executing_eagerly():
with sess.graph.as_default():
@ -1714,7 +1503,7 @@ class Saver(object):
if save_path is None:
raise ValueError("Can't load save_path when it is None.")
if not checkpoint_exists(compat.as_text(save_path)):
if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)):
raise ValueError("The passed save_path is not a valid checkpoint: "
+ compat.as_text(save_path))
@ -1800,55 +1589,6 @@ class Saver(object):
export_scope=export_scope)
def _prefix_to_checkpoint_path(prefix, format_version):
"""Returns the pathname of a checkpoint file, given the checkpoint prefix.
For V1 checkpoint, simply returns the prefix itself (the data file). For V2,
returns the pathname to the index file.
Args:
prefix: a string, the prefix of a checkpoint.
format_version: the checkpoint format version that corresponds to the
prefix.
Returns:
The pathname of a checkpoint file, taking into account the checkpoint
format version.
"""
if format_version == saver_pb2.SaverDef.V2:
return prefix + ".index" # The index file identifies a checkpoint.
return prefix # Just the data file.
@tf_export("train.latest_checkpoint")
def latest_checkpoint(checkpoint_dir, latest_filename=None):
"""Finds the filename of latest saved checkpoint file.
Args:
checkpoint_dir: Directory where the variables were saved.
latest_filename: Optional name for the protocol buffer file that
contains the list of most recent checkpoint filenames.
See the corresponding argument to `Saver.save()`.
Returns:
The full path to the latest checkpoint or `None` if no checkpoint was found.
"""
# Pick the latest checkpoint based on checkpoint state.
ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
if ckpt and ckpt.model_checkpoint_path:
# Look for either a V2 path or a V1 path, with priority for V2.
v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
saver_pb2.SaverDef.V2)
v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
saver_pb2.SaverDef.V1)
if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
v1_path):
return ckpt.model_checkpoint_path
else:
logging.error("Couldn't match files for checkpoint %s",
ckpt.model_checkpoint_path)
return None
@tf_export("train.import_meta_graph")
def import_meta_graph(meta_graph_or_file, clear_devices=False,
import_scope=None, **kwargs):
@ -2056,119 +1796,6 @@ def export_meta_graph(filename=None,
return meta_graph_def
@tf_export("train.checkpoint_exists")
def checkpoint_exists(checkpoint_prefix):
"""Checks whether a V1 or V2 checkpoint exists with the specified prefix.
This is the recommended way to check if a checkpoint exists, since it takes
into account the naming difference between V1 and V2 formats.
Args:
checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
priority. Typically the result of `Saver.save()` or that of
`tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
V1/V2.
Returns:
A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
"""
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
saver_pb2.SaverDef.V2)
if file_io.get_matching_files(pathname):
return True
elif file_io.get_matching_files(checkpoint_prefix):
return True
else:
return False
@tf_export("train.get_checkpoint_mtimes")
def get_checkpoint_mtimes(checkpoint_prefixes):
"""Returns the mtimes (modification timestamps) of the checkpoints.
Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files
exist, collect their mtime. Both V2 and V1 checkpoints are considered, in
that priority.
This is the recommended way to get the mtimes, since it takes into account
the naming difference between V1 and V2 formats.
Args:
checkpoint_prefixes: a list of checkpoint paths, typically the results of
`Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
sharded/non-sharded or V1/V2.
Returns:
A list of mtimes (in microseconds) of the found checkpoints.
"""
mtimes = []
def match_maybe_append(pathname):
fnames = file_io.get_matching_files(pathname)
if fnames:
mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
return True
return False
for checkpoint_prefix in checkpoint_prefixes:
# Tries V2's metadata file first.
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
saver_pb2.SaverDef.V2)
if match_maybe_append(pathname):
continue
# Otherwise, tries V1, where the prefix is the complete pathname.
match_maybe_append(checkpoint_prefix)
return mtimes
@tf_export("train.remove_checkpoint")
def remove_checkpoint(checkpoint_prefix,
checkpoint_format_version=saver_pb2.SaverDef.V2,
meta_graph_suffix="meta"):
"""Removes a checkpoint given by `checkpoint_prefix`.
Args:
checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
sharded/non-sharded or V1/V2.
checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
`SaverDef.V2`.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
"""
_delete_file_if_exists(
_meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
if checkpoint_format_version == saver_pb2.SaverDef.V2:
# V2 has a metadata file and some data files.
_delete_file_if_exists(checkpoint_prefix + ".index")
_delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
else:
# V1, Legacy. Exact match on the data file.
_delete_file_if_exists(checkpoint_prefix)
def _delete_file_if_exists(filespec):
"""Deletes files matching `filespec`."""
for pathname in file_io.get_matching_files(filespec):
file_io.delete_file(pathname)
def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
"""Returns the meta graph filename.
Args:
checkpoint_filename: Name of the checkpoint file.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
Returns:
MetaGraph file name.
"""
# If the checkpoint_filename is sharded, the checkpoint_filename could
# be of format model.ckpt-step#-?????-of-shard#. For example,
# model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
meta_graph_filename = ".".join([basename, meta_graph_suffix])
return meta_graph_filename
def _wrap_restore_error_with_msg(err, extra_verbiage):
err_msg = ("Restoring from checkpoint failed. This is most likely "
"due to {} from the checkpoint. Please ensure that you "

View File

@ -18,20 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import functools
import math
import os
import random
import shutil
import tempfile
import time
import numpy as np
import six
from google.protobuf.any_pb2 import Any
from google.protobuf import text_format
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
@ -71,12 +67,12 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.training import adam
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training import saver_test_utils
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.training.checkpointable import base as checkpointable_base
from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@ -343,11 +339,13 @@ class SaverTest(test.TestCase):
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path1, val)
self.assertEqual(saver_module.latest_checkpoint(save_dir1), save_path1)
self.assertEqual(
checkpoint_management.latest_checkpoint(save_dir1), save_path1)
save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2")
os.renames(save_dir1, save_dir2)
save_path2 = os.path.join(save_dir2, "save_copy_restore")
self.assertEqual(saver_module.latest_checkpoint(save_dir2), save_path2)
self.assertEqual(
checkpoint_management.latest_checkpoint(save_dir2), save_path2)
# Start a second session. In that session the parameter nodes
# have not been initialized either.
@ -857,7 +855,7 @@ class SaveRestoreShardedTest(test.TestCase):
self.assertEqual(save_path + "-?????-of-00002", val)
else:
self.assertEqual(save_path, val)
meta_graph_filename = saver_module._meta_graph_filename(val)
meta_graph_filename = checkpoint_management.meta_graph_filename(val)
self.assertEqual(save_path + ".meta", meta_graph_filename)
if save._write_version is saver_pb2.SaverDef.V1:
@ -951,11 +949,11 @@ class SaveRestoreShardedTest(test.TestCase):
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(
saver_module.latest_checkpoint(self.get_temp_dir()),
checkpoint_management.latest_checkpoint(self.get_temp_dir()),
os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002"))
else:
self.assertEqual(
saver_module.latest_checkpoint(self.get_temp_dir()),
checkpoint_management.latest_checkpoint(self.get_temp_dir()),
os.path.join(self.get_temp_dir(), "sharded_basics"))
def testSaverDef(self):
@ -1105,7 +1103,7 @@ class MaxToKeepTest(test.TestCase):
def assertCheckpointState(self, model_checkpoint_path,
all_model_checkpoint_paths, save_dir):
checkpoint_state = saver_module.get_checkpoint_state(save_dir)
checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir)
self.assertEqual(checkpoint_state.model_checkpoint_path,
model_checkpoint_path)
self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
@ -1113,7 +1111,7 @@ class MaxToKeepTest(test.TestCase):
def testMaxToKeepEager(self):
with context.eager_mode():
save_dir = self._get_test_dir("max_to_keep_non_sharded")
save_dir = self._get_test_dir("max_to_keep_eager")
v = variable_scope.variable(10.0, name="v")
save = saver_module.Saver({"v": v}, max_to_keep=2)
@ -1123,7 +1121,7 @@ class MaxToKeepTest(test.TestCase):
s1 = save.save(None, os.path.join(save_dir, "s1"))
self.assertEqual([s1], save.last_checkpoints)
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s1],
@ -1131,8 +1129,8 @@ class MaxToKeepTest(test.TestCase):
s2 = save.save(None, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s1, s2],
@ -1140,9 +1138,9 @@ class MaxToKeepTest(test.TestCase):
s3 = save.save(None, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s1))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(saver_module.checkpoint_exists(s3))
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertCheckpointState(
model_checkpoint_path=s3,
all_model_checkpoint_paths=[s2, s3],
@ -1157,9 +1155,9 @@ class MaxToKeepTest(test.TestCase):
# Adding s2 again (old s2 is removed first, then new s2 appended)
s2 = save.save(None, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s1))
self.assertTrue(saver_module.checkpoint_exists(s3))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
@ -1168,8 +1166,8 @@ class MaxToKeepTest(test.TestCase):
# Adding s1 (s3 should now be deleted as oldest in list)
s1 = save.save(None, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@ -1178,9 +1176,9 @@ class MaxToKeepTest(test.TestCase):
s2 = save2.save(None, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save2.last_checkpoints)
# Created by the first helper.
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
# Deleted by the first helper.
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
def testNonSharded(self):
save_dir = self._get_test_dir("max_to_keep_non_sharded")
@ -1193,7 +1191,7 @@ class MaxToKeepTest(test.TestCase):
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s1], save.last_checkpoints)
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s1],
@ -1201,8 +1199,8 @@ class MaxToKeepTest(test.TestCase):
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s1, s2],
@ -1210,9 +1208,9 @@ class MaxToKeepTest(test.TestCase):
s3 = save.save(sess, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s1))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(saver_module.checkpoint_exists(s3))
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertCheckpointState(
model_checkpoint_path=s3,
all_model_checkpoint_paths=[s2, s3],
@ -1231,15 +1229,18 @@ class MaxToKeepTest(test.TestCase):
# Adding s2 again (old s2 is removed first, then new s2 appended)
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s1))
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
self.assertFalse(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
self.assertTrue(saver_module.checkpoint_exists(s3))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s1)))
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertTrue(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s3)))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s2)))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
@ -1248,15 +1249,18 @@ class MaxToKeepTest(test.TestCase):
# Adding s1 (s3 should now be deleted as oldest in list)
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s3)))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
self.assertTrue(saver_module.checkpoint_exists(s1))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s2)))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@ -1268,16 +1272,19 @@ class MaxToKeepTest(test.TestCase):
s2 = save2.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save2.last_checkpoints)
# Created by the first helper.
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s1)))
# Deleted by the first helper.
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s3)))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s2)))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
@ -1286,15 +1293,18 @@ class MaxToKeepTest(test.TestCase):
# Adding s1 (s3 should now be deleted as oldest in list)
s1 = save2.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save2.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s3)))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
self.assertTrue(saver_module.checkpoint_exists(s1))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s2)))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@ -1306,16 +1316,19 @@ class MaxToKeepTest(test.TestCase):
s2 = save3.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s2], save3.last_checkpoints)
# Created by the first helper.
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s1)))
# Deleted by the first helper.
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s3)))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s2)))
# Even though the file for s1 exists, this saver isn't aware of it, which
# is why it doesn't end up in the checkpoint state.
self.assertCheckpointState(
@ -1326,15 +1339,18 @@ class MaxToKeepTest(test.TestCase):
# Adding s1 (s3 should not be deleted because helper is unaware of it)
s1 = save3.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save3.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s3)))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
self.assertTrue(saver_module.checkpoint_exists(s1))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s2)))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@ -1365,7 +1381,8 @@ class MaxToKeepTest(test.TestCase):
else:
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1)))
self.assertTrue(
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
@ -1373,27 +1390,32 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual(2, len(gfile.Glob(s1)))
else:
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1)))
self.assertTrue(
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s2)))
else:
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2)))
self.assertTrue(
gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
s3 = save.save(sess, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
self.assertEqual(0, len(gfile.Glob(s1 + "*")))
self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1)))
self.assertFalse(
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s2)))
else:
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2)))
self.assertTrue(
gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s3)))
else:
self.assertEqual(4, len(gfile.Glob(s3 + "*")))
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(
gfile.Exists(checkpoint_management.meta_graph_filename(s3)))
def testNoMaxToKeep(self):
save_dir = self._get_test_dir("no_max_to_keep")
@ -1408,20 +1430,20 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual([], save.last_checkpoints)
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([], save.last_checkpoints)
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([], save.last_checkpoints)
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
# Test max_to_keep being 0.
save2 = saver_module.Saver({"v": v}, max_to_keep=0)
self.assertEqual([], save2.last_checkpoints)
s1 = save2.save(sess, os.path.join(save_dir2, "s1"))
self.assertEqual([], save2.last_checkpoints)
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
s2 = save2.save(sess, os.path.join(save_dir2, "s2"))
self.assertEqual([], save2.last_checkpoints)
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
def testNoMetaGraph(self):
save_dir = self._get_test_dir("no_meta_graph")
@ -1432,8 +1454,9 @@ class MaxToKeepTest(test.TestCase):
variables.global_variables_initializer().run()
s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False)
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1)))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertFalse(
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
class KeepCheckpointEveryNHoursTest(test.TestCase):
@ -1489,10 +1512,10 @@ class KeepCheckpointEveryNHoursTest(test.TestCase):
self.assertEqual([s3, s4], save.last_checkpoints)
# Check that s1 is still here, but s2 is gone.
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertFalse(saver_module.checkpoint_exists(s2))
self.assertTrue(saver_module.checkpoint_exists(s3))
self.assertTrue(saver_module.checkpoint_exists(s4))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertFalse(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertTrue(checkpoint_management.checkpoint_exists(s4))
class SaveRestoreWithVariableNameMap(test.TestCase):
@ -1571,221 +1594,6 @@ class SaveRestoreWithVariableNameMap(test.TestCase):
self._testNonReshape(variables.Variable)
class LatestCheckpointWithRelativePaths(test.TestCase):
@staticmethod
@contextlib.contextmanager
def tempWorkingDir(temppath):
cwd = os.getcwd()
os.chdir(temppath)
try:
yield
finally:
os.chdir(cwd)
@staticmethod
@contextlib.contextmanager
def tempDir():
tempdir = tempfile.mkdtemp()
try:
yield tempdir
finally:
shutil.rmtree(tempdir)
def testNameCollision(self):
# Make sure we have a clean directory to work in.
with self.tempDir() as tempdir:
# Jump to that directory until this test is done.
with self.tempWorkingDir(tempdir):
# Save training snapshots to a relative path.
traindir = "train/"
os.mkdir(traindir)
# Collides with the default name of the checkpoint state file.
filepath = os.path.join(traindir, "checkpoint")
with self.test_session() as sess:
unused_a = variables.Variable(0.0) # So that Saver saves something.
variables.global_variables_initializer().run()
# Should fail.
saver = saver_module.Saver(sharded=False)
with self.assertRaisesRegexp(ValueError, "collides with"):
saver.save(sess, filepath)
# Succeeds: the file will be named "checkpoint-<step>".
saver.save(sess, filepath, global_step=1)
self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
# Succeeds: the file will be named "checkpoint-<i>-of-<n>".
saver = saver_module.Saver(sharded=True)
saver.save(sess, filepath)
self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
# Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
saver = saver_module.Saver(sharded=True)
saver.save(sess, filepath, global_step=1)
self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
def testRelativePath(self):
# Make sure we have a clean directory to work in.
with self.tempDir() as tempdir:
# Jump to that directory until this test is done.
with self.tempWorkingDir(tempdir):
# Save training snapshots to a relative path.
traindir = "train/"
os.mkdir(traindir)
filename = "snapshot"
filepath = os.path.join(traindir, filename)
with self.test_session() as sess:
# Build a simple graph.
v0 = variables.Variable(0.0)
inc = v0.assign_add(1.0)
save = saver_module.Saver({"v0": v0})
# Record a short training history.
variables.global_variables_initializer().run()
save.save(sess, filepath, global_step=0)
inc.eval()
save.save(sess, filepath, global_step=1)
inc.eval()
save.save(sess, filepath, global_step=2)
with self.test_session() as sess:
# Build a new graph with different initialization.
v0 = variables.Variable(-1.0)
# Create a new saver.
save = saver_module.Saver({"v0": v0})
variables.global_variables_initializer().run()
# Get the most recent checkpoint name from the training history file.
name = saver_module.latest_checkpoint(traindir)
self.assertIsNotNone(name)
# Restore "v0" from that checkpoint.
save.restore(sess, name)
self.assertEqual(v0.eval(), 2.0)
class CheckpointStateTest(test.TestCase):
def _get_test_dir(self, dirname):
test_dir = os.path.join(self.get_temp_dir(), dirname)
gfile.MakeDirs(test_dir)
return test_dir
def testAbsPath(self):
save_dir = self._get_test_dir("abs_paths")
abs_path = os.path.join(save_dir, "model-0")
ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path)
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
def testRelPath(self):
train_dir = "train"
model = os.path.join(train_dir, "model-0")
# model_checkpoint_path should have no "train" directory part.
new_rel_path = "model-0"
ckpt = saver_module.generate_checkpoint_state_proto(train_dir, model)
self.assertEqual(ckpt.model_checkpoint_path, new_rel_path)
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
def testAllModelCheckpointPaths(self):
save_dir = self._get_test_dir("all_models_test")
abs_path = os.path.join(save_dir, "model-0")
for paths in [None, [], ["model-2"]]:
ckpt = saver_module.generate_checkpoint_state_proto(
save_dir, abs_path, all_model_checkpoint_paths=paths)
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
self.assertEqual(
len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1)
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
def testUpdateCheckpointState(self):
save_dir = self._get_test_dir("update_checkpoint_state")
os.chdir(save_dir)
# Make a temporary train directory.
train_dir = "train"
os.mkdir(train_dir)
abs_path = os.path.join(save_dir, "model-0")
rel_path = os.path.join("train", "model-2")
saver_module.update_checkpoint_state(
train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])
ckpt = saver_module.get_checkpoint_state(train_dir)
self.assertEqual(ckpt.model_checkpoint_path, rel_path)
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)
self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
def testUpdateCheckpointStateSaveRelativePaths(self):
save_dir = self._get_test_dir("update_checkpoint_state")
os.chdir(save_dir)
abs_path2 = os.path.join(save_dir, "model-2")
rel_path2 = "model-2"
abs_path0 = os.path.join(save_dir, "model-0")
rel_path0 = "model-0"
saver_module._update_checkpoint_state( # pylint: disable=protected-access
save_dir=save_dir,
model_checkpoint_path=abs_path2,
all_model_checkpoint_paths=[rel_path0, abs_path2],
save_relative_paths=True)
# File should contain relative paths.
file_content = file_io.read_file_to_string(
os.path.join(save_dir, "checkpoint"))
ckpt = CheckpointState()
text_format.Merge(file_content, ckpt)
self.assertEqual(ckpt.model_checkpoint_path, rel_path2)
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)
self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)
# get_checkpoint_state should return absolute paths.
ckpt = saver_module.get_checkpoint_state(save_dir)
self.assertEqual(ckpt.model_checkpoint_path, abs_path2)
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)
self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
def testCheckPointStateFailsWhenIncomplete(self):
save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
os.chdir(save_dir)
ckpt_path = os.path.join(save_dir, "checkpoint")
ckpt_file = open(ckpt_path, "w")
ckpt_file.write("")
ckpt_file.close()
with self.assertRaises(ValueError):
saver_module.get_checkpoint_state(save_dir)
def testCheckPointCompletesRelativePaths(self):
save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
os.chdir(save_dir)
ckpt_path = os.path.join(save_dir, "checkpoint")
ckpt_file = open(ckpt_path, "w")
ckpt_file.write("""
model_checkpoint_path: "./model.ckpt-687529"
all_model_checkpoint_paths: "./model.ckpt-687500"
all_model_checkpoint_paths: "./model.ckpt-687529"
""")
ckpt_file.close()
ckpt = saver_module.get_checkpoint_state(save_dir)
self.assertEqual(ckpt.model_checkpoint_path,
os.path.join(save_dir, "./model.ckpt-687529"))
self.assertEqual(ckpt.all_model_checkpoint_paths[0],
os.path.join(save_dir, "./model.ckpt-687500"))
self.assertEqual(ckpt.all_model_checkpoint_paths[1],
os.path.join(save_dir, "./model.ckpt-687529"))
class MetaGraphTest(test.TestCase):
def _get_test_dir(self, dirname):
@ -2628,62 +2436,6 @@ class WriteGraphTest(test.TestCase):
self.assertTrue(os.path.exists(path))
class SaverUtilsTest(test.TestCase):
def setUp(self):
self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test")
gfile.MakeDirs(self._base_dir)
def tearDown(self):
gfile.DeleteRecursively(self._base_dir)
def testCheckpointExists(self):
for sharded in (False, True):
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
with self.test_session(graph=ops_lib.Graph()) as sess:
unused_v = variables.Variable(1.0, name="v")
variables.global_variables_initializer().run()
saver = saver_module.Saver(sharded=sharded, write_version=version)
path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
self.assertFalse(
saver_module.checkpoint_exists(path)) # Not saved yet.
ckpt_prefix = saver.save(sess, path)
self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
ckpt_prefix = saver_module.latest_checkpoint(self._base_dir)
self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
def testGetCheckpointMtimes(self):
prefixes = []
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
with self.test_session(graph=ops_lib.Graph()) as sess:
unused_v = variables.Variable(1.0, name="v")
variables.global_variables_initializer().run()
saver = saver_module.Saver(write_version=version)
prefixes.append(
saver.save(sess, os.path.join(self._base_dir, str(version))))
mtimes = saver_module.get_checkpoint_mtimes(prefixes)
self.assertEqual(2, len(mtimes))
self.assertTrue(mtimes[1] >= mtimes[0])
def testRemoveCheckpoint(self):
for sharded in (False, True):
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
with self.test_session(graph=ops_lib.Graph()) as sess:
unused_v = variables.Variable(1.0, name="v")
variables.global_variables_initializer().run()
saver = saver_module.Saver(sharded=sharded, write_version=version)
path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
ckpt_prefix = saver.save(sess, path)
self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
saver_module.remove_checkpoint(ckpt_prefix, version)
self.assertFalse(saver_module.checkpoint_exists(ckpt_prefix))
class ScopedGraphTest(test.TestCase):
def _get_test_dir(self, dirname):

View File

@ -24,7 +24,7 @@ from tensorflow.python.client import session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as saver_mod
from tensorflow.python.training import checkpoint_management
from tensorflow.python.util.tf_export import tf_export
@ -197,13 +197,13 @@ class SessionManager(object):
# Waits up until max_wait_secs for checkpoint to become available.
wait_time = 0
ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
while not ckpt or not ckpt.model_checkpoint_path:
if wait_for_checkpoint and wait_time < max_wait_secs:
logging.info("Waiting for checkpoint to be available.")
time.sleep(self._recovery_wait_secs)
wait_time += self._recovery_wait_secs
ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
else:
return sess, False

View File

@ -30,6 +30,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_manager
@ -174,13 +175,13 @@ class SessionManagerTest(test.TestCase):
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
self._test_recovered_variable(
checkpoint_filename_with_path=saver_lib.latest_checkpoint(
checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
checkpoint_dir))
# Cannot set both checkpoint_dir and checkpoint_filename_with_path.
with self.assertRaises(ValueError):
self._test_recovered_variable(
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=saver_lib.latest_checkpoint(
checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
checkpoint_dir))
def testWaitForSessionReturnsNoneAfterTimeout(self):

View File

@ -44,6 +44,7 @@ from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import server_lib
@ -83,7 +84,7 @@ class SupervisorTest(test.TestCase):
end_time = time.time() + timeout_secs
while time.time() < end_time:
if for_checkpoint:
if saver_lib.checkpoint_exists(pattern):
if checkpoint_management.checkpoint_exists(pattern):
return
else:
if len(gfile.Glob(pattern)) >= 1:

View File

@ -82,12 +82,12 @@ from tensorflow.python.training.monitored_session import WorkerSessionCreator
from tensorflow.python.training.monitored_session import MonitoredSession
from tensorflow.python.training.monitored_session import SingularMonitoredSession
from tensorflow.python.training.saver import Saver
from tensorflow.python.training.saver import checkpoint_exists
from tensorflow.python.training.saver import generate_checkpoint_state_proto
from tensorflow.python.training.saver import get_checkpoint_mtimes
from tensorflow.python.training.saver import get_checkpoint_state
from tensorflow.python.training.saver import latest_checkpoint
from tensorflow.python.training.saver import update_checkpoint_state
from tensorflow.python.training.checkpoint_management import checkpoint_exists
from tensorflow.python.training.checkpoint_management import generate_checkpoint_state_proto
from tensorflow.python.training.checkpoint_management import get_checkpoint_mtimes
from tensorflow.python.training.checkpoint_management import get_checkpoint_state
from tensorflow.python.training.checkpoint_management import latest_checkpoint
from tensorflow.python.training.checkpoint_management import update_checkpoint_state
from tensorflow.python.training.saver import export_meta_graph
from tensorflow.python.training.saver import import_meta_graph
from tensorflow.python.training.session_run_hook import SessionRunHook