Fixed the missing labels test in TPUEstimator.

PiperOrigin-RevId: 161131282
This commit is contained in:
Jianwei Xie 2017-07-06 14:29:37 -07:00 committed by TensorFlower Gardener
parent 7d5c74a9c8
commit 53604916ed

View File

@ -360,6 +360,12 @@ def _call_model_fn(model_fn, features, labels, mode, config, params,
"""Calls the model_fn with required parameters."""
model_fn_args = util.fn_args(model_fn)
kwargs = {}
if 'labels' in model_fn_args:
kwargs['labels'] = labels
else:
if labels is not None:
raise ValueError(
'model_fn does not take labels, but input_fn returns labels.')
if 'mode' in model_fn_args:
kwargs['mode'] = mode
if 'config' in model_fn_args:
@ -371,7 +377,7 @@ def _call_model_fn(model_fn, features, labels, mode, config, params,
'model_fn ({}) does not include params argument, '
'required by TPUEstimator to pass batch size as '
'params[\'batch_size\']'.format(model_fn))
return model_fn(features=features, labels=labels, **kwargs)
return model_fn(features=features, **kwargs)
def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params):