From 1e4899035ed82c85f6c85b7349211528f161402c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Thu, 4 May 2017 08:29:59 -0800 Subject: [PATCH] Use six.string_types instead of str in estimator/export. Change: 155087824 --- tensorflow/python/estimator/export/export.py | 6 +- .../python/estimator/export/export_output.py | 4 +- .../estimator/export/export_output_test.py | 29 ++++++++ .../python/estimator/export/export_test.py | 67 ++++++++++++++++++- 4 files changed, 100 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 37a98cf4815..a1ecd794df6 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -23,6 +23,8 @@ import collections import os import time +import six + from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -56,7 +58,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver', if not isinstance(features, dict): features = {_SINGLE_FEATURE_DEFAULT_NAME: features} for name, tensor in features.items(): - if not isinstance(name, str): + if not isinstance(name, six.string_types): raise ValueError('feature keys must be strings: {}.'.format(name)) if not (isinstance(tensor, ops.Tensor) or isinstance(tensor, sparse_tensor.SparseTensor)): @@ -68,7 +70,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver', if not isinstance(receiver_tensors, dict): receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} for name, tensor in receiver_tensors.items(): - if not isinstance(name, str): + if not isinstance(name, six.string_types): raise ValueError( 'receiver_tensors keys must be strings: {}.'.format(name)) if not isinstance(tensor, ops.Tensor): diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index 69be0f687c1..49bcd06d504 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -20,6 +20,8 @@ from __future__ import print_function import abc +import six + from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -171,7 +173,7 @@ class PredictOutput(ExportOutput): 'Prediction outputs must be given as a dict of string to Tensor; ' 'got {}'.format(outputs)) for key, value in outputs.items(): - if not isinstance(key, str): + if not isinstance(key, six.string_types): raise ValueError( 'Prediction output key must be a string; got {}.'.format(key)) if not isinstance(value, ops.Tensor): diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py index 27a088e551c..035a9a143e6 100644 --- a/tensorflow/python/estimator/export/export_output_test.py +++ b/tensorflow/python/estimator/export/export_output_test.py @@ -22,7 +22,9 @@ from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.estimator.export import export_output as export_output_lib +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -197,6 +199,33 @@ class ExportOutputTest(test.TestCase): signature_constants.CLASSIFY_METHOD_NAME) self.assertEqual(actual_signature_def, expected_signature_def) + def test_predict_output_constructor(self): + """Tests that no errors are raised when input is expected.""" + outputs = { + "output0": constant_op.constant([0]), + u"output1": constant_op.constant([1]), + } + export_output_lib.PredictOutput(outputs) + + def test_predict_output_outputs_invalid(self): + with self.assertRaisesRegexp( + ValueError, + "Prediction outputs must be given as a dict of string to Tensor"): + export_output_lib.PredictOutput(constant_op.constant([0])) + + with self.assertRaisesRegexp( + ValueError, + "Prediction output key must be a string"): + export_output_lib.PredictOutput({1: constant_op.constant([0])}) + + with self.assertRaisesRegexp( + ValueError, + "Prediction output value must be a Tensor"): + export_output_lib.PredictOutput({ + "prediction1": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + }) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index fdd924f2e1c..7946bd88ba0 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -28,13 +28,11 @@ from tensorflow.core.example import example_pb2 from tensorflow.python.estimator.export import export from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import constant_op -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -43,6 +41,69 @@ from tensorflow.python.saved_model import signature_def_utils class ExportTest(test_util.TensorFlowTestCase): + def test_serving_input_receiver_constructor(self): + """Tests that no errors are raised when input is expected.""" + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + export.ServingInputReceiver(features, receiver_tensors) + + def test_serving_input_receiver_features_invalid(self): + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + + with self.assertRaisesRegexp(ValueError, "features must be defined"): + export.ServingInputReceiver( + features=None, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp(ValueError, "feature keys must be strings"): + export.ServingInputReceiver( + features={1: constant_op.constant([1])}, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "feature feature1 must be a Tensor or SparseTensor"): + export.ServingInputReceiver( + features={"feature1": [1]}, + receiver_tensors=receiver_tensors) + + def test_serving_input_receiver_receiver_tensors_invalid(self): + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors must be defined"): + export.ServingInputReceiver( + features=features, + receiver_tensors=None) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors keys must be strings"): + export.ServingInputReceiver( + features=features, + receiver_tensors={ + 1: array_ops.placeholder(dtypes.string, name="example0")}) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensor example1 must be a Tensor"): + export.ServingInputReceiver( + features=features, + receiver_tensors={"example1": [1]}) + def test_single_feature_single_receiver(self): feature = constant_op.constant(5) receiver_tensor = array_ops.placeholder(dtypes.string)