From c17d085f3414a5da7bab90d3caae12ade2b02ed1 Mon Sep 17 00:00:00 2001
From: k-w-w <31663267+k-w-w@users.noreply.github.com>
Date: Thu, 28 Dec 2017 16:44:59 -0800
Subject: [PATCH] added mode to input_fn argument, and modified existing
 estimator unit tests (#14671)

---
 tensorflow/python/estimator/estimator.py      |  3 ++-
 tensorflow/python/estimator/estimator_test.py | 16 ++++++++++++----
 2 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 1e3d6d5755f..ea4715a97e3 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -683,9 +683,10 @@ class Estimator(object):
     Raises:
       ValueError: if input_fn takes invalid arguments.
     """
-    del mode  # unused
     input_fn_args = util.fn_args(input_fn)
     kwargs = {}
+    if 'mode' in input_fn_args:
+      kwargs['mode'] = mode
     if 'params' in input_fn_args:
       kwargs['params'] = self.params
     if 'config' in input_fn_args:
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index db64fbc9ccc..f9c117a790a 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -418,6 +418,7 @@ class EstimatorTrainTest(test.TestCase):
     self.assertEqual(1, model_fn_call_count[0])
 
   def test_callable_input_fn(self):
+    expected_mode = model_fn_lib.ModeKeys.TRAIN
     expected_params = {'batch_size': 10}
     expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
     input_fn_call_count = [0]
@@ -430,8 +431,9 @@ class EstimatorTrainTest(test.TestCase):
 
     class InputFn(object):
 
-      def __call__(self, params, config):
+      def __call__(self, mode, params, config):
         input_fn_call_count[0] += 1
+        test_self.assertEqual(expected_mode, mode)
         test_self.assertEqual(expected_params, params)
         test_self.assertEqual(4321, config.tf_random_seed)
         return dummy_input_fn()
@@ -444,6 +446,7 @@ class EstimatorTrainTest(test.TestCase):
     self.assertEqual(1, input_fn_call_count[0])
 
   def test_input_fn_args(self):
+    expected_mode = model_fn_lib.ModeKeys.TRAIN
     expected_params = {'batch_size': 10}
     expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
     input_fn_call_count = [0]
@@ -452,8 +455,9 @@ class EstimatorTrainTest(test.TestCase):
       del params, config
       return model_fn_global_step_incrementer(features, labels, mode)
 
-    def _input_fn(params, config):
+    def _input_fn(mode, params, config):
       input_fn_call_count[0] += 1
+      self.assertEqual(expected_mode, mode)
       self.assertEqual(expected_params, params)
       self.assertEqual(4321, config.tf_random_seed)
       return dummy_input_fn()
@@ -990,6 +994,7 @@ class EstimatorDatasetIntegrationTest(test.TestCase):
 class EstimatorEvaluateTest(test.TestCase):
 
   def test_input_fn_args(self):
+    expected_mode = model_fn_lib.ModeKeys.EVAL
     expected_params = {'batch_size': 10}
     expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
     input_fn_call_count = [0]
@@ -998,8 +1003,9 @@ class EstimatorEvaluateTest(test.TestCase):
       del params, config
       return model_fn_global_step_incrementer(features, labels, mode)
 
-    def _input_fn(params, config):
+    def _input_fn(mode, params, config):
       input_fn_call_count[0] += 1
+      self.assertEqual(expected_mode, mode)
       self.assertEqual(expected_params, params)
       self.assertEqual(4321, config.tf_random_seed)
       return dummy_input_fn()
@@ -1263,6 +1269,7 @@ class EstimatorEvaluateTest(test.TestCase):
 class EstimatorPredictTest(test.TestCase):
 
   def test_input_fn_args(self):
+    expected_mode = model_fn_lib.ModeKeys.PREDICT
     expected_params = {'batch_size': 10}
     expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
     input_fn_call_count = [0]
@@ -1275,8 +1282,9 @@ class EstimatorPredictTest(test.TestCase):
           train_op=state_ops.assign_add(training.get_global_step(), 1),
           predictions=constant_op.constant([[10.]]))
 
-    def _input_fn(params, config):
+    def _input_fn(mode, params, config):
       input_fn_call_count[0] += 1
+      self.assertEqual(expected_mode, mode)
       self.assertEqual(expected_params, params)
       self.assertEqual(4321, config.tf_random_seed)
       return dummy_input_fn()