Use six.string_types instead of str in estimator/export.
Change: 155087824
This commit is contained in:
parent
42c7659edd
commit
1e4899035e
tensorflow/python/estimator/export
@ -23,6 +23,8 @@ import collections
|
||||
import os
|
||||
import time
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -56,7 +58,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver',
|
||||
if not isinstance(features, dict):
|
||||
features = {_SINGLE_FEATURE_DEFAULT_NAME: features}
|
||||
for name, tensor in features.items():
|
||||
if not isinstance(name, str):
|
||||
if not isinstance(name, six.string_types):
|
||||
raise ValueError('feature keys must be strings: {}.'.format(name))
|
||||
if not (isinstance(tensor, ops.Tensor)
|
||||
or isinstance(tensor, sparse_tensor.SparseTensor)):
|
||||
@ -68,7 +70,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver',
|
||||
if not isinstance(receiver_tensors, dict):
|
||||
receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
|
||||
for name, tensor in receiver_tensors.items():
|
||||
if not isinstance(name, str):
|
||||
if not isinstance(name, six.string_types):
|
||||
raise ValueError(
|
||||
'receiver_tensors keys must be strings: {}.'.format(name))
|
||||
if not isinstance(tensor, ops.Tensor):
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import abc
|
||||
|
||||
import six
|
||||
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -171,7 +173,7 @@ class PredictOutput(ExportOutput):
|
||||
'Prediction outputs must be given as a dict of string to Tensor; '
|
||||
'got {}'.format(outputs))
|
||||
for key, value in outputs.items():
|
||||
if not isinstance(key, str):
|
||||
if not isinstance(key, six.string_types):
|
||||
raise ValueError(
|
||||
'Prediction output key must be a string; got {}.'.format(key))
|
||||
if not isinstance(value, ops.Tensor):
|
||||
|
@ -22,7 +22,9 @@ from tensorflow.core.framework import tensor_shape_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.python.estimator.export import export_output as export_output_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
@ -197,6 +199,33 @@ class ExportOutputTest(test.TestCase):
|
||||
signature_constants.CLASSIFY_METHOD_NAME)
|
||||
self.assertEqual(actual_signature_def, expected_signature_def)
|
||||
|
||||
def test_predict_output_constructor(self):
|
||||
"""Tests that no errors are raised when input is expected."""
|
||||
outputs = {
|
||||
"output0": constant_op.constant([0]),
|
||||
u"output1": constant_op.constant([1]),
|
||||
}
|
||||
export_output_lib.PredictOutput(outputs)
|
||||
|
||||
def test_predict_output_outputs_invalid(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"Prediction outputs must be given as a dict of string to Tensor"):
|
||||
export_output_lib.PredictOutput(constant_op.constant([0]))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"Prediction output key must be a string"):
|
||||
export_output_lib.PredictOutput({1: constant_op.constant([0])})
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"Prediction output value must be a Tensor"):
|
||||
export_output_lib.PredictOutput({
|
||||
"prediction1": sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -28,13 +28,11 @@ from tensorflow.core.example import example_pb2
|
||||
from tensorflow.python.estimator.export import export
|
||||
from tensorflow.python.estimator.export import export_output
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
@ -43,6 +41,69 @@ from tensorflow.python.saved_model import signature_def_utils
|
||||
|
||||
class ExportTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_serving_input_receiver_constructor(self):
|
||||
"""Tests that no errors are raised when input is expected."""
|
||||
features = {
|
||||
"feature0": constant_op.constant([0]),
|
||||
u"feature1": constant_op.constant([1]),
|
||||
"feature2": sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
|
||||
}
|
||||
receiver_tensors = {
|
||||
"example0": array_ops.placeholder(dtypes.string, name="example0"),
|
||||
u"example1": array_ops.placeholder(dtypes.string, name="example1"),
|
||||
}
|
||||
export.ServingInputReceiver(features, receiver_tensors)
|
||||
|
||||
def test_serving_input_receiver_features_invalid(self):
|
||||
receiver_tensors = {
|
||||
"example0": array_ops.placeholder(dtypes.string, name="example0"),
|
||||
u"example1": array_ops.placeholder(dtypes.string, name="example1"),
|
||||
}
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "features must be defined"):
|
||||
export.ServingInputReceiver(
|
||||
features=None,
|
||||
receiver_tensors=receiver_tensors)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "feature keys must be strings"):
|
||||
export.ServingInputReceiver(
|
||||
features={1: constant_op.constant([1])},
|
||||
receiver_tensors=receiver_tensors)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "feature feature1 must be a Tensor or SparseTensor"):
|
||||
export.ServingInputReceiver(
|
||||
features={"feature1": [1]},
|
||||
receiver_tensors=receiver_tensors)
|
||||
|
||||
def test_serving_input_receiver_receiver_tensors_invalid(self):
|
||||
features = {
|
||||
"feature0": constant_op.constant([0]),
|
||||
u"feature1": constant_op.constant([1]),
|
||||
"feature2": sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
|
||||
}
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "receiver_tensors must be defined"):
|
||||
export.ServingInputReceiver(
|
||||
features=features,
|
||||
receiver_tensors=None)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "receiver_tensors keys must be strings"):
|
||||
export.ServingInputReceiver(
|
||||
features=features,
|
||||
receiver_tensors={
|
||||
1: array_ops.placeholder(dtypes.string, name="example0")})
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "receiver_tensor example1 must be a Tensor"):
|
||||
export.ServingInputReceiver(
|
||||
features=features,
|
||||
receiver_tensors={"example1": [1]})
|
||||
|
||||
def test_single_feature_single_receiver(self):
|
||||
feature = constant_op.constant(5)
|
||||
receiver_tensor = array_ops.placeholder(dtypes.string)
|
||||
|
Loading…
Reference in New Issue
Block a user