Check for the presence of a Worker machine when reassigning hooks in distributed
training jobs. PiperOrigin-RevId: 217407558
This commit is contained in:
parent
a3f855aca2
commit
cc0cf49a0d
@ -1423,7 +1423,13 @@ class Estimator(object):
|
|||||||
# evaluations.
|
# evaluations.
|
||||||
save_summary_steps = self._config.save_summary_steps
|
save_summary_steps = self._config.save_summary_steps
|
||||||
log_step_count_steps = self._config.log_step_count_steps
|
log_step_count_steps = self._config.log_step_count_steps
|
||||||
|
|
||||||
|
# Check existence of appropriate cluster spec fields, as well as master and
|
||||||
|
# worker nodes. As master also performs evaluation, summary writing must
|
||||||
|
# occur on a different node. The presence of a worker is also checked to
|
||||||
|
# prevent reassigning hooks for single-replica jobs with just a master node.
|
||||||
if (self._config.cluster_spec and self._config.cluster_spec.jobs and
|
if (self._config.cluster_spec and self._config.cluster_spec.jobs and
|
||||||
|
(run_config.TaskType.WORKER in self._config.cluster_spec.jobs) and
|
||||||
(run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
|
(run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
|
||||||
# Update config values to prevent the default hooks from being created on
|
# Update config values to prevent the default hooks from being created on
|
||||||
# the master or other workers.
|
# the master or other workers.
|
||||||
|
@ -1063,6 +1063,67 @@ class EstimatorTrainTest(test.TestCase):
|
|||||||
self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
|
self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
|
||||||
self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
|
self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
|
||||||
|
|
||||||
|
def test_master_hooks_single_replica(self):
|
||||||
|
tf_config = json.dumps({
|
||||||
|
'cluster': {
|
||||||
|
run_config.TaskType.MASTER: ['localhost:1234']
|
||||||
|
},
|
||||||
|
'task': {
|
||||||
|
'type': run_config.TaskType.MASTER,
|
||||||
|
'index': 0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
|
||||||
|
est = estimator.Estimator(
|
||||||
|
model_fn=model_fn_global_step_incrementer,
|
||||||
|
config=run_config.RunConfig(
|
||||||
|
save_summary_steps=100, log_step_count_steps=200))
|
||||||
|
|
||||||
|
with test.mock.patch.object(training,
|
||||||
|
'MonitoredTrainingSession') as mock_sess:
|
||||||
|
est.train(dummy_input_fn, steps=1)
|
||||||
|
self.assertFalse(
|
||||||
|
any(
|
||||||
|
isinstance(hook, basic_session_run_hooks.SummarySaverHook)
|
||||||
|
for hook in mock_sess.call_args[1]['hooks']))
|
||||||
|
self.assertFalse(
|
||||||
|
any(
|
||||||
|
isinstance(hook, basic_session_run_hooks.StepCounterHook)
|
||||||
|
for hook in mock_sess.call_args[1]['hooks']))
|
||||||
|
self.assertEqual(100, mock_sess.call_args[1]['save_summaries_steps'])
|
||||||
|
self.assertEqual(200, mock_sess.call_args[1]['log_step_count_steps'])
|
||||||
|
|
||||||
|
def test_master_hooks_single_replica_with_ps(self):
|
||||||
|
tf_config = json.dumps({
|
||||||
|
'cluster': {
|
||||||
|
run_config.TaskType.MASTER: ['localhost:1234'],
|
||||||
|
run_config.TaskType.PS: ['localhost: 1235'],
|
||||||
|
},
|
||||||
|
'task': {
|
||||||
|
'type': run_config.TaskType.MASTER,
|
||||||
|
'index': 0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
|
||||||
|
est = estimator.Estimator(
|
||||||
|
model_fn=model_fn_global_step_incrementer,
|
||||||
|
config=run_config.RunConfig(
|
||||||
|
save_summary_steps=100, log_step_count_steps=200))
|
||||||
|
|
||||||
|
with test.mock.patch.object(training,
|
||||||
|
'MonitoredTrainingSession') as mock_sess:
|
||||||
|
est.train(dummy_input_fn, steps=1)
|
||||||
|
self.assertFalse(
|
||||||
|
any(
|
||||||
|
isinstance(hook, basic_session_run_hooks.SummarySaverHook)
|
||||||
|
for hook in mock_sess.call_args[1]['hooks']))
|
||||||
|
self.assertFalse(
|
||||||
|
any(
|
||||||
|
isinstance(hook, basic_session_run_hooks.StepCounterHook)
|
||||||
|
for hook in mock_sess.call_args[1]['hooks']))
|
||||||
|
self.assertEqual(100, mock_sess.call_args[1]['save_summaries_steps'])
|
||||||
|
self.assertEqual(200, mock_sess.call_args[1]['log_step_count_steps'])
|
||||||
|
|
||||||
|
|
||||||
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
|
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
|
||||||
_, _ = features, labels
|
_, _ = features, labels
|
||||||
|
Loading…
Reference in New Issue
Block a user