Split checkpoint management utility functions out of saver.py
Pure refactor, in preparation for adding a higher level checkpoint management utility. This utility will also need to work with the Checkpoint proto, and globbing it on to saver.py seems dirty. PiperOrigin-RevId: 207179646
This commit is contained in:
parent
6fbbad97e2
commit
1bf206bc82
@ -28,6 +28,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
@ -55,7 +56,7 @@ class CheckpointInputPipelineHookTest(test.TestCase):
|
||||
def _read_vars(self, model_dir):
|
||||
"""Returns (global_step, latest_feature)."""
|
||||
with ops.Graph().as_default() as g:
|
||||
ckpt_path = saver_lib.latest_checkpoint(model_dir)
|
||||
ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
|
||||
meta_filename = ckpt_path + '.meta'
|
||||
saver_lib.import_meta_graph(meta_filename)
|
||||
saver = saver_lib.Saver()
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
@ -655,7 +656,7 @@ class DatasetSerializationTestBase(test.TestCase):
|
||||
return os.path.join(self.get_temp_dir(), "iterator")
|
||||
|
||||
def _latest_ckpt(self):
|
||||
return saver_lib.latest_checkpoint(self.get_temp_dir())
|
||||
return checkpoint_management.latest_checkpoint(self.get_temp_dir())
|
||||
|
||||
def _save(self, sess, saver):
|
||||
saver.save(sess, self._ckpt_path())
|
||||
|
@ -20,6 +20,7 @@ from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import session_run_hook
|
||||
|
||||
@ -206,7 +207,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
|
||||
|
||||
# Check if there is an existing checkpoint. If so, restore from it.
|
||||
# pylint: disable=protected-access
|
||||
latest_checkpoint_path = saver_lib.latest_checkpoint(
|
||||
latest_checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
self._checkpoint_saver_hook._checkpoint_dir,
|
||||
latest_filename=self._latest_filename)
|
||||
if latest_checkpoint_path:
|
||||
|
@ -37,7 +37,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
|
||||
|
||||
@ -314,7 +314,8 @@ class IteratorTest(test.TestCase):
|
||||
for i in range(5):
|
||||
iterator = datasets.Iterator(dataset)
|
||||
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
|
||||
checkpoint.restore(saver.latest_checkpoint(checkpoint_directory))
|
||||
checkpoint.restore(checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory))
|
||||
for j in range(2):
|
||||
self.assertEqual(i * 2 + j, iterator.get_next().numpy())
|
||||
checkpoint.save(file_prefix=checkpoint_prefix)
|
||||
|
@ -36,7 +36,7 @@ from third_party.examples.eager.spinn import spinn
|
||||
from tensorflow.contrib.summary import summary_test_util
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
# pylint: enable=g-bad-import-order
|
||||
|
||||
@ -422,7 +422,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
|
||||
# 5. Verify that checkpoints exist and contains all the expected variables.
|
||||
self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
|
||||
object_graph = checkpointable_utils.object_metadata(
|
||||
saver.latest_checkpoint(config.logdir))
|
||||
checkpoint_management.latest_checkpoint(config.logdir))
|
||||
ckpt_variable_names = set()
|
||||
for node in object_graph.nodes:
|
||||
for attribute in node.attributes:
|
||||
|
@ -27,7 +27,7 @@ from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import training as train
|
||||
|
||||
__all__ = [
|
||||
@ -40,7 +40,7 @@ __all__ = [
|
||||
def _get_checkpoint_filename(filepattern):
|
||||
"""Returns checkpoint filename given directory or specific filepattern."""
|
||||
if gfile.IsDirectory(filepattern):
|
||||
return saver.latest_checkpoint(filepattern)
|
||||
return checkpoint_management.latest_checkpoint(filepattern)
|
||||
return filepattern
|
||||
|
||||
|
||||
|
@ -72,6 +72,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.summary import summary as core_summary
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import device_setter
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver
|
||||
@ -891,7 +892,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
|
||||
|
||||
# Check that model has been trained (if nothing has been set explicitly).
|
||||
if not checkpoint_path:
|
||||
latest_path = saver.latest_checkpoint(self._model_dir)
|
||||
latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
|
||||
if not latest_path:
|
||||
raise NotFittedError(
|
||||
"Couldn't find trained model at %s." % self._model_dir)
|
||||
@ -956,7 +957,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
|
||||
as_iterable=True,
|
||||
iterate_batches=False):
|
||||
# Check that model has been trained.
|
||||
checkpoint_path = saver.latest_checkpoint(self._model_dir)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir)
|
||||
if not checkpoint_path:
|
||||
raise NotFittedError(
|
||||
"Couldn't find trained model at %s." % self._model_dir)
|
||||
@ -1364,7 +1365,7 @@ class Estimator(BaseEstimator):
|
||||
|
||||
if not checkpoint_path:
|
||||
# Locate the latest checkpoint
|
||||
checkpoint_path = saver.latest_checkpoint(self._model_dir)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir)
|
||||
if not checkpoint_path:
|
||||
raise NotFittedError(
|
||||
"Couldn't find trained model at %s." % self._model_dir)
|
||||
|
@ -41,7 +41,7 @@ from tensorflow.python.estimator import estimator as core_estimator
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import function_utils
|
||||
@ -95,7 +95,7 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener):
|
||||
# Load and cache the path of the most recent checkpoint to avoid duplicate
|
||||
# searches on GCS.
|
||||
logging.info("Checking for checkpoint in %s", self._model_dir)
|
||||
latest_path = saver.latest_checkpoint(self._model_dir)
|
||||
latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
|
||||
|
||||
if not latest_path:
|
||||
logging.warning("Skipping evaluation and export since model has not been "
|
||||
@ -516,7 +516,8 @@ class Experiment(object):
|
||||
start = time.time()
|
||||
|
||||
error_msg = None
|
||||
latest_path = saver.latest_checkpoint(self._estimator.model_dir)
|
||||
latest_path = checkpoint_management.latest_checkpoint(
|
||||
self._estimator.model_dir)
|
||||
if not latest_path:
|
||||
error_msg = ("Estimator is not fitted yet. "
|
||||
"Will start an evaluation when a checkpoint is ready.")
|
||||
@ -778,7 +779,8 @@ class Experiment(object):
|
||||
saving_listeners=self._saving_listeners)
|
||||
|
||||
logging.info("Evaluating model now.")
|
||||
latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir)
|
||||
latest_checkpoint = checkpoint_management.latest_checkpoint(
|
||||
self._estimator.model_dir)
|
||||
eval_result = self._call_evaluate(
|
||||
input_fn=self._eval_input_fn,
|
||||
steps=self._eval_steps,
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
|
||||
@ -124,7 +125,7 @@ class GraphActionsTest(test.TestCase):
|
||||
|
||||
# TODO(ptucker): Test number and contents of checkpoint files.
|
||||
def _assert_ckpt(self, output_dir, expected=True):
|
||||
ckpt_state = saver_lib.get_checkpoint_state(output_dir)
|
||||
ckpt_state = checkpoint_management.get_checkpoint_state(output_dir)
|
||||
if expected:
|
||||
pattern = '%s/model.ckpt-.*' % output_dir
|
||||
primary_ckpt_path = ckpt_state.model_checkpoint_path
|
||||
@ -434,7 +435,7 @@ class GraphActionsTrainTest(test.TestCase):
|
||||
|
||||
# TODO(ptucker): Test number and contents of checkpoint files.
|
||||
def _assert_ckpt(self, output_dir, expected=True):
|
||||
ckpt_state = saver_lib.get_checkpoint_state(output_dir)
|
||||
ckpt_state = checkpoint_management.get_checkpoint_state(output_dir)
|
||||
if expected:
|
||||
pattern = '%s/model.ckpt-.*' % output_dir
|
||||
primary_ckpt_path = ckpt_state.model_checkpoint_path
|
||||
|
@ -51,7 +51,7 @@ from tensorflow.python.estimator import estimator as core_estimator
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.summary import summary as core_summary
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.util import deprecation
|
||||
@ -735,7 +735,8 @@ class ValidationMonitor(EveryN):
|
||||
return False
|
||||
self._last_checkpoint_check_time = current_time
|
||||
# Check that we are not running evaluation on the same checkpoint.
|
||||
latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
|
||||
latest_path = checkpoint_management.latest_checkpoint(
|
||||
self._estimator.model_dir)
|
||||
if latest_path is None:
|
||||
logging.debug("Skipping evaluation since model has not been saved yet "
|
||||
"at step %d.", step)
|
||||
@ -1059,7 +1060,8 @@ class ExportMonitor(EveryN):
|
||||
|
||||
def end(self, session=None):
|
||||
super(ExportMonitor, self).end(session=session)
|
||||
latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
|
||||
latest_path = checkpoint_management.latest_checkpoint(
|
||||
self._estimator.model_dir)
|
||||
if latest_path is None:
|
||||
logging.info("Skipping export at the end since model has not been saved "
|
||||
"yet.")
|
||||
|
@ -39,9 +39,9 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
@ -317,7 +317,7 @@ class MonitorsTest(test.TestCase):
|
||||
self._run_monitor(monitor)
|
||||
|
||||
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
|
||||
@test.mock.patch.object(saver, 'latest_checkpoint')
|
||||
@test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
|
||||
def test_validation_monitor_no_ckpt(self, mock_latest_checkpoint,
|
||||
mock_estimator_class):
|
||||
estimator = mock_estimator_class()
|
||||
@ -336,7 +336,7 @@ class MonitorsTest(test.TestCase):
|
||||
mock_latest_checkpoint.assert_called_with(model_dir)
|
||||
|
||||
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
|
||||
@test.mock.patch.object(saver, 'latest_checkpoint')
|
||||
@test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
|
||||
def test_validation_monitor_no_early_stopping_rounds(self,
|
||||
mock_latest_checkpoint,
|
||||
mock_estimator_class):
|
||||
@ -356,7 +356,7 @@ class MonitorsTest(test.TestCase):
|
||||
self._assert_validation_monitor(monitor)
|
||||
|
||||
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
|
||||
@test.mock.patch.object(saver, 'latest_checkpoint')
|
||||
@test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
|
||||
def test_validation_monitor_invalid_metric(self, mock_latest_checkpoint,
|
||||
mock_estimator_class):
|
||||
estimator = mock_estimator_class()
|
||||
@ -375,7 +375,7 @@ class MonitorsTest(test.TestCase):
|
||||
self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1)
|
||||
|
||||
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
|
||||
@test.mock.patch.object(saver, 'latest_checkpoint')
|
||||
@test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
|
||||
def test_validation_monitor(self, mock_latest_checkpoint,
|
||||
mock_estimator_class):
|
||||
estimator = mock_estimator_class()
|
||||
@ -464,7 +464,7 @@ class MonitorsTest(test.TestCase):
|
||||
monitor.epoch_end(epoch=0)
|
||||
monitor.end()
|
||||
|
||||
@test.mock.patch.object(saver, 'latest_checkpoint')
|
||||
@test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
|
||||
def test_validation_monitor_with_core_estimator(self, mock_latest_checkpoint):
|
||||
estimator = test.mock.Mock(spec=core_estimator.Estimator)
|
||||
model_dir = 'model/dir'
|
||||
@ -495,7 +495,7 @@ class MonitorsTest(test.TestCase):
|
||||
expected_best_metrics={'loss': 42.0, 'auc': 0.5})
|
||||
monitor.post_step(step=step, session=None)
|
||||
|
||||
@test.mock.patch.object(saver, 'latest_checkpoint')
|
||||
@test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
|
||||
def test_validation_monitor_fail_with_core_estimator_and_metrics(
|
||||
self, mock_latest_checkpoint):
|
||||
estimator = test.mock.Mock(spec=core_estimator.Estimator)
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as tf_saver
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
@ -298,7 +299,8 @@ def _export_estimator(estimator,
|
||||
|
||||
# If checkpoint_path is specified, use the specified checkpoint path.
|
||||
checkpoint_path = (checkpoint_path or
|
||||
tf_saver.latest_checkpoint(estimator._model_dir))
|
||||
checkpoint_management.latest_checkpoint(
|
||||
estimator._model_dir))
|
||||
with ops.Graph().as_default() as g:
|
||||
training_util.create_global_step(g)
|
||||
|
||||
|
@ -55,7 +55,7 @@ from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model import signature_def_utils
|
||||
from tensorflow.python.summary import summary_iterator
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
|
||||
@ -714,7 +714,8 @@ def make_best_model_export_strategy(
|
||||
# as soon as contrib is cleaned up and we can thus be sure that
|
||||
# estimator is a tf.estimator.Estimator and not a
|
||||
# tf.contrib.learn.Estimator
|
||||
checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
estimator.model_dir)
|
||||
export_checkpoint_path, export_eval_result = best_model_selector.update(
|
||||
checkpoint_path, eval_result)
|
||||
|
||||
|
@ -41,6 +41,7 @@ from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import template
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as core_saver
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
@ -278,7 +279,8 @@ class CheckpointingTests(test.TestCase):
|
||||
root = util.Checkpoint(
|
||||
optimizer=optimizer, model=model,
|
||||
optimizer_step=training_util.get_or_create_global_step())
|
||||
root.restore(core_saver.latest_checkpoint(checkpoint_directory))
|
||||
root.restore(checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory))
|
||||
for _ in range(num_training_steps):
|
||||
# TODO(allenl): Use a Dataset and serialize/checkpoint it.
|
||||
input_value = constant_op.constant([[3.]])
|
||||
@ -306,7 +308,8 @@ class CheckpointingTests(test.TestCase):
|
||||
train_op = optimizer.minimize(
|
||||
model(input_value),
|
||||
global_step=root.global_step)
|
||||
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)
|
||||
with self.test_session(graph=ops.get_default_graph()) as session:
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
status.initialize_or_restore(session=session)
|
||||
@ -339,7 +342,8 @@ class CheckpointingTests(test.TestCase):
|
||||
root = util.Checkpoint(
|
||||
optimizer=optimizer, model=model,
|
||||
global_step=training_util.get_or_create_global_step())
|
||||
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
train_fn = functools.partial(
|
||||
@ -372,7 +376,8 @@ class CheckpointingTests(test.TestCase):
|
||||
root = util.Checkpoint(
|
||||
optimizer=optimizer, model=model,
|
||||
global_step=training_util.get_or_create_global_step())
|
||||
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
def train_fn():
|
||||
@function.defun
|
||||
|
@ -22,8 +22,8 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
|
||||
from tensorflow.contrib.predictor import predictor
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver
|
||||
|
||||
|
||||
class ContribEstimatorPredictor(predictor.Predictor):
|
||||
@ -57,7 +57,8 @@ class ContribEstimatorPredictor(predictor.Predictor):
|
||||
# pylint: disable=protected-access
|
||||
model_fn_ops = estimator._get_predict_ops(input_fn_ops.features)
|
||||
# pylint: enable=protected-access
|
||||
checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
estimator.model_dir)
|
||||
self._session = monitored_session.MonitoredSession(
|
||||
session_creator=monitored_session.ChiefSessionCreator(
|
||||
config=config,
|
||||
|
@ -142,9 +142,9 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import evaluation
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver as tf_saver
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
@ -189,7 +189,7 @@ def wait_for_new_checkpoint(checkpoint_dir,
|
||||
logging.info('Waiting for new checkpoint at %s', checkpoint_dir)
|
||||
stop_time = time.time() + timeout if timeout is not None else None
|
||||
while True:
|
||||
checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir)
|
||||
if checkpoint_path is None or checkpoint_path == last_checkpoint:
|
||||
if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
|
||||
return None
|
||||
|
@ -36,6 +36,7 @@ from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
@ -421,7 +422,7 @@ class TrainTest(test.TestCase):
|
||||
train_op = self.create_train_op()
|
||||
|
||||
model_variables = variables_lib2.global_variables()
|
||||
model_path = saver_lib.latest_checkpoint(logdir1)
|
||||
model_path = checkpoint_management.latest_checkpoint(logdir1)
|
||||
|
||||
assign_fn = variables_lib.assign_from_checkpoint_fn(
|
||||
model_path, model_variables)
|
||||
|
@ -3216,6 +3216,7 @@ py_library(
|
||||
"training/checkpointable/**/*.py",
|
||||
# The following targets have their own build rules (same name as the
|
||||
# file):
|
||||
"training/checkpoint_management.py",
|
||||
"training/saveable_object.py",
|
||||
"training/saver.py",
|
||||
"training/training_util.py",
|
||||
@ -3223,8 +3224,10 @@ py_library(
|
||||
),
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"saver",
|
||||
":array_ops",
|
||||
":array_ops_gen",
|
||||
":checkpoint_management",
|
||||
":checkpoint_ops_gen",
|
||||
":client",
|
||||
":control_flow_ops",
|
||||
@ -3236,25 +3239,20 @@ py_library(
|
||||
":framework_ops",
|
||||
":gradients",
|
||||
":init_ops",
|
||||
":distribute",
|
||||
":io_ops",
|
||||
":io_ops_gen",
|
||||
":layers_base",
|
||||
":lib",
|
||||
":lookup_ops",
|
||||
":math_ops",
|
||||
":platform",
|
||||
":protos_all_py",
|
||||
":pywrap_tensorflow",
|
||||
":random_ops",
|
||||
":resource_variable_ops",
|
||||
":resources",
|
||||
"saver",
|
||||
":saveable_object",
|
||||
":sdca_ops",
|
||||
":session",
|
||||
":sparse_ops",
|
||||
":sparse_tensor",
|
||||
":state_ops",
|
||||
":string_ops",
|
||||
":summary",
|
||||
":training_ops_gen",
|
||||
":training_util",
|
||||
@ -3264,6 +3262,7 @@ py_library(
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
# `layers` dependency only exists due to the use of a small utility.
|
||||
@ -3280,12 +3279,26 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "checkpoint_management",
|
||||
srcs = ["training/checkpoint_management.py"],
|
||||
deps = [
|
||||
":errors",
|
||||
":lib",
|
||||
":platform",
|
||||
":protos_all_py",
|
||||
":util",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "saver",
|
||||
srcs = ["training/saver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":array_ops",
|
||||
":checkpoint_management",
|
||||
":constant_op",
|
||||
":control_flow_ops",
|
||||
":device",
|
||||
@ -3294,9 +3307,7 @@ py_library(
|
||||
":framework_ops",
|
||||
":io_ops",
|
||||
":io_ops_gen",
|
||||
":lib",
|
||||
":platform",
|
||||
":protos_all_py",
|
||||
":pywrap_tensorflow",
|
||||
":resource_variable_ops",
|
||||
":saveable_object",
|
||||
@ -4423,6 +4434,42 @@ cuda_py_test(
|
||||
tags = ["multi_gpu"],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "checkpoint_management_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"training/checkpoint_management_test.py",
|
||||
],
|
||||
additional_deps = [
|
||||
":array_ops",
|
||||
":client_testlib",
|
||||
":control_flow_ops",
|
||||
":data_flow_ops",
|
||||
":errors",
|
||||
":gradients",
|
||||
":math_ops",
|
||||
":nn_grad",
|
||||
":nn_ops",
|
||||
":saver_test_utils",
|
||||
":partitioned_variables",
|
||||
":platform",
|
||||
":platform_test",
|
||||
":pywrap_tensorflow",
|
||||
":random_ops",
|
||||
":resource_variable_ops",
|
||||
":sparse_ops",
|
||||
":summary",
|
||||
":training",
|
||||
":util",
|
||||
":variable_scope",
|
||||
":variables",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "saver_large_variable_test",
|
||||
size = "medium",
|
||||
@ -4489,6 +4536,7 @@ tf_py_test(
|
||||
srcs = ["training/supervisor_test.py"],
|
||||
additional_deps = [
|
||||
":array_ops",
|
||||
":checkpoint_management",
|
||||
":client_testlib",
|
||||
":errors",
|
||||
":framework",
|
||||
@ -4496,6 +4544,7 @@ tf_py_test(
|
||||
":io_ops",
|
||||
":parsing_ops",
|
||||
":platform",
|
||||
":saver",
|
||||
":summary",
|
||||
":training",
|
||||
":variables",
|
||||
@ -4609,10 +4658,13 @@ py_test(
|
||||
tags = ["notsan"], # b/67945581
|
||||
deps = [
|
||||
":array_ops",
|
||||
":checkpoint_management",
|
||||
":client_testlib",
|
||||
":control_flow_ops",
|
||||
":errors",
|
||||
":framework_for_generated_wrappers",
|
||||
":resource_variable_ops",
|
||||
":saver",
|
||||
":session",
|
||||
":state_ops",
|
||||
":summary",
|
||||
|
@ -47,7 +47,7 @@ from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
from tensorflow.python.util import compat
|
||||
@ -877,7 +877,7 @@ class IteratorCheckpointingTest(test.TestCase):
|
||||
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
|
||||
for i in range(5):
|
||||
with self.test_session() as sess:
|
||||
checkpoint.restore(saver.latest_checkpoint(
|
||||
checkpoint.restore(checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)).initialize_or_restore(sess)
|
||||
for j in range(2):
|
||||
self.assertEqual(i * 2 + j, sess.run(get_next))
|
||||
|
@ -53,6 +53,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.summary.writer import writer_cache
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import device_setter
|
||||
from tensorflow.python.training import distribute as distribute_lib
|
||||
from tensorflow.python.training import evaluation
|
||||
@ -268,7 +269,7 @@ class Estimator(object):
|
||||
found.
|
||||
"""
|
||||
with context.graph_mode():
|
||||
return saver.latest_checkpoint(self.model_dir)
|
||||
return checkpoint_management.latest_checkpoint(self.model_dir)
|
||||
|
||||
def train(self,
|
||||
input_fn,
|
||||
@ -417,7 +418,7 @@ class Estimator(object):
|
||||
|
||||
# Check that model has been trained (if nothing has been set explicitly).
|
||||
if not checkpoint_path:
|
||||
latest_path = saver.latest_checkpoint(self._model_dir)
|
||||
latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
|
||||
if not latest_path:
|
||||
logging.info('Could not find trained model in model_dir: {}, running '
|
||||
'initialization to evaluate.'.format(self._model_dir))
|
||||
@ -504,7 +505,8 @@ class Estimator(object):
|
||||
hooks = _check_hooks_type(hooks)
|
||||
# Check that model has been trained.
|
||||
if not checkpoint_path:
|
||||
checkpoint_path = saver.latest_checkpoint(self._model_dir)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
self._model_dir)
|
||||
if not checkpoint_path:
|
||||
logging.info('Could not find trained model in model_dir: {}, running '
|
||||
'initialization to predict.'.format(self._model_dir))
|
||||
@ -769,7 +771,8 @@ class Estimator(object):
|
||||
with context.graph_mode():
|
||||
if not checkpoint_path:
|
||||
# Locate the latest checkpoint
|
||||
checkpoint_path = saver.latest_checkpoint(self._model_dir)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
self._model_dir)
|
||||
if not checkpoint_path:
|
||||
raise ValueError("Couldn't find trained model at %s." % self._model_dir)
|
||||
|
||||
@ -1626,7 +1629,7 @@ def _combine_distributed_scaffold(grouped_scaffold, distribution):
|
||||
|
||||
|
||||
def _check_checkpoint_available(model_dir):
|
||||
latest_path = saver.latest_checkpoint(model_dir)
|
||||
latest_path = checkpoint_management.latest_checkpoint(model_dir)
|
||||
if not latest_path:
|
||||
raise ValueError(
|
||||
'Could not find trained model in model_dir: {}.'.format(model_dir))
|
||||
|
@ -69,6 +69,7 @@ from tensorflow.python.summary import summary
|
||||
from tensorflow.python.summary import summary_iterator
|
||||
from tensorflow.python.summary.writer import writer_cache
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import checkpoint_state_pb2
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import saver_test_utils
|
||||
@ -1548,7 +1549,8 @@ class EstimatorPredictTest(test.TestCase):
|
||||
next(
|
||||
est.predict(
|
||||
dummy_input_fn,
|
||||
checkpoint_path=saver.latest_checkpoint('fakedir')))
|
||||
checkpoint_path=
|
||||
checkpoint_management.latest_checkpoint('fakedir')))
|
||||
|
||||
def test_tensor_predictions(self):
|
||||
|
||||
|
@ -42,6 +42,7 @@ from tensorflow.python.ops import metrics as metrics_module
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import distribute as distribute_lib
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import training_util
|
||||
@ -442,7 +443,7 @@ def _save_first_checkpoint(keras_model, custom_objects, config):
|
||||
# save checkpoint into subdirectory to allow warm start
|
||||
keras_model_dir = os.path.join(config.model_dir, 'keras')
|
||||
# Load weights and save to checkpoint if there is no checkpoint
|
||||
latest_path = saver_lib.latest_checkpoint(keras_model_dir)
|
||||
latest_path = checkpoint_management.latest_checkpoint(keras_model_dir)
|
||||
if not latest_path:
|
||||
keras_weights = None
|
||||
if _any_weight_initialized(keras_model):
|
||||
|
@ -55,6 +55,7 @@ from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.tools import saved_model_utils
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
|
||||
@ -78,7 +79,7 @@ def freeze_graph_with_def_protos(input_graph_def,
|
||||
|
||||
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
|
||||
if (not input_saved_model_dir and
|
||||
not saver_lib.checkpoint_exists(input_checkpoint)):
|
||||
not checkpoint_management.checkpoint_exists(input_checkpoint)):
|
||||
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
|
||||
return -1
|
||||
|
||||
|
406
tensorflow/python/training/checkpoint_management.py
Normal file
406
tensorflow/python/training/checkpoint_management.py
Normal file
@ -0,0 +1,406 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
"""Save and restore variables."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
import re
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def _GetCheckpointFilename(save_dir, latest_filename):
|
||||
"""Returns a filename for storing the CheckpointState.
|
||||
|
||||
Args:
|
||||
save_dir: The directory for saving and restoring checkpoints.
|
||||
latest_filename: Name of the file in 'save_dir' that is used
|
||||
to store the CheckpointState.
|
||||
|
||||
Returns:
|
||||
The path of the file that contains the CheckpointState proto.
|
||||
"""
|
||||
if latest_filename is None:
|
||||
latest_filename = "checkpoint"
|
||||
return os.path.join(save_dir, latest_filename)
|
||||
|
||||
|
||||
@tf_export("train.generate_checkpoint_state_proto")
|
||||
def generate_checkpoint_state_proto(save_dir,
|
||||
model_checkpoint_path,
|
||||
all_model_checkpoint_paths=None):
|
||||
"""Generates a checkpoint state proto.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the model was saved.
|
||||
model_checkpoint_path: The checkpoint file.
|
||||
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
|
||||
checkpoints, sorted from oldest to newest. If this is a non-empty list,
|
||||
the last element must be equal to model_checkpoint_path. These paths
|
||||
are also saved in the CheckpointState proto.
|
||||
|
||||
Returns:
|
||||
CheckpointState proto with model_checkpoint_path and
|
||||
all_model_checkpoint_paths updated to either absolute paths or
|
||||
relative paths to the current save_dir.
|
||||
"""
|
||||
if all_model_checkpoint_paths is None:
|
||||
all_model_checkpoint_paths = []
|
||||
|
||||
if (not all_model_checkpoint_paths or
|
||||
all_model_checkpoint_paths[-1] != model_checkpoint_path):
|
||||
logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
|
||||
model_checkpoint_path)
|
||||
all_model_checkpoint_paths.append(model_checkpoint_path)
|
||||
|
||||
# Relative paths need to be rewritten to be relative to the "save_dir"
|
||||
# if model_checkpoint_path already contains "save_dir".
|
||||
if not os.path.isabs(save_dir):
|
||||
if not os.path.isabs(model_checkpoint_path):
|
||||
model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
|
||||
for i in range(len(all_model_checkpoint_paths)):
|
||||
p = all_model_checkpoint_paths[i]
|
||||
if not os.path.isabs(p):
|
||||
all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
|
||||
|
||||
coord_checkpoint_proto = CheckpointState(
|
||||
model_checkpoint_path=model_checkpoint_path,
|
||||
all_model_checkpoint_paths=all_model_checkpoint_paths)
|
||||
|
||||
return coord_checkpoint_proto
|
||||
|
||||
|
||||
@tf_export("train.update_checkpoint_state")
|
||||
def update_checkpoint_state(save_dir,
|
||||
model_checkpoint_path,
|
||||
all_model_checkpoint_paths=None,
|
||||
latest_filename=None):
|
||||
"""Updates the content of the 'checkpoint' file.
|
||||
|
||||
This updates the checkpoint file containing a CheckpointState
|
||||
proto.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the model was saved.
|
||||
model_checkpoint_path: The checkpoint file.
|
||||
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
|
||||
checkpoints, sorted from oldest to newest. If this is a non-empty list,
|
||||
the last element must be equal to model_checkpoint_path. These paths
|
||||
are also saved in the CheckpointState proto.
|
||||
latest_filename: Optional name of the checkpoint file. Default to
|
||||
'checkpoint'.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If any of the model checkpoint paths conflict with the file
|
||||
containing CheckpointSate.
|
||||
"""
|
||||
update_checkpoint_state_internal(
|
||||
save_dir=save_dir,
|
||||
model_checkpoint_path=model_checkpoint_path,
|
||||
all_model_checkpoint_paths=all_model_checkpoint_paths,
|
||||
latest_filename=latest_filename,
|
||||
save_relative_paths=False)
|
||||
|
||||
|
||||
def update_checkpoint_state_internal(save_dir,
|
||||
model_checkpoint_path,
|
||||
all_model_checkpoint_paths=None,
|
||||
latest_filename=None,
|
||||
save_relative_paths=False):
|
||||
"""Updates the content of the 'checkpoint' file.
|
||||
|
||||
This updates the checkpoint file containing a CheckpointState
|
||||
proto.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the model was saved.
|
||||
model_checkpoint_path: The checkpoint file.
|
||||
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
|
||||
checkpoints, sorted from oldest to newest. If this is a non-empty list,
|
||||
the last element must be equal to model_checkpoint_path. These paths
|
||||
are also saved in the CheckpointState proto.
|
||||
latest_filename: Optional name of the checkpoint file. Default to
|
||||
'checkpoint'.
|
||||
save_relative_paths: If `True`, will write relative paths to the checkpoint
|
||||
state file.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If any of the model checkpoint paths conflict with the file
|
||||
containing CheckpointSate.
|
||||
"""
|
||||
# Writes the "checkpoint" file for the coordinator for later restoration.
|
||||
coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
|
||||
if save_relative_paths:
|
||||
if os.path.isabs(model_checkpoint_path):
|
||||
rel_model_checkpoint_path = os.path.relpath(
|
||||
model_checkpoint_path, save_dir)
|
||||
else:
|
||||
rel_model_checkpoint_path = model_checkpoint_path
|
||||
rel_all_model_checkpoint_paths = []
|
||||
for p in all_model_checkpoint_paths:
|
||||
if os.path.isabs(p):
|
||||
rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
|
||||
else:
|
||||
rel_all_model_checkpoint_paths.append(p)
|
||||
ckpt = generate_checkpoint_state_proto(
|
||||
save_dir,
|
||||
rel_model_checkpoint_path,
|
||||
all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
|
||||
else:
|
||||
ckpt = generate_checkpoint_state_proto(
|
||||
save_dir,
|
||||
model_checkpoint_path,
|
||||
all_model_checkpoint_paths=all_model_checkpoint_paths)
|
||||
|
||||
if coord_checkpoint_filename == ckpt.model_checkpoint_path:
|
||||
raise RuntimeError("Save path '%s' conflicts with path used for "
|
||||
"checkpoint state. Please use a different save path." %
|
||||
model_checkpoint_path)
|
||||
|
||||
# Preventing potential read/write race condition by *atomically* writing to a
|
||||
# file.
|
||||
file_io.atomic_write_string_to_file(coord_checkpoint_filename,
|
||||
text_format.MessageToString(ckpt))
|
||||
|
||||
|
||||
@tf_export("train.get_checkpoint_state")
|
||||
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
|
||||
"""Returns CheckpointState proto from the "checkpoint" file.
|
||||
|
||||
If the "checkpoint" file contains a valid CheckpointState
|
||||
proto, returns it.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: The directory of checkpoints.
|
||||
latest_filename: Optional name of the checkpoint file. Default to
|
||||
'checkpoint'.
|
||||
|
||||
Returns:
|
||||
A CheckpointState if the state was available, None
|
||||
otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
|
||||
"""
|
||||
ckpt = None
|
||||
coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
|
||||
latest_filename)
|
||||
f = None
|
||||
try:
|
||||
# Check that the file exists before opening it to avoid
|
||||
# many lines of errors from colossus in the logs.
|
||||
if file_io.file_exists(coord_checkpoint_filename):
|
||||
file_content = file_io.read_file_to_string(
|
||||
coord_checkpoint_filename)
|
||||
ckpt = CheckpointState()
|
||||
text_format.Merge(file_content, ckpt)
|
||||
if not ckpt.model_checkpoint_path:
|
||||
raise ValueError("Invalid checkpoint state loaded from "
|
||||
+ checkpoint_dir)
|
||||
# For relative model_checkpoint_path and all_model_checkpoint_paths,
|
||||
# prepend checkpoint_dir.
|
||||
if not os.path.isabs(ckpt.model_checkpoint_path):
|
||||
ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
|
||||
ckpt.model_checkpoint_path)
|
||||
for i in range(len(ckpt.all_model_checkpoint_paths)):
|
||||
p = ckpt.all_model_checkpoint_paths[i]
|
||||
if not os.path.isabs(p):
|
||||
ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
|
||||
except errors.OpError as e:
|
||||
# It's ok if the file cannot be read
|
||||
logging.warning("%s: %s", type(e).__name__, e)
|
||||
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
|
||||
return None
|
||||
except text_format.ParseError as e:
|
||||
logging.warning("%s: %s", type(e).__name__, e)
|
||||
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
|
||||
return None
|
||||
finally:
|
||||
if f:
|
||||
f.close()
|
||||
return ckpt
|
||||
|
||||
|
||||
def _prefix_to_checkpoint_path(prefix, format_version):
|
||||
"""Returns the pathname of a checkpoint file, given the checkpoint prefix.
|
||||
|
||||
For V1 checkpoint, simply returns the prefix itself (the data file). For V2,
|
||||
returns the pathname to the index file.
|
||||
|
||||
Args:
|
||||
prefix: a string, the prefix of a checkpoint.
|
||||
format_version: the checkpoint format version that corresponds to the
|
||||
prefix.
|
||||
Returns:
|
||||
The pathname of a checkpoint file, taking into account the checkpoint
|
||||
format version.
|
||||
"""
|
||||
if format_version == saver_pb2.SaverDef.V2:
|
||||
return prefix + ".index" # The index file identifies a checkpoint.
|
||||
return prefix # Just the data file.
|
||||
|
||||
|
||||
@tf_export("train.latest_checkpoint")
|
||||
def latest_checkpoint(checkpoint_dir, latest_filename=None):
|
||||
"""Finds the filename of latest saved checkpoint file.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Directory where the variables were saved.
|
||||
latest_filename: Optional name for the protocol buffer file that
|
||||
contains the list of most recent checkpoint filenames.
|
||||
See the corresponding argument to `Saver.save()`.
|
||||
|
||||
Returns:
|
||||
The full path to the latest checkpoint or `None` if no checkpoint was found.
|
||||
"""
|
||||
# Pick the latest checkpoint based on checkpoint state.
|
||||
ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
|
||||
if ckpt and ckpt.model_checkpoint_path:
|
||||
# Look for either a V2 path or a V1 path, with priority for V2.
|
||||
v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
|
||||
saver_pb2.SaverDef.V2)
|
||||
v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
|
||||
saver_pb2.SaverDef.V1)
|
||||
if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
|
||||
v1_path):
|
||||
return ckpt.model_checkpoint_path
|
||||
else:
|
||||
logging.error("Couldn't match files for checkpoint %s",
|
||||
ckpt.model_checkpoint_path)
|
||||
return None
|
||||
|
||||
|
||||
@tf_export("train.checkpoint_exists")
|
||||
def checkpoint_exists(checkpoint_prefix):
|
||||
"""Checks whether a V1 or V2 checkpoint exists with the specified prefix.
|
||||
|
||||
This is the recommended way to check if a checkpoint exists, since it takes
|
||||
into account the naming difference between V1 and V2 formats.
|
||||
|
||||
Args:
|
||||
checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
|
||||
priority. Typically the result of `Saver.save()` or that of
|
||||
`tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
|
||||
V1/V2.
|
||||
Returns:
|
||||
A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
|
||||
"""
|
||||
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
|
||||
saver_pb2.SaverDef.V2)
|
||||
if file_io.get_matching_files(pathname):
|
||||
return True
|
||||
elif file_io.get_matching_files(checkpoint_prefix):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
@tf_export("train.get_checkpoint_mtimes")
|
||||
def get_checkpoint_mtimes(checkpoint_prefixes):
|
||||
"""Returns the mtimes (modification timestamps) of the checkpoints.
|
||||
|
||||
Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files
|
||||
exist, collect their mtime. Both V2 and V1 checkpoints are considered, in
|
||||
that priority.
|
||||
|
||||
This is the recommended way to get the mtimes, since it takes into account
|
||||
the naming difference between V1 and V2 formats.
|
||||
|
||||
Args:
|
||||
checkpoint_prefixes: a list of checkpoint paths, typically the results of
|
||||
`Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
|
||||
sharded/non-sharded or V1/V2.
|
||||
Returns:
|
||||
A list of mtimes (in microseconds) of the found checkpoints.
|
||||
"""
|
||||
mtimes = []
|
||||
|
||||
def match_maybe_append(pathname):
|
||||
fnames = file_io.get_matching_files(pathname)
|
||||
if fnames:
|
||||
mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
|
||||
return True
|
||||
return False
|
||||
|
||||
for checkpoint_prefix in checkpoint_prefixes:
|
||||
# Tries V2's metadata file first.
|
||||
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
|
||||
saver_pb2.SaverDef.V2)
|
||||
if match_maybe_append(pathname):
|
||||
continue
|
||||
# Otherwise, tries V1, where the prefix is the complete pathname.
|
||||
match_maybe_append(checkpoint_prefix)
|
||||
|
||||
return mtimes
|
||||
|
||||
|
||||
@tf_export("train.remove_checkpoint")
|
||||
def remove_checkpoint(checkpoint_prefix,
|
||||
checkpoint_format_version=saver_pb2.SaverDef.V2,
|
||||
meta_graph_suffix="meta"):
|
||||
"""Removes a checkpoint given by `checkpoint_prefix`.
|
||||
|
||||
Args:
|
||||
checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
|
||||
of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
|
||||
sharded/non-sharded or V1/V2.
|
||||
checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
|
||||
`SaverDef.V2`.
|
||||
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
|
||||
"""
|
||||
_delete_file_if_exists(
|
||||
meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
|
||||
if checkpoint_format_version == saver_pb2.SaverDef.V2:
|
||||
# V2 has a metadata file and some data files.
|
||||
_delete_file_if_exists(checkpoint_prefix + ".index")
|
||||
_delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
|
||||
else:
|
||||
# V1, Legacy. Exact match on the data file.
|
||||
_delete_file_if_exists(checkpoint_prefix)
|
||||
|
||||
|
||||
def _delete_file_if_exists(filespec):
|
||||
"""Deletes files matching `filespec`."""
|
||||
for pathname in file_io.get_matching_files(filespec):
|
||||
file_io.delete_file(pathname)
|
||||
|
||||
|
||||
def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
|
||||
"""Returns the meta graph filename.
|
||||
|
||||
Args:
|
||||
checkpoint_filename: Name of the checkpoint file.
|
||||
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
|
||||
|
||||
Returns:
|
||||
MetaGraph file name.
|
||||
"""
|
||||
# If the checkpoint_filename is sharded, the checkpoint_filename could
|
||||
# be of format model.ckpt-step#-?????-of-shard#. For example,
|
||||
# model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
|
||||
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
|
||||
suffixed_filename = ".".join([basename, meta_graph_suffix])
|
||||
return suffixed_filename
|
316
tensorflow/python/training/checkpoint_management_test.py
Normal file
316
tensorflow/python/training/checkpoint_management_test.py
Normal file
@ -0,0 +1,316 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Tests for tensorflow.python.training.saver.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.framework import ops as ops_lib
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_module
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
|
||||
|
||||
class LatestCheckpointWithRelativePaths(test.TestCase):
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def tempWorkingDir(temppath):
|
||||
cwd = os.getcwd()
|
||||
os.chdir(temppath)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def tempDir():
|
||||
tempdir = tempfile.mkdtemp()
|
||||
try:
|
||||
yield tempdir
|
||||
finally:
|
||||
shutil.rmtree(tempdir)
|
||||
|
||||
def testNameCollision(self):
|
||||
# Make sure we have a clean directory to work in.
|
||||
with self.tempDir() as tempdir:
|
||||
# Jump to that directory until this test is done.
|
||||
with self.tempWorkingDir(tempdir):
|
||||
# Save training snapshots to a relative path.
|
||||
traindir = "train/"
|
||||
os.mkdir(traindir)
|
||||
# Collides with the default name of the checkpoint state file.
|
||||
filepath = os.path.join(traindir, "checkpoint")
|
||||
|
||||
with self.test_session() as sess:
|
||||
unused_a = variables.Variable(0.0) # So that Saver saves something.
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Should fail.
|
||||
saver = saver_module.Saver(sharded=False)
|
||||
with self.assertRaisesRegexp(ValueError, "collides with"):
|
||||
saver.save(sess, filepath)
|
||||
|
||||
# Succeeds: the file will be named "checkpoint-<step>".
|
||||
saver.save(sess, filepath, global_step=1)
|
||||
self.assertIsNotNone(
|
||||
checkpoint_management.latest_checkpoint(traindir))
|
||||
|
||||
# Succeeds: the file will be named "checkpoint-<i>-of-<n>".
|
||||
saver = saver_module.Saver(sharded=True)
|
||||
saver.save(sess, filepath)
|
||||
self.assertIsNotNone(
|
||||
checkpoint_management.latest_checkpoint(traindir))
|
||||
|
||||
# Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
|
||||
saver = saver_module.Saver(sharded=True)
|
||||
saver.save(sess, filepath, global_step=1)
|
||||
self.assertIsNotNone(
|
||||
checkpoint_management.latest_checkpoint(traindir))
|
||||
|
||||
def testRelativePath(self):
|
||||
# Make sure we have a clean directory to work in.
|
||||
with self.tempDir() as tempdir:
|
||||
|
||||
# Jump to that directory until this test is done.
|
||||
with self.tempWorkingDir(tempdir):
|
||||
|
||||
# Save training snapshots to a relative path.
|
||||
traindir = "train/"
|
||||
os.mkdir(traindir)
|
||||
|
||||
filename = "snapshot"
|
||||
filepath = os.path.join(traindir, filename)
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Build a simple graph.
|
||||
v0 = variables.Variable(0.0)
|
||||
inc = v0.assign_add(1.0)
|
||||
|
||||
save = saver_module.Saver({"v0": v0})
|
||||
|
||||
# Record a short training history.
|
||||
variables.global_variables_initializer().run()
|
||||
save.save(sess, filepath, global_step=0)
|
||||
inc.eval()
|
||||
save.save(sess, filepath, global_step=1)
|
||||
inc.eval()
|
||||
save.save(sess, filepath, global_step=2)
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Build a new graph with different initialization.
|
||||
v0 = variables.Variable(-1.0)
|
||||
|
||||
# Create a new saver.
|
||||
save = saver_module.Saver({"v0": v0})
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Get the most recent checkpoint name from the training history file.
|
||||
name = checkpoint_management.latest_checkpoint(traindir)
|
||||
self.assertIsNotNone(name)
|
||||
|
||||
# Restore "v0" from that checkpoint.
|
||||
save.restore(sess, name)
|
||||
self.assertEqual(v0.eval(), 2.0)
|
||||
|
||||
|
||||
class CheckpointStateTest(test.TestCase):
|
||||
|
||||
def _get_test_dir(self, dirname):
|
||||
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||
gfile.MakeDirs(test_dir)
|
||||
return test_dir
|
||||
|
||||
def testAbsPath(self):
|
||||
save_dir = self._get_test_dir("abs_paths")
|
||||
abs_path = os.path.join(save_dir, "model-0")
|
||||
ckpt = checkpoint_management.generate_checkpoint_state_proto(
|
||||
save_dir, abs_path)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
|
||||
self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
|
||||
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
|
||||
|
||||
def testRelPath(self):
|
||||
train_dir = "train"
|
||||
model = os.path.join(train_dir, "model-0")
|
||||
# model_checkpoint_path should have no "train" directory part.
|
||||
new_rel_path = "model-0"
|
||||
ckpt = checkpoint_management.generate_checkpoint_state_proto(
|
||||
train_dir, model)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, new_rel_path)
|
||||
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
|
||||
|
||||
def testAllModelCheckpointPaths(self):
|
||||
save_dir = self._get_test_dir("all_models_test")
|
||||
abs_path = os.path.join(save_dir, "model-0")
|
||||
for paths in [None, [], ["model-2"]]:
|
||||
ckpt = checkpoint_management.generate_checkpoint_state_proto(
|
||||
save_dir, abs_path, all_model_checkpoint_paths=paths)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
|
||||
self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
|
||||
self.assertEqual(
|
||||
len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
|
||||
|
||||
def testUpdateCheckpointState(self):
|
||||
save_dir = self._get_test_dir("update_checkpoint_state")
|
||||
os.chdir(save_dir)
|
||||
# Make a temporary train directory.
|
||||
train_dir = "train"
|
||||
os.mkdir(train_dir)
|
||||
abs_path = os.path.join(save_dir, "model-0")
|
||||
rel_path = os.path.join("train", "model-2")
|
||||
checkpoint_management.update_checkpoint_state(
|
||||
train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])
|
||||
ckpt = checkpoint_management.get_checkpoint_state(train_dir)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, rel_path)
|
||||
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
|
||||
|
||||
def testUpdateCheckpointStateSaveRelativePaths(self):
|
||||
save_dir = self._get_test_dir("update_checkpoint_state")
|
||||
os.chdir(save_dir)
|
||||
abs_path2 = os.path.join(save_dir, "model-2")
|
||||
rel_path2 = "model-2"
|
||||
abs_path0 = os.path.join(save_dir, "model-0")
|
||||
rel_path0 = "model-0"
|
||||
checkpoint_management.update_checkpoint_state_internal(
|
||||
save_dir=save_dir,
|
||||
model_checkpoint_path=abs_path2,
|
||||
all_model_checkpoint_paths=[rel_path0, abs_path2],
|
||||
save_relative_paths=True)
|
||||
|
||||
# File should contain relative paths.
|
||||
file_content = file_io.read_file_to_string(
|
||||
os.path.join(save_dir, "checkpoint"))
|
||||
ckpt = CheckpointState()
|
||||
text_format.Merge(file_content, ckpt)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, rel_path2)
|
||||
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)
|
||||
|
||||
# get_checkpoint_state should return absolute paths.
|
||||
ckpt = checkpoint_management.get_checkpoint_state(save_dir)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, abs_path2)
|
||||
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
|
||||
|
||||
def testCheckPointStateFailsWhenIncomplete(self):
|
||||
save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
|
||||
os.chdir(save_dir)
|
||||
ckpt_path = os.path.join(save_dir, "checkpoint")
|
||||
ckpt_file = open(ckpt_path, "w")
|
||||
ckpt_file.write("")
|
||||
ckpt_file.close()
|
||||
with self.assertRaises(ValueError):
|
||||
checkpoint_management.get_checkpoint_state(save_dir)
|
||||
|
||||
def testCheckPointCompletesRelativePaths(self):
|
||||
save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
|
||||
os.chdir(save_dir)
|
||||
ckpt_path = os.path.join(save_dir, "checkpoint")
|
||||
ckpt_file = open(ckpt_path, "w")
|
||||
ckpt_file.write("""
|
||||
model_checkpoint_path: "./model.ckpt-687529"
|
||||
all_model_checkpoint_paths: "./model.ckpt-687500"
|
||||
all_model_checkpoint_paths: "./model.ckpt-687529"
|
||||
""")
|
||||
ckpt_file.close()
|
||||
ckpt = checkpoint_management.get_checkpoint_state(save_dir)
|
||||
self.assertEqual(ckpt.model_checkpoint_path,
|
||||
os.path.join(save_dir, "./model.ckpt-687529"))
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[0],
|
||||
os.path.join(save_dir, "./model.ckpt-687500"))
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[1],
|
||||
os.path.join(save_dir, "./model.ckpt-687529"))
|
||||
|
||||
|
||||
class SaverUtilsTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test")
|
||||
gfile.MakeDirs(self._base_dir)
|
||||
|
||||
def tearDown(self):
|
||||
gfile.DeleteRecursively(self._base_dir)
|
||||
|
||||
def testCheckpointExists(self):
|
||||
for sharded in (False, True):
|
||||
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
|
||||
with self.test_session(graph=ops_lib.Graph()) as sess:
|
||||
unused_v = variables.Variable(1.0, name="v")
|
||||
variables.global_variables_initializer().run()
|
||||
saver = saver_module.Saver(sharded=sharded, write_version=version)
|
||||
|
||||
path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
|
||||
self.assertFalse(
|
||||
checkpoint_management.checkpoint_exists(path)) # Not saved yet.
|
||||
|
||||
ckpt_prefix = saver.save(sess, path)
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
|
||||
|
||||
ckpt_prefix = checkpoint_management.latest_checkpoint(self._base_dir)
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
|
||||
|
||||
def testGetCheckpointMtimes(self):
|
||||
prefixes = []
|
||||
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
|
||||
with self.test_session(graph=ops_lib.Graph()) as sess:
|
||||
unused_v = variables.Variable(1.0, name="v")
|
||||
variables.global_variables_initializer().run()
|
||||
saver = saver_module.Saver(write_version=version)
|
||||
prefixes.append(
|
||||
saver.save(sess, os.path.join(self._base_dir, str(version))))
|
||||
|
||||
mtimes = checkpoint_management.get_checkpoint_mtimes(prefixes)
|
||||
self.assertEqual(2, len(mtimes))
|
||||
self.assertTrue(mtimes[1] >= mtimes[0])
|
||||
|
||||
def testRemoveCheckpoint(self):
|
||||
for sharded in (False, True):
|
||||
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
|
||||
with self.test_session(graph=ops_lib.Graph()) as sess:
|
||||
unused_v = variables.Variable(1.0, name="v")
|
||||
variables.global_variables_initializer().run()
|
||||
saver = saver_module.Saver(sharded=sharded, write_version=version)
|
||||
|
||||
path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
|
||||
ckpt_prefix = saver.save(sess, path)
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
|
||||
checkpoint_management.remove_checkpoint(ckpt_prefix, version)
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -28,6 +28,7 @@ from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import distribute as distribute_lib
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -277,7 +278,7 @@ def _init_from_checkpoint(_, ckpt_dir_or_file, assignment_map):
|
||||
def _get_checkpoint_filename(ckpt_dir_or_file):
|
||||
"""Returns checkpoint filename given directory or specific checkpoint file."""
|
||||
if gfile.IsDirectory(ckpt_dir_or_file):
|
||||
return saver.latest_checkpoint(ckpt_dir_or_file)
|
||||
return checkpoint_management.latest_checkpoint(ckpt_dir_or_file)
|
||||
return ckpt_dir_or_file
|
||||
|
||||
|
||||
|
@ -124,14 +124,18 @@ py_test(
|
||||
],
|
||||
deps = [
|
||||
":base",
|
||||
":tracking",
|
||||
":util",
|
||||
"//tensorflow/python:checkpoint_management",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:saver",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:template",
|
||||
|
@ -42,6 +42,7 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import template
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.checkpointable import base
|
||||
@ -467,7 +468,8 @@ class CheckpointingTests(test.TestCase):
|
||||
root = checkpointable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model,
|
||||
optimizer_step=training_util.get_or_create_global_step())
|
||||
root.restore(saver_lib.latest_checkpoint(checkpoint_directory))
|
||||
root.restore(checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory))
|
||||
for _ in range(num_training_steps):
|
||||
# TODO(allenl): Use a Dataset and serialize/checkpoint it.
|
||||
input_value = constant_op.constant([[3.]])
|
||||
@ -495,7 +497,8 @@ class CheckpointingTests(test.TestCase):
|
||||
train_op = optimizer.minimize(
|
||||
model(input_value),
|
||||
global_step=root.global_step)
|
||||
checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)
|
||||
with self.test_session(graph=ops.get_default_graph()) as session:
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
status.initialize_or_restore(session=session)
|
||||
@ -528,7 +531,8 @@ class CheckpointingTests(test.TestCase):
|
||||
root = checkpointable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model,
|
||||
global_step=training_util.get_or_create_global_step())
|
||||
checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
train_fn = functools.partial(
|
||||
@ -561,7 +565,8 @@ class CheckpointingTests(test.TestCase):
|
||||
root = checkpointable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model,
|
||||
global_step=training_util.get_or_create_global_step())
|
||||
checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
def train_fn():
|
||||
@function.defun
|
||||
@ -1180,7 +1185,8 @@ class CheckpointingTests(test.TestCase):
|
||||
optimizer_checkpoint = checkpointable_utils.Checkpoint(
|
||||
optimizer=optimizer)
|
||||
|
||||
checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
train_fn = functools.partial(
|
||||
|
@ -44,6 +44,7 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
@ -1364,8 +1365,8 @@ class MonitoredSessionTest(test.TestCase):
|
||||
with monitored_session.MonitoredSession(
|
||||
session_creator=monitored_session.ChiefSessionCreator(
|
||||
scaffold,
|
||||
checkpoint_filename_with_path=saver_lib.latest_checkpoint(
|
||||
logdir))) as session:
|
||||
checkpoint_filename_with_path=
|
||||
checkpoint_management.latest_checkpoint(logdir))) as session:
|
||||
self.assertEqual(2, session.run(gstep))
|
||||
|
||||
def test_retry_initialization_on_aborted_error(self):
|
||||
|
@ -21,15 +21,12 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os.path
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.core.protobuf import checkpointable_object_graph_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
@ -41,7 +38,6 @@ from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_io_ops
|
||||
@ -52,14 +48,19 @@ from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saveable_object
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# TODO(allenl): Remove these aliases once all users are migrated off.
|
||||
get_checkpoint_state = checkpoint_management.get_checkpoint_state
|
||||
update_checkpoint_state = checkpoint_management.update_checkpoint_state
|
||||
|
||||
|
||||
# Op names which identify variable reads which should be saved.
|
||||
_VARIABLE_OPS = set(["Variable",
|
||||
"VariableV2",
|
||||
@ -858,218 +859,6 @@ def _get_saver_or_default():
|
||||
return saver
|
||||
|
||||
|
||||
def _GetCheckpointFilename(save_dir, latest_filename):
|
||||
"""Returns a filename for storing the CheckpointState.
|
||||
|
||||
Args:
|
||||
save_dir: The directory for saving and restoring checkpoints.
|
||||
latest_filename: Name of the file in 'save_dir' that is used
|
||||
to store the CheckpointState.
|
||||
|
||||
Returns:
|
||||
The path of the file that contains the CheckpointState proto.
|
||||
"""
|
||||
if latest_filename is None:
|
||||
latest_filename = "checkpoint"
|
||||
return os.path.join(save_dir, latest_filename)
|
||||
|
||||
|
||||
@tf_export("train.generate_checkpoint_state_proto")
|
||||
def generate_checkpoint_state_proto(save_dir,
|
||||
model_checkpoint_path,
|
||||
all_model_checkpoint_paths=None):
|
||||
"""Generates a checkpoint state proto.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the model was saved.
|
||||
model_checkpoint_path: The checkpoint file.
|
||||
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
|
||||
checkpoints, sorted from oldest to newest. If this is a non-empty list,
|
||||
the last element must be equal to model_checkpoint_path. These paths
|
||||
are also saved in the CheckpointState proto.
|
||||
|
||||
Returns:
|
||||
CheckpointState proto with model_checkpoint_path and
|
||||
all_model_checkpoint_paths updated to either absolute paths or
|
||||
relative paths to the current save_dir.
|
||||
"""
|
||||
if all_model_checkpoint_paths is None:
|
||||
all_model_checkpoint_paths = []
|
||||
|
||||
if (not all_model_checkpoint_paths or
|
||||
all_model_checkpoint_paths[-1] != model_checkpoint_path):
|
||||
logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
|
||||
model_checkpoint_path)
|
||||
all_model_checkpoint_paths.append(model_checkpoint_path)
|
||||
|
||||
# Relative paths need to be rewritten to be relative to the "save_dir"
|
||||
# if model_checkpoint_path already contains "save_dir".
|
||||
if not os.path.isabs(save_dir):
|
||||
if not os.path.isabs(model_checkpoint_path):
|
||||
model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
|
||||
for i in range(len(all_model_checkpoint_paths)):
|
||||
p = all_model_checkpoint_paths[i]
|
||||
if not os.path.isabs(p):
|
||||
all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
|
||||
|
||||
coord_checkpoint_proto = CheckpointState(
|
||||
model_checkpoint_path=model_checkpoint_path,
|
||||
all_model_checkpoint_paths=all_model_checkpoint_paths)
|
||||
|
||||
return coord_checkpoint_proto
|
||||
|
||||
|
||||
@tf_export("train.update_checkpoint_state")
|
||||
def update_checkpoint_state(save_dir,
|
||||
model_checkpoint_path,
|
||||
all_model_checkpoint_paths=None,
|
||||
latest_filename=None):
|
||||
"""Updates the content of the 'checkpoint' file.
|
||||
|
||||
This updates the checkpoint file containing a CheckpointState
|
||||
proto.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the model was saved.
|
||||
model_checkpoint_path: The checkpoint file.
|
||||
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
|
||||
checkpoints, sorted from oldest to newest. If this is a non-empty list,
|
||||
the last element must be equal to model_checkpoint_path. These paths
|
||||
are also saved in the CheckpointState proto.
|
||||
latest_filename: Optional name of the checkpoint file. Default to
|
||||
'checkpoint'.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If any of the model checkpoint paths conflict with the file
|
||||
containing CheckpointSate.
|
||||
"""
|
||||
_update_checkpoint_state(
|
||||
save_dir=save_dir,
|
||||
model_checkpoint_path=model_checkpoint_path,
|
||||
all_model_checkpoint_paths=all_model_checkpoint_paths,
|
||||
latest_filename=latest_filename,
|
||||
save_relative_paths=False)
|
||||
|
||||
|
||||
def _update_checkpoint_state(save_dir,
|
||||
model_checkpoint_path,
|
||||
all_model_checkpoint_paths=None,
|
||||
latest_filename=None,
|
||||
save_relative_paths=False):
|
||||
"""Updates the content of the 'checkpoint' file.
|
||||
|
||||
This updates the checkpoint file containing a CheckpointState
|
||||
proto.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the model was saved.
|
||||
model_checkpoint_path: The checkpoint file.
|
||||
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
|
||||
checkpoints, sorted from oldest to newest. If this is a non-empty list,
|
||||
the last element must be equal to model_checkpoint_path. These paths
|
||||
are also saved in the CheckpointState proto.
|
||||
latest_filename: Optional name of the checkpoint file. Default to
|
||||
'checkpoint'.
|
||||
save_relative_paths: If `True`, will write relative paths to the checkpoint
|
||||
state file.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If any of the model checkpoint paths conflict with the file
|
||||
containing CheckpointSate.
|
||||
"""
|
||||
# Writes the "checkpoint" file for the coordinator for later restoration.
|
||||
coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
|
||||
if save_relative_paths:
|
||||
if os.path.isabs(model_checkpoint_path):
|
||||
rel_model_checkpoint_path = os.path.relpath(
|
||||
model_checkpoint_path, save_dir)
|
||||
else:
|
||||
rel_model_checkpoint_path = model_checkpoint_path
|
||||
rel_all_model_checkpoint_paths = []
|
||||
for p in all_model_checkpoint_paths:
|
||||
if os.path.isabs(p):
|
||||
rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
|
||||
else:
|
||||
rel_all_model_checkpoint_paths.append(p)
|
||||
ckpt = generate_checkpoint_state_proto(
|
||||
save_dir,
|
||||
rel_model_checkpoint_path,
|
||||
all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
|
||||
else:
|
||||
ckpt = generate_checkpoint_state_proto(
|
||||
save_dir,
|
||||
model_checkpoint_path,
|
||||
all_model_checkpoint_paths=all_model_checkpoint_paths)
|
||||
|
||||
if coord_checkpoint_filename == ckpt.model_checkpoint_path:
|
||||
raise RuntimeError("Save path '%s' conflicts with path used for "
|
||||
"checkpoint state. Please use a different save path." %
|
||||
model_checkpoint_path)
|
||||
|
||||
# Preventing potential read/write race condition by *atomically* writing to a
|
||||
# file.
|
||||
file_io.atomic_write_string_to_file(coord_checkpoint_filename,
|
||||
text_format.MessageToString(ckpt))
|
||||
|
||||
|
||||
@tf_export("train.get_checkpoint_state")
|
||||
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
|
||||
"""Returns CheckpointState proto from the "checkpoint" file.
|
||||
|
||||
If the "checkpoint" file contains a valid CheckpointState
|
||||
proto, returns it.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: The directory of checkpoints.
|
||||
latest_filename: Optional name of the checkpoint file. Default to
|
||||
'checkpoint'.
|
||||
|
||||
Returns:
|
||||
A CheckpointState if the state was available, None
|
||||
otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
|
||||
"""
|
||||
ckpt = None
|
||||
coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
|
||||
latest_filename)
|
||||
f = None
|
||||
try:
|
||||
# Check that the file exists before opening it to avoid
|
||||
# many lines of errors from colossus in the logs.
|
||||
if file_io.file_exists(coord_checkpoint_filename):
|
||||
file_content = file_io.read_file_to_string(
|
||||
coord_checkpoint_filename)
|
||||
ckpt = CheckpointState()
|
||||
text_format.Merge(file_content, ckpt)
|
||||
if not ckpt.model_checkpoint_path:
|
||||
raise ValueError("Invalid checkpoint state loaded from "
|
||||
+ checkpoint_dir)
|
||||
# For relative model_checkpoint_path and all_model_checkpoint_paths,
|
||||
# prepend checkpoint_dir.
|
||||
if not os.path.isabs(ckpt.model_checkpoint_path):
|
||||
ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
|
||||
ckpt.model_checkpoint_path)
|
||||
for i in range(len(ckpt.all_model_checkpoint_paths)):
|
||||
p = ckpt.all_model_checkpoint_paths[i]
|
||||
if not os.path.isabs(p):
|
||||
ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
|
||||
except errors.OpError as e:
|
||||
# It's ok if the file cannot be read
|
||||
logging.warning("%s: %s", type(e).__name__, e)
|
||||
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
|
||||
return None
|
||||
except text_format.ParseError as e:
|
||||
logging.warning("%s: %s", type(e).__name__, e)
|
||||
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
|
||||
return None
|
||||
finally:
|
||||
if f:
|
||||
f.close()
|
||||
return ckpt
|
||||
|
||||
|
||||
@tf_export("train.Saver")
|
||||
class Saver(object):
|
||||
"""Saves and restores variables.
|
||||
@ -1412,7 +1201,7 @@ class Saver(object):
|
||||
|
||||
# Otherwise delete the files.
|
||||
try:
|
||||
remove_checkpoint(
|
||||
checkpoint_management.remove_checkpoint(
|
||||
self._CheckpointFilename(p), self.saver_def.version,
|
||||
meta_graph_suffix)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
@ -1518,7 +1307,7 @@ class Saver(object):
|
||||
Args:
|
||||
checkpoint_paths: a list of checkpoint paths.
|
||||
"""
|
||||
mtimes = get_checkpoint_mtimes(checkpoint_paths)
|
||||
mtimes = checkpoint_management.get_checkpoint_mtimes(checkpoint_paths)
|
||||
self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes)))
|
||||
|
||||
def save(self,
|
||||
@ -1624,7 +1413,7 @@ class Saver(object):
|
||||
model_checkpoint_path = compat.as_str(model_checkpoint_path)
|
||||
if write_state:
|
||||
self._RecordLastCheckpoint(model_checkpoint_path)
|
||||
_update_checkpoint_state(
|
||||
checkpoint_management.update_checkpoint_state_internal(
|
||||
save_dir=save_path_parent,
|
||||
model_checkpoint_path=model_checkpoint_path,
|
||||
all_model_checkpoint_paths=self.last_checkpoints,
|
||||
@ -1639,7 +1428,7 @@ class Saver(object):
|
||||
raise exc
|
||||
|
||||
if write_meta_graph:
|
||||
meta_graph_filename = _meta_graph_filename(
|
||||
meta_graph_filename = checkpoint_management.meta_graph_filename(
|
||||
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
|
||||
if not context.executing_eagerly():
|
||||
with sess.graph.as_default():
|
||||
@ -1714,7 +1503,7 @@ class Saver(object):
|
||||
if save_path is None:
|
||||
raise ValueError("Can't load save_path when it is None.")
|
||||
|
||||
if not checkpoint_exists(compat.as_text(save_path)):
|
||||
if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)):
|
||||
raise ValueError("The passed save_path is not a valid checkpoint: "
|
||||
+ compat.as_text(save_path))
|
||||
|
||||
@ -1800,55 +1589,6 @@ class Saver(object):
|
||||
export_scope=export_scope)
|
||||
|
||||
|
||||
def _prefix_to_checkpoint_path(prefix, format_version):
|
||||
"""Returns the pathname of a checkpoint file, given the checkpoint prefix.
|
||||
|
||||
For V1 checkpoint, simply returns the prefix itself (the data file). For V2,
|
||||
returns the pathname to the index file.
|
||||
|
||||
Args:
|
||||
prefix: a string, the prefix of a checkpoint.
|
||||
format_version: the checkpoint format version that corresponds to the
|
||||
prefix.
|
||||
Returns:
|
||||
The pathname of a checkpoint file, taking into account the checkpoint
|
||||
format version.
|
||||
"""
|
||||
if format_version == saver_pb2.SaverDef.V2:
|
||||
return prefix + ".index" # The index file identifies a checkpoint.
|
||||
return prefix # Just the data file.
|
||||
|
||||
|
||||
@tf_export("train.latest_checkpoint")
|
||||
def latest_checkpoint(checkpoint_dir, latest_filename=None):
|
||||
"""Finds the filename of latest saved checkpoint file.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Directory where the variables were saved.
|
||||
latest_filename: Optional name for the protocol buffer file that
|
||||
contains the list of most recent checkpoint filenames.
|
||||
See the corresponding argument to `Saver.save()`.
|
||||
|
||||
Returns:
|
||||
The full path to the latest checkpoint or `None` if no checkpoint was found.
|
||||
"""
|
||||
# Pick the latest checkpoint based on checkpoint state.
|
||||
ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
|
||||
if ckpt and ckpt.model_checkpoint_path:
|
||||
# Look for either a V2 path or a V1 path, with priority for V2.
|
||||
v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
|
||||
saver_pb2.SaverDef.V2)
|
||||
v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
|
||||
saver_pb2.SaverDef.V1)
|
||||
if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
|
||||
v1_path):
|
||||
return ckpt.model_checkpoint_path
|
||||
else:
|
||||
logging.error("Couldn't match files for checkpoint %s",
|
||||
ckpt.model_checkpoint_path)
|
||||
return None
|
||||
|
||||
|
||||
@tf_export("train.import_meta_graph")
|
||||
def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
||||
import_scope=None, **kwargs):
|
||||
@ -2056,119 +1796,6 @@ def export_meta_graph(filename=None,
|
||||
return meta_graph_def
|
||||
|
||||
|
||||
@tf_export("train.checkpoint_exists")
|
||||
def checkpoint_exists(checkpoint_prefix):
|
||||
"""Checks whether a V1 or V2 checkpoint exists with the specified prefix.
|
||||
|
||||
This is the recommended way to check if a checkpoint exists, since it takes
|
||||
into account the naming difference between V1 and V2 formats.
|
||||
|
||||
Args:
|
||||
checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
|
||||
priority. Typically the result of `Saver.save()` or that of
|
||||
`tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
|
||||
V1/V2.
|
||||
Returns:
|
||||
A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
|
||||
"""
|
||||
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
|
||||
saver_pb2.SaverDef.V2)
|
||||
if file_io.get_matching_files(pathname):
|
||||
return True
|
||||
elif file_io.get_matching_files(checkpoint_prefix):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
@tf_export("train.get_checkpoint_mtimes")
|
||||
def get_checkpoint_mtimes(checkpoint_prefixes):
|
||||
"""Returns the mtimes (modification timestamps) of the checkpoints.
|
||||
|
||||
Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files
|
||||
exist, collect their mtime. Both V2 and V1 checkpoints are considered, in
|
||||
that priority.
|
||||
|
||||
This is the recommended way to get the mtimes, since it takes into account
|
||||
the naming difference between V1 and V2 formats.
|
||||
|
||||
Args:
|
||||
checkpoint_prefixes: a list of checkpoint paths, typically the results of
|
||||
`Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
|
||||
sharded/non-sharded or V1/V2.
|
||||
Returns:
|
||||
A list of mtimes (in microseconds) of the found checkpoints.
|
||||
"""
|
||||
mtimes = []
|
||||
|
||||
def match_maybe_append(pathname):
|
||||
fnames = file_io.get_matching_files(pathname)
|
||||
if fnames:
|
||||
mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
|
||||
return True
|
||||
return False
|
||||
|
||||
for checkpoint_prefix in checkpoint_prefixes:
|
||||
# Tries V2's metadata file first.
|
||||
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
|
||||
saver_pb2.SaverDef.V2)
|
||||
if match_maybe_append(pathname):
|
||||
continue
|
||||
# Otherwise, tries V1, where the prefix is the complete pathname.
|
||||
match_maybe_append(checkpoint_prefix)
|
||||
|
||||
return mtimes
|
||||
|
||||
|
||||
@tf_export("train.remove_checkpoint")
|
||||
def remove_checkpoint(checkpoint_prefix,
|
||||
checkpoint_format_version=saver_pb2.SaverDef.V2,
|
||||
meta_graph_suffix="meta"):
|
||||
"""Removes a checkpoint given by `checkpoint_prefix`.
|
||||
|
||||
Args:
|
||||
checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
|
||||
of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
|
||||
sharded/non-sharded or V1/V2.
|
||||
checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
|
||||
`SaverDef.V2`.
|
||||
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
|
||||
"""
|
||||
_delete_file_if_exists(
|
||||
_meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
|
||||
if checkpoint_format_version == saver_pb2.SaverDef.V2:
|
||||
# V2 has a metadata file and some data files.
|
||||
_delete_file_if_exists(checkpoint_prefix + ".index")
|
||||
_delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
|
||||
else:
|
||||
# V1, Legacy. Exact match on the data file.
|
||||
_delete_file_if_exists(checkpoint_prefix)
|
||||
|
||||
|
||||
def _delete_file_if_exists(filespec):
|
||||
"""Deletes files matching `filespec`."""
|
||||
for pathname in file_io.get_matching_files(filespec):
|
||||
file_io.delete_file(pathname)
|
||||
|
||||
|
||||
def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
|
||||
"""Returns the meta graph filename.
|
||||
|
||||
Args:
|
||||
checkpoint_filename: Name of the checkpoint file.
|
||||
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
|
||||
|
||||
Returns:
|
||||
MetaGraph file name.
|
||||
"""
|
||||
# If the checkpoint_filename is sharded, the checkpoint_filename could
|
||||
# be of format model.ckpt-step#-?????-of-shard#. For example,
|
||||
# model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
|
||||
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
|
||||
meta_graph_filename = ".".join([basename, meta_graph_suffix])
|
||||
return meta_graph_filename
|
||||
|
||||
|
||||
def _wrap_restore_error_with_msg(err, extra_verbiage):
|
||||
err_msg = ("Restoring from checkpoint failed. This is most likely "
|
||||
"due to {} from the checkpoint. Please ensure that you "
|
||||
|
@ -18,20 +18,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from google.protobuf.any_pb2 import Any
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
@ -71,12 +67,12 @@ from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import queue_runner_impl
|
||||
from tensorflow.python.training import saver as saver_module
|
||||
from tensorflow.python.training import saver_test_utils
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable_base
|
||||
from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
@ -343,11 +339,13 @@ class SaverTest(test.TestCase):
|
||||
self.assertTrue(isinstance(val, six.string_types))
|
||||
self.assertEqual(save_path1, val)
|
||||
|
||||
self.assertEqual(saver_module.latest_checkpoint(save_dir1), save_path1)
|
||||
self.assertEqual(
|
||||
checkpoint_management.latest_checkpoint(save_dir1), save_path1)
|
||||
save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2")
|
||||
os.renames(save_dir1, save_dir2)
|
||||
save_path2 = os.path.join(save_dir2, "save_copy_restore")
|
||||
self.assertEqual(saver_module.latest_checkpoint(save_dir2), save_path2)
|
||||
self.assertEqual(
|
||||
checkpoint_management.latest_checkpoint(save_dir2), save_path2)
|
||||
|
||||
# Start a second session. In that session the parameter nodes
|
||||
# have not been initialized either.
|
||||
@ -857,7 +855,7 @@ class SaveRestoreShardedTest(test.TestCase):
|
||||
self.assertEqual(save_path + "-?????-of-00002", val)
|
||||
else:
|
||||
self.assertEqual(save_path, val)
|
||||
meta_graph_filename = saver_module._meta_graph_filename(val)
|
||||
meta_graph_filename = checkpoint_management.meta_graph_filename(val)
|
||||
self.assertEqual(save_path + ".meta", meta_graph_filename)
|
||||
|
||||
if save._write_version is saver_pb2.SaverDef.V1:
|
||||
@ -951,11 +949,11 @@ class SaveRestoreShardedTest(test.TestCase):
|
||||
|
||||
if save._write_version is saver_pb2.SaverDef.V1:
|
||||
self.assertEqual(
|
||||
saver_module.latest_checkpoint(self.get_temp_dir()),
|
||||
checkpoint_management.latest_checkpoint(self.get_temp_dir()),
|
||||
os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002"))
|
||||
else:
|
||||
self.assertEqual(
|
||||
saver_module.latest_checkpoint(self.get_temp_dir()),
|
||||
checkpoint_management.latest_checkpoint(self.get_temp_dir()),
|
||||
os.path.join(self.get_temp_dir(), "sharded_basics"))
|
||||
|
||||
def testSaverDef(self):
|
||||
@ -1105,7 +1103,7 @@ class MaxToKeepTest(test.TestCase):
|
||||
|
||||
def assertCheckpointState(self, model_checkpoint_path,
|
||||
all_model_checkpoint_paths, save_dir):
|
||||
checkpoint_state = saver_module.get_checkpoint_state(save_dir)
|
||||
checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir)
|
||||
self.assertEqual(checkpoint_state.model_checkpoint_path,
|
||||
model_checkpoint_path)
|
||||
self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
|
||||
@ -1113,7 +1111,7 @@ class MaxToKeepTest(test.TestCase):
|
||||
|
||||
def testMaxToKeepEager(self):
|
||||
with context.eager_mode():
|
||||
save_dir = self._get_test_dir("max_to_keep_non_sharded")
|
||||
save_dir = self._get_test_dir("max_to_keep_eager")
|
||||
|
||||
v = variable_scope.variable(10.0, name="v")
|
||||
save = saver_module.Saver({"v": v}, max_to_keep=2)
|
||||
@ -1123,7 +1121,7 @@ class MaxToKeepTest(test.TestCase):
|
||||
|
||||
s1 = save.save(None, os.path.join(save_dir, "s1"))
|
||||
self.assertEqual([s1], save.last_checkpoints)
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s1,
|
||||
all_model_checkpoint_paths=[s1],
|
||||
@ -1131,8 +1129,8 @@ class MaxToKeepTest(test.TestCase):
|
||||
|
||||
s2 = save.save(None, os.path.join(save_dir, "s2"))
|
||||
self.assertEqual([s1, s2], save.last_checkpoints)
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s2,
|
||||
all_model_checkpoint_paths=[s1, s2],
|
||||
@ -1140,9 +1138,9 @@ class MaxToKeepTest(test.TestCase):
|
||||
|
||||
s3 = save.save(None, os.path.join(save_dir, "s3"))
|
||||
self.assertEqual([s2, s3], save.last_checkpoints)
|
||||
self.assertFalse(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s3))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s3,
|
||||
all_model_checkpoint_paths=[s2, s3],
|
||||
@ -1157,9 +1155,9 @@ class MaxToKeepTest(test.TestCase):
|
||||
# Adding s2 again (old s2 is removed first, then new s2 appended)
|
||||
s2 = save.save(None, os.path.join(save_dir, "s2"))
|
||||
self.assertEqual([s3, s2], save.last_checkpoints)
|
||||
self.assertFalse(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s3))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s2,
|
||||
all_model_checkpoint_paths=[s3, s2],
|
||||
@ -1168,8 +1166,8 @@ class MaxToKeepTest(test.TestCase):
|
||||
# Adding s1 (s3 should now be deleted as oldest in list)
|
||||
s1 = save.save(None, os.path.join(save_dir, "s1"))
|
||||
self.assertEqual([s2, s1], save.last_checkpoints)
|
||||
self.assertFalse(saver_module.checkpoint_exists(s3))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s1,
|
||||
all_model_checkpoint_paths=[s2, s1],
|
||||
@ -1178,9 +1176,9 @@ class MaxToKeepTest(test.TestCase):
|
||||
s2 = save2.save(None, os.path.join(save_dir, "s2"))
|
||||
self.assertEqual([s3, s2], save2.last_checkpoints)
|
||||
# Created by the first helper.
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||
# Deleted by the first helper.
|
||||
self.assertFalse(saver_module.checkpoint_exists(s3))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
|
||||
|
||||
def testNonSharded(self):
|
||||
save_dir = self._get_test_dir("max_to_keep_non_sharded")
|
||||
@ -1193,7 +1191,7 @@ class MaxToKeepTest(test.TestCase):
|
||||
|
||||
s1 = save.save(sess, os.path.join(save_dir, "s1"))
|
||||
self.assertEqual([s1], save.last_checkpoints)
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s1,
|
||||
all_model_checkpoint_paths=[s1],
|
||||
@ -1201,8 +1199,8 @@ class MaxToKeepTest(test.TestCase):
|
||||
|
||||
s2 = save.save(sess, os.path.join(save_dir, "s2"))
|
||||
self.assertEqual([s1, s2], save.last_checkpoints)
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s2,
|
||||
all_model_checkpoint_paths=[s1, s2],
|
||||
@ -1210,9 +1208,9 @@ class MaxToKeepTest(test.TestCase):
|
||||
|
||||
s3 = save.save(sess, os.path.join(save_dir, "s3"))
|
||||
self.assertEqual([s2, s3], save.last_checkpoints)
|
||||
self.assertFalse(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s3))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s3,
|
||||
all_model_checkpoint_paths=[s2, s3],
|
||||
@ -1231,15 +1229,18 @@ class MaxToKeepTest(test.TestCase):
|
||||
# Adding s2 again (old s2 is removed first, then new s2 appended)
|
||||
s2 = save.save(sess, os.path.join(save_dir, "s2"))
|
||||
self.assertEqual([s3, s2], save.last_checkpoints)
|
||||
self.assertFalse(saver_module.checkpoint_exists(s1))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertFalse(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s3))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s1)))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s3)))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s2)))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s2,
|
||||
all_model_checkpoint_paths=[s3, s2],
|
||||
@ -1248,15 +1249,18 @@ class MaxToKeepTest(test.TestCase):
|
||||
# Adding s1 (s3 should now be deleted as oldest in list)
|
||||
s1 = save.save(sess, os.path.join(save_dir, "s1"))
|
||||
self.assertEqual([s2, s1], save.last_checkpoints)
|
||||
self.assertFalse(saver_module.checkpoint_exists(s3))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
|
||||
self.assertFalse(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s3)))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s2)))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s1)))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s1,
|
||||
all_model_checkpoint_paths=[s2, s1],
|
||||
@ -1268,16 +1272,19 @@ class MaxToKeepTest(test.TestCase):
|
||||
s2 = save2.save(sess, os.path.join(save_dir, "s2"))
|
||||
self.assertEqual([s3, s2], save2.last_checkpoints)
|
||||
# Created by the first helper.
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s1)))
|
||||
# Deleted by the first helper.
|
||||
self.assertFalse(saver_module.checkpoint_exists(s3))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
|
||||
self.assertFalse(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s3)))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s2)))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s2,
|
||||
all_model_checkpoint_paths=[s3, s2],
|
||||
@ -1286,15 +1293,18 @@ class MaxToKeepTest(test.TestCase):
|
||||
# Adding s1 (s3 should now be deleted as oldest in list)
|
||||
s1 = save2.save(sess, os.path.join(save_dir, "s1"))
|
||||
self.assertEqual([s2, s1], save2.last_checkpoints)
|
||||
self.assertFalse(saver_module.checkpoint_exists(s3))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
|
||||
self.assertFalse(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s3)))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s2)))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s1)))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s1,
|
||||
all_model_checkpoint_paths=[s2, s1],
|
||||
@ -1306,16 +1316,19 @@ class MaxToKeepTest(test.TestCase):
|
||||
s2 = save3.save(sess, os.path.join(save_dir, "s2"))
|
||||
self.assertEqual([s2], save3.last_checkpoints)
|
||||
# Created by the first helper.
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s1)))
|
||||
# Deleted by the first helper.
|
||||
self.assertFalse(saver_module.checkpoint_exists(s3))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
|
||||
self.assertFalse(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s3)))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s2)))
|
||||
# Even though the file for s1 exists, this saver isn't aware of it, which
|
||||
# is why it doesn't end up in the checkpoint state.
|
||||
self.assertCheckpointState(
|
||||
@ -1326,15 +1339,18 @@ class MaxToKeepTest(test.TestCase):
|
||||
# Adding s1 (s3 should not be deleted because helper is unaware of it)
|
||||
s1 = save3.save(sess, os.path.join(save_dir, "s1"))
|
||||
self.assertEqual([s2, s1], save3.last_checkpoints)
|
||||
self.assertFalse(saver_module.checkpoint_exists(s3))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
|
||||
self.assertFalse(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s3)))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s2)))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertTrue(
|
||||
saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
|
||||
checkpoint_management.checkpoint_exists(
|
||||
checkpoint_management.meta_graph_filename(s1)))
|
||||
self.assertCheckpointState(
|
||||
model_checkpoint_path=s1,
|
||||
all_model_checkpoint_paths=[s2, s1],
|
||||
@ -1365,7 +1381,8 @@ class MaxToKeepTest(test.TestCase):
|
||||
else:
|
||||
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
|
||||
|
||||
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1)))
|
||||
self.assertTrue(
|
||||
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
|
||||
|
||||
s2 = save.save(sess, os.path.join(save_dir, "s2"))
|
||||
self.assertEqual([s1, s2], save.last_checkpoints)
|
||||
@ -1373,27 +1390,32 @@ class MaxToKeepTest(test.TestCase):
|
||||
self.assertEqual(2, len(gfile.Glob(s1)))
|
||||
else:
|
||||
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
|
||||
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1)))
|
||||
self.assertTrue(
|
||||
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
|
||||
if save._write_version is saver_pb2.SaverDef.V1:
|
||||
self.assertEqual(2, len(gfile.Glob(s2)))
|
||||
else:
|
||||
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
|
||||
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2)))
|
||||
self.assertTrue(
|
||||
gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
|
||||
|
||||
s3 = save.save(sess, os.path.join(save_dir, "s3"))
|
||||
self.assertEqual([s2, s3], save.last_checkpoints)
|
||||
self.assertEqual(0, len(gfile.Glob(s1 + "*")))
|
||||
self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1)))
|
||||
self.assertFalse(
|
||||
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
|
||||
if save._write_version is saver_pb2.SaverDef.V1:
|
||||
self.assertEqual(2, len(gfile.Glob(s2)))
|
||||
else:
|
||||
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
|
||||
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2)))
|
||||
self.assertTrue(
|
||||
gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
|
||||
if save._write_version is saver_pb2.SaverDef.V1:
|
||||
self.assertEqual(2, len(gfile.Glob(s3)))
|
||||
else:
|
||||
self.assertEqual(4, len(gfile.Glob(s3 + "*")))
|
||||
self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s3)))
|
||||
self.assertTrue(
|
||||
gfile.Exists(checkpoint_management.meta_graph_filename(s3)))
|
||||
|
||||
def testNoMaxToKeep(self):
|
||||
save_dir = self._get_test_dir("no_max_to_keep")
|
||||
@ -1408,20 +1430,20 @@ class MaxToKeepTest(test.TestCase):
|
||||
self.assertEqual([], save.last_checkpoints)
|
||||
s1 = save.save(sess, os.path.join(save_dir, "s1"))
|
||||
self.assertEqual([], save.last_checkpoints)
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||
s2 = save.save(sess, os.path.join(save_dir, "s2"))
|
||||
self.assertEqual([], save.last_checkpoints)
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||
|
||||
# Test max_to_keep being 0.
|
||||
save2 = saver_module.Saver({"v": v}, max_to_keep=0)
|
||||
self.assertEqual([], save2.last_checkpoints)
|
||||
s1 = save2.save(sess, os.path.join(save_dir2, "s1"))
|
||||
self.assertEqual([], save2.last_checkpoints)
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||
s2 = save2.save(sess, os.path.join(save_dir2, "s2"))
|
||||
self.assertEqual([], save2.last_checkpoints)
|
||||
self.assertTrue(saver_module.checkpoint_exists(s2))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
|
||||
|
||||
def testNoMetaGraph(self):
|
||||
save_dir = self._get_test_dir("no_meta_graph")
|
||||
@ -1432,8 +1454,9 @@ class MaxToKeepTest(test.TestCase):
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False)
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1)))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertFalse(
|
||||
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
|
||||
|
||||
|
||||
class KeepCheckpointEveryNHoursTest(test.TestCase):
|
||||
@ -1489,10 +1512,10 @@ class KeepCheckpointEveryNHoursTest(test.TestCase):
|
||||
self.assertEqual([s3, s4], save.last_checkpoints)
|
||||
|
||||
# Check that s1 is still here, but s2 is gone.
|
||||
self.assertTrue(saver_module.checkpoint_exists(s1))
|
||||
self.assertFalse(saver_module.checkpoint_exists(s2))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s3))
|
||||
self.assertTrue(saver_module.checkpoint_exists(s4))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
|
||||
self.assertFalse(checkpoint_management.checkpoint_exists(s2))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
|
||||
self.assertTrue(checkpoint_management.checkpoint_exists(s4))
|
||||
|
||||
|
||||
class SaveRestoreWithVariableNameMap(test.TestCase):
|
||||
@ -1571,221 +1594,6 @@ class SaveRestoreWithVariableNameMap(test.TestCase):
|
||||
self._testNonReshape(variables.Variable)
|
||||
|
||||
|
||||
class LatestCheckpointWithRelativePaths(test.TestCase):
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def tempWorkingDir(temppath):
|
||||
cwd = os.getcwd()
|
||||
os.chdir(temppath)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def tempDir():
|
||||
tempdir = tempfile.mkdtemp()
|
||||
try:
|
||||
yield tempdir
|
||||
finally:
|
||||
shutil.rmtree(tempdir)
|
||||
|
||||
def testNameCollision(self):
|
||||
# Make sure we have a clean directory to work in.
|
||||
with self.tempDir() as tempdir:
|
||||
# Jump to that directory until this test is done.
|
||||
with self.tempWorkingDir(tempdir):
|
||||
# Save training snapshots to a relative path.
|
||||
traindir = "train/"
|
||||
os.mkdir(traindir)
|
||||
# Collides with the default name of the checkpoint state file.
|
||||
filepath = os.path.join(traindir, "checkpoint")
|
||||
|
||||
with self.test_session() as sess:
|
||||
unused_a = variables.Variable(0.0) # So that Saver saves something.
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Should fail.
|
||||
saver = saver_module.Saver(sharded=False)
|
||||
with self.assertRaisesRegexp(ValueError, "collides with"):
|
||||
saver.save(sess, filepath)
|
||||
|
||||
# Succeeds: the file will be named "checkpoint-<step>".
|
||||
saver.save(sess, filepath, global_step=1)
|
||||
self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
|
||||
|
||||
# Succeeds: the file will be named "checkpoint-<i>-of-<n>".
|
||||
saver = saver_module.Saver(sharded=True)
|
||||
saver.save(sess, filepath)
|
||||
self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
|
||||
|
||||
# Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
|
||||
saver = saver_module.Saver(sharded=True)
|
||||
saver.save(sess, filepath, global_step=1)
|
||||
self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
|
||||
|
||||
def testRelativePath(self):
|
||||
# Make sure we have a clean directory to work in.
|
||||
with self.tempDir() as tempdir:
|
||||
|
||||
# Jump to that directory until this test is done.
|
||||
with self.tempWorkingDir(tempdir):
|
||||
|
||||
# Save training snapshots to a relative path.
|
||||
traindir = "train/"
|
||||
os.mkdir(traindir)
|
||||
|
||||
filename = "snapshot"
|
||||
filepath = os.path.join(traindir, filename)
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Build a simple graph.
|
||||
v0 = variables.Variable(0.0)
|
||||
inc = v0.assign_add(1.0)
|
||||
|
||||
save = saver_module.Saver({"v0": v0})
|
||||
|
||||
# Record a short training history.
|
||||
variables.global_variables_initializer().run()
|
||||
save.save(sess, filepath, global_step=0)
|
||||
inc.eval()
|
||||
save.save(sess, filepath, global_step=1)
|
||||
inc.eval()
|
||||
save.save(sess, filepath, global_step=2)
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Build a new graph with different initialization.
|
||||
v0 = variables.Variable(-1.0)
|
||||
|
||||
# Create a new saver.
|
||||
save = saver_module.Saver({"v0": v0})
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Get the most recent checkpoint name from the training history file.
|
||||
name = saver_module.latest_checkpoint(traindir)
|
||||
self.assertIsNotNone(name)
|
||||
|
||||
# Restore "v0" from that checkpoint.
|
||||
save.restore(sess, name)
|
||||
self.assertEqual(v0.eval(), 2.0)
|
||||
|
||||
|
||||
class CheckpointStateTest(test.TestCase):
|
||||
|
||||
def _get_test_dir(self, dirname):
|
||||
test_dir = os.path.join(self.get_temp_dir(), dirname)
|
||||
gfile.MakeDirs(test_dir)
|
||||
return test_dir
|
||||
|
||||
def testAbsPath(self):
|
||||
save_dir = self._get_test_dir("abs_paths")
|
||||
abs_path = os.path.join(save_dir, "model-0")
|
||||
ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
|
||||
self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
|
||||
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
|
||||
|
||||
def testRelPath(self):
|
||||
train_dir = "train"
|
||||
model = os.path.join(train_dir, "model-0")
|
||||
# model_checkpoint_path should have no "train" directory part.
|
||||
new_rel_path = "model-0"
|
||||
ckpt = saver_module.generate_checkpoint_state_proto(train_dir, model)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, new_rel_path)
|
||||
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
|
||||
|
||||
def testAllModelCheckpointPaths(self):
|
||||
save_dir = self._get_test_dir("all_models_test")
|
||||
abs_path = os.path.join(save_dir, "model-0")
|
||||
for paths in [None, [], ["model-2"]]:
|
||||
ckpt = saver_module.generate_checkpoint_state_proto(
|
||||
save_dir, abs_path, all_model_checkpoint_paths=paths)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
|
||||
self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
|
||||
self.assertEqual(
|
||||
len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
|
||||
|
||||
def testUpdateCheckpointState(self):
|
||||
save_dir = self._get_test_dir("update_checkpoint_state")
|
||||
os.chdir(save_dir)
|
||||
# Make a temporary train directory.
|
||||
train_dir = "train"
|
||||
os.mkdir(train_dir)
|
||||
abs_path = os.path.join(save_dir, "model-0")
|
||||
rel_path = os.path.join("train", "model-2")
|
||||
saver_module.update_checkpoint_state(
|
||||
train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])
|
||||
ckpt = saver_module.get_checkpoint_state(train_dir)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, rel_path)
|
||||
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
|
||||
|
||||
def testUpdateCheckpointStateSaveRelativePaths(self):
|
||||
save_dir = self._get_test_dir("update_checkpoint_state")
|
||||
os.chdir(save_dir)
|
||||
abs_path2 = os.path.join(save_dir, "model-2")
|
||||
rel_path2 = "model-2"
|
||||
abs_path0 = os.path.join(save_dir, "model-0")
|
||||
rel_path0 = "model-0"
|
||||
saver_module._update_checkpoint_state( # pylint: disable=protected-access
|
||||
save_dir=save_dir,
|
||||
model_checkpoint_path=abs_path2,
|
||||
all_model_checkpoint_paths=[rel_path0, abs_path2],
|
||||
save_relative_paths=True)
|
||||
|
||||
# File should contain relative paths.
|
||||
file_content = file_io.read_file_to_string(
|
||||
os.path.join(save_dir, "checkpoint"))
|
||||
ckpt = CheckpointState()
|
||||
text_format.Merge(file_content, ckpt)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, rel_path2)
|
||||
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)
|
||||
|
||||
# get_checkpoint_state should return absolute paths.
|
||||
ckpt = saver_module.get_checkpoint_state(save_dir)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, abs_path2)
|
||||
self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
|
||||
|
||||
def testCheckPointStateFailsWhenIncomplete(self):
|
||||
save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
|
||||
os.chdir(save_dir)
|
||||
ckpt_path = os.path.join(save_dir, "checkpoint")
|
||||
ckpt_file = open(ckpt_path, "w")
|
||||
ckpt_file.write("")
|
||||
ckpt_file.close()
|
||||
with self.assertRaises(ValueError):
|
||||
saver_module.get_checkpoint_state(save_dir)
|
||||
|
||||
def testCheckPointCompletesRelativePaths(self):
|
||||
save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
|
||||
os.chdir(save_dir)
|
||||
ckpt_path = os.path.join(save_dir, "checkpoint")
|
||||
ckpt_file = open(ckpt_path, "w")
|
||||
ckpt_file.write("""
|
||||
model_checkpoint_path: "./model.ckpt-687529"
|
||||
all_model_checkpoint_paths: "./model.ckpt-687500"
|
||||
all_model_checkpoint_paths: "./model.ckpt-687529"
|
||||
""")
|
||||
ckpt_file.close()
|
||||
ckpt = saver_module.get_checkpoint_state(save_dir)
|
||||
self.assertEqual(ckpt.model_checkpoint_path,
|
||||
os.path.join(save_dir, "./model.ckpt-687529"))
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[0],
|
||||
os.path.join(save_dir, "./model.ckpt-687500"))
|
||||
self.assertEqual(ckpt.all_model_checkpoint_paths[1],
|
||||
os.path.join(save_dir, "./model.ckpt-687529"))
|
||||
|
||||
|
||||
class MetaGraphTest(test.TestCase):
|
||||
|
||||
def _get_test_dir(self, dirname):
|
||||
@ -2628,62 +2436,6 @@ class WriteGraphTest(test.TestCase):
|
||||
self.assertTrue(os.path.exists(path))
|
||||
|
||||
|
||||
class SaverUtilsTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test")
|
||||
gfile.MakeDirs(self._base_dir)
|
||||
|
||||
def tearDown(self):
|
||||
gfile.DeleteRecursively(self._base_dir)
|
||||
|
||||
def testCheckpointExists(self):
|
||||
for sharded in (False, True):
|
||||
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
|
||||
with self.test_session(graph=ops_lib.Graph()) as sess:
|
||||
unused_v = variables.Variable(1.0, name="v")
|
||||
variables.global_variables_initializer().run()
|
||||
saver = saver_module.Saver(sharded=sharded, write_version=version)
|
||||
|
||||
path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
|
||||
self.assertFalse(
|
||||
saver_module.checkpoint_exists(path)) # Not saved yet.
|
||||
|
||||
ckpt_prefix = saver.save(sess, path)
|
||||
self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
|
||||
|
||||
ckpt_prefix = saver_module.latest_checkpoint(self._base_dir)
|
||||
self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
|
||||
|
||||
def testGetCheckpointMtimes(self):
|
||||
prefixes = []
|
||||
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
|
||||
with self.test_session(graph=ops_lib.Graph()) as sess:
|
||||
unused_v = variables.Variable(1.0, name="v")
|
||||
variables.global_variables_initializer().run()
|
||||
saver = saver_module.Saver(write_version=version)
|
||||
prefixes.append(
|
||||
saver.save(sess, os.path.join(self._base_dir, str(version))))
|
||||
|
||||
mtimes = saver_module.get_checkpoint_mtimes(prefixes)
|
||||
self.assertEqual(2, len(mtimes))
|
||||
self.assertTrue(mtimes[1] >= mtimes[0])
|
||||
|
||||
def testRemoveCheckpoint(self):
|
||||
for sharded in (False, True):
|
||||
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
|
||||
with self.test_session(graph=ops_lib.Graph()) as sess:
|
||||
unused_v = variables.Variable(1.0, name="v")
|
||||
variables.global_variables_initializer().run()
|
||||
saver = saver_module.Saver(sharded=sharded, write_version=version)
|
||||
|
||||
path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
|
||||
ckpt_prefix = saver.save(sess, path)
|
||||
self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
|
||||
saver_module.remove_checkpoint(ckpt_prefix, version)
|
||||
self.assertFalse(saver_module.checkpoint_exists(ckpt_prefix))
|
||||
|
||||
|
||||
class ScopedGraphTest(test.TestCase):
|
||||
|
||||
def _get_test_dir(self, dirname):
|
||||
|
@ -24,7 +24,7 @@ from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import saver as saver_mod
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -197,13 +197,13 @@ class SessionManager(object):
|
||||
|
||||
# Waits up until max_wait_secs for checkpoint to become available.
|
||||
wait_time = 0
|
||||
ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
|
||||
ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
|
||||
while not ckpt or not ckpt.model_checkpoint_path:
|
||||
if wait_for_checkpoint and wait_time < max_wait_secs:
|
||||
logging.info("Waiting for checkpoint to be available.")
|
||||
time.sleep(self._recovery_wait_secs)
|
||||
wait_time += self._recovery_wait_secs
|
||||
ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
|
||||
ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
|
||||
else:
|
||||
return sess, False
|
||||
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.training import session_manager
|
||||
@ -174,13 +175,13 @@ class SessionManagerTest(test.TestCase):
|
||||
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
|
||||
self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
|
||||
self._test_recovered_variable(
|
||||
checkpoint_filename_with_path=saver_lib.latest_checkpoint(
|
||||
checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
|
||||
checkpoint_dir))
|
||||
# Cannot set both checkpoint_dir and checkpoint_filename_with_path.
|
||||
with self.assertRaises(ValueError):
|
||||
self._test_recovered_variable(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
checkpoint_filename_with_path=saver_lib.latest_checkpoint(
|
||||
checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
|
||||
checkpoint_dir))
|
||||
|
||||
def testWaitForSessionReturnsNoneAfterTimeout(self):
|
||||
|
@ -44,6 +44,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.summary import summary_iterator
|
||||
from tensorflow.python.summary.writer import writer
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import input as input_lib
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import server_lib
|
||||
@ -83,7 +84,7 @@ class SupervisorTest(test.TestCase):
|
||||
end_time = time.time() + timeout_secs
|
||||
while time.time() < end_time:
|
||||
if for_checkpoint:
|
||||
if saver_lib.checkpoint_exists(pattern):
|
||||
if checkpoint_management.checkpoint_exists(pattern):
|
||||
return
|
||||
else:
|
||||
if len(gfile.Glob(pattern)) >= 1:
|
||||
|
@ -82,12 +82,12 @@ from tensorflow.python.training.monitored_session import WorkerSessionCreator
|
||||
from tensorflow.python.training.monitored_session import MonitoredSession
|
||||
from tensorflow.python.training.monitored_session import SingularMonitoredSession
|
||||
from tensorflow.python.training.saver import Saver
|
||||
from tensorflow.python.training.saver import checkpoint_exists
|
||||
from tensorflow.python.training.saver import generate_checkpoint_state_proto
|
||||
from tensorflow.python.training.saver import get_checkpoint_mtimes
|
||||
from tensorflow.python.training.saver import get_checkpoint_state
|
||||
from tensorflow.python.training.saver import latest_checkpoint
|
||||
from tensorflow.python.training.saver import update_checkpoint_state
|
||||
from tensorflow.python.training.checkpoint_management import checkpoint_exists
|
||||
from tensorflow.python.training.checkpoint_management import generate_checkpoint_state_proto
|
||||
from tensorflow.python.training.checkpoint_management import get_checkpoint_mtimes
|
||||
from tensorflow.python.training.checkpoint_management import get_checkpoint_state
|
||||
from tensorflow.python.training.checkpoint_management import latest_checkpoint
|
||||
from tensorflow.python.training.checkpoint_management import update_checkpoint_state
|
||||
from tensorflow.python.training.saver import export_meta_graph
|
||||
from tensorflow.python.training.saver import import_meta_graph
|
||||
from tensorflow.python.training.session_run_hook import SessionRunHook
|
||||
|
Loading…
Reference in New Issue
Block a user