From 283789792f3a4df2e54ee6b3b62771809894f7aa Mon Sep 17 00:00:00 2001
From: Jianwei Xie <xiejw@google.com>
Date: Fri, 14 Apr 2017 19:16:03 -0800
Subject: [PATCH] Add model_dir in RunConfig. Estimator will read it also.
 Change: 153233683

---
 tensorflow/python/estimator/estimator.py      | 25 +++++---
 tensorflow/python/estimator/estimator_test.py | 62 ++++++++++++++++++-
 tensorflow/python/estimator/run_config.py     |  4 ++
 .../tensorflow.estimator.-run-config.pbtxt    |  4 ++
 4 files changed, 86 insertions(+), 9 deletions(-)

diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 80c5bbf6848..449cb54c841 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -111,7 +111,9 @@ class Estimator(object):
           `EstimatorSpec`
       model_dir: Directory to save model parameters, graph and etc. This can
         also be used to load checkpoints from the directory into a estimator to
-        continue training a previously saved model.
+        continue training a previously saved model. If `None`, the model_dir in
+        `config` will be used if set. If both are set, they must be same. If
+        both are `None`, a temporary directory will be used.
       config: Configuration object.
       params: `dict` of hyper parameters that will be passed into `model_fn`.
               Keys are names of parameters, values are basic python types.
@@ -122,12 +124,6 @@ class Estimator(object):
         a member of `Estimator`.
     """
     Estimator._assert_members_are_not_overridden(self)
-    # Model directory.
-    self._model_dir = model_dir
-    if self._model_dir is None:
-      self._model_dir = tempfile.mkdtemp()
-      logging.warning('Using temporary folder as model directory: %s',
-                      self._model_dir)
 
     if config is None:
       self._config = run_config.RunConfig()
@@ -139,6 +135,21 @@ class Estimator(object):
             config)
       self._config = config
 
+    # Model directory.
+    if (model_dir is not None) and (self._config.model_dir is not None):
+      if model_dir != self._config.model_dir:
+        # pylint: disable=g-doc-exception
+        raise ValueError(
+            "model_dir are set both in constructor and RunConfig, but with "
+            "different values. In constructor: '{}', in RunConfig: "
+            "'{}' ".format(model_dir, self._config.model_dir))
+        # pylint: enable=g-doc-exception
+
+    self._model_dir = model_dir or self._config.model_dir
+    if self._model_dir is None:
+      self._model_dir = tempfile.mkdtemp()
+      logging.warning('Using temporary folder as model directory: %s',
+                      self._model_dir)
     logging.info('Using config: %s', str(vars(self._config)))
 
     if self._config.session_config is None:
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 84813073d35..89a9483e201 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -58,6 +58,9 @@ from tensorflow.python.training import session_run_hook
 from tensorflow.python.training import training
 from tensorflow.python.util import compat
 
+_TMP_DIR = '/tmp'
+_ANOTHER_TMP_DIR = '/another_tmp'
+
 
 def dummy_model_fn(features, labels, params):
   _, _, _ = features, labels, params
@@ -122,7 +125,7 @@ class EstimatorConstructorTest(test.TestCase):
     def model_fn(features, labels, params):
       _, _, _ = features, labels, params
 
-    class FakeConfig(run_config.RunConfig):  # pylint: disable=g-wrong-blank-lines
+    class FakeConfig(run_config.RunConfig):
       pass
 
     params = {'hidden_layers': [3, 4]}
@@ -148,6 +151,61 @@ class EstimatorConstructorTest(test.TestCase):
     est = estimator.Estimator(model_fn=model_fn)
     self.assertTrue(est.model_dir is not None)
 
+  def test_model_dir_in_constructor(self):
+
+    def model_fn(features, labels):
+      _, _ = features, labels
+
+    est = estimator.Estimator(model_fn=model_fn, model_dir=_TMP_DIR)
+    self.assertEqual(_TMP_DIR, est.model_dir)
+
+  def test_model_dir_in_run_config(self):
+
+    class FakeConfig(run_config.RunConfig):
+
+      @property
+      def model_dir(self):
+        return _TMP_DIR
+
+    def model_fn(features, labels):
+      _, _ = features, labels
+
+    est = estimator.Estimator(model_fn=model_fn, config=FakeConfig())
+    self.assertEqual(_TMP_DIR, est.model_dir)
+
+  def test_same_model_dir_in_constructor_and_run_config(self):
+
+    class FakeConfig(run_config.RunConfig):
+
+      @property
+      def model_dir(self):
+        return _TMP_DIR
+
+    def model_fn(features, labels):
+      _, _ = features, labels
+
+    est = estimator.Estimator(
+        model_fn=model_fn, config=FakeConfig(), model_dir=_TMP_DIR)
+    self.assertEqual(_TMP_DIR, est.model_dir)
+
+  def test_different_model_dir_in_constructor_and_run_config(self):
+
+    class FakeConfig(run_config.RunConfig):
+
+      @property
+      def model_dir(self):
+        return _TMP_DIR
+
+    def model_fn(features, labels):
+      _, _ = features, labels
+
+    with self.assertRaisesRegexp(
+        ValueError,
+        'model_dir are set both in constructor and RunConfig, but '
+        'with different values'):
+      estimator.Estimator(
+          model_fn=model_fn, config=FakeConfig(), model_dir=_ANOTHER_TMP_DIR)
+
   def test_model_fn_args_must_include_features(self):
 
     def model_fn(x, labels):
@@ -359,7 +417,7 @@ class EstimatorTrainTest(test.TestCase):
           training_chief_hooks=[chief_hook],
           training_hooks=[hook])
 
-    class NonChiefRunConfig(run_config.RunConfig):  # pylint: disable=g-wrong-blank-lines
+    class NonChiefRunConfig(run_config.RunConfig):
       @property
       def is_chief(self):  # pylint: disable=g-wrong-blank-lines
         return False
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index 79b55c68532..504bf5c3fe5 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -87,3 +87,7 @@ class RunConfig(object):
   @property
   def keep_checkpoint_every_n_hours(self):
     return 10000
+
+  @property
+  def model_dir(self):
+    return None
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
index 8fd991a317b..32082fc10bb 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
@@ -26,6 +26,10 @@ tf_class {
     name: "master"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "model_dir"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "num_ps_replicas"
     mtype: "<type \'property\'>"