Check for the presence of a Worker machine when reassigning hooks in distributed
training jobs. PiperOrigin-RevId: 217407558
This commit is contained in:
		
							parent
							
								
									a3f855aca2
								
							
						
					
					
						commit
						cc0cf49a0d
					
				| @ -1423,7 +1423,13 @@ class Estimator(object): | ||||
|     # evaluations. | ||||
|     save_summary_steps = self._config.save_summary_steps | ||||
|     log_step_count_steps = self._config.log_step_count_steps | ||||
| 
 | ||||
|     # Check existence of appropriate cluster spec fields, as well as master and | ||||
|     # worker nodes. As master also performs evaluation, summary writing must | ||||
|     # occur on a different node. The presence of a worker is also checked to | ||||
|     # prevent reassigning hooks for single-replica jobs with just a master node. | ||||
|     if (self._config.cluster_spec and self._config.cluster_spec.jobs and | ||||
|         (run_config.TaskType.WORKER in self._config.cluster_spec.jobs) and | ||||
|         (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)): | ||||
|       # Update config values to prevent the default hooks from being created on | ||||
|       # the master or other workers. | ||||
|  | ||||
| @ -1063,6 +1063,67 @@ class EstimatorTrainTest(test.TestCase): | ||||
|       self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps']) | ||||
|       self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps']) | ||||
| 
 | ||||
|   def test_master_hooks_single_replica(self): | ||||
|     tf_config = json.dumps({ | ||||
|         'cluster': { | ||||
|             run_config.TaskType.MASTER: ['localhost:1234'] | ||||
|         }, | ||||
|         'task': { | ||||
|             'type': run_config.TaskType.MASTER, | ||||
|             'index': 0 | ||||
|         } | ||||
|     }) | ||||
|     with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): | ||||
|       est = estimator.Estimator( | ||||
|           model_fn=model_fn_global_step_incrementer, | ||||
|           config=run_config.RunConfig( | ||||
|               save_summary_steps=100, log_step_count_steps=200)) | ||||
| 
 | ||||
|     with test.mock.patch.object(training, | ||||
|                                 'MonitoredTrainingSession') as mock_sess: | ||||
|       est.train(dummy_input_fn, steps=1) | ||||
|       self.assertFalse( | ||||
|           any( | ||||
|               isinstance(hook, basic_session_run_hooks.SummarySaverHook) | ||||
|               for hook in mock_sess.call_args[1]['hooks'])) | ||||
|       self.assertFalse( | ||||
|           any( | ||||
|               isinstance(hook, basic_session_run_hooks.StepCounterHook) | ||||
|               for hook in mock_sess.call_args[1]['hooks'])) | ||||
|       self.assertEqual(100, mock_sess.call_args[1]['save_summaries_steps']) | ||||
|       self.assertEqual(200, mock_sess.call_args[1]['log_step_count_steps']) | ||||
| 
 | ||||
|   def test_master_hooks_single_replica_with_ps(self): | ||||
|     tf_config = json.dumps({ | ||||
|         'cluster': { | ||||
|             run_config.TaskType.MASTER: ['localhost:1234'], | ||||
|             run_config.TaskType.PS: ['localhost: 1235'], | ||||
|         }, | ||||
|         'task': { | ||||
|             'type': run_config.TaskType.MASTER, | ||||
|             'index': 0 | ||||
|         } | ||||
|     }) | ||||
|     with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): | ||||
|       est = estimator.Estimator( | ||||
|           model_fn=model_fn_global_step_incrementer, | ||||
|           config=run_config.RunConfig( | ||||
|               save_summary_steps=100, log_step_count_steps=200)) | ||||
| 
 | ||||
|     with test.mock.patch.object(training, | ||||
|                                 'MonitoredTrainingSession') as mock_sess: | ||||
|       est.train(dummy_input_fn, steps=1) | ||||
|       self.assertFalse( | ||||
|           any( | ||||
|               isinstance(hook, basic_session_run_hooks.SummarySaverHook) | ||||
|               for hook in mock_sess.call_args[1]['hooks'])) | ||||
|       self.assertFalse( | ||||
|           any( | ||||
|               isinstance(hook, basic_session_run_hooks.StepCounterHook) | ||||
|               for hook in mock_sess.call_args[1]['hooks'])) | ||||
|       self.assertEqual(100, mock_sess.call_args[1]['save_summaries_steps']) | ||||
|       self.assertEqual(200, mock_sess.call_args[1]['log_step_count_steps']) | ||||
| 
 | ||||
| 
 | ||||
| def _model_fn_with_eval_metric_ops(features, labels, mode, params): | ||||
|   _, _ = features, labels | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user