Introduces the placeholder for _TrainingExecutor, which serves the implementation of tf.estimator.train_and_evaluate.
PiperOrigin-RevId: 168240151
This commit is contained in:
parent
10ba148f77
commit
a4042cd2a4
@ -76,6 +76,7 @@ py_library(
|
|||||||
srcs = ["training.py"],
|
srcs = ["training.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":estimator",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
|
@ -23,6 +23,7 @@ import collections
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.estimator import estimator as estimator_lib
|
||||||
from tensorflow.python.training import session_run_hook
|
from tensorflow.python.training import session_run_hook
|
||||||
|
|
||||||
|
|
||||||
@ -177,3 +178,51 @@ class EvalSpec(
|
|||||||
delay_secs=delay_secs,
|
delay_secs=delay_secs,
|
||||||
throttle_secs=throttle_secs)
|
throttle_secs=throttle_secs)
|
||||||
|
|
||||||
|
|
||||||
|
class UnimplementedError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _TrainingExecutor(object):
|
||||||
|
"""The executor to run `Estimator` training and evaluation.
|
||||||
|
|
||||||
|
This implementation supports both distributed and non-distributed (aka local)
|
||||||
|
training and evaluation based on the setting in `tf.estimator.RunConfig`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, estimator, train_spec, eval_spec):
|
||||||
|
if not isinstance(estimator, estimator_lib.Estimator):
|
||||||
|
raise TypeError('`estimator` must have type `tf.estimator.Estimator`.')
|
||||||
|
self._estimator = estimator
|
||||||
|
|
||||||
|
if not isinstance(train_spec, TrainSpec):
|
||||||
|
raise TypeError('`train_spec` must have type `tf.estimator.TrainSpec`.')
|
||||||
|
self._train_spec = train_spec
|
||||||
|
|
||||||
|
if not isinstance(eval_spec, EvalSpec):
|
||||||
|
raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`.')
|
||||||
|
self._eval_spec = eval_spec
|
||||||
|
|
||||||
|
@property
|
||||||
|
def estimator(self):
|
||||||
|
return self._estimator
|
||||||
|
|
||||||
|
def run_chief(self):
|
||||||
|
"""Runs task chief."""
|
||||||
|
raise UnimplementedError('Method run_chief has not been implemented.')
|
||||||
|
|
||||||
|
def run_worker(self):
|
||||||
|
"""Runs task (training) worker."""
|
||||||
|
raise UnimplementedError('Method run_worker has not been implemented.')
|
||||||
|
|
||||||
|
def run_evaluator(self):
|
||||||
|
"""Runs task evaluator."""
|
||||||
|
raise UnimplementedError('Method run_evaluator has not been implemented.')
|
||||||
|
|
||||||
|
def run_ps(self):
|
||||||
|
"""Runs task parameter server (in training cluster spec)."""
|
||||||
|
raise UnimplementedError('Method run_ps has not been implemented.')
|
||||||
|
|
||||||
|
def run_local(self):
|
||||||
|
"""Runs training and evaluation locally (non-distributed)."""
|
||||||
|
raise UnimplementedError('Method run_local has not been implemented.')
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.estimator import estimator as estimator_lib
|
||||||
from tensorflow.python.estimator import training
|
from tensorflow.python.estimator import training
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training import session_run_hook
|
from tensorflow.python.training import session_run_hook
|
||||||
@ -33,6 +34,9 @@ _INVALID_STEPS_MSG = 'Must specify steps > 0'
|
|||||||
_INVALID_NAME_MSG = '`name` must be string'
|
_INVALID_NAME_MSG = '`name` must be string'
|
||||||
_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify delay_secs >= 0'
|
_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify delay_secs >= 0'
|
||||||
_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0'
|
_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0'
|
||||||
|
_INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`'
|
||||||
|
_INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`'
|
||||||
|
_INVALID_EVAL_SPEC_MSG = '`eval_spec` must have type `tf.estimator.EvalSpec`'
|
||||||
|
|
||||||
|
|
||||||
class _FakeHook(session_run_hook.SessionRunHook):
|
class _FakeHook(session_run_hook.SessionRunHook):
|
||||||
@ -129,5 +133,40 @@ class EvalSpecTest(test.TestCase):
|
|||||||
training.EvalSpec(input_fn=lambda: 1, throttle_secs=-1)
|
training.EvalSpec(input_fn=lambda: 1, throttle_secs=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingExecutorTest(test.TestCase):
|
||||||
|
"""Tests _TrainingExecutor."""
|
||||||
|
|
||||||
|
def testRequiredArgumentsSet(self):
|
||||||
|
estimator = estimator_lib.Estimator(model_fn=lambda features: features)
|
||||||
|
train_spec = training.TrainSpec(input_fn=lambda: 1)
|
||||||
|
eval_spec = training.EvalSpec(input_fn=lambda: 1)
|
||||||
|
|
||||||
|
executor = training._TrainingExecutor(estimator, train_spec, eval_spec)
|
||||||
|
self.assertEqual(estimator, executor.estimator)
|
||||||
|
|
||||||
|
def test_invalid_estimator(self):
|
||||||
|
invalid_estimator = object()
|
||||||
|
train_spec = training.TrainSpec(input_fn=lambda: 1)
|
||||||
|
eval_spec = training.EvalSpec(input_fn=lambda: 1)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG):
|
||||||
|
training._TrainingExecutor(invalid_estimator, train_spec, eval_spec)
|
||||||
|
|
||||||
|
def test_invalid_train_spec(self):
|
||||||
|
estimator = estimator_lib.Estimator(model_fn=lambda features: features)
|
||||||
|
invalid_train_spec = object()
|
||||||
|
eval_spec = training.EvalSpec(input_fn=lambda: 1)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError, _INVALID_TRAIN_SPEC_MSG):
|
||||||
|
training._TrainingExecutor(estimator, invalid_train_spec, eval_spec)
|
||||||
|
|
||||||
|
def test_invalid_eval_spec(self):
|
||||||
|
estimator = estimator_lib.Estimator(model_fn=lambda features: features)
|
||||||
|
train_spec = training.TrainSpec(input_fn=lambda: 1)
|
||||||
|
invalid_eval_spec = object()
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
|
||||||
|
training._TrainingExecutor(estimator, train_spec, invalid_eval_spec)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user