Checking that TPUEstimator model function features have static shapes.
PiperOrigin-RevId: 200306833
This commit is contained in:
parent
abc55107eb
commit
db2f9fd007
@ -1343,8 +1343,55 @@ class _ModelFnWrapper(object):
|
||||
key, tensor))
|
||||
return predictions
|
||||
|
||||
def _validate_model_features_and_labels(self,
|
||||
features,
|
||||
labels,
|
||||
is_export_mode):
|
||||
"""Validates that the features and labels for the model function are valid.
|
||||
|
||||
A valid features/labels object is the one with:
|
||||
- Type: Tensor or a dictionary of Tensors
|
||||
- Static shape if is_export_mode is False.
|
||||
|
||||
Args:
|
||||
features: the features that would be input to the model function.
|
||||
labels: the labels that would be input to the model function.
|
||||
is_export_mode: boolean value specifying if in export mode.
|
||||
|
||||
Raises:
|
||||
TypeError: If features/labels are not of the correct type.
|
||||
ValueError: If features/labels have dynamic shape.
|
||||
"""
|
||||
|
||||
def validate(obj, obj_name):
|
||||
"""Helper validate function."""
|
||||
if not isinstance(obj, ops.Tensor) and not isinstance(obj, dict):
|
||||
raise TypeError(
|
||||
'The {} to the model returned by input_fn must be either a Tensor '
|
||||
'or a dictionary of Tensors. {}: {}'.format(obj_name, obj_name,
|
||||
obj))
|
||||
if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode):
|
||||
return
|
||||
if isinstance(obj, ops.Tensor):
|
||||
if not obj.get_shape().is_fully_defined():
|
||||
raise ValueError(
|
||||
'The {} to the model returned by input_fn must have static shape.'
|
||||
' Tensor: {}'.format(obj_name, obj))
|
||||
else:
|
||||
for (key, tensor) in obj.items():
|
||||
if not tensor.get_shape().is_fully_defined():
|
||||
raise ValueError(
|
||||
'The {} to the model returned by input_fn must have static '
|
||||
'shape. Key: \'{}\', Tensor: {}'.format(
|
||||
obj_name, key, tensor))
|
||||
|
||||
validate(features, 'features')
|
||||
if labels is not None:
|
||||
validate(labels, 'labels')
|
||||
|
||||
def _call_model_fn(self, features, labels, is_export_mode=False):
|
||||
"""Calls the model_fn with required parameters."""
|
||||
self._validate_model_features_and_labels(features, labels, is_export_mode)
|
||||
model_fn_args = function_utils.fn_args(self._model_fn)
|
||||
kwargs = {}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user