diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 123db50d325..74a6da20d4e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -966,7 +966,8 @@ class BaseEstimator( saver.Saver( sharded=True, max_to_keep=self._config.keep_checkpoint_max, - defer_build=True)) + defer_build=True, + save_relative_paths=True)) chief_hooks = [] if (self._config.save_checkpoints_secs or diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index 8c61ffad553..c95df75356b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -28,6 +28,8 @@ import numpy as np import six from six.moves import xrange # pylint: disable=redefined-builtin +from google.protobuf import text_format + from tensorflow.contrib import learn from tensorflow.contrib import lookup from tensorflow.contrib.framework.python.ops import variables @@ -50,6 +52,7 @@ from tensorflow.python.client import session as session_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -61,6 +64,7 @@ from tensorflow.python.platform import test from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_state_pb2 from tensorflow.python.training import input as input_lib from tensorflow.python.training import monitored_session from tensorflow.python.training import saver as saver_lib @@ -674,6 +678,38 @@ class EstimatorTest(test.TestCase): metrics={'MSE': metric_ops.streaming_mean_squared_error}) self.assertLess(scores3['MSE'], scores['MSE']) + def test_checkpoint_contains_relative_paths(self): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator( + model_dir=tmpdir, + model_fn=linear_model_fn_with_model_fn_ops) + est.fit(input_fn=boston_input_fn, steps=5) + + checkpoint_file_content = file_io.read_file_to_string( + os.path.join(tmpdir, 'checkpoint')) + ckpt = checkpoint_state_pb2.CheckpointState() + text_format.Merge(checkpoint_file_content, ckpt) + self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5') + self.assertAllEqual( + ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths) + + def test_train_save_copy_reload(self): + tmpdir = tempfile.mkdtemp() + model_dir1 = os.path.join(tmpdir, 'model_dir1') + est1 = estimator.Estimator( + model_dir=model_dir1, + model_fn=linear_model_fn_with_model_fn_ops) + est1.fit(input_fn=boston_input_fn, steps=5) + + model_dir2 = os.path.join(tmpdir, 'model_dir2') + os.renames(model_dir1, model_dir2) + est2 = estimator.Estimator( + model_dir=model_dir2, + model_fn=linear_model_fn_with_model_fn_ops) + self.assertEqual(5, est2.get_variable_value('global_step')) + est2.fit(input_fn=boston_input_fn, steps=5) + self.assertEqual(10, est2.get_variable_value('global_step')) + def testEstimatorParams(self): boston = base.load_boston() est = estimator.SKCompat(