Introduces the placeholder for _TrainingExecutor, which serves the implementation of tf.estimator.train_and_evaluate.

PiperOrigin-RevId: 168240151
This commit is contained in:
Jianwei Xie 2017-09-11 09:25:22 -07:00 committed by TensorFlower Gardener
parent 10ba148f77
commit a4042cd2a4
3 changed files with 89 additions and 0 deletions

View File

@ -76,6 +76,7 @@ py_library(
srcs = ["training.py"],
srcs_version = "PY2AND3",
deps = [
":estimator",
"//tensorflow/python:training",
"@six_archive//:six",
],

View File

@ -23,6 +23,7 @@ import collections
import six
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.training import session_run_hook
@ -177,3 +178,51 @@ class EvalSpec(
delay_secs=delay_secs,
throttle_secs=throttle_secs)
class UnimplementedError(Exception):
pass
class _TrainingExecutor(object):
"""The executor to run `Estimator` training and evaluation.
This implementation supports both distributed and non-distributed (aka local)
training and evaluation based on the setting in `tf.estimator.RunConfig`.
"""
def __init__(self, estimator, train_spec, eval_spec):
if not isinstance(estimator, estimator_lib.Estimator):
raise TypeError('`estimator` must have type `tf.estimator.Estimator`.')
self._estimator = estimator
if not isinstance(train_spec, TrainSpec):
raise TypeError('`train_spec` must have type `tf.estimator.TrainSpec`.')
self._train_spec = train_spec
if not isinstance(eval_spec, EvalSpec):
raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`.')
self._eval_spec = eval_spec
@property
def estimator(self):
return self._estimator
def run_chief(self):
"""Runs task chief."""
raise UnimplementedError('Method run_chief has not been implemented.')
def run_worker(self):
"""Runs task (training) worker."""
raise UnimplementedError('Method run_worker has not been implemented.')
def run_evaluator(self):
"""Runs task evaluator."""
raise UnimplementedError('Method run_evaluator has not been implemented.')
def run_ps(self):
"""Runs task parameter server (in training cluster spec)."""
raise UnimplementedError('Method run_ps has not been implemented.')
def run_local(self):
"""Runs training and evaluation locally (non-distributed)."""
raise UnimplementedError('Method run_local has not been implemented.')

View File

@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import training
from tensorflow.python.platform import test
from tensorflow.python.training import session_run_hook
@ -33,6 +34,9 @@ _INVALID_STEPS_MSG = 'Must specify steps > 0'
_INVALID_NAME_MSG = '`name` must be string'
_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify delay_secs >= 0'
_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0'
_INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`'
_INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`'
_INVALID_EVAL_SPEC_MSG = '`eval_spec` must have type `tf.estimator.EvalSpec`'
class _FakeHook(session_run_hook.SessionRunHook):
@ -129,5 +133,40 @@ class EvalSpecTest(test.TestCase):
training.EvalSpec(input_fn=lambda: 1, throttle_secs=-1)
class TrainingExecutorTest(test.TestCase):
"""Tests _TrainingExecutor."""
def testRequiredArgumentsSet(self):
estimator = estimator_lib.Estimator(model_fn=lambda features: features)
train_spec = training.TrainSpec(input_fn=lambda: 1)
eval_spec = training.EvalSpec(input_fn=lambda: 1)
executor = training._TrainingExecutor(estimator, train_spec, eval_spec)
self.assertEqual(estimator, executor.estimator)
def test_invalid_estimator(self):
invalid_estimator = object()
train_spec = training.TrainSpec(input_fn=lambda: 1)
eval_spec = training.EvalSpec(input_fn=lambda: 1)
with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG):
training._TrainingExecutor(invalid_estimator, train_spec, eval_spec)
def test_invalid_train_spec(self):
estimator = estimator_lib.Estimator(model_fn=lambda features: features)
invalid_train_spec = object()
eval_spec = training.EvalSpec(input_fn=lambda: 1)
with self.assertRaisesRegexp(TypeError, _INVALID_TRAIN_SPEC_MSG):
training._TrainingExecutor(estimator, invalid_train_spec, eval_spec)
def test_invalid_eval_spec(self):
estimator = estimator_lib.Estimator(model_fn=lambda features: features)
train_spec = training.TrainSpec(input_fn=lambda: 1)
invalid_eval_spec = object()
with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
training._TrainingExecutor(estimator, train_spec, invalid_eval_spec)
if __name__ == '__main__':
test.main()