Use six.string_types instead of str in estimator/export.

Change: 155087824
This commit is contained in:
A. Unique TensorFlower 2017-05-04 08:29:59 -08:00 committed by TensorFlower Gardener
parent 42c7659edd
commit 1e4899035e
4 changed files with 100 additions and 6 deletions

View File

@ -23,6 +23,8 @@ import collections
import os import os
import time import time
import six
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
@ -56,7 +58,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver',
if not isinstance(features, dict): if not isinstance(features, dict):
features = {_SINGLE_FEATURE_DEFAULT_NAME: features} features = {_SINGLE_FEATURE_DEFAULT_NAME: features}
for name, tensor in features.items(): for name, tensor in features.items():
if not isinstance(name, str): if not isinstance(name, six.string_types):
raise ValueError('feature keys must be strings: {}.'.format(name)) raise ValueError('feature keys must be strings: {}.'.format(name))
if not (isinstance(tensor, ops.Tensor) if not (isinstance(tensor, ops.Tensor)
or isinstance(tensor, sparse_tensor.SparseTensor)): or isinstance(tensor, sparse_tensor.SparseTensor)):
@ -68,7 +70,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver',
if not isinstance(receiver_tensors, dict): if not isinstance(receiver_tensors, dict):
receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
for name, tensor in receiver_tensors.items(): for name, tensor in receiver_tensors.items():
if not isinstance(name, str): if not isinstance(name, six.string_types):
raise ValueError( raise ValueError(
'receiver_tensors keys must be strings: {}.'.format(name)) 'receiver_tensors keys must be strings: {}.'.format(name))
if not isinstance(tensor, ops.Tensor): if not isinstance(tensor, ops.Tensor):

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import abc import abc
import six
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -171,7 +173,7 @@ class PredictOutput(ExportOutput):
'Prediction outputs must be given as a dict of string to Tensor; ' 'Prediction outputs must be given as a dict of string to Tensor; '
'got {}'.format(outputs)) 'got {}'.format(outputs))
for key, value in outputs.items(): for key, value in outputs.items():
if not isinstance(key, str): if not isinstance(key, six.string_types):
raise ValueError( raise ValueError(
'Prediction output key must be a string; got {}.'.format(key)) 'Prediction output key must be a string; got {}.'.format(key))
if not isinstance(value, ops.Tensor): if not isinstance(value, ops.Tensor):

View File

@ -22,7 +22,9 @@ from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2 from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_constants
@ -197,6 +199,33 @@ class ExportOutputTest(test.TestCase):
signature_constants.CLASSIFY_METHOD_NAME) signature_constants.CLASSIFY_METHOD_NAME)
self.assertEqual(actual_signature_def, expected_signature_def) self.assertEqual(actual_signature_def, expected_signature_def)
def test_predict_output_constructor(self):
"""Tests that no errors are raised when input is expected."""
outputs = {
"output0": constant_op.constant([0]),
u"output1": constant_op.constant([1]),
}
export_output_lib.PredictOutput(outputs)
def test_predict_output_outputs_invalid(self):
with self.assertRaisesRegexp(
ValueError,
"Prediction outputs must be given as a dict of string to Tensor"):
export_output_lib.PredictOutput(constant_op.constant([0]))
with self.assertRaisesRegexp(
ValueError,
"Prediction output key must be a string"):
export_output_lib.PredictOutput({1: constant_op.constant([0])})
with self.assertRaisesRegexp(
ValueError,
"Prediction output value must be a Tensor"):
export_output_lib.PredictOutput({
"prediction1": sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
})
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -28,13 +28,11 @@ from tensorflow.core.example import example_pb2
from tensorflow.python.estimator.export import export from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.export import export_output from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_constants
@ -43,6 +41,69 @@ from tensorflow.python.saved_model import signature_def_utils
class ExportTest(test_util.TensorFlowTestCase): class ExportTest(test_util.TensorFlowTestCase):
def test_serving_input_receiver_constructor(self):
"""Tests that no errors are raised when input is expected."""
features = {
"feature0": constant_op.constant([0]),
u"feature1": constant_op.constant([1]),
"feature2": sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
}
receiver_tensors = {
"example0": array_ops.placeholder(dtypes.string, name="example0"),
u"example1": array_ops.placeholder(dtypes.string, name="example1"),
}
export.ServingInputReceiver(features, receiver_tensors)
def test_serving_input_receiver_features_invalid(self):
receiver_tensors = {
"example0": array_ops.placeholder(dtypes.string, name="example0"),
u"example1": array_ops.placeholder(dtypes.string, name="example1"),
}
with self.assertRaisesRegexp(ValueError, "features must be defined"):
export.ServingInputReceiver(
features=None,
receiver_tensors=receiver_tensors)
with self.assertRaisesRegexp(ValueError, "feature keys must be strings"):
export.ServingInputReceiver(
features={1: constant_op.constant([1])},
receiver_tensors=receiver_tensors)
with self.assertRaisesRegexp(
ValueError, "feature feature1 must be a Tensor or SparseTensor"):
export.ServingInputReceiver(
features={"feature1": [1]},
receiver_tensors=receiver_tensors)
def test_serving_input_receiver_receiver_tensors_invalid(self):
features = {
"feature0": constant_op.constant([0]),
u"feature1": constant_op.constant([1]),
"feature2": sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
}
with self.assertRaisesRegexp(
ValueError, "receiver_tensors must be defined"):
export.ServingInputReceiver(
features=features,
receiver_tensors=None)
with self.assertRaisesRegexp(
ValueError, "receiver_tensors keys must be strings"):
export.ServingInputReceiver(
features=features,
receiver_tensors={
1: array_ops.placeholder(dtypes.string, name="example0")})
with self.assertRaisesRegexp(
ValueError, "receiver_tensor example1 must be a Tensor"):
export.ServingInputReceiver(
features=features,
receiver_tensors={"example1": [1]})
def test_single_feature_single_receiver(self): def test_single_feature_single_receiver(self):
feature = constant_op.constant(5) feature = constant_op.constant(5)
receiver_tensor = array_ops.placeholder(dtypes.string) receiver_tensor = array_ops.placeholder(dtypes.string)