added mode to input_fn argument, and modified existing estimator unit tests (#14671)
This commit is contained in:
parent
3a3b7530eb
commit
c17d085f34
@ -683,9 +683,10 @@ class Estimator(object):
|
||||
Raises:
|
||||
ValueError: if input_fn takes invalid arguments.
|
||||
"""
|
||||
del mode # unused
|
||||
input_fn_args = util.fn_args(input_fn)
|
||||
kwargs = {}
|
||||
if 'mode' in input_fn_args:
|
||||
kwargs['mode'] = mode
|
||||
if 'params' in input_fn_args:
|
||||
kwargs['params'] = self.params
|
||||
if 'config' in input_fn_args:
|
||||
|
@ -418,6 +418,7 @@ class EstimatorTrainTest(test.TestCase):
|
||||
self.assertEqual(1, model_fn_call_count[0])
|
||||
|
||||
def test_callable_input_fn(self):
|
||||
expected_mode = model_fn_lib.ModeKeys.TRAIN
|
||||
expected_params = {'batch_size': 10}
|
||||
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
||||
input_fn_call_count = [0]
|
||||
@ -430,8 +431,9 @@ class EstimatorTrainTest(test.TestCase):
|
||||
|
||||
class InputFn(object):
|
||||
|
||||
def __call__(self, params, config):
|
||||
def __call__(self, mode, params, config):
|
||||
input_fn_call_count[0] += 1
|
||||
test_self.assertEqual(expected_mode, mode)
|
||||
test_self.assertEqual(expected_params, params)
|
||||
test_self.assertEqual(4321, config.tf_random_seed)
|
||||
return dummy_input_fn()
|
||||
@ -444,6 +446,7 @@ class EstimatorTrainTest(test.TestCase):
|
||||
self.assertEqual(1, input_fn_call_count[0])
|
||||
|
||||
def test_input_fn_args(self):
|
||||
expected_mode = model_fn_lib.ModeKeys.TRAIN
|
||||
expected_params = {'batch_size': 10}
|
||||
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
||||
input_fn_call_count = [0]
|
||||
@ -452,8 +455,9 @@ class EstimatorTrainTest(test.TestCase):
|
||||
del params, config
|
||||
return model_fn_global_step_incrementer(features, labels, mode)
|
||||
|
||||
def _input_fn(params, config):
|
||||
def _input_fn(mode, params, config):
|
||||
input_fn_call_count[0] += 1
|
||||
self.assertEqual(expected_mode, mode)
|
||||
self.assertEqual(expected_params, params)
|
||||
self.assertEqual(4321, config.tf_random_seed)
|
||||
return dummy_input_fn()
|
||||
@ -990,6 +994,7 @@ class EstimatorDatasetIntegrationTest(test.TestCase):
|
||||
class EstimatorEvaluateTest(test.TestCase):
|
||||
|
||||
def test_input_fn_args(self):
|
||||
expected_mode = model_fn_lib.ModeKeys.EVAL
|
||||
expected_params = {'batch_size': 10}
|
||||
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
||||
input_fn_call_count = [0]
|
||||
@ -998,8 +1003,9 @@ class EstimatorEvaluateTest(test.TestCase):
|
||||
del params, config
|
||||
return model_fn_global_step_incrementer(features, labels, mode)
|
||||
|
||||
def _input_fn(params, config):
|
||||
def _input_fn(mode, params, config):
|
||||
input_fn_call_count[0] += 1
|
||||
self.assertEqual(expected_mode, mode)
|
||||
self.assertEqual(expected_params, params)
|
||||
self.assertEqual(4321, config.tf_random_seed)
|
||||
return dummy_input_fn()
|
||||
@ -1263,6 +1269,7 @@ class EstimatorEvaluateTest(test.TestCase):
|
||||
class EstimatorPredictTest(test.TestCase):
|
||||
|
||||
def test_input_fn_args(self):
|
||||
expected_mode = model_fn_lib.ModeKeys.PREDICT
|
||||
expected_params = {'batch_size': 10}
|
||||
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
||||
input_fn_call_count = [0]
|
||||
@ -1275,8 +1282,9 @@ class EstimatorPredictTest(test.TestCase):
|
||||
train_op=state_ops.assign_add(training.get_global_step(), 1),
|
||||
predictions=constant_op.constant([[10.]]))
|
||||
|
||||
def _input_fn(params, config):
|
||||
def _input_fn(mode, params, config):
|
||||
input_fn_call_count[0] += 1
|
||||
self.assertEqual(expected_mode, mode)
|
||||
self.assertEqual(expected_params, params)
|
||||
self.assertEqual(4321, config.tf_random_seed)
|
||||
return dummy_input_fn()
|
||||
|
Loading…
Reference in New Issue
Block a user