Check for the presence of a Worker machine when reassigning hooks in distributed
training jobs. PiperOrigin-RevId: 217407558
This commit is contained in:
parent
a3f855aca2
commit
cc0cf49a0d
@ -1423,7 +1423,13 @@ class Estimator(object):
|
||||
# evaluations.
|
||||
save_summary_steps = self._config.save_summary_steps
|
||||
log_step_count_steps = self._config.log_step_count_steps
|
||||
|
||||
# Check existence of appropriate cluster spec fields, as well as master and
|
||||
# worker nodes. As master also performs evaluation, summary writing must
|
||||
# occur on a different node. The presence of a worker is also checked to
|
||||
# prevent reassigning hooks for single-replica jobs with just a master node.
|
||||
if (self._config.cluster_spec and self._config.cluster_spec.jobs and
|
||||
(run_config.TaskType.WORKER in self._config.cluster_spec.jobs) and
|
||||
(run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
|
||||
# Update config values to prevent the default hooks from being created on
|
||||
# the master or other workers.
|
||||
|
@ -1063,6 +1063,67 @@ class EstimatorTrainTest(test.TestCase):
|
||||
self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
|
||||
self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
|
||||
|
||||
def test_master_hooks_single_replica(self):
|
||||
tf_config = json.dumps({
|
||||
'cluster': {
|
||||
run_config.TaskType.MASTER: ['localhost:1234']
|
||||
},
|
||||
'task': {
|
||||
'type': run_config.TaskType.MASTER,
|
||||
'index': 0
|
||||
}
|
||||
})
|
||||
with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
|
||||
est = estimator.Estimator(
|
||||
model_fn=model_fn_global_step_incrementer,
|
||||
config=run_config.RunConfig(
|
||||
save_summary_steps=100, log_step_count_steps=200))
|
||||
|
||||
with test.mock.patch.object(training,
|
||||
'MonitoredTrainingSession') as mock_sess:
|
||||
est.train(dummy_input_fn, steps=1)
|
||||
self.assertFalse(
|
||||
any(
|
||||
isinstance(hook, basic_session_run_hooks.SummarySaverHook)
|
||||
for hook in mock_sess.call_args[1]['hooks']))
|
||||
self.assertFalse(
|
||||
any(
|
||||
isinstance(hook, basic_session_run_hooks.StepCounterHook)
|
||||
for hook in mock_sess.call_args[1]['hooks']))
|
||||
self.assertEqual(100, mock_sess.call_args[1]['save_summaries_steps'])
|
||||
self.assertEqual(200, mock_sess.call_args[1]['log_step_count_steps'])
|
||||
|
||||
def test_master_hooks_single_replica_with_ps(self):
|
||||
tf_config = json.dumps({
|
||||
'cluster': {
|
||||
run_config.TaskType.MASTER: ['localhost:1234'],
|
||||
run_config.TaskType.PS: ['localhost: 1235'],
|
||||
},
|
||||
'task': {
|
||||
'type': run_config.TaskType.MASTER,
|
||||
'index': 0
|
||||
}
|
||||
})
|
||||
with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
|
||||
est = estimator.Estimator(
|
||||
model_fn=model_fn_global_step_incrementer,
|
||||
config=run_config.RunConfig(
|
||||
save_summary_steps=100, log_step_count_steps=200))
|
||||
|
||||
with test.mock.patch.object(training,
|
||||
'MonitoredTrainingSession') as mock_sess:
|
||||
est.train(dummy_input_fn, steps=1)
|
||||
self.assertFalse(
|
||||
any(
|
||||
isinstance(hook, basic_session_run_hooks.SummarySaverHook)
|
||||
for hook in mock_sess.call_args[1]['hooks']))
|
||||
self.assertFalse(
|
||||
any(
|
||||
isinstance(hook, basic_session_run_hooks.StepCounterHook)
|
||||
for hook in mock_sess.call_args[1]['hooks']))
|
||||
self.assertEqual(100, mock_sess.call_args[1]['save_summaries_steps'])
|
||||
self.assertEqual(200, mock_sess.call_args[1]['log_step_count_steps'])
|
||||
|
||||
|
||||
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
|
||||
_, _ = features, labels
|
||||
|
Loading…
Reference in New Issue
Block a user