Add model_dir in RunConfig. Estimator will read it also.
Change: 153233683
This commit is contained in:
parent
572556d271
commit
283789792f
@ -111,7 +111,9 @@ class Estimator(object):
|
||||
`EstimatorSpec`
|
||||
model_dir: Directory to save model parameters, graph and etc. This can
|
||||
also be used to load checkpoints from the directory into a estimator to
|
||||
continue training a previously saved model.
|
||||
continue training a previously saved model. If `None`, the model_dir in
|
||||
`config` will be used if set. If both are set, they must be same. If
|
||||
both are `None`, a temporary directory will be used.
|
||||
config: Configuration object.
|
||||
params: `dict` of hyper parameters that will be passed into `model_fn`.
|
||||
Keys are names of parameters, values are basic python types.
|
||||
@ -122,12 +124,6 @@ class Estimator(object):
|
||||
a member of `Estimator`.
|
||||
"""
|
||||
Estimator._assert_members_are_not_overridden(self)
|
||||
# Model directory.
|
||||
self._model_dir = model_dir
|
||||
if self._model_dir is None:
|
||||
self._model_dir = tempfile.mkdtemp()
|
||||
logging.warning('Using temporary folder as model directory: %s',
|
||||
self._model_dir)
|
||||
|
||||
if config is None:
|
||||
self._config = run_config.RunConfig()
|
||||
@ -139,6 +135,21 @@ class Estimator(object):
|
||||
config)
|
||||
self._config = config
|
||||
|
||||
# Model directory.
|
||||
if (model_dir is not None) and (self._config.model_dir is not None):
|
||||
if model_dir != self._config.model_dir:
|
||||
# pylint: disable=g-doc-exception
|
||||
raise ValueError(
|
||||
"model_dir are set both in constructor and RunConfig, but with "
|
||||
"different values. In constructor: '{}', in RunConfig: "
|
||||
"'{}' ".format(model_dir, self._config.model_dir))
|
||||
# pylint: enable=g-doc-exception
|
||||
|
||||
self._model_dir = model_dir or self._config.model_dir
|
||||
if self._model_dir is None:
|
||||
self._model_dir = tempfile.mkdtemp()
|
||||
logging.warning('Using temporary folder as model directory: %s',
|
||||
self._model_dir)
|
||||
logging.info('Using config: %s', str(vars(self._config)))
|
||||
|
||||
if self._config.session_config is None:
|
||||
|
@ -58,6 +58,9 @@ from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training import training
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
_TMP_DIR = '/tmp'
|
||||
_ANOTHER_TMP_DIR = '/another_tmp'
|
||||
|
||||
|
||||
def dummy_model_fn(features, labels, params):
|
||||
_, _, _ = features, labels, params
|
||||
@ -122,7 +125,7 @@ class EstimatorConstructorTest(test.TestCase):
|
||||
def model_fn(features, labels, params):
|
||||
_, _, _ = features, labels, params
|
||||
|
||||
class FakeConfig(run_config.RunConfig): # pylint: disable=g-wrong-blank-lines
|
||||
class FakeConfig(run_config.RunConfig):
|
||||
pass
|
||||
|
||||
params = {'hidden_layers': [3, 4]}
|
||||
@ -148,6 +151,61 @@ class EstimatorConstructorTest(test.TestCase):
|
||||
est = estimator.Estimator(model_fn=model_fn)
|
||||
self.assertTrue(est.model_dir is not None)
|
||||
|
||||
def test_model_dir_in_constructor(self):
|
||||
|
||||
def model_fn(features, labels):
|
||||
_, _ = features, labels
|
||||
|
||||
est = estimator.Estimator(model_fn=model_fn, model_dir=_TMP_DIR)
|
||||
self.assertEqual(_TMP_DIR, est.model_dir)
|
||||
|
||||
def test_model_dir_in_run_config(self):
|
||||
|
||||
class FakeConfig(run_config.RunConfig):
|
||||
|
||||
@property
|
||||
def model_dir(self):
|
||||
return _TMP_DIR
|
||||
|
||||
def model_fn(features, labels):
|
||||
_, _ = features, labels
|
||||
|
||||
est = estimator.Estimator(model_fn=model_fn, config=FakeConfig())
|
||||
self.assertEqual(_TMP_DIR, est.model_dir)
|
||||
|
||||
def test_same_model_dir_in_constructor_and_run_config(self):
|
||||
|
||||
class FakeConfig(run_config.RunConfig):
|
||||
|
||||
@property
|
||||
def model_dir(self):
|
||||
return _TMP_DIR
|
||||
|
||||
def model_fn(features, labels):
|
||||
_, _ = features, labels
|
||||
|
||||
est = estimator.Estimator(
|
||||
model_fn=model_fn, config=FakeConfig(), model_dir=_TMP_DIR)
|
||||
self.assertEqual(_TMP_DIR, est.model_dir)
|
||||
|
||||
def test_different_model_dir_in_constructor_and_run_config(self):
|
||||
|
||||
class FakeConfig(run_config.RunConfig):
|
||||
|
||||
@property
|
||||
def model_dir(self):
|
||||
return _TMP_DIR
|
||||
|
||||
def model_fn(features, labels):
|
||||
_, _ = features, labels
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'model_dir are set both in constructor and RunConfig, but '
|
||||
'with different values'):
|
||||
estimator.Estimator(
|
||||
model_fn=model_fn, config=FakeConfig(), model_dir=_ANOTHER_TMP_DIR)
|
||||
|
||||
def test_model_fn_args_must_include_features(self):
|
||||
|
||||
def model_fn(x, labels):
|
||||
@ -359,7 +417,7 @@ class EstimatorTrainTest(test.TestCase):
|
||||
training_chief_hooks=[chief_hook],
|
||||
training_hooks=[hook])
|
||||
|
||||
class NonChiefRunConfig(run_config.RunConfig): # pylint: disable=g-wrong-blank-lines
|
||||
class NonChiefRunConfig(run_config.RunConfig):
|
||||
@property
|
||||
def is_chief(self): # pylint: disable=g-wrong-blank-lines
|
||||
return False
|
||||
|
@ -87,3 +87,7 @@ class RunConfig(object):
|
||||
@property
|
||||
def keep_checkpoint_every_n_hours(self):
|
||||
return 10000
|
||||
|
||||
@property
|
||||
def model_dir(self):
|
||||
return None
|
||||
|
@ -26,6 +26,10 @@ tf_class {
|
||||
name: "master"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "model_dir"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "num_ps_replicas"
|
||||
mtype: "<type \'property\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user