Introduces TrainSpec and EvalSpec.
PiperOrigin-RevId: 168040435
This commit is contained in:
parent
c8b9e92f07
commit
86f1713e51
@ -22,6 +22,7 @@ py_library(
|
|||||||
":model_fn",
|
":model_fn",
|
||||||
":parsing_utils",
|
":parsing_utils",
|
||||||
":run_config",
|
":run_config",
|
||||||
|
":training",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -70,6 +71,27 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "training",
|
||||||
|
srcs = ["training.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:training",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "training_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["training_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":training",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "run_config",
|
name = "run_config",
|
||||||
srcs = ["run_config.py"],
|
srcs = ["run_config.py"],
|
||||||
|
179
tensorflow/python/estimator/training.py
Normal file
179
tensorflow/python/estimator/training.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Classes and functions related to train_and_evaluate."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.training import session_run_hook
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_input_fn(input_fn):
|
||||||
|
"""Validates the `input_fn`."""
|
||||||
|
if not callable(input_fn):
|
||||||
|
raise TypeError(
|
||||||
|
'`input_fn` must be callable, given: {}'.format(input_fn))
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_hooks(hooks):
|
||||||
|
"""Validates the `hooks`."""
|
||||||
|
hooks = tuple(hooks or [])
|
||||||
|
for hook in hooks:
|
||||||
|
if not isinstance(hook, session_run_hook.SessionRunHook):
|
||||||
|
raise TypeError(
|
||||||
|
'All hooks must be `SessionRunHook` instances, given: {}'.format(
|
||||||
|
hook))
|
||||||
|
return hooks
|
||||||
|
|
||||||
|
|
||||||
|
class TrainSpec(
|
||||||
|
collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])):
|
||||||
|
"""Objects passed to `train_and_evaluate`.
|
||||||
|
|
||||||
|
`TrainSpec` fully defines the objects to be run by `Estimator.train`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls,
|
||||||
|
input_fn,
|
||||||
|
max_steps=None,
|
||||||
|
hooks=None):
|
||||||
|
"""Creates a validated `TrainSpec` instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_fn: Training input function returning a tuple of:
|
||||||
|
features - `Tensor` or dictionary of string feature name to `Tensor`.
|
||||||
|
labels - `Tensor` or dictionary of `Tensor` with labels.
|
||||||
|
max_steps: Int. Number of total steps for which to train model. If `None`,
|
||||||
|
train forever or train until `input_fn` generates the `OutOfRange` error
|
||||||
|
or `StopIteration` exception. See `Estimator.train` for details.
|
||||||
|
hooks: Iterable of `tf.train.SessionRunHook` objects to run
|
||||||
|
on all workers (including chief) during training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A validated `TrainSpec` object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If validation fails.
|
||||||
|
TypeError: If any of the arguments is not the expected type.
|
||||||
|
"""
|
||||||
|
# Validate input_fn.
|
||||||
|
_validate_input_fn(input_fn)
|
||||||
|
|
||||||
|
# Validate max_steps.
|
||||||
|
if max_steps is not None and max_steps <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
'Must specify max_steps > 0, given: {}'.format(max_steps))
|
||||||
|
|
||||||
|
# Validate hooks.
|
||||||
|
hooks = _validate_hooks(hooks)
|
||||||
|
|
||||||
|
return super(TrainSpec, cls).__new__(
|
||||||
|
cls,
|
||||||
|
input_fn=input_fn,
|
||||||
|
max_steps=max_steps,
|
||||||
|
hooks=hooks)
|
||||||
|
|
||||||
|
|
||||||
|
class EvalSpec(
|
||||||
|
collections.namedtuple('EvalSpec', [
|
||||||
|
'input_fn', 'steps', 'name', 'hooks', 'export_strategies',
|
||||||
|
'delay_secs', 'throttle_secs'
|
||||||
|
])):
|
||||||
|
"""Objects passed to `train_and_evaluate`.
|
||||||
|
|
||||||
|
`EvalSpec` fully defines the objects to be run by `Estimator.evaluate` and
|
||||||
|
`Estimator.export_savedmodel`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls,
|
||||||
|
input_fn,
|
||||||
|
steps=100,
|
||||||
|
name=None,
|
||||||
|
hooks=None,
|
||||||
|
export_strategies=None,
|
||||||
|
delay_secs=120,
|
||||||
|
throttle_secs=60):
|
||||||
|
"""Creates a validated `EvalSpec` instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_fn: Training input function returning a tuple of:
|
||||||
|
features - `Tensor` or dictionary of string feature name to `Tensor`.
|
||||||
|
labels - `Tensor` or dictionary of `Tensor` with labels.
|
||||||
|
steps: Int. Number of total steps for which to train model. If `None`,
|
||||||
|
train forever or train until `input_fn` generates the `OutOfRange` error
|
||||||
|
or `StopIteration` exception. See `Estimator.train` for details.
|
||||||
|
name: String. Name of the evaluation if user needs to run multiple
|
||||||
|
evaluations on different data sets. Metrics for different evaluations
|
||||||
|
are saved in separate folders, and appear separately in tensorboard.
|
||||||
|
hooks: Iterable of `tf.train.SessionRunHook` objects to run
|
||||||
|
on all workers (including chief) during training.
|
||||||
|
export_strategies: Iterable of `ExportStrategy`s, or a single one, or
|
||||||
|
`None`. `export_strategies` will be invoked after each evaluation.
|
||||||
|
delay_secs: Int. Start evaluating after waiting for this many seconds.
|
||||||
|
throttle_secs: Int. Do not re-evaluate unless the last evaluation was
|
||||||
|
started at least this many seconds ago. Of course, evaluation does not
|
||||||
|
occur if no new checkpoint is available, hence, this is the minimum.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A validated `TrainSpec` object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If validation fails.
|
||||||
|
TypeError: If any of the arguments is not the expected type.
|
||||||
|
"""
|
||||||
|
# Validate input_fn.
|
||||||
|
_validate_input_fn(input_fn)
|
||||||
|
|
||||||
|
# Validate steps.
|
||||||
|
if steps is not None and steps <= 0:
|
||||||
|
raise ValueError('Must specify steps > 0, given: {}'.format(steps))
|
||||||
|
|
||||||
|
# Validate name.
|
||||||
|
if name is not None and not isinstance(name, six.string_types):
|
||||||
|
raise TypeError('`name` must be string, given: {}'.format(name))
|
||||||
|
|
||||||
|
# Validate hooks.
|
||||||
|
hooks = _validate_hooks(hooks)
|
||||||
|
|
||||||
|
# Validate export_strategies.
|
||||||
|
export_strategies = tuple(export_strategies or [])
|
||||||
|
# TODO(b/65169058): Validate export_strategies once `ExportStratey` defined.
|
||||||
|
|
||||||
|
# Validate delay_secs.
|
||||||
|
if delay_secs < 0:
|
||||||
|
raise ValueError(
|
||||||
|
'Must specify delay_secs >= 0, given: {}'.format(delay_secs))
|
||||||
|
|
||||||
|
# Validate throttle_secs.
|
||||||
|
if throttle_secs < 0:
|
||||||
|
raise ValueError(
|
||||||
|
'Must specify throttle_secs >= 0, given: {}'.format(throttle_secs))
|
||||||
|
|
||||||
|
return super(EvalSpec, cls).__new__(
|
||||||
|
cls,
|
||||||
|
input_fn=input_fn,
|
||||||
|
steps=steps,
|
||||||
|
name=name,
|
||||||
|
hooks=hooks,
|
||||||
|
export_strategies=export_strategies,
|
||||||
|
delay_secs=delay_secs,
|
||||||
|
throttle_secs=throttle_secs)
|
||||||
|
|
133
tensorflow/python/estimator/training_test.py
Normal file
133
tensorflow/python/estimator/training_test.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for training.py."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.estimator import training
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.training import session_run_hook
|
||||||
|
|
||||||
|
_DEFAULT_EVAL_STEPS = 100
|
||||||
|
_DEFAULT_EVAL_DELAY_SECS = 120
|
||||||
|
_DEFAULT_EVAL_THROTTLE_SECS = 60
|
||||||
|
_INVALID_INPUT_FN_MSG = '`input_fn` must be callable'
|
||||||
|
_INVALID_HOOK_MSG = 'All hooks must be `SessionRunHook` instances'
|
||||||
|
_INVALID_MAX_STEPS_MSG = 'Must specify max_steps > 0'
|
||||||
|
_INVALID_STEPS_MSG = 'Must specify steps > 0'
|
||||||
|
_INVALID_NAME_MSG = '`name` must be string'
|
||||||
|
_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify delay_secs >= 0'
|
||||||
|
_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0'
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeHook(session_run_hook.SessionRunHook):
|
||||||
|
"""Fake implementation of `SessionRunHook`."""
|
||||||
|
|
||||||
|
|
||||||
|
class _InvalidHook(object):
|
||||||
|
"""Invalid hook (not a subclass of `SessionRunHook`)."""
|
||||||
|
|
||||||
|
|
||||||
|
class TrainSpecTest(test.TestCase):
|
||||||
|
"""Tests TrainSpec."""
|
||||||
|
|
||||||
|
def testRequiredArgumentsSet(self):
|
||||||
|
"""Tests that no errors are raised when all required arguments are set."""
|
||||||
|
spec = training.TrainSpec(input_fn=lambda: 1)
|
||||||
|
self.assertEqual(1, spec.input_fn())
|
||||||
|
self.assertIsNone(spec.max_steps)
|
||||||
|
self.assertEqual(0, len(spec.hooks))
|
||||||
|
|
||||||
|
def testAllArgumentsSet(self):
|
||||||
|
"""Tests that no errors are raised when all arguments are set."""
|
||||||
|
hooks = [_FakeHook()]
|
||||||
|
spec = training.TrainSpec(input_fn=lambda: 1, max_steps=2, hooks=hooks)
|
||||||
|
self.assertEqual(1, spec.input_fn())
|
||||||
|
self.assertEqual(2, spec.max_steps)
|
||||||
|
self.assertEqual(tuple(hooks), spec.hooks)
|
||||||
|
|
||||||
|
def testInvalidInputFn(self):
|
||||||
|
with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG):
|
||||||
|
training.TrainSpec(input_fn='invalid')
|
||||||
|
|
||||||
|
def testInvalidMaxStep(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError, _INVALID_MAX_STEPS_MSG):
|
||||||
|
training.TrainSpec(input_fn=lambda: 1, max_steps=0)
|
||||||
|
|
||||||
|
def testInvalidHook(self):
|
||||||
|
with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG):
|
||||||
|
training.TrainSpec(input_fn=lambda: 1, hooks=[_InvalidHook()])
|
||||||
|
|
||||||
|
|
||||||
|
class EvalSpecTest(test.TestCase):
|
||||||
|
"""Tests EvalSpec."""
|
||||||
|
|
||||||
|
def testRequiredArgumentsSet(self):
|
||||||
|
"""Tests that no errors are raised when all required arguments are set."""
|
||||||
|
spec = training.EvalSpec(input_fn=lambda: 1)
|
||||||
|
self.assertEqual(1, spec.input_fn())
|
||||||
|
self.assertEqual(_DEFAULT_EVAL_STEPS, spec.steps)
|
||||||
|
self.assertIsNone(spec.name)
|
||||||
|
self.assertEqual(0, len(spec.hooks))
|
||||||
|
self.assertEqual(0, len(spec.export_strategies))
|
||||||
|
self.assertEqual(_DEFAULT_EVAL_DELAY_SECS, spec.delay_secs)
|
||||||
|
self.assertEqual(_DEFAULT_EVAL_THROTTLE_SECS, spec.throttle_secs)
|
||||||
|
|
||||||
|
def testAllArgumentsSet(self):
|
||||||
|
"""Tests that no errors are raised when all arguments are set."""
|
||||||
|
hooks = [_FakeHook()]
|
||||||
|
|
||||||
|
# TODO(b/65169058): Replace the export_strategies with valid instances.
|
||||||
|
spec = training.EvalSpec(input_fn=lambda: 1, steps=2, name='name',
|
||||||
|
hooks=hooks, export_strategies=hooks,
|
||||||
|
delay_secs=3, throttle_secs=4)
|
||||||
|
self.assertEqual(1, spec.input_fn())
|
||||||
|
self.assertEqual(2, spec.steps)
|
||||||
|
self.assertEqual('name', spec.name)
|
||||||
|
self.assertEqual(tuple(hooks), spec.hooks)
|
||||||
|
self.assertEqual(tuple(hooks), spec.export_strategies)
|
||||||
|
self.assertEqual(3, spec.delay_secs)
|
||||||
|
self.assertEqual(4, spec.throttle_secs)
|
||||||
|
|
||||||
|
def testInvalidInputFn(self):
|
||||||
|
with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG):
|
||||||
|
training.EvalSpec(input_fn='invalid')
|
||||||
|
|
||||||
|
def testInvalidMaxStep(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError, _INVALID_STEPS_MSG):
|
||||||
|
training.EvalSpec(input_fn=lambda: 1, steps=0)
|
||||||
|
|
||||||
|
def testInvalidName(self):
|
||||||
|
with self.assertRaisesRegexp(TypeError, _INVALID_NAME_MSG):
|
||||||
|
training.EvalSpec(input_fn=lambda: 1, name=123)
|
||||||
|
|
||||||
|
def testInvalidHook(self):
|
||||||
|
with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG):
|
||||||
|
training.EvalSpec(input_fn=lambda: 1, hooks=[_InvalidHook()])
|
||||||
|
|
||||||
|
def testInvalidDelaySecs(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_DELAY_SECS_MSG):
|
||||||
|
training.EvalSpec(input_fn=lambda: 1, delay_secs=-1)
|
||||||
|
|
||||||
|
def testInvalidThrottleSecs(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_THROTTLE_SECS_MSG):
|
||||||
|
training.EvalSpec(input_fn=lambda: 1, throttle_secs=-1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
Loading…
Reference in New Issue
Block a user