Change contrib estimator to save relative paths in checkpoint.
Change: 155016674
This commit is contained in:
parent
1457d7ffdc
commit
6b493f72c8
@ -966,7 +966,8 @@ class BaseEstimator(
|
||||
saver.Saver(
|
||||
sharded=True,
|
||||
max_to_keep=self._config.keep_checkpoint_max,
|
||||
defer_build=True))
|
||||
defer_build=True,
|
||||
save_relative_paths=True))
|
||||
|
||||
chief_hooks = []
|
||||
if (self._config.save_checkpoints_secs or
|
||||
|
@ -28,6 +28,8 @@ import numpy as np
|
||||
import six
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.contrib import learn
|
||||
from tensorflow.contrib import lookup
|
||||
from tensorflow.contrib.framework.python.ops import variables
|
||||
@ -50,6 +52,7 @@ from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -61,6 +64,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import checkpoint_state_pb2
|
||||
from tensorflow.python.training import input as input_lib
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
@ -674,6 +678,38 @@ class EstimatorTest(test.TestCase):
|
||||
metrics={'MSE': metric_ops.streaming_mean_squared_error})
|
||||
self.assertLess(scores3['MSE'], scores['MSE'])
|
||||
|
||||
def test_checkpoint_contains_relative_paths(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
est = estimator.Estimator(
|
||||
model_dir=tmpdir,
|
||||
model_fn=linear_model_fn_with_model_fn_ops)
|
||||
est.fit(input_fn=boston_input_fn, steps=5)
|
||||
|
||||
checkpoint_file_content = file_io.read_file_to_string(
|
||||
os.path.join(tmpdir, 'checkpoint'))
|
||||
ckpt = checkpoint_state_pb2.CheckpointState()
|
||||
text_format.Merge(checkpoint_file_content, ckpt)
|
||||
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
|
||||
self.assertAllEqual(
|
||||
['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
|
||||
|
||||
def test_train_save_copy_reload(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
model_dir1 = os.path.join(tmpdir, 'model_dir1')
|
||||
est1 = estimator.Estimator(
|
||||
model_dir=model_dir1,
|
||||
model_fn=linear_model_fn_with_model_fn_ops)
|
||||
est1.fit(input_fn=boston_input_fn, steps=5)
|
||||
|
||||
model_dir2 = os.path.join(tmpdir, 'model_dir2')
|
||||
os.renames(model_dir1, model_dir2)
|
||||
est2 = estimator.Estimator(
|
||||
model_dir=model_dir2,
|
||||
model_fn=linear_model_fn_with_model_fn_ops)
|
||||
self.assertEqual(5, est2.get_variable_value('global_step'))
|
||||
est2.fit(input_fn=boston_input_fn, steps=5)
|
||||
self.assertEqual(10, est2.get_variable_value('global_step'))
|
||||
|
||||
def testEstimatorParams(self):
|
||||
boston = base.load_boston()
|
||||
est = estimator.SKCompat(
|
||||
|
Loading…
Reference in New Issue
Block a user