Change contrib estimator to save relative paths in checkpoint.
Change: 155016674
This commit is contained in:
parent
1457d7ffdc
commit
6b493f72c8
@ -966,7 +966,8 @@ class BaseEstimator(
|
|||||||
saver.Saver(
|
saver.Saver(
|
||||||
sharded=True,
|
sharded=True,
|
||||||
max_to_keep=self._config.keep_checkpoint_max,
|
max_to_keep=self._config.keep_checkpoint_max,
|
||||||
defer_build=True))
|
defer_build=True,
|
||||||
|
save_relative_paths=True))
|
||||||
|
|
||||||
chief_hooks = []
|
chief_hooks = []
|
||||||
if (self._config.save_checkpoints_secs or
|
if (self._config.save_checkpoints_secs or
|
||||||
|
@ -28,6 +28,8 @@ import numpy as np
|
|||||||
import six
|
import six
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
|
from google.protobuf import text_format
|
||||||
|
|
||||||
from tensorflow.contrib import learn
|
from tensorflow.contrib import learn
|
||||||
from tensorflow.contrib import lookup
|
from tensorflow.contrib import lookup
|
||||||
from tensorflow.contrib.framework.python.ops import variables
|
from tensorflow.contrib.framework.python.ops import variables
|
||||||
@ -50,6 +52,7 @@ from tensorflow.python.client import session as session_lib
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.lib.io import file_io
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -61,6 +64,7 @@ from tensorflow.python.platform import test
|
|||||||
from tensorflow.python.saved_model import loader
|
from tensorflow.python.saved_model import loader
|
||||||
from tensorflow.python.saved_model import tag_constants
|
from tensorflow.python.saved_model import tag_constants
|
||||||
from tensorflow.python.training import basic_session_run_hooks
|
from tensorflow.python.training import basic_session_run_hooks
|
||||||
|
from tensorflow.python.training import checkpoint_state_pb2
|
||||||
from tensorflow.python.training import input as input_lib
|
from tensorflow.python.training import input as input_lib
|
||||||
from tensorflow.python.training import monitored_session
|
from tensorflow.python.training import monitored_session
|
||||||
from tensorflow.python.training import saver as saver_lib
|
from tensorflow.python.training import saver as saver_lib
|
||||||
@ -674,6 +678,38 @@ class EstimatorTest(test.TestCase):
|
|||||||
metrics={'MSE': metric_ops.streaming_mean_squared_error})
|
metrics={'MSE': metric_ops.streaming_mean_squared_error})
|
||||||
self.assertLess(scores3['MSE'], scores['MSE'])
|
self.assertLess(scores3['MSE'], scores['MSE'])
|
||||||
|
|
||||||
|
def test_checkpoint_contains_relative_paths(self):
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
est = estimator.Estimator(
|
||||||
|
model_dir=tmpdir,
|
||||||
|
model_fn=linear_model_fn_with_model_fn_ops)
|
||||||
|
est.fit(input_fn=boston_input_fn, steps=5)
|
||||||
|
|
||||||
|
checkpoint_file_content = file_io.read_file_to_string(
|
||||||
|
os.path.join(tmpdir, 'checkpoint'))
|
||||||
|
ckpt = checkpoint_state_pb2.CheckpointState()
|
||||||
|
text_format.Merge(checkpoint_file_content, ckpt)
|
||||||
|
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
|
||||||
|
self.assertAllEqual(
|
||||||
|
['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
|
||||||
|
|
||||||
|
def test_train_save_copy_reload(self):
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
model_dir1 = os.path.join(tmpdir, 'model_dir1')
|
||||||
|
est1 = estimator.Estimator(
|
||||||
|
model_dir=model_dir1,
|
||||||
|
model_fn=linear_model_fn_with_model_fn_ops)
|
||||||
|
est1.fit(input_fn=boston_input_fn, steps=5)
|
||||||
|
|
||||||
|
model_dir2 = os.path.join(tmpdir, 'model_dir2')
|
||||||
|
os.renames(model_dir1, model_dir2)
|
||||||
|
est2 = estimator.Estimator(
|
||||||
|
model_dir=model_dir2,
|
||||||
|
model_fn=linear_model_fn_with_model_fn_ops)
|
||||||
|
self.assertEqual(5, est2.get_variable_value('global_step'))
|
||||||
|
est2.fit(input_fn=boston_input_fn, steps=5)
|
||||||
|
self.assertEqual(10, est2.get_variable_value('global_step'))
|
||||||
|
|
||||||
def testEstimatorParams(self):
|
def testEstimatorParams(self):
|
||||||
boston = base.load_boston()
|
boston = base.load_boston()
|
||||||
est = estimator.SKCompat(
|
est = estimator.SKCompat(
|
||||||
|
Loading…
Reference in New Issue
Block a user