Automated g4 rollback of changelist 159583264

PiperOrigin-RevId: 159630408
This commit is contained in:
Jonathan Hseu 2017-06-20 16:18:12 -07:00 committed by TensorFlower Gardener
parent 35af7113de
commit 5856f9ea6d
2 changed files with 8 additions and 128 deletions

View File

@ -357,7 +357,7 @@ class Estimator(object):
}
def _assert_members_are_not_overridden(self):
allowed_overrides = set(['_call_input_fn', '_create_global_step'])
allowed_overrides = set(['_create_global_step'])
estimator_members = set([m for m in Estimator.__dict__.keys()
if not m.startswith('__')])
subclass_members = set(self.__class__.__dict__.keys())
@ -485,7 +485,7 @@ class Estimator(object):
return export_dir
def _get_features_from_input_fn(self, input_fn):
result = self._call_input_fn(input_fn)
result = input_fn()
if not ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
logging.warning('Input graph does not contain a QueueRunner. '
'That means predict yields forever. '
@ -549,32 +549,6 @@ class Estimator(object):
assert step.dtype.is_integer
return step
def _call_input_fn(self, input_fn):
"""Calls the input function.
Args:
input_fn: The input function.
Returns:
Either features or (features, labels) where features and labels are:
features - `Tensor` or dictionary of string feature name to `Tensor`.
labels - `Tensor` or dictionary of `Tensor` with labels.
Raises:
ValueError: if input_fn takes invalid arguments.
"""
input_fn_args = _fn_args(input_fn)
for arg in input_fn_args:
if arg not in ('config', 'params'):
raise ValueError('input_fn should not include argument {}.'.format(arg))
kwargs = {}
if 'params' in input_fn_args:
kwargs['params'] = self.params
if 'config' in input_fn_args:
kwargs['config'] = self.config
with ops.device('/cpu:0'):
return input_fn(**kwargs)
def _call_model_fn(self, features, labels, mode):
"""Calls model function.
@ -589,7 +563,7 @@ class Estimator(object):
Raises:
ValueError: if model_fn returns invalid objects.
"""
model_fn_args = _fn_args(self._model_fn)
model_fn_args = _model_fn_args(self._model_fn)
kwargs = {}
if 'mode' in model_fn_args:
kwargs['mode'] = mode
@ -610,7 +584,8 @@ class Estimator(object):
with ops.Graph().as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step_tensor = self._create_and_assert_global_step(g)
features, labels = self._call_input_fn(input_fn)
with ops.device('/cpu:0'):
features, labels = input_fn()
estimator_spec = self._call_model_fn(features, labels,
model_fn_lib.ModeKeys.TRAIN)
ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
@ -691,7 +666,7 @@ class Estimator(object):
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
global_step_tensor = self._create_and_assert_global_step(g)
features, labels = self._call_input_fn(input_fn)
features, labels = input_fn()
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.EVAL)
@ -774,7 +749,7 @@ def _get_replica_device_setter(config):
return None
def _fn_args(fn):
def _model_fn_args(fn):
"""Get argument names for function-like object.
Args:
@ -799,7 +774,7 @@ def _fn_args(fn):
def _verify_model_fn_args(model_fn, params):
"""Verifies model fn arguments."""
args = set(_fn_args(model_fn))
args = set(_model_fn_args(model_fn))
if 'features' not in args:
raise ValueError('model_fn (%s) must include features argument.' % model_fn)
if 'labels' not in args:

View File

@ -120,9 +120,6 @@ class EstimatorInheritanceConstraintTest(test.TestCase):
def __init__(self):
super(_Estimator, self).__init__(model_fn=dummy_model_fn)
def _call_input_fn(self, input_fn):
return input_fn()
def _create_global_step(self, graph):
pass
@ -328,48 +325,6 @@ def _make_input_fn(features, labels):
class EstimatorTrainTest(test.TestCase):
def test_bad_input_fn_args(self):
expected_params = {'batch_size': 10}
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
def _model_fn(features, labels, mode, params, config):
del params, config
return model_fn_global_step_incrementer(features, labels, mode)
def _input_fn(params, config, not_allowed):
del not_allowed
self.assertEqual(expected_params, params)
self.assertEqual(4321, config.tf_random_seed)
return dummy_input_fn()
est = estimator.Estimator(model_fn=_model_fn,
params=expected_params,
config=expected_config)
with self.assertRaisesRegexp(ValueError, 'should not include argument'):
est.train(_input_fn, steps=1)
def test_input_fn_args(self):
expected_params = {'batch_size': 10}
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
input_fn_call_count = [0]
def _model_fn(features, labels, mode, params, config):
del params, config
return model_fn_global_step_incrementer(features, labels, mode)
def _input_fn(params, config):
input_fn_call_count[0] += 1
self.assertEqual(expected_params, params)
self.assertEqual(4321, config.tf_random_seed)
return dummy_input_fn()
est = estimator.Estimator(model_fn=_model_fn,
params=expected_params,
config=expected_config)
self.assertEqual(0, input_fn_call_count[0])
est.train(_input_fn, steps=1)
self.assertEqual(1, input_fn_call_count[0])
def test_minimal_model_fn_args(self):
expected_features = {'x': 42., 'y': 43.}
expected_labels = 44.
@ -710,29 +665,6 @@ class _StepCounterHook(session_run_hook.SessionRunHook):
class EstimatorEvaluateTest(test.TestCase):
def test_input_fn_args(self):
expected_params = {'batch_size': 10}
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
input_fn_call_count = [0]
def _model_fn(features, labels, mode, params, config):
del params, config
return model_fn_global_step_incrementer(features, labels, mode)
def _input_fn(params, config):
input_fn_call_count[0] += 1
self.assertEqual(expected_params, params)
self.assertEqual(4321, config.tf_random_seed)
return dummy_input_fn()
est = estimator.Estimator(model_fn=_model_fn,
params=expected_params,
config=expected_config)
est.train(dummy_input_fn, steps=1)
self.assertEqual(0, input_fn_call_count[0])
est.evaluate(_input_fn, steps=1)
self.assertEqual(1, input_fn_call_count[0])
def test_model_fn_must_return_estimator_spec(self):
def _model_fn(features, labels, mode):
_, _ = features, labels
@ -934,33 +866,6 @@ class EstimatorEvaluateTest(test.TestCase):
class EstimatorPredictTest(test.TestCase):
def test_input_fn_args(self):
expected_params = {'batch_size': 10}
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
input_fn_call_count = [0]
def _model_fn(features, labels, mode, params, config):
del features, labels, params, config
return model_fn_lib.EstimatorSpec(
mode,
loss=constant_op.constant(0.),
train_op=state_ops.assign_add(training.get_global_step(), 1),
predictions=constant_op.constant([[10.]]))
def _input_fn(params, config):
input_fn_call_count[0] += 1
self.assertEqual(expected_params, params)
self.assertEqual(4321, config.tf_random_seed)
return dummy_input_fn()
est = estimator.Estimator(model_fn=_model_fn,
params=expected_params,
config=expected_config)
est.train(dummy_input_fn, steps=1)
self.assertEqual(0, input_fn_call_count[0])
next(est.predict(_input_fn))
self.assertEqual(1, input_fn_call_count[0])
def test_no_trained_model_in_model_dir(self):
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
with self.assertRaisesRegexp(ValueError,