Change contrib estimator to save relative paths in checkpoint.

Change: 155016674
This commit is contained in:
A. Unique TensorFlower 2017-05-03 14:58:36 -08:00 committed by TensorFlower Gardener
parent 1457d7ffdc
commit 6b493f72c8
2 changed files with 38 additions and 1 deletions

View File

@ -966,7 +966,8 @@ class BaseEstimator(
saver.Saver(
sharded=True,
max_to_keep=self._config.keep_checkpoint_max,
defer_build=True))
defer_build=True,
save_relative_paths=True))
chief_hooks = []
if (self._config.save_checkpoints_secs or

View File

@ -28,6 +28,8 @@ import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
from google.protobuf import text_format
from tensorflow.contrib import learn
from tensorflow.contrib import lookup
from tensorflow.contrib.framework.python.ops import variables
@ -50,6 +52,7 @@ from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@ -61,6 +64,7 @@ from tensorflow.python.platform import test
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_state_pb2
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib
@ -674,6 +678,38 @@ class EstimatorTest(test.TestCase):
metrics={'MSE': metric_ops.streaming_mean_squared_error})
self.assertLess(scores3['MSE'], scores['MSE'])
def test_checkpoint_contains_relative_paths(self):
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(
model_dir=tmpdir,
model_fn=linear_model_fn_with_model_fn_ops)
est.fit(input_fn=boston_input_fn, steps=5)
checkpoint_file_content = file_io.read_file_to_string(
os.path.join(tmpdir, 'checkpoint'))
ckpt = checkpoint_state_pb2.CheckpointState()
text_format.Merge(checkpoint_file_content, ckpt)
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
self.assertAllEqual(
['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
def test_train_save_copy_reload(self):
tmpdir = tempfile.mkdtemp()
model_dir1 = os.path.join(tmpdir, 'model_dir1')
est1 = estimator.Estimator(
model_dir=model_dir1,
model_fn=linear_model_fn_with_model_fn_ops)
est1.fit(input_fn=boston_input_fn, steps=5)
model_dir2 = os.path.join(tmpdir, 'model_dir2')
os.renames(model_dir1, model_dir2)
est2 = estimator.Estimator(
model_dir=model_dir2,
model_fn=linear_model_fn_with_model_fn_ops)
self.assertEqual(5, est2.get_variable_value('global_step'))
est2.fit(input_fn=boston_input_fn, steps=5)
self.assertEqual(10, est2.get_variable_value('global_step'))
def testEstimatorParams(self):
boston = base.load_boston()
est = estimator.SKCompat(