diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index d060faa4efd..eaee48bae72 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -76,6 +76,7 @@ py_library( srcs = ["training.py"], srcs_version = "PY2AND3", deps = [ + ":estimator", "//tensorflow/python:training", "@six_archive//:six", ], diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 9a8a0db66ee..0681ebff564 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -23,6 +23,7 @@ import collections import six +from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.training import session_run_hook @@ -177,3 +178,51 @@ class EvalSpec( delay_secs=delay_secs, throttle_secs=throttle_secs) + +class UnimplementedError(Exception): + pass + + +class _TrainingExecutor(object): + """The executor to run `Estimator` training and evaluation. + + This implementation supports both distributed and non-distributed (aka local) + training and evaluation based on the setting in `tf.estimator.RunConfig`. + """ + + def __init__(self, estimator, train_spec, eval_spec): + if not isinstance(estimator, estimator_lib.Estimator): + raise TypeError('`estimator` must have type `tf.estimator.Estimator`.') + self._estimator = estimator + + if not isinstance(train_spec, TrainSpec): + raise TypeError('`train_spec` must have type `tf.estimator.TrainSpec`.') + self._train_spec = train_spec + + if not isinstance(eval_spec, EvalSpec): + raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`.') + self._eval_spec = eval_spec + + @property + def estimator(self): + return self._estimator + + def run_chief(self): + """Runs task chief.""" + raise UnimplementedError('Method run_chief has not been implemented.') + + def run_worker(self): + """Runs task (training) worker.""" + raise UnimplementedError('Method run_worker has not been implemented.') + + def run_evaluator(self): + """Runs task evaluator.""" + raise UnimplementedError('Method run_evaluator has not been implemented.') + + def run_ps(self): + """Runs task parameter server (in training cluster spec).""" + raise UnimplementedError('Method run_ps has not been implemented.') + + def run_local(self): + """Runs training and evaluation locally (non-distributed).""" + raise UnimplementedError('Method run_local has not been implemented.') diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py index 654a1659b29..4e67d457719 100644 --- a/tensorflow/python/estimator/training_test.py +++ b/tensorflow/python/estimator/training_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import training from tensorflow.python.platform import test from tensorflow.python.training import session_run_hook @@ -33,6 +34,9 @@ _INVALID_STEPS_MSG = 'Must specify steps > 0' _INVALID_NAME_MSG = '`name` must be string' _INVALID_EVAL_DELAY_SECS_MSG = 'Must specify delay_secs >= 0' _INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0' +_INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`' +_INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`' +_INVALID_EVAL_SPEC_MSG = '`eval_spec` must have type `tf.estimator.EvalSpec`' class _FakeHook(session_run_hook.SessionRunHook): @@ -129,5 +133,40 @@ class EvalSpecTest(test.TestCase): training.EvalSpec(input_fn=lambda: 1, throttle_secs=-1) +class TrainingExecutorTest(test.TestCase): + """Tests _TrainingExecutor.""" + + def testRequiredArgumentsSet(self): + estimator = estimator_lib.Estimator(model_fn=lambda features: features) + train_spec = training.TrainSpec(input_fn=lambda: 1) + eval_spec = training.EvalSpec(input_fn=lambda: 1) + + executor = training._TrainingExecutor(estimator, train_spec, eval_spec) + self.assertEqual(estimator, executor.estimator) + + def test_invalid_estimator(self): + invalid_estimator = object() + train_spec = training.TrainSpec(input_fn=lambda: 1) + eval_spec = training.EvalSpec(input_fn=lambda: 1) + + with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG): + training._TrainingExecutor(invalid_estimator, train_spec, eval_spec) + + def test_invalid_train_spec(self): + estimator = estimator_lib.Estimator(model_fn=lambda features: features) + invalid_train_spec = object() + eval_spec = training.EvalSpec(input_fn=lambda: 1) + + with self.assertRaisesRegexp(TypeError, _INVALID_TRAIN_SPEC_MSG): + training._TrainingExecutor(estimator, invalid_train_spec, eval_spec) + + def test_invalid_eval_spec(self): + estimator = estimator_lib.Estimator(model_fn=lambda features: features) + train_spec = training.TrainSpec(input_fn=lambda: 1) + invalid_eval_spec = object() + + with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG): + training._TrainingExecutor(estimator, train_spec, invalid_eval_spec) + if __name__ == '__main__': test.main()