Use six.string_types instead of str in estimator/export.
Change: 155087824
This commit is contained in:
parent
42c7659edd
commit
1e4899035e
@ -23,6 +23,8 @@ import collections
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
@ -56,7 +58,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver',
|
|||||||
if not isinstance(features, dict):
|
if not isinstance(features, dict):
|
||||||
features = {_SINGLE_FEATURE_DEFAULT_NAME: features}
|
features = {_SINGLE_FEATURE_DEFAULT_NAME: features}
|
||||||
for name, tensor in features.items():
|
for name, tensor in features.items():
|
||||||
if not isinstance(name, str):
|
if not isinstance(name, six.string_types):
|
||||||
raise ValueError('feature keys must be strings: {}.'.format(name))
|
raise ValueError('feature keys must be strings: {}.'.format(name))
|
||||||
if not (isinstance(tensor, ops.Tensor)
|
if not (isinstance(tensor, ops.Tensor)
|
||||||
or isinstance(tensor, sparse_tensor.SparseTensor)):
|
or isinstance(tensor, sparse_tensor.SparseTensor)):
|
||||||
@ -68,7 +70,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver',
|
|||||||
if not isinstance(receiver_tensors, dict):
|
if not isinstance(receiver_tensors, dict):
|
||||||
receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
|
receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
|
||||||
for name, tensor in receiver_tensors.items():
|
for name, tensor in receiver_tensors.items():
|
||||||
if not isinstance(name, str):
|
if not isinstance(name, six.string_types):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'receiver_tensors keys must be strings: {}.'.format(name))
|
'receiver_tensors keys must be strings: {}.'.format(name))
|
||||||
if not isinstance(tensor, ops.Tensor):
|
if not isinstance(tensor, ops.Tensor):
|
||||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -171,7 +173,7 @@ class PredictOutput(ExportOutput):
|
|||||||
'Prediction outputs must be given as a dict of string to Tensor; '
|
'Prediction outputs must be given as a dict of string to Tensor; '
|
||||||
'got {}'.format(outputs))
|
'got {}'.format(outputs))
|
||||||
for key, value in outputs.items():
|
for key, value in outputs.items():
|
||||||
if not isinstance(key, str):
|
if not isinstance(key, six.string_types):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Prediction output key must be a string; got {}.'.format(key))
|
'Prediction output key must be a string; got {}.'.format(key))
|
||||||
if not isinstance(value, ops.Tensor):
|
if not isinstance(value, ops.Tensor):
|
||||||
|
@ -22,7 +22,9 @@ from tensorflow.core.framework import tensor_shape_pb2
|
|||||||
from tensorflow.core.framework import types_pb2
|
from tensorflow.core.framework import types_pb2
|
||||||
from tensorflow.core.protobuf import meta_graph_pb2
|
from tensorflow.core.protobuf import meta_graph_pb2
|
||||||
from tensorflow.python.estimator.export import export_output as export_output_lib
|
from tensorflow.python.estimator.export import export_output as export_output_lib
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.saved_model import signature_constants
|
from tensorflow.python.saved_model import signature_constants
|
||||||
@ -197,6 +199,33 @@ class ExportOutputTest(test.TestCase):
|
|||||||
signature_constants.CLASSIFY_METHOD_NAME)
|
signature_constants.CLASSIFY_METHOD_NAME)
|
||||||
self.assertEqual(actual_signature_def, expected_signature_def)
|
self.assertEqual(actual_signature_def, expected_signature_def)
|
||||||
|
|
||||||
|
def test_predict_output_constructor(self):
|
||||||
|
"""Tests that no errors are raised when input is expected."""
|
||||||
|
outputs = {
|
||||||
|
"output0": constant_op.constant([0]),
|
||||||
|
u"output1": constant_op.constant([1]),
|
||||||
|
}
|
||||||
|
export_output_lib.PredictOutput(outputs)
|
||||||
|
|
||||||
|
def test_predict_output_outputs_invalid(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
"Prediction outputs must be given as a dict of string to Tensor"):
|
||||||
|
export_output_lib.PredictOutput(constant_op.constant([0]))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
"Prediction output key must be a string"):
|
||||||
|
export_output_lib.PredictOutput({1: constant_op.constant([0])})
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
"Prediction output value must be a Tensor"):
|
||||||
|
export_output_lib.PredictOutput({
|
||||||
|
"prediction1": sparse_tensor.SparseTensor(
|
||||||
|
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -28,13 +28,11 @@ from tensorflow.core.example import example_pb2
|
|||||||
from tensorflow.python.estimator.export import export
|
from tensorflow.python.estimator.export import export
|
||||||
from tensorflow.python.estimator.export import export_output
|
from tensorflow.python.estimator.export import export_output
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import constant_op
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.saved_model import signature_constants
|
from tensorflow.python.saved_model import signature_constants
|
||||||
@ -43,6 +41,69 @@ from tensorflow.python.saved_model import signature_def_utils
|
|||||||
|
|
||||||
class ExportTest(test_util.TensorFlowTestCase):
|
class ExportTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def test_serving_input_receiver_constructor(self):
|
||||||
|
"""Tests that no errors are raised when input is expected."""
|
||||||
|
features = {
|
||||||
|
"feature0": constant_op.constant([0]),
|
||||||
|
u"feature1": constant_op.constant([1]),
|
||||||
|
"feature2": sparse_tensor.SparseTensor(
|
||||||
|
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
|
||||||
|
}
|
||||||
|
receiver_tensors = {
|
||||||
|
"example0": array_ops.placeholder(dtypes.string, name="example0"),
|
||||||
|
u"example1": array_ops.placeholder(dtypes.string, name="example1"),
|
||||||
|
}
|
||||||
|
export.ServingInputReceiver(features, receiver_tensors)
|
||||||
|
|
||||||
|
def test_serving_input_receiver_features_invalid(self):
|
||||||
|
receiver_tensors = {
|
||||||
|
"example0": array_ops.placeholder(dtypes.string, name="example0"),
|
||||||
|
u"example1": array_ops.placeholder(dtypes.string, name="example1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, "features must be defined"):
|
||||||
|
export.ServingInputReceiver(
|
||||||
|
features=None,
|
||||||
|
receiver_tensors=receiver_tensors)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, "feature keys must be strings"):
|
||||||
|
export.ServingInputReceiver(
|
||||||
|
features={1: constant_op.constant([1])},
|
||||||
|
receiver_tensors=receiver_tensors)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "feature feature1 must be a Tensor or SparseTensor"):
|
||||||
|
export.ServingInputReceiver(
|
||||||
|
features={"feature1": [1]},
|
||||||
|
receiver_tensors=receiver_tensors)
|
||||||
|
|
||||||
|
def test_serving_input_receiver_receiver_tensors_invalid(self):
|
||||||
|
features = {
|
||||||
|
"feature0": constant_op.constant([0]),
|
||||||
|
u"feature1": constant_op.constant([1]),
|
||||||
|
"feature2": sparse_tensor.SparseTensor(
|
||||||
|
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
|
||||||
|
}
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "receiver_tensors must be defined"):
|
||||||
|
export.ServingInputReceiver(
|
||||||
|
features=features,
|
||||||
|
receiver_tensors=None)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "receiver_tensors keys must be strings"):
|
||||||
|
export.ServingInputReceiver(
|
||||||
|
features=features,
|
||||||
|
receiver_tensors={
|
||||||
|
1: array_ops.placeholder(dtypes.string, name="example0")})
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "receiver_tensor example1 must be a Tensor"):
|
||||||
|
export.ServingInputReceiver(
|
||||||
|
features=features,
|
||||||
|
receiver_tensors={"example1": [1]})
|
||||||
|
|
||||||
def test_single_feature_single_receiver(self):
|
def test_single_feature_single_receiver(self):
|
||||||
feature = constant_op.constant(5)
|
feature = constant_op.constant(5)
|
||||||
receiver_tensor = array_ops.placeholder(dtypes.string)
|
receiver_tensor = array_ops.placeholder(dtypes.string)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user