Checking that TPUEstimator model function features have static shapes.

PiperOrigin-RevId: 200306833
This commit is contained in:
A. Unique TensorFlower 2018-06-12 17:03:59 -07:00 committed by TensorFlower Gardener
parent abc55107eb
commit db2f9fd007

View File

@ -1343,8 +1343,55 @@ class _ModelFnWrapper(object):
key, tensor))
return predictions
def _validate_model_features_and_labels(self,
features,
labels,
is_export_mode):
"""Validates that the features and labels for the model function are valid.
A valid features/labels object is the one with:
- Type: Tensor or a dictionary of Tensors
- Static shape if is_export_mode is False.
Args:
features: the features that would be input to the model function.
labels: the labels that would be input to the model function.
is_export_mode: boolean value specifying if in export mode.
Raises:
TypeError: If features/labels are not of the correct type.
ValueError: If features/labels have dynamic shape.
"""
def validate(obj, obj_name):
"""Helper validate function."""
if not isinstance(obj, ops.Tensor) and not isinstance(obj, dict):
raise TypeError(
'The {} to the model returned by input_fn must be either a Tensor '
'or a dictionary of Tensors. {}: {}'.format(obj_name, obj_name,
obj))
if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode):
return
if isinstance(obj, ops.Tensor):
if not obj.get_shape().is_fully_defined():
raise ValueError(
'The {} to the model returned by input_fn must have static shape.'
' Tensor: {}'.format(obj_name, obj))
else:
for (key, tensor) in obj.items():
if not tensor.get_shape().is_fully_defined():
raise ValueError(
'The {} to the model returned by input_fn must have static '
'shape. Key: \'{}\', Tensor: {}'.format(
obj_name, key, tensor))
validate(features, 'features')
if labels is not None:
validate(labels, 'labels')
def _call_model_fn(self, features, labels, is_export_mode=False):
"""Calls the model_fn with required parameters."""
self._validate_model_features_and_labels(features, labels, is_export_mode)
model_fn_args = function_utils.fn_args(self._model_fn)
kwargs = {}