Introduces the placeholder for _TrainingExecutor, which serves the implementation of tf.estimator.train_and_evaluate.
PiperOrigin-RevId: 168240151
This commit is contained in:
parent
10ba148f77
commit
a4042cd2a4
@ -76,6 +76,7 @@ py_library(
|
||||
srcs = ["training.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":estimator",
|
||||
"//tensorflow/python:training",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
@ -23,6 +23,7 @@ import collections
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.estimator import estimator as estimator_lib
|
||||
from tensorflow.python.training import session_run_hook
|
||||
|
||||
|
||||
@ -177,3 +178,51 @@ class EvalSpec(
|
||||
delay_secs=delay_secs,
|
||||
throttle_secs=throttle_secs)
|
||||
|
||||
|
||||
class UnimplementedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class _TrainingExecutor(object):
|
||||
"""The executor to run `Estimator` training and evaluation.
|
||||
|
||||
This implementation supports both distributed and non-distributed (aka local)
|
||||
training and evaluation based on the setting in `tf.estimator.RunConfig`.
|
||||
"""
|
||||
|
||||
def __init__(self, estimator, train_spec, eval_spec):
|
||||
if not isinstance(estimator, estimator_lib.Estimator):
|
||||
raise TypeError('`estimator` must have type `tf.estimator.Estimator`.')
|
||||
self._estimator = estimator
|
||||
|
||||
if not isinstance(train_spec, TrainSpec):
|
||||
raise TypeError('`train_spec` must have type `tf.estimator.TrainSpec`.')
|
||||
self._train_spec = train_spec
|
||||
|
||||
if not isinstance(eval_spec, EvalSpec):
|
||||
raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`.')
|
||||
self._eval_spec = eval_spec
|
||||
|
||||
@property
|
||||
def estimator(self):
|
||||
return self._estimator
|
||||
|
||||
def run_chief(self):
|
||||
"""Runs task chief."""
|
||||
raise UnimplementedError('Method run_chief has not been implemented.')
|
||||
|
||||
def run_worker(self):
|
||||
"""Runs task (training) worker."""
|
||||
raise UnimplementedError('Method run_worker has not been implemented.')
|
||||
|
||||
def run_evaluator(self):
|
||||
"""Runs task evaluator."""
|
||||
raise UnimplementedError('Method run_evaluator has not been implemented.')
|
||||
|
||||
def run_ps(self):
|
||||
"""Runs task parameter server (in training cluster spec)."""
|
||||
raise UnimplementedError('Method run_ps has not been implemented.')
|
||||
|
||||
def run_local(self):
|
||||
"""Runs training and evaluation locally (non-distributed)."""
|
||||
raise UnimplementedError('Method run_local has not been implemented.')
|
||||
|
@ -19,6 +19,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.estimator import estimator as estimator_lib
|
||||
from tensorflow.python.estimator import training
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import session_run_hook
|
||||
@ -33,6 +34,9 @@ _INVALID_STEPS_MSG = 'Must specify steps > 0'
|
||||
_INVALID_NAME_MSG = '`name` must be string'
|
||||
_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify delay_secs >= 0'
|
||||
_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0'
|
||||
_INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`'
|
||||
_INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`'
|
||||
_INVALID_EVAL_SPEC_MSG = '`eval_spec` must have type `tf.estimator.EvalSpec`'
|
||||
|
||||
|
||||
class _FakeHook(session_run_hook.SessionRunHook):
|
||||
@ -129,5 +133,40 @@ class EvalSpecTest(test.TestCase):
|
||||
training.EvalSpec(input_fn=lambda: 1, throttle_secs=-1)
|
||||
|
||||
|
||||
class TrainingExecutorTest(test.TestCase):
|
||||
"""Tests _TrainingExecutor."""
|
||||
|
||||
def testRequiredArgumentsSet(self):
|
||||
estimator = estimator_lib.Estimator(model_fn=lambda features: features)
|
||||
train_spec = training.TrainSpec(input_fn=lambda: 1)
|
||||
eval_spec = training.EvalSpec(input_fn=lambda: 1)
|
||||
|
||||
executor = training._TrainingExecutor(estimator, train_spec, eval_spec)
|
||||
self.assertEqual(estimator, executor.estimator)
|
||||
|
||||
def test_invalid_estimator(self):
|
||||
invalid_estimator = object()
|
||||
train_spec = training.TrainSpec(input_fn=lambda: 1)
|
||||
eval_spec = training.EvalSpec(input_fn=lambda: 1)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG):
|
||||
training._TrainingExecutor(invalid_estimator, train_spec, eval_spec)
|
||||
|
||||
def test_invalid_train_spec(self):
|
||||
estimator = estimator_lib.Estimator(model_fn=lambda features: features)
|
||||
invalid_train_spec = object()
|
||||
eval_spec = training.EvalSpec(input_fn=lambda: 1)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, _INVALID_TRAIN_SPEC_MSG):
|
||||
training._TrainingExecutor(estimator, invalid_train_spec, eval_spec)
|
||||
|
||||
def test_invalid_eval_spec(self):
|
||||
estimator = estimator_lib.Estimator(model_fn=lambda features: features)
|
||||
train_spec = training.TrainSpec(input_fn=lambda: 1)
|
||||
invalid_eval_spec = object()
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
|
||||
training._TrainingExecutor(estimator, train_spec, invalid_eval_spec)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user