Check for the presence of a Worker machine when reassigning hooks in distributed

training jobs.

PiperOrigin-RevId: 217407558
This commit is contained in:
A. Unique TensorFlower 2018-10-16 16:20:58 -07:00 committed by TensorFlower Gardener
parent a3f855aca2
commit cc0cf49a0d
2 changed files with 67 additions and 0 deletions

View File

@ -1423,7 +1423,13 @@ class Estimator(object):
# evaluations.
save_summary_steps = self._config.save_summary_steps
log_step_count_steps = self._config.log_step_count_steps
# Check existence of appropriate cluster spec fields, as well as master and
# worker nodes. As master also performs evaluation, summary writing must
# occur on a different node. The presence of a worker is also checked to
# prevent reassigning hooks for single-replica jobs with just a master node.
if (self._config.cluster_spec and self._config.cluster_spec.jobs and
(run_config.TaskType.WORKER in self._config.cluster_spec.jobs) and
(run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
# Update config values to prevent the default hooks from being created on
# the master or other workers.

View File

@ -1063,6 +1063,67 @@ class EstimatorTrainTest(test.TestCase):
self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
def test_master_hooks_single_replica(self):
tf_config = json.dumps({
'cluster': {
run_config.TaskType.MASTER: ['localhost:1234']
},
'task': {
'type': run_config.TaskType.MASTER,
'index': 0
}
})
with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
est = estimator.Estimator(
model_fn=model_fn_global_step_incrementer,
config=run_config.RunConfig(
save_summary_steps=100, log_step_count_steps=200))
with test.mock.patch.object(training,
'MonitoredTrainingSession') as mock_sess:
est.train(dummy_input_fn, steps=1)
self.assertFalse(
any(
isinstance(hook, basic_session_run_hooks.SummarySaverHook)
for hook in mock_sess.call_args[1]['hooks']))
self.assertFalse(
any(
isinstance(hook, basic_session_run_hooks.StepCounterHook)
for hook in mock_sess.call_args[1]['hooks']))
self.assertEqual(100, mock_sess.call_args[1]['save_summaries_steps'])
self.assertEqual(200, mock_sess.call_args[1]['log_step_count_steps'])
def test_master_hooks_single_replica_with_ps(self):
tf_config = json.dumps({
'cluster': {
run_config.TaskType.MASTER: ['localhost:1234'],
run_config.TaskType.PS: ['localhost: 1235'],
},
'task': {
'type': run_config.TaskType.MASTER,
'index': 0
}
})
with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
est = estimator.Estimator(
model_fn=model_fn_global_step_incrementer,
config=run_config.RunConfig(
save_summary_steps=100, log_step_count_steps=200))
with test.mock.patch.object(training,
'MonitoredTrainingSession') as mock_sess:
est.train(dummy_input_fn, steps=1)
self.assertFalse(
any(
isinstance(hook, basic_session_run_hooks.SummarySaverHook)
for hook in mock_sess.call_args[1]['hooks']))
self.assertFalse(
any(
isinstance(hook, basic_session_run_hooks.StepCounterHook)
for hook in mock_sess.call_args[1]['hooks']))
self.assertEqual(100, mock_sess.call_args[1]['save_summaries_steps'])
self.assertEqual(200, mock_sess.call_args[1]['log_step_count_steps'])
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
_, _ = features, labels