Add model_dir in RunConfig. Estimator will read it also.

Change: 153233683
This commit is contained in:
Jianwei Xie 2017-04-14 19:16:03 -08:00 committed by TensorFlower Gardener
parent 572556d271
commit 283789792f
4 changed files with 86 additions and 9 deletions

View File

@ -111,7 +111,9 @@ class Estimator(object):
`EstimatorSpec`
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model.
continue training a previously saved model. If `None`, the model_dir in
`config` will be used if set. If both are set, they must be same. If
both are `None`, a temporary directory will be used.
config: Configuration object.
params: `dict` of hyper parameters that will be passed into `model_fn`.
Keys are names of parameters, values are basic python types.
@ -122,12 +124,6 @@ class Estimator(object):
a member of `Estimator`.
"""
Estimator._assert_members_are_not_overridden(self)
# Model directory.
self._model_dir = model_dir
if self._model_dir is None:
self._model_dir = tempfile.mkdtemp()
logging.warning('Using temporary folder as model directory: %s',
self._model_dir)
if config is None:
self._config = run_config.RunConfig()
@ -139,6 +135,21 @@ class Estimator(object):
config)
self._config = config
# Model directory.
if (model_dir is not None) and (self._config.model_dir is not None):
if model_dir != self._config.model_dir:
# pylint: disable=g-doc-exception
raise ValueError(
"model_dir are set both in constructor and RunConfig, but with "
"different values. In constructor: '{}', in RunConfig: "
"'{}' ".format(model_dir, self._config.model_dir))
# pylint: enable=g-doc-exception
self._model_dir = model_dir or self._config.model_dir
if self._model_dir is None:
self._model_dir = tempfile.mkdtemp()
logging.warning('Using temporary folder as model directory: %s',
self._model_dir)
logging.info('Using config: %s', str(vars(self._config)))
if self._config.session_config is None:

View File

@ -58,6 +58,9 @@ from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.util import compat
_TMP_DIR = '/tmp'
_ANOTHER_TMP_DIR = '/another_tmp'
def dummy_model_fn(features, labels, params):
_, _, _ = features, labels, params
@ -122,7 +125,7 @@ class EstimatorConstructorTest(test.TestCase):
def model_fn(features, labels, params):
_, _, _ = features, labels, params
class FakeConfig(run_config.RunConfig): # pylint: disable=g-wrong-blank-lines
class FakeConfig(run_config.RunConfig):
pass
params = {'hidden_layers': [3, 4]}
@ -148,6 +151,61 @@ class EstimatorConstructorTest(test.TestCase):
est = estimator.Estimator(model_fn=model_fn)
self.assertTrue(est.model_dir is not None)
def test_model_dir_in_constructor(self):
def model_fn(features, labels):
_, _ = features, labels
est = estimator.Estimator(model_fn=model_fn, model_dir=_TMP_DIR)
self.assertEqual(_TMP_DIR, est.model_dir)
def test_model_dir_in_run_config(self):
class FakeConfig(run_config.RunConfig):
@property
def model_dir(self):
return _TMP_DIR
def model_fn(features, labels):
_, _ = features, labels
est = estimator.Estimator(model_fn=model_fn, config=FakeConfig())
self.assertEqual(_TMP_DIR, est.model_dir)
def test_same_model_dir_in_constructor_and_run_config(self):
class FakeConfig(run_config.RunConfig):
@property
def model_dir(self):
return _TMP_DIR
def model_fn(features, labels):
_, _ = features, labels
est = estimator.Estimator(
model_fn=model_fn, config=FakeConfig(), model_dir=_TMP_DIR)
self.assertEqual(_TMP_DIR, est.model_dir)
def test_different_model_dir_in_constructor_and_run_config(self):
class FakeConfig(run_config.RunConfig):
@property
def model_dir(self):
return _TMP_DIR
def model_fn(features, labels):
_, _ = features, labels
with self.assertRaisesRegexp(
ValueError,
'model_dir are set both in constructor and RunConfig, but '
'with different values'):
estimator.Estimator(
model_fn=model_fn, config=FakeConfig(), model_dir=_ANOTHER_TMP_DIR)
def test_model_fn_args_must_include_features(self):
def model_fn(x, labels):
@ -359,7 +417,7 @@ class EstimatorTrainTest(test.TestCase):
training_chief_hooks=[chief_hook],
training_hooks=[hook])
class NonChiefRunConfig(run_config.RunConfig): # pylint: disable=g-wrong-blank-lines
class NonChiefRunConfig(run_config.RunConfig):
@property
def is_chief(self): # pylint: disable=g-wrong-blank-lines
return False

View File

@ -87,3 +87,7 @@ class RunConfig(object):
@property
def keep_checkpoint_every_n_hours(self):
return 10000
@property
def model_dir(self):
return None

View File

@ -26,6 +26,10 @@ tf_class {
name: "master"
mtype: "<type \'property\'>"
}
member {
name: "model_dir"
mtype: "<type \'property\'>"
}
member {
name: "num_ps_replicas"
mtype: "<type \'property\'>"