Automated g4 rollback of changelist 159583264
PiperOrigin-RevId: 159630408
This commit is contained in:
parent
35af7113de
commit
5856f9ea6d
@ -357,7 +357,7 @@ class Estimator(object):
|
||||
}
|
||||
|
||||
def _assert_members_are_not_overridden(self):
|
||||
allowed_overrides = set(['_call_input_fn', '_create_global_step'])
|
||||
allowed_overrides = set(['_create_global_step'])
|
||||
estimator_members = set([m for m in Estimator.__dict__.keys()
|
||||
if not m.startswith('__')])
|
||||
subclass_members = set(self.__class__.__dict__.keys())
|
||||
@ -485,7 +485,7 @@ class Estimator(object):
|
||||
return export_dir
|
||||
|
||||
def _get_features_from_input_fn(self, input_fn):
|
||||
result = self._call_input_fn(input_fn)
|
||||
result = input_fn()
|
||||
if not ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
|
||||
logging.warning('Input graph does not contain a QueueRunner. '
|
||||
'That means predict yields forever. '
|
||||
@ -549,32 +549,6 @@ class Estimator(object):
|
||||
assert step.dtype.is_integer
|
||||
return step
|
||||
|
||||
def _call_input_fn(self, input_fn):
|
||||
"""Calls the input function.
|
||||
|
||||
Args:
|
||||
input_fn: The input function.
|
||||
|
||||
Returns:
|
||||
Either features or (features, labels) where features and labels are:
|
||||
features - `Tensor` or dictionary of string feature name to `Tensor`.
|
||||
labels - `Tensor` or dictionary of `Tensor` with labels.
|
||||
|
||||
Raises:
|
||||
ValueError: if input_fn takes invalid arguments.
|
||||
"""
|
||||
input_fn_args = _fn_args(input_fn)
|
||||
for arg in input_fn_args:
|
||||
if arg not in ('config', 'params'):
|
||||
raise ValueError('input_fn should not include argument {}.'.format(arg))
|
||||
kwargs = {}
|
||||
if 'params' in input_fn_args:
|
||||
kwargs['params'] = self.params
|
||||
if 'config' in input_fn_args:
|
||||
kwargs['config'] = self.config
|
||||
with ops.device('/cpu:0'):
|
||||
return input_fn(**kwargs)
|
||||
|
||||
def _call_model_fn(self, features, labels, mode):
|
||||
"""Calls model function.
|
||||
|
||||
@ -589,7 +563,7 @@ class Estimator(object):
|
||||
Raises:
|
||||
ValueError: if model_fn returns invalid objects.
|
||||
"""
|
||||
model_fn_args = _fn_args(self._model_fn)
|
||||
model_fn_args = _model_fn_args(self._model_fn)
|
||||
kwargs = {}
|
||||
if 'mode' in model_fn_args:
|
||||
kwargs['mode'] = mode
|
||||
@ -610,7 +584,8 @@ class Estimator(object):
|
||||
with ops.Graph().as_default() as g, g.device(self._device_fn):
|
||||
random_seed.set_random_seed(self._config.tf_random_seed)
|
||||
global_step_tensor = self._create_and_assert_global_step(g)
|
||||
features, labels = self._call_input_fn(input_fn)
|
||||
with ops.device('/cpu:0'):
|
||||
features, labels = input_fn()
|
||||
estimator_spec = self._call_model_fn(features, labels,
|
||||
model_fn_lib.ModeKeys.TRAIN)
|
||||
ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
|
||||
@ -691,7 +666,7 @@ class Estimator(object):
|
||||
with ops.Graph().as_default() as g:
|
||||
random_seed.set_random_seed(self._config.tf_random_seed)
|
||||
global_step_tensor = self._create_and_assert_global_step(g)
|
||||
features, labels = self._call_input_fn(input_fn)
|
||||
features, labels = input_fn()
|
||||
estimator_spec = self._call_model_fn(
|
||||
features, labels, model_fn_lib.ModeKeys.EVAL)
|
||||
|
||||
@ -774,7 +749,7 @@ def _get_replica_device_setter(config):
|
||||
return None
|
||||
|
||||
|
||||
def _fn_args(fn):
|
||||
def _model_fn_args(fn):
|
||||
"""Get argument names for function-like object.
|
||||
|
||||
Args:
|
||||
@ -799,7 +774,7 @@ def _fn_args(fn):
|
||||
|
||||
def _verify_model_fn_args(model_fn, params):
|
||||
"""Verifies model fn arguments."""
|
||||
args = set(_fn_args(model_fn))
|
||||
args = set(_model_fn_args(model_fn))
|
||||
if 'features' not in args:
|
||||
raise ValueError('model_fn (%s) must include features argument.' % model_fn)
|
||||
if 'labels' not in args:
|
||||
|
@ -120,9 +120,6 @@ class EstimatorInheritanceConstraintTest(test.TestCase):
|
||||
def __init__(self):
|
||||
super(_Estimator, self).__init__(model_fn=dummy_model_fn)
|
||||
|
||||
def _call_input_fn(self, input_fn):
|
||||
return input_fn()
|
||||
|
||||
def _create_global_step(self, graph):
|
||||
pass
|
||||
|
||||
@ -328,48 +325,6 @@ def _make_input_fn(features, labels):
|
||||
|
||||
class EstimatorTrainTest(test.TestCase):
|
||||
|
||||
def test_bad_input_fn_args(self):
|
||||
expected_params = {'batch_size': 10}
|
||||
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
||||
|
||||
def _model_fn(features, labels, mode, params, config):
|
||||
del params, config
|
||||
return model_fn_global_step_incrementer(features, labels, mode)
|
||||
|
||||
def _input_fn(params, config, not_allowed):
|
||||
del not_allowed
|
||||
self.assertEqual(expected_params, params)
|
||||
self.assertEqual(4321, config.tf_random_seed)
|
||||
return dummy_input_fn()
|
||||
|
||||
est = estimator.Estimator(model_fn=_model_fn,
|
||||
params=expected_params,
|
||||
config=expected_config)
|
||||
with self.assertRaisesRegexp(ValueError, 'should not include argument'):
|
||||
est.train(_input_fn, steps=1)
|
||||
|
||||
def test_input_fn_args(self):
|
||||
expected_params = {'batch_size': 10}
|
||||
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
||||
input_fn_call_count = [0]
|
||||
|
||||
def _model_fn(features, labels, mode, params, config):
|
||||
del params, config
|
||||
return model_fn_global_step_incrementer(features, labels, mode)
|
||||
|
||||
def _input_fn(params, config):
|
||||
input_fn_call_count[0] += 1
|
||||
self.assertEqual(expected_params, params)
|
||||
self.assertEqual(4321, config.tf_random_seed)
|
||||
return dummy_input_fn()
|
||||
|
||||
est = estimator.Estimator(model_fn=_model_fn,
|
||||
params=expected_params,
|
||||
config=expected_config)
|
||||
self.assertEqual(0, input_fn_call_count[0])
|
||||
est.train(_input_fn, steps=1)
|
||||
self.assertEqual(1, input_fn_call_count[0])
|
||||
|
||||
def test_minimal_model_fn_args(self):
|
||||
expected_features = {'x': 42., 'y': 43.}
|
||||
expected_labels = 44.
|
||||
@ -710,29 +665,6 @@ class _StepCounterHook(session_run_hook.SessionRunHook):
|
||||
|
||||
class EstimatorEvaluateTest(test.TestCase):
|
||||
|
||||
def test_input_fn_args(self):
|
||||
expected_params = {'batch_size': 10}
|
||||
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
||||
input_fn_call_count = [0]
|
||||
|
||||
def _model_fn(features, labels, mode, params, config):
|
||||
del params, config
|
||||
return model_fn_global_step_incrementer(features, labels, mode)
|
||||
|
||||
def _input_fn(params, config):
|
||||
input_fn_call_count[0] += 1
|
||||
self.assertEqual(expected_params, params)
|
||||
self.assertEqual(4321, config.tf_random_seed)
|
||||
return dummy_input_fn()
|
||||
|
||||
est = estimator.Estimator(model_fn=_model_fn,
|
||||
params=expected_params,
|
||||
config=expected_config)
|
||||
est.train(dummy_input_fn, steps=1)
|
||||
self.assertEqual(0, input_fn_call_count[0])
|
||||
est.evaluate(_input_fn, steps=1)
|
||||
self.assertEqual(1, input_fn_call_count[0])
|
||||
|
||||
def test_model_fn_must_return_estimator_spec(self):
|
||||
def _model_fn(features, labels, mode):
|
||||
_, _ = features, labels
|
||||
@ -934,33 +866,6 @@ class EstimatorEvaluateTest(test.TestCase):
|
||||
|
||||
class EstimatorPredictTest(test.TestCase):
|
||||
|
||||
def test_input_fn_args(self):
|
||||
expected_params = {'batch_size': 10}
|
||||
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
||||
input_fn_call_count = [0]
|
||||
|
||||
def _model_fn(features, labels, mode, params, config):
|
||||
del features, labels, params, config
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode,
|
||||
loss=constant_op.constant(0.),
|
||||
train_op=state_ops.assign_add(training.get_global_step(), 1),
|
||||
predictions=constant_op.constant([[10.]]))
|
||||
|
||||
def _input_fn(params, config):
|
||||
input_fn_call_count[0] += 1
|
||||
self.assertEqual(expected_params, params)
|
||||
self.assertEqual(4321, config.tf_random_seed)
|
||||
return dummy_input_fn()
|
||||
|
||||
est = estimator.Estimator(model_fn=_model_fn,
|
||||
params=expected_params,
|
||||
config=expected_config)
|
||||
est.train(dummy_input_fn, steps=1)
|
||||
self.assertEqual(0, input_fn_call_count[0])
|
||||
next(est.predict(_input_fn))
|
||||
self.assertEqual(1, input_fn_call_count[0])
|
||||
|
||||
def test_no_trained_model_in_model_dir(self):
|
||||
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
|
Loading…
Reference in New Issue
Block a user