From cc0cf49a0d0cfdb23073810260ca1af480d08850 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Tue, 16 Oct 2018 16:20:58 -0700 Subject: [PATCH] Check for the presence of a Worker machine when reassigning hooks in distributed training jobs. PiperOrigin-RevId: 217407558 --- tensorflow/python/estimator/estimator.py | 6 ++ tensorflow/python/estimator/estimator_test.py | 61 +++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 3c1be9dbad3..c44413090a6 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -1423,7 +1423,13 @@ class Estimator(object): # evaluations. save_summary_steps = self._config.save_summary_steps log_step_count_steps = self._config.log_step_count_steps + + # Check existence of appropriate cluster spec fields, as well as master and + # worker nodes. As master also performs evaluation, summary writing must + # occur on a different node. The presence of a worker is also checked to + # prevent reassigning hooks for single-replica jobs with just a master node. if (self._config.cluster_spec and self._config.cluster_spec.jobs and + (run_config.TaskType.WORKER in self._config.cluster_spec.jobs) and (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)): # Update config values to prevent the default hooks from being created on # the master or other workers. diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 246dfb1a4bd..c26b3e65098 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -1063,6 +1063,67 @@ class EstimatorTrainTest(test.TestCase): self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps']) self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps']) + def test_master_hooks_single_replica(self): + tf_config = json.dumps({ + 'cluster': { + run_config.TaskType.MASTER: ['localhost:1234'] + }, + 'task': { + 'type': run_config.TaskType.MASTER, + 'index': 0 + } + }) + with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): + est = estimator.Estimator( + model_fn=model_fn_global_step_incrementer, + config=run_config.RunConfig( + save_summary_steps=100, log_step_count_steps=200)) + + with test.mock.patch.object(training, + 'MonitoredTrainingSession') as mock_sess: + est.train(dummy_input_fn, steps=1) + self.assertFalse( + any( + isinstance(hook, basic_session_run_hooks.SummarySaverHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertFalse( + any( + isinstance(hook, basic_session_run_hooks.StepCounterHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertEqual(100, mock_sess.call_args[1]['save_summaries_steps']) + self.assertEqual(200, mock_sess.call_args[1]['log_step_count_steps']) + + def test_master_hooks_single_replica_with_ps(self): + tf_config = json.dumps({ + 'cluster': { + run_config.TaskType.MASTER: ['localhost:1234'], + run_config.TaskType.PS: ['localhost: 1235'], + }, + 'task': { + 'type': run_config.TaskType.MASTER, + 'index': 0 + } + }) + with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): + est = estimator.Estimator( + model_fn=model_fn_global_step_incrementer, + config=run_config.RunConfig( + save_summary_steps=100, log_step_count_steps=200)) + + with test.mock.patch.object(training, + 'MonitoredTrainingSession') as mock_sess: + est.train(dummy_input_fn, steps=1) + self.assertFalse( + any( + isinstance(hook, basic_session_run_hooks.SummarySaverHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertFalse( + any( + isinstance(hook, basic_session_run_hooks.StepCounterHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertEqual(100, mock_sess.call_args[1]['save_summaries_steps']) + self.assertEqual(200, mock_sess.call_args[1]['log_step_count_steps']) + def _model_fn_with_eval_metric_ops(features, labels, mode, params): _, _ = features, labels