From 283789792f3a4df2e54ee6b3b62771809894f7aa Mon Sep 17 00:00:00 2001 From: Jianwei Xie <xiejw@google.com> Date: Fri, 14 Apr 2017 19:16:03 -0800 Subject: [PATCH] Add model_dir in RunConfig. Estimator will read it also. Change: 153233683 --- tensorflow/python/estimator/estimator.py | 25 +++++--- tensorflow/python/estimator/estimator_test.py | 62 ++++++++++++++++++- tensorflow/python/estimator/run_config.py | 4 ++ .../tensorflow.estimator.-run-config.pbtxt | 4 ++ 4 files changed, 86 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 80c5bbf6848..449cb54c841 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -111,7 +111,9 @@ class Estimator(object): `EstimatorSpec` model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to - continue training a previously saved model. + continue training a previously saved model. If `None`, the model_dir in + `config` will be used if set. If both are set, they must be same. If + both are `None`, a temporary directory will be used. config: Configuration object. params: `dict` of hyper parameters that will be passed into `model_fn`. Keys are names of parameters, values are basic python types. @@ -122,12 +124,6 @@ class Estimator(object): a member of `Estimator`. """ Estimator._assert_members_are_not_overridden(self) - # Model directory. - self._model_dir = model_dir - if self._model_dir is None: - self._model_dir = tempfile.mkdtemp() - logging.warning('Using temporary folder as model directory: %s', - self._model_dir) if config is None: self._config = run_config.RunConfig() @@ -139,6 +135,21 @@ class Estimator(object): config) self._config = config + # Model directory. + if (model_dir is not None) and (self._config.model_dir is not None): + if model_dir != self._config.model_dir: + # pylint: disable=g-doc-exception + raise ValueError( + "model_dir are set both in constructor and RunConfig, but with " + "different values. In constructor: '{}', in RunConfig: " + "'{}' ".format(model_dir, self._config.model_dir)) + # pylint: enable=g-doc-exception + + self._model_dir = model_dir or self._config.model_dir + if self._model_dir is None: + self._model_dir = tempfile.mkdtemp() + logging.warning('Using temporary folder as model directory: %s', + self._model_dir) logging.info('Using config: %s', str(vars(self._config))) if self._config.session_config is None: diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 84813073d35..89a9483e201 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -58,6 +58,9 @@ from tensorflow.python.training import session_run_hook from tensorflow.python.training import training from tensorflow.python.util import compat +_TMP_DIR = '/tmp' +_ANOTHER_TMP_DIR = '/another_tmp' + def dummy_model_fn(features, labels, params): _, _, _ = features, labels, params @@ -122,7 +125,7 @@ class EstimatorConstructorTest(test.TestCase): def model_fn(features, labels, params): _, _, _ = features, labels, params - class FakeConfig(run_config.RunConfig): # pylint: disable=g-wrong-blank-lines + class FakeConfig(run_config.RunConfig): pass params = {'hidden_layers': [3, 4]} @@ -148,6 +151,61 @@ class EstimatorConstructorTest(test.TestCase): est = estimator.Estimator(model_fn=model_fn) self.assertTrue(est.model_dir is not None) + def test_model_dir_in_constructor(self): + + def model_fn(features, labels): + _, _ = features, labels + + est = estimator.Estimator(model_fn=model_fn, model_dir=_TMP_DIR) + self.assertEqual(_TMP_DIR, est.model_dir) + + def test_model_dir_in_run_config(self): + + class FakeConfig(run_config.RunConfig): + + @property + def model_dir(self): + return _TMP_DIR + + def model_fn(features, labels): + _, _ = features, labels + + est = estimator.Estimator(model_fn=model_fn, config=FakeConfig()) + self.assertEqual(_TMP_DIR, est.model_dir) + + def test_same_model_dir_in_constructor_and_run_config(self): + + class FakeConfig(run_config.RunConfig): + + @property + def model_dir(self): + return _TMP_DIR + + def model_fn(features, labels): + _, _ = features, labels + + est = estimator.Estimator( + model_fn=model_fn, config=FakeConfig(), model_dir=_TMP_DIR) + self.assertEqual(_TMP_DIR, est.model_dir) + + def test_different_model_dir_in_constructor_and_run_config(self): + + class FakeConfig(run_config.RunConfig): + + @property + def model_dir(self): + return _TMP_DIR + + def model_fn(features, labels): + _, _ = features, labels + + with self.assertRaisesRegexp( + ValueError, + 'model_dir are set both in constructor and RunConfig, but ' + 'with different values'): + estimator.Estimator( + model_fn=model_fn, config=FakeConfig(), model_dir=_ANOTHER_TMP_DIR) + def test_model_fn_args_must_include_features(self): def model_fn(x, labels): @@ -359,7 +417,7 @@ class EstimatorTrainTest(test.TestCase): training_chief_hooks=[chief_hook], training_hooks=[hook]) - class NonChiefRunConfig(run_config.RunConfig): # pylint: disable=g-wrong-blank-lines + class NonChiefRunConfig(run_config.RunConfig): @property def is_chief(self): # pylint: disable=g-wrong-blank-lines return False diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index 79b55c68532..504bf5c3fe5 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -87,3 +87,7 @@ class RunConfig(object): @property def keep_checkpoint_every_n_hours(self): return 10000 + + @property + def model_dir(self): + return None diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt index 8fd991a317b..32082fc10bb 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt @@ -26,6 +26,10 @@ tf_class { name: "master" mtype: "<type \'property\'>" } + member { + name: "model_dir" + mtype: "<type \'property\'>" + } member { name: "num_ps_replicas" mtype: "<type \'property\'>"