diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 1e3d6d5755f..ea4715a97e3 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -683,9 +683,10 @@ class Estimator(object): Raises: ValueError: if input_fn takes invalid arguments. """ - del mode # unused input_fn_args = util.fn_args(input_fn) kwargs = {} + if 'mode' in input_fn_args: + kwargs['mode'] = mode if 'params' in input_fn_args: kwargs['params'] = self.params if 'config' in input_fn_args: diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index db64fbc9ccc..f9c117a790a 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -418,6 +418,7 @@ class EstimatorTrainTest(test.TestCase): self.assertEqual(1, model_fn_call_count[0]) def test_callable_input_fn(self): + expected_mode = model_fn_lib.ModeKeys.TRAIN expected_params = {'batch_size': 10} expected_config = run_config.RunConfig().replace(tf_random_seed=4321) input_fn_call_count = [0] @@ -430,8 +431,9 @@ class EstimatorTrainTest(test.TestCase): class InputFn(object): - def __call__(self, params, config): + def __call__(self, mode, params, config): input_fn_call_count[0] += 1 + test_self.assertEqual(expected_mode, mode) test_self.assertEqual(expected_params, params) test_self.assertEqual(4321, config.tf_random_seed) return dummy_input_fn() @@ -444,6 +446,7 @@ class EstimatorTrainTest(test.TestCase): self.assertEqual(1, input_fn_call_count[0]) def test_input_fn_args(self): + expected_mode = model_fn_lib.ModeKeys.TRAIN expected_params = {'batch_size': 10} expected_config = run_config.RunConfig().replace(tf_random_seed=4321) input_fn_call_count = [0] @@ -452,8 +455,9 @@ class EstimatorTrainTest(test.TestCase): del params, config return model_fn_global_step_incrementer(features, labels, mode) - def _input_fn(params, config): + def _input_fn(mode, params, config): input_fn_call_count[0] += 1 + self.assertEqual(expected_mode, mode) self.assertEqual(expected_params, params) self.assertEqual(4321, config.tf_random_seed) return dummy_input_fn() @@ -990,6 +994,7 @@ class EstimatorDatasetIntegrationTest(test.TestCase): class EstimatorEvaluateTest(test.TestCase): def test_input_fn_args(self): + expected_mode = model_fn_lib.ModeKeys.EVAL expected_params = {'batch_size': 10} expected_config = run_config.RunConfig().replace(tf_random_seed=4321) input_fn_call_count = [0] @@ -998,8 +1003,9 @@ class EstimatorEvaluateTest(test.TestCase): del params, config return model_fn_global_step_incrementer(features, labels, mode) - def _input_fn(params, config): + def _input_fn(mode, params, config): input_fn_call_count[0] += 1 + self.assertEqual(expected_mode, mode) self.assertEqual(expected_params, params) self.assertEqual(4321, config.tf_random_seed) return dummy_input_fn() @@ -1263,6 +1269,7 @@ class EstimatorEvaluateTest(test.TestCase): class EstimatorPredictTest(test.TestCase): def test_input_fn_args(self): + expected_mode = model_fn_lib.ModeKeys.PREDICT expected_params = {'batch_size': 10} expected_config = run_config.RunConfig().replace(tf_random_seed=4321) input_fn_call_count = [0] @@ -1275,8 +1282,9 @@ class EstimatorPredictTest(test.TestCase): train_op=state_ops.assign_add(training.get_global_step(), 1), predictions=constant_op.constant([[10.]])) - def _input_fn(params, config): + def _input_fn(mode, params, config): input_fn_call_count[0] += 1 + self.assertEqual(expected_mode, mode) self.assertEqual(expected_params, params) self.assertEqual(4321, config.tf_random_seed) return dummy_input_fn()