Use six.string_types instead of str in estimator/export.

Change: 155087824
This commit is contained in:
A. Unique TensorFlower 2017-05-04 08:29:59 -08:00 committed by TensorFlower Gardener
parent 42c7659edd
commit 1e4899035e
4 changed files with 100 additions and 6 deletions

View File

@ -23,6 +23,8 @@ import collections
import os
import time
import six
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@ -56,7 +58,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver',
if not isinstance(features, dict):
features = {_SINGLE_FEATURE_DEFAULT_NAME: features}
for name, tensor in features.items():
if not isinstance(name, str):
if not isinstance(name, six.string_types):
raise ValueError('feature keys must be strings: {}.'.format(name))
if not (isinstance(tensor, ops.Tensor)
or isinstance(tensor, sparse_tensor.SparseTensor)):
@ -68,7 +70,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver',
if not isinstance(receiver_tensors, dict):
receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
for name, tensor in receiver_tensors.items():
if not isinstance(name, str):
if not isinstance(name, six.string_types):
raise ValueError(
'receiver_tensors keys must be strings: {}.'.format(name))
if not isinstance(tensor, ops.Tensor):

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import abc
import six
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -171,7 +173,7 @@ class PredictOutput(ExportOutput):
'Prediction outputs must be given as a dict of string to Tensor; '
'got {}'.format(outputs))
for key, value in outputs.items():
if not isinstance(key, str):
if not isinstance(key, six.string_types):
raise ValueError(
'Prediction output key must be a string; got {}.'.format(key))
if not isinstance(value, ops.Tensor):

View File

@ -22,7 +22,9 @@ from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
@ -197,6 +199,33 @@ class ExportOutputTest(test.TestCase):
signature_constants.CLASSIFY_METHOD_NAME)
self.assertEqual(actual_signature_def, expected_signature_def)
def test_predict_output_constructor(self):
"""Tests that no errors are raised when input is expected."""
outputs = {
"output0": constant_op.constant([0]),
u"output1": constant_op.constant([1]),
}
export_output_lib.PredictOutput(outputs)
def test_predict_output_outputs_invalid(self):
with self.assertRaisesRegexp(
ValueError,
"Prediction outputs must be given as a dict of string to Tensor"):
export_output_lib.PredictOutput(constant_op.constant([0]))
with self.assertRaisesRegexp(
ValueError,
"Prediction output key must be a string"):
export_output_lib.PredictOutput({1: constant_op.constant([0])})
with self.assertRaisesRegexp(
ValueError,
"Prediction output value must be a Tensor"):
export_output_lib.PredictOutput({
"prediction1": sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
})
if __name__ == "__main__":
test.main()

View File

@ -28,13 +28,11 @@ from tensorflow.core.example import example_pb2
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
@ -43,6 +41,69 @@ from tensorflow.python.saved_model import signature_def_utils
class ExportTest(test_util.TensorFlowTestCase):
def test_serving_input_receiver_constructor(self):
"""Tests that no errors are raised when input is expected."""
features = {
"feature0": constant_op.constant([0]),
u"feature1": constant_op.constant([1]),
"feature2": sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
}
receiver_tensors = {
"example0": array_ops.placeholder(dtypes.string, name="example0"),
u"example1": array_ops.placeholder(dtypes.string, name="example1"),
}
export.ServingInputReceiver(features, receiver_tensors)
def test_serving_input_receiver_features_invalid(self):
receiver_tensors = {
"example0": array_ops.placeholder(dtypes.string, name="example0"),
u"example1": array_ops.placeholder(dtypes.string, name="example1"),
}
with self.assertRaisesRegexp(ValueError, "features must be defined"):
export.ServingInputReceiver(
features=None,
receiver_tensors=receiver_tensors)
with self.assertRaisesRegexp(ValueError, "feature keys must be strings"):
export.ServingInputReceiver(
features={1: constant_op.constant([1])},
receiver_tensors=receiver_tensors)
with self.assertRaisesRegexp(
ValueError, "feature feature1 must be a Tensor or SparseTensor"):
export.ServingInputReceiver(
features={"feature1": [1]},
receiver_tensors=receiver_tensors)
def test_serving_input_receiver_receiver_tensors_invalid(self):
features = {
"feature0": constant_op.constant([0]),
u"feature1": constant_op.constant([1]),
"feature2": sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
}
with self.assertRaisesRegexp(
ValueError, "receiver_tensors must be defined"):
export.ServingInputReceiver(
features=features,
receiver_tensors=None)
with self.assertRaisesRegexp(
ValueError, "receiver_tensors keys must be strings"):
export.ServingInputReceiver(
features=features,
receiver_tensors={
1: array_ops.placeholder(dtypes.string, name="example0")})
with self.assertRaisesRegexp(
ValueError, "receiver_tensor example1 must be a Tensor"):
export.ServingInputReceiver(
features=features,
receiver_tensors={"example1": [1]})
def test_single_feature_single_receiver(self):
feature = constant_op.constant(5)
receiver_tensor = array_ops.placeholder(dtypes.string)