From cc0cf49a0d0cfdb23073810260ca1af480d08850 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 16 Oct 2018 16:20:58 -0700
Subject: [PATCH] Check for the presence of a Worker machine when reassigning
 hooks in distributed training jobs.

PiperOrigin-RevId: 217407558
---
 tensorflow/python/estimator/estimator.py      |  6 ++
 tensorflow/python/estimator/estimator_test.py | 61 +++++++++++++++++++
 2 files changed, 67 insertions(+)

diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 3c1be9dbad3..c44413090a6 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -1423,7 +1423,13 @@ class Estimator(object):
     # evaluations.
     save_summary_steps = self._config.save_summary_steps
     log_step_count_steps = self._config.log_step_count_steps
+
+    # Check existence of appropriate cluster spec fields, as well as master and
+    # worker nodes. As master also performs evaluation, summary writing must
+    # occur on a different node. The presence of a worker is also checked to
+    # prevent reassigning hooks for single-replica jobs with just a master node.
     if (self._config.cluster_spec and self._config.cluster_spec.jobs and
+        (run_config.TaskType.WORKER in self._config.cluster_spec.jobs) and
         (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
       # Update config values to prevent the default hooks from being created on
       # the master or other workers.
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 246dfb1a4bd..c26b3e65098 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -1063,6 +1063,67 @@ class EstimatorTrainTest(test.TestCase):
       self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
       self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
 
+  def test_master_hooks_single_replica(self):
+    tf_config = json.dumps({
+        'cluster': {
+            run_config.TaskType.MASTER: ['localhost:1234']
+        },
+        'task': {
+            'type': run_config.TaskType.MASTER,
+            'index': 0
+        }
+    })
+    with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+      est = estimator.Estimator(
+          model_fn=model_fn_global_step_incrementer,
+          config=run_config.RunConfig(
+              save_summary_steps=100, log_step_count_steps=200))
+
+    with test.mock.patch.object(training,
+                                'MonitoredTrainingSession') as mock_sess:
+      est.train(dummy_input_fn, steps=1)
+      self.assertFalse(
+          any(
+              isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+              for hook in mock_sess.call_args[1]['hooks']))
+      self.assertFalse(
+          any(
+              isinstance(hook, basic_session_run_hooks.StepCounterHook)
+              for hook in mock_sess.call_args[1]['hooks']))
+      self.assertEqual(100, mock_sess.call_args[1]['save_summaries_steps'])
+      self.assertEqual(200, mock_sess.call_args[1]['log_step_count_steps'])
+
+  def test_master_hooks_single_replica_with_ps(self):
+    tf_config = json.dumps({
+        'cluster': {
+            run_config.TaskType.MASTER: ['localhost:1234'],
+            run_config.TaskType.PS: ['localhost: 1235'],
+        },
+        'task': {
+            'type': run_config.TaskType.MASTER,
+            'index': 0
+        }
+    })
+    with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+      est = estimator.Estimator(
+          model_fn=model_fn_global_step_incrementer,
+          config=run_config.RunConfig(
+              save_summary_steps=100, log_step_count_steps=200))
+
+    with test.mock.patch.object(training,
+                                'MonitoredTrainingSession') as mock_sess:
+      est.train(dummy_input_fn, steps=1)
+      self.assertFalse(
+          any(
+              isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+              for hook in mock_sess.call_args[1]['hooks']))
+      self.assertFalse(
+          any(
+              isinstance(hook, basic_session_run_hooks.StepCounterHook)
+              for hook in mock_sess.call_args[1]['hooks']))
+      self.assertEqual(100, mock_sess.call_args[1]['save_summaries_steps'])
+      self.assertEqual(200, mock_sess.call_args[1]['log_step_count_steps'])
+
 
 def _model_fn_with_eval_metric_ops(features, labels, mode, params):
   _, _ = features, labels