Fixed the missing labels test in TPUEstimator.
PiperOrigin-RevId: 161131282
This commit is contained in:
parent
7d5c74a9c8
commit
53604916ed
@ -360,6 +360,12 @@ def _call_model_fn(model_fn, features, labels, mode, config, params,
|
||||
"""Calls the model_fn with required parameters."""
|
||||
model_fn_args = util.fn_args(model_fn)
|
||||
kwargs = {}
|
||||
if 'labels' in model_fn_args:
|
||||
kwargs['labels'] = labels
|
||||
else:
|
||||
if labels is not None:
|
||||
raise ValueError(
|
||||
'model_fn does not take labels, but input_fn returns labels.')
|
||||
if 'mode' in model_fn_args:
|
||||
kwargs['mode'] = mode
|
||||
if 'config' in model_fn_args:
|
||||
@ -371,7 +377,7 @@ def _call_model_fn(model_fn, features, labels, mode, config, params,
|
||||
'model_fn ({}) does not include params argument, '
|
||||
'required by TPUEstimator to pass batch size as '
|
||||
'params[\'batch_size\']'.format(model_fn))
|
||||
return model_fn(features=features, labels=labels, **kwargs)
|
||||
return model_fn(features=features, **kwargs)
|
||||
|
||||
|
||||
def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params):
|
||||
|
Loading…
Reference in New Issue
Block a user