Checking that TPUEstimator model function features have static shapes.
PiperOrigin-RevId: 200306833
This commit is contained in:
parent
abc55107eb
commit
db2f9fd007
@ -1343,8 +1343,55 @@ class _ModelFnWrapper(object):
|
|||||||
key, tensor))
|
key, tensor))
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
def _validate_model_features_and_labels(self,
|
||||||
|
features,
|
||||||
|
labels,
|
||||||
|
is_export_mode):
|
||||||
|
"""Validates that the features and labels for the model function are valid.
|
||||||
|
|
||||||
|
A valid features/labels object is the one with:
|
||||||
|
- Type: Tensor or a dictionary of Tensors
|
||||||
|
- Static shape if is_export_mode is False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: the features that would be input to the model function.
|
||||||
|
labels: the labels that would be input to the model function.
|
||||||
|
is_export_mode: boolean value specifying if in export mode.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If features/labels are not of the correct type.
|
||||||
|
ValueError: If features/labels have dynamic shape.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def validate(obj, obj_name):
|
||||||
|
"""Helper validate function."""
|
||||||
|
if not isinstance(obj, ops.Tensor) and not isinstance(obj, dict):
|
||||||
|
raise TypeError(
|
||||||
|
'The {} to the model returned by input_fn must be either a Tensor '
|
||||||
|
'or a dictionary of Tensors. {}: {}'.format(obj_name, obj_name,
|
||||||
|
obj))
|
||||||
|
if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode):
|
||||||
|
return
|
||||||
|
if isinstance(obj, ops.Tensor):
|
||||||
|
if not obj.get_shape().is_fully_defined():
|
||||||
|
raise ValueError(
|
||||||
|
'The {} to the model returned by input_fn must have static shape.'
|
||||||
|
' Tensor: {}'.format(obj_name, obj))
|
||||||
|
else:
|
||||||
|
for (key, tensor) in obj.items():
|
||||||
|
if not tensor.get_shape().is_fully_defined():
|
||||||
|
raise ValueError(
|
||||||
|
'The {} to the model returned by input_fn must have static '
|
||||||
|
'shape. Key: \'{}\', Tensor: {}'.format(
|
||||||
|
obj_name, key, tensor))
|
||||||
|
|
||||||
|
validate(features, 'features')
|
||||||
|
if labels is not None:
|
||||||
|
validate(labels, 'labels')
|
||||||
|
|
||||||
def _call_model_fn(self, features, labels, is_export_mode=False):
|
def _call_model_fn(self, features, labels, is_export_mode=False):
|
||||||
"""Calls the model_fn with required parameters."""
|
"""Calls the model_fn with required parameters."""
|
||||||
|
self._validate_model_features_and_labels(features, labels, is_export_mode)
|
||||||
model_fn_args = function_utils.fn_args(self._model_fn)
|
model_fn_args = function_utils.fn_args(self._model_fn)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user