Fixed the missing labels test in TPUEstimator.
PiperOrigin-RevId: 161131282
This commit is contained in:
parent
7d5c74a9c8
commit
53604916ed
@ -360,6 +360,12 @@ def _call_model_fn(model_fn, features, labels, mode, config, params,
|
|||||||
"""Calls the model_fn with required parameters."""
|
"""Calls the model_fn with required parameters."""
|
||||||
model_fn_args = util.fn_args(model_fn)
|
model_fn_args = util.fn_args(model_fn)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
if 'labels' in model_fn_args:
|
||||||
|
kwargs['labels'] = labels
|
||||||
|
else:
|
||||||
|
if labels is not None:
|
||||||
|
raise ValueError(
|
||||||
|
'model_fn does not take labels, but input_fn returns labels.')
|
||||||
if 'mode' in model_fn_args:
|
if 'mode' in model_fn_args:
|
||||||
kwargs['mode'] = mode
|
kwargs['mode'] = mode
|
||||||
if 'config' in model_fn_args:
|
if 'config' in model_fn_args:
|
||||||
@ -371,7 +377,7 @@ def _call_model_fn(model_fn, features, labels, mode, config, params,
|
|||||||
'model_fn ({}) does not include params argument, '
|
'model_fn ({}) does not include params argument, '
|
||||||
'required by TPUEstimator to pass batch size as '
|
'required by TPUEstimator to pass batch size as '
|
||||||
'params[\'batch_size\']'.format(model_fn))
|
'params[\'batch_size\']'.format(model_fn))
|
||||||
return model_fn(features=features, labels=labels, **kwargs)
|
return model_fn(features=features, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params):
|
def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params):
|
||||||
|
Loading…
Reference in New Issue
Block a user