Introduces TrainSpec and EvalSpec.
PiperOrigin-RevId: 168040435
This commit is contained in:
parent
c8b9e92f07
commit
86f1713e51
@ -22,6 +22,7 @@ py_library(
|
||||
":model_fn",
|
||||
":parsing_utils",
|
||||
":run_config",
|
||||
":training",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
@ -70,6 +71,27 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "training",
|
||||
srcs = ["training.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:training",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "training_test",
|
||||
size = "small",
|
||||
srcs = ["training_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":training",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "run_config",
|
||||
srcs = ["run_config.py"],
|
||||
|
179
tensorflow/python/estimator/training.py
Normal file
179
tensorflow/python/estimator/training.py
Normal file
@ -0,0 +1,179 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Classes and functions related to train_and_evaluate."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.training import session_run_hook
|
||||
|
||||
|
||||
def _validate_input_fn(input_fn):
|
||||
"""Validates the `input_fn`."""
|
||||
if not callable(input_fn):
|
||||
raise TypeError(
|
||||
'`input_fn` must be callable, given: {}'.format(input_fn))
|
||||
|
||||
|
||||
def _validate_hooks(hooks):
|
||||
"""Validates the `hooks`."""
|
||||
hooks = tuple(hooks or [])
|
||||
for hook in hooks:
|
||||
if not isinstance(hook, session_run_hook.SessionRunHook):
|
||||
raise TypeError(
|
||||
'All hooks must be `SessionRunHook` instances, given: {}'.format(
|
||||
hook))
|
||||
return hooks
|
||||
|
||||
|
||||
class TrainSpec(
|
||||
collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])):
|
||||
"""Objects passed to `train_and_evaluate`.
|
||||
|
||||
`TrainSpec` fully defines the objects to be run by `Estimator.train`.
|
||||
"""
|
||||
|
||||
def __new__(cls,
|
||||
input_fn,
|
||||
max_steps=None,
|
||||
hooks=None):
|
||||
"""Creates a validated `TrainSpec` instance.
|
||||
|
||||
Args:
|
||||
input_fn: Training input function returning a tuple of:
|
||||
features - `Tensor` or dictionary of string feature name to `Tensor`.
|
||||
labels - `Tensor` or dictionary of `Tensor` with labels.
|
||||
max_steps: Int. Number of total steps for which to train model. If `None`,
|
||||
train forever or train until `input_fn` generates the `OutOfRange` error
|
||||
or `StopIteration` exception. See `Estimator.train` for details.
|
||||
hooks: Iterable of `tf.train.SessionRunHook` objects to run
|
||||
on all workers (including chief) during training.
|
||||
|
||||
Returns:
|
||||
A validated `TrainSpec` object.
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails.
|
||||
TypeError: If any of the arguments is not the expected type.
|
||||
"""
|
||||
# Validate input_fn.
|
||||
_validate_input_fn(input_fn)
|
||||
|
||||
# Validate max_steps.
|
||||
if max_steps is not None and max_steps <= 0:
|
||||
raise ValueError(
|
||||
'Must specify max_steps > 0, given: {}'.format(max_steps))
|
||||
|
||||
# Validate hooks.
|
||||
hooks = _validate_hooks(hooks)
|
||||
|
||||
return super(TrainSpec, cls).__new__(
|
||||
cls,
|
||||
input_fn=input_fn,
|
||||
max_steps=max_steps,
|
||||
hooks=hooks)
|
||||
|
||||
|
||||
class EvalSpec(
|
||||
collections.namedtuple('EvalSpec', [
|
||||
'input_fn', 'steps', 'name', 'hooks', 'export_strategies',
|
||||
'delay_secs', 'throttle_secs'
|
||||
])):
|
||||
"""Objects passed to `train_and_evaluate`.
|
||||
|
||||
`EvalSpec` fully defines the objects to be run by `Estimator.evaluate` and
|
||||
`Estimator.export_savedmodel`.
|
||||
"""
|
||||
|
||||
def __new__(cls,
|
||||
input_fn,
|
||||
steps=100,
|
||||
name=None,
|
||||
hooks=None,
|
||||
export_strategies=None,
|
||||
delay_secs=120,
|
||||
throttle_secs=60):
|
||||
"""Creates a validated `EvalSpec` instance.
|
||||
|
||||
Args:
|
||||
input_fn: Training input function returning a tuple of:
|
||||
features - `Tensor` or dictionary of string feature name to `Tensor`.
|
||||
labels - `Tensor` or dictionary of `Tensor` with labels.
|
||||
steps: Int. Number of total steps for which to train model. If `None`,
|
||||
train forever or train until `input_fn` generates the `OutOfRange` error
|
||||
or `StopIteration` exception. See `Estimator.train` for details.
|
||||
name: String. Name of the evaluation if user needs to run multiple
|
||||
evaluations on different data sets. Metrics for different evaluations
|
||||
are saved in separate folders, and appear separately in tensorboard.
|
||||
hooks: Iterable of `tf.train.SessionRunHook` objects to run
|
||||
on all workers (including chief) during training.
|
||||
export_strategies: Iterable of `ExportStrategy`s, or a single one, or
|
||||
`None`. `export_strategies` will be invoked after each evaluation.
|
||||
delay_secs: Int. Start evaluating after waiting for this many seconds.
|
||||
throttle_secs: Int. Do not re-evaluate unless the last evaluation was
|
||||
started at least this many seconds ago. Of course, evaluation does not
|
||||
occur if no new checkpoint is available, hence, this is the minimum.
|
||||
|
||||
Returns:
|
||||
A validated `TrainSpec` object.
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails.
|
||||
TypeError: If any of the arguments is not the expected type.
|
||||
"""
|
||||
# Validate input_fn.
|
||||
_validate_input_fn(input_fn)
|
||||
|
||||
# Validate steps.
|
||||
if steps is not None and steps <= 0:
|
||||
raise ValueError('Must specify steps > 0, given: {}'.format(steps))
|
||||
|
||||
# Validate name.
|
||||
if name is not None and not isinstance(name, six.string_types):
|
||||
raise TypeError('`name` must be string, given: {}'.format(name))
|
||||
|
||||
# Validate hooks.
|
||||
hooks = _validate_hooks(hooks)
|
||||
|
||||
# Validate export_strategies.
|
||||
export_strategies = tuple(export_strategies or [])
|
||||
# TODO(b/65169058): Validate export_strategies once `ExportStratey` defined.
|
||||
|
||||
# Validate delay_secs.
|
||||
if delay_secs < 0:
|
||||
raise ValueError(
|
||||
'Must specify delay_secs >= 0, given: {}'.format(delay_secs))
|
||||
|
||||
# Validate throttle_secs.
|
||||
if throttle_secs < 0:
|
||||
raise ValueError(
|
||||
'Must specify throttle_secs >= 0, given: {}'.format(throttle_secs))
|
||||
|
||||
return super(EvalSpec, cls).__new__(
|
||||
cls,
|
||||
input_fn=input_fn,
|
||||
steps=steps,
|
||||
name=name,
|
||||
hooks=hooks,
|
||||
export_strategies=export_strategies,
|
||||
delay_secs=delay_secs,
|
||||
throttle_secs=throttle_secs)
|
||||
|
133
tensorflow/python/estimator/training_test.py
Normal file
133
tensorflow/python/estimator/training_test.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for training.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.estimator import training
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import session_run_hook
|
||||
|
||||
_DEFAULT_EVAL_STEPS = 100
|
||||
_DEFAULT_EVAL_DELAY_SECS = 120
|
||||
_DEFAULT_EVAL_THROTTLE_SECS = 60
|
||||
_INVALID_INPUT_FN_MSG = '`input_fn` must be callable'
|
||||
_INVALID_HOOK_MSG = 'All hooks must be `SessionRunHook` instances'
|
||||
_INVALID_MAX_STEPS_MSG = 'Must specify max_steps > 0'
|
||||
_INVALID_STEPS_MSG = 'Must specify steps > 0'
|
||||
_INVALID_NAME_MSG = '`name` must be string'
|
||||
_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify delay_secs >= 0'
|
||||
_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0'
|
||||
|
||||
|
||||
class _FakeHook(session_run_hook.SessionRunHook):
|
||||
"""Fake implementation of `SessionRunHook`."""
|
||||
|
||||
|
||||
class _InvalidHook(object):
|
||||
"""Invalid hook (not a subclass of `SessionRunHook`)."""
|
||||
|
||||
|
||||
class TrainSpecTest(test.TestCase):
|
||||
"""Tests TrainSpec."""
|
||||
|
||||
def testRequiredArgumentsSet(self):
|
||||
"""Tests that no errors are raised when all required arguments are set."""
|
||||
spec = training.TrainSpec(input_fn=lambda: 1)
|
||||
self.assertEqual(1, spec.input_fn())
|
||||
self.assertIsNone(spec.max_steps)
|
||||
self.assertEqual(0, len(spec.hooks))
|
||||
|
||||
def testAllArgumentsSet(self):
|
||||
"""Tests that no errors are raised when all arguments are set."""
|
||||
hooks = [_FakeHook()]
|
||||
spec = training.TrainSpec(input_fn=lambda: 1, max_steps=2, hooks=hooks)
|
||||
self.assertEqual(1, spec.input_fn())
|
||||
self.assertEqual(2, spec.max_steps)
|
||||
self.assertEqual(tuple(hooks), spec.hooks)
|
||||
|
||||
def testInvalidInputFn(self):
|
||||
with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG):
|
||||
training.TrainSpec(input_fn='invalid')
|
||||
|
||||
def testInvalidMaxStep(self):
|
||||
with self.assertRaisesRegexp(ValueError, _INVALID_MAX_STEPS_MSG):
|
||||
training.TrainSpec(input_fn=lambda: 1, max_steps=0)
|
||||
|
||||
def testInvalidHook(self):
|
||||
with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG):
|
||||
training.TrainSpec(input_fn=lambda: 1, hooks=[_InvalidHook()])
|
||||
|
||||
|
||||
class EvalSpecTest(test.TestCase):
|
||||
"""Tests EvalSpec."""
|
||||
|
||||
def testRequiredArgumentsSet(self):
|
||||
"""Tests that no errors are raised when all required arguments are set."""
|
||||
spec = training.EvalSpec(input_fn=lambda: 1)
|
||||
self.assertEqual(1, spec.input_fn())
|
||||
self.assertEqual(_DEFAULT_EVAL_STEPS, spec.steps)
|
||||
self.assertIsNone(spec.name)
|
||||
self.assertEqual(0, len(spec.hooks))
|
||||
self.assertEqual(0, len(spec.export_strategies))
|
||||
self.assertEqual(_DEFAULT_EVAL_DELAY_SECS, spec.delay_secs)
|
||||
self.assertEqual(_DEFAULT_EVAL_THROTTLE_SECS, spec.throttle_secs)
|
||||
|
||||
def testAllArgumentsSet(self):
|
||||
"""Tests that no errors are raised when all arguments are set."""
|
||||
hooks = [_FakeHook()]
|
||||
|
||||
# TODO(b/65169058): Replace the export_strategies with valid instances.
|
||||
spec = training.EvalSpec(input_fn=lambda: 1, steps=2, name='name',
|
||||
hooks=hooks, export_strategies=hooks,
|
||||
delay_secs=3, throttle_secs=4)
|
||||
self.assertEqual(1, spec.input_fn())
|
||||
self.assertEqual(2, spec.steps)
|
||||
self.assertEqual('name', spec.name)
|
||||
self.assertEqual(tuple(hooks), spec.hooks)
|
||||
self.assertEqual(tuple(hooks), spec.export_strategies)
|
||||
self.assertEqual(3, spec.delay_secs)
|
||||
self.assertEqual(4, spec.throttle_secs)
|
||||
|
||||
def testInvalidInputFn(self):
|
||||
with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG):
|
||||
training.EvalSpec(input_fn='invalid')
|
||||
|
||||
def testInvalidMaxStep(self):
|
||||
with self.assertRaisesRegexp(ValueError, _INVALID_STEPS_MSG):
|
||||
training.EvalSpec(input_fn=lambda: 1, steps=0)
|
||||
|
||||
def testInvalidName(self):
|
||||
with self.assertRaisesRegexp(TypeError, _INVALID_NAME_MSG):
|
||||
training.EvalSpec(input_fn=lambda: 1, name=123)
|
||||
|
||||
def testInvalidHook(self):
|
||||
with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG):
|
||||
training.EvalSpec(input_fn=lambda: 1, hooks=[_InvalidHook()])
|
||||
|
||||
def testInvalidDelaySecs(self):
|
||||
with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_DELAY_SECS_MSG):
|
||||
training.EvalSpec(input_fn=lambda: 1, delay_secs=-1)
|
||||
|
||||
def testInvalidThrottleSecs(self):
|
||||
with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_THROTTLE_SECS_MSG):
|
||||
training.EvalSpec(input_fn=lambda: 1, throttle_secs=-1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user