From 53604916ed2acbe931d7773a6f6edcfc191956e5 Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Thu, 6 Jul 2017 14:29:37 -0700 Subject: [PATCH] Fixed the missing labels test in TPUEstimator. PiperOrigin-RevId: 161131282 --- tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index e001d866c35..193d14e1ce8 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -360,6 +360,12 @@ def _call_model_fn(model_fn, features, labels, mode, config, params, """Calls the model_fn with required parameters.""" model_fn_args = util.fn_args(model_fn) kwargs = {} + if 'labels' in model_fn_args: + kwargs['labels'] = labels + else: + if labels is not None: + raise ValueError( + 'model_fn does not take labels, but input_fn returns labels.') if 'mode' in model_fn_args: kwargs['mode'] = mode if 'config' in model_fn_args: @@ -371,7 +377,7 @@ def _call_model_fn(model_fn, features, labels, mode, config, params, 'model_fn ({}) does not include params argument, ' 'required by TPUEstimator to pass batch size as ' 'params[\'batch_size\']'.format(model_fn)) - return model_fn(features=features, labels=labels, **kwargs) + return model_fn(features=features, **kwargs) def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params):