added mode to input_fn argument, and modified existing estimator unit tests (#14671)

This commit is contained in:
k-w-w 2017-12-28 16:44:59 -08:00 committed by drpngx
parent 3a3b7530eb
commit c17d085f34
2 changed files with 14 additions and 5 deletions

View File

@ -683,9 +683,10 @@ class Estimator(object):
Raises:
ValueError: if input_fn takes invalid arguments.
"""
del mode # unused
input_fn_args = util.fn_args(input_fn)
kwargs = {}
if 'mode' in input_fn_args:
kwargs['mode'] = mode
if 'params' in input_fn_args:
kwargs['params'] = self.params
if 'config' in input_fn_args:

View File

@ -418,6 +418,7 @@ class EstimatorTrainTest(test.TestCase):
self.assertEqual(1, model_fn_call_count[0])
def test_callable_input_fn(self):
expected_mode = model_fn_lib.ModeKeys.TRAIN
expected_params = {'batch_size': 10}
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
input_fn_call_count = [0]
@ -430,8 +431,9 @@ class EstimatorTrainTest(test.TestCase):
class InputFn(object):
def __call__(self, params, config):
def __call__(self, mode, params, config):
input_fn_call_count[0] += 1
test_self.assertEqual(expected_mode, mode)
test_self.assertEqual(expected_params, params)
test_self.assertEqual(4321, config.tf_random_seed)
return dummy_input_fn()
@ -444,6 +446,7 @@ class EstimatorTrainTest(test.TestCase):
self.assertEqual(1, input_fn_call_count[0])
def test_input_fn_args(self):
expected_mode = model_fn_lib.ModeKeys.TRAIN
expected_params = {'batch_size': 10}
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
input_fn_call_count = [0]
@ -452,8 +455,9 @@ class EstimatorTrainTest(test.TestCase):
del params, config
return model_fn_global_step_incrementer(features, labels, mode)
def _input_fn(params, config):
def _input_fn(mode, params, config):
input_fn_call_count[0] += 1
self.assertEqual(expected_mode, mode)
self.assertEqual(expected_params, params)
self.assertEqual(4321, config.tf_random_seed)
return dummy_input_fn()
@ -990,6 +994,7 @@ class EstimatorDatasetIntegrationTest(test.TestCase):
class EstimatorEvaluateTest(test.TestCase):
def test_input_fn_args(self):
expected_mode = model_fn_lib.ModeKeys.EVAL
expected_params = {'batch_size': 10}
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
input_fn_call_count = [0]
@ -998,8 +1003,9 @@ class EstimatorEvaluateTest(test.TestCase):
del params, config
return model_fn_global_step_incrementer(features, labels, mode)
def _input_fn(params, config):
def _input_fn(mode, params, config):
input_fn_call_count[0] += 1
self.assertEqual(expected_mode, mode)
self.assertEqual(expected_params, params)
self.assertEqual(4321, config.tf_random_seed)
return dummy_input_fn()
@ -1263,6 +1269,7 @@ class EstimatorEvaluateTest(test.TestCase):
class EstimatorPredictTest(test.TestCase):
def test_input_fn_args(self):
expected_mode = model_fn_lib.ModeKeys.PREDICT
expected_params = {'batch_size': 10}
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
input_fn_call_count = [0]
@ -1275,8 +1282,9 @@ class EstimatorPredictTest(test.TestCase):
train_op=state_ops.assign_add(training.get_global_step(), 1),
predictions=constant_op.constant([[10.]]))
def _input_fn(params, config):
def _input_fn(mode, params, config):
input_fn_call_count[0] += 1
self.assertEqual(expected_mode, mode)
self.assertEqual(expected_params, params)
self.assertEqual(4321, config.tf_random_seed)
return dummy_input_fn()