Local run option of estimator training.

PiperOrigin-RevId: 169756384
This commit is contained in:
Mustafa Ispir 2017-09-22 16:55:58 -07:00 committed by TensorFlower Gardener
parent 1dc2fe7acd
commit 2957cd8948
2 changed files with 147 additions and 1 deletions

View File

@ -203,6 +203,22 @@ class EvalSpec(
throttle_secs=throttle_secs)
class _StopAtSecsHook(session_run_hook.SessionRunHook):
"""Stops given secs after begin is called."""
def __init__(self, stop_after_secs):
self._stop_after_secs = stop_after_secs
self._start_time = None
def begin(self):
self._start_time = time.time()
def after_run(self, run_context, run_values):
del run_values
if time.time() - self._start_time >= self._stop_after_secs:
run_context.request_stop()
class UnimplementedError(Exception):
pass
@ -254,7 +270,38 @@ class _TrainingExecutor(object):
def run_local(self):
"""Runs training and evaluation locally (non-distributed)."""
raise UnimplementedError('Method run_local has not been implemented.')
def _should_stop_local_train(global_step):
if self._train_spec.max_steps is None:
return False
if global_step >= self._train_spec.max_steps:
return True
return False
if self._eval_spec.throttle_secs <= 0:
raise ValueError('eval_spec.throttle_secs should be positive, given: {}.'
'It is used do determine how long each training '
'iteration should go when train and evaluate '
'locally.'.format(
self._eval_spec.throttle_secs))
stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs)
train_hooks = list(self._train_spec.hooks) + [stop_hook]
logging.info('Start train and evaluate loop. The evaluate will happen '
'after {} secs (eval_spec.throttle_secs) or training is '
'finished.'.format(self._eval_spec.throttle_secs))
while True:
self._estimator.train(
input_fn=self._train_spec.input_fn,
max_steps=self._train_spec.max_steps,
hooks=train_hooks)
metrics = self._estimator.evaluate(
input_fn=self._eval_spec.input_fn,
steps=self._eval_spec.steps,
hooks=self._eval_spec.hooks,
name=self._eval_spec.name)
if _should_stop_local_train(metrics[ops.GraphKeys.GLOBAL_STEP]):
break
def _start_std_server(self, config):
"""Creates, starts, and returns a server_lib.Server."""

View File

@ -27,8 +27,10 @@ from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.estimator import training
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_run_hook
@ -614,5 +616,102 @@ class TrainingExecutorRunPsTest(test.TestCase):
mock_eval_spec).run_ps()
class StopAtSecsHookTest(test.TestCase):
"""Tests StopAtSecsHook."""
@test.mock.patch.object(time, 'time')
def test_stops_after_time(self, mock_time):
mock_time.return_value = 1484695987.209386
hook = training._StopAtSecsHook(1000)
with ops.Graph().as_default():
no_op = control_flow_ops.no_op()
# some time passed before training starts
mock_time.return_value += 250
with monitored_session.MonitoredSession(hooks=[hook]) as sess:
self.assertFalse(sess.should_stop())
sess.run(no_op)
self.assertFalse(sess.should_stop())
mock_time.return_value += 500
sess.run(no_op)
self.assertFalse(sess.should_stop())
mock_time.return_value += 400
sess.run(no_op)
self.assertFalse(sess.should_stop())
mock_time.return_value += 200
sess.run(no_op)
self.assertTrue(sess.should_stop())
class TrainingExecutorRunLocalTest(test.TestCase):
"""Tests run_local of _TrainingExecutor."""
def test_send_stop_at_secs_to_train(self):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
train_spec = training.TrainSpec(
input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
eval_spec = training.EvalSpec(
input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100)
mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
executor.run_local()
stop_hook = mock_est.train.call_args[1]['hooks'][-1]
self.assertIsInstance(stop_hook, training._StopAtSecsHook)
self.assertEqual(eval_spec.throttle_secs, stop_hook._stop_after_secs)
def test_runs_in_a_loop_until_max_steps(self):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
train_spec = training.TrainSpec(
input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
eval_spec = training.EvalSpec(
input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100)
# should be called 3 times.
mock_est.evaluate.side_effect = [{
_GLOBAL_STEP_KEY: train_spec.max_steps - 100
}, {
_GLOBAL_STEP_KEY: train_spec.max_steps - 50
}, {
_GLOBAL_STEP_KEY: train_spec.max_steps
}]
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
executor.run_local()
self.assertEqual(3, mock_est.train.call_count)
self.assertEqual(3, mock_est.evaluate.call_count)
def test_train_and_evaluate_args(self):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
train_spec = training.TrainSpec(
input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
eval_spec = training.EvalSpec(
input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval')
mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
executor.run_local()
mock_est.evaluate.assert_called_with(
name=eval_spec.name,
input_fn=eval_spec.input_fn,
steps=eval_spec.steps,
hooks=eval_spec.hooks)
train_args = mock_est.train.call_args[1]
self.assertEqual(list(train_spec.hooks), list(train_args['hooks'][:-1]))
self.assertEqual(train_spec.input_fn, train_args['input_fn'])
self.assertEqual(train_spec.max_steps, train_args['max_steps'])
def test_errors_out_if_throttle_secs_is_zero(self):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
train_spec = training.TrainSpec(input_fn=lambda: 1)
eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=0)
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
with self.assertRaisesRegexp(ValueError, 'throttle_secs'):
executor.run_local()
if __name__ == '__main__':
test.main()