From 1e4899035ed82c85f6c85b7349211528f161402c Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 4 May 2017 08:29:59 -0800
Subject: [PATCH] Use six.string_types instead of str in estimator/export.
 Change: 155087824

---
 tensorflow/python/estimator/export/export.py  |  6 +-
 .../python/estimator/export/export_output.py  |  4 +-
 .../estimator/export/export_output_test.py    | 29 ++++++++
 .../python/estimator/export/export_test.py    | 67 ++++++++++++++++++-
 4 files changed, 100 insertions(+), 6 deletions(-)

diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index 37a98cf4815..a1ecd794df6 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -23,6 +23,8 @@ import collections
 import os
 import time
 
+import six
+
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
@@ -56,7 +58,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver',
     if not isinstance(features, dict):
       features = {_SINGLE_FEATURE_DEFAULT_NAME: features}
     for name, tensor in features.items():
-      if not isinstance(name, str):
+      if not isinstance(name, six.string_types):
         raise ValueError('feature keys must be strings: {}.'.format(name))
       if not (isinstance(tensor, ops.Tensor)
               or isinstance(tensor, sparse_tensor.SparseTensor)):
@@ -68,7 +70,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver',
     if not isinstance(receiver_tensors, dict):
       receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
     for name, tensor in receiver_tensors.items():
-      if not isinstance(name, str):
+      if not isinstance(name, six.string_types):
         raise ValueError(
             'receiver_tensors keys must be strings: {}.'.format(name))
       if not isinstance(tensor, ops.Tensor):
diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py
index 69be0f687c1..49bcd06d504 100644
--- a/tensorflow/python/estimator/export/export_output.py
+++ b/tensorflow/python/estimator/export/export_output.py
@@ -20,6 +20,8 @@ from __future__ import print_function
 
 import abc
 
+import six
+
 
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -171,7 +173,7 @@ class PredictOutput(ExportOutput):
           'Prediction outputs must be given as a dict of string to Tensor; '
           'got {}'.format(outputs))
     for key, value in outputs.items():
-      if not isinstance(key, str):
+      if not isinstance(key, six.string_types):
         raise ValueError(
             'Prediction output key must be a string; got {}.'.format(key))
       if not isinstance(value, ops.Tensor):
diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py
index 27a088e551c..035a9a143e6 100644
--- a/tensorflow/python/estimator/export/export_output_test.py
+++ b/tensorflow/python/estimator/export/export_output_test.py
@@ -22,7 +22,9 @@ from tensorflow.core.framework import tensor_shape_pb2
 from tensorflow.core.framework import types_pb2
 from tensorflow.core.protobuf import meta_graph_pb2
 from tensorflow.python.estimator.export import export_output as export_output_lib
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import test
 from tensorflow.python.saved_model import signature_constants
@@ -197,6 +199,33 @@ class ExportOutputTest(test.TestCase):
         signature_constants.CLASSIFY_METHOD_NAME)
     self.assertEqual(actual_signature_def, expected_signature_def)
 
+  def test_predict_output_constructor(self):
+    """Tests that no errors are raised when input is expected."""
+    outputs = {
+        "output0": constant_op.constant([0]),
+        u"output1": constant_op.constant([1]),
+    }
+    export_output_lib.PredictOutput(outputs)
+
+  def test_predict_output_outputs_invalid(self):
+    with self.assertRaisesRegexp(
+        ValueError,
+        "Prediction outputs must be given as a dict of string to Tensor"):
+      export_output_lib.PredictOutput(constant_op.constant([0]))
+
+    with self.assertRaisesRegexp(
+        ValueError,
+        "Prediction output key must be a string"):
+      export_output_lib.PredictOutput({1: constant_op.constant([0])})
+
+    with self.assertRaisesRegexp(
+        ValueError,
+        "Prediction output value must be a Tensor"):
+      export_output_lib.PredictOutput({
+          "prediction1": sparse_tensor.SparseTensor(
+              indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+      })
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index fdd924f2e1c..7946bd88ba0 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -28,13 +28,11 @@ from tensorflow.core.example import example_pb2
 from tensorflow.python.estimator.export import export
 from tensorflow.python.estimator.export import export_output
 from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.platform import test
 from tensorflow.python.saved_model import signature_constants
@@ -43,6 +41,69 @@ from tensorflow.python.saved_model import signature_def_utils
 
 class ExportTest(test_util.TensorFlowTestCase):
 
+  def test_serving_input_receiver_constructor(self):
+    """Tests that no errors are raised when input is expected."""
+    features = {
+        "feature0": constant_op.constant([0]),
+        u"feature1": constant_op.constant([1]),
+        "feature2": sparse_tensor.SparseTensor(
+            indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+    }
+    receiver_tensors = {
+        "example0": array_ops.placeholder(dtypes.string, name="example0"),
+        u"example1": array_ops.placeholder(dtypes.string, name="example1"),
+    }
+    export.ServingInputReceiver(features, receiver_tensors)
+
+  def test_serving_input_receiver_features_invalid(self):
+    receiver_tensors = {
+        "example0": array_ops.placeholder(dtypes.string, name="example0"),
+        u"example1": array_ops.placeholder(dtypes.string, name="example1"),
+    }
+
+    with self.assertRaisesRegexp(ValueError, "features must be defined"):
+      export.ServingInputReceiver(
+          features=None,
+          receiver_tensors=receiver_tensors)
+
+    with self.assertRaisesRegexp(ValueError, "feature keys must be strings"):
+      export.ServingInputReceiver(
+          features={1: constant_op.constant([1])},
+          receiver_tensors=receiver_tensors)
+
+    with self.assertRaisesRegexp(
+        ValueError, "feature feature1 must be a Tensor or SparseTensor"):
+      export.ServingInputReceiver(
+          features={"feature1": [1]},
+          receiver_tensors=receiver_tensors)
+
+  def test_serving_input_receiver_receiver_tensors_invalid(self):
+    features = {
+        "feature0": constant_op.constant([0]),
+        u"feature1": constant_op.constant([1]),
+        "feature2": sparse_tensor.SparseTensor(
+            indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+    }
+
+    with self.assertRaisesRegexp(
+        ValueError, "receiver_tensors must be defined"):
+      export.ServingInputReceiver(
+          features=features,
+          receiver_tensors=None)
+
+    with self.assertRaisesRegexp(
+        ValueError, "receiver_tensors keys must be strings"):
+      export.ServingInputReceiver(
+          features=features,
+          receiver_tensors={
+              1: array_ops.placeholder(dtypes.string, name="example0")})
+
+    with self.assertRaisesRegexp(
+        ValueError, "receiver_tensor example1 must be a Tensor"):
+      export.ServingInputReceiver(
+          features=features,
+          receiver_tensors={"example1": [1]})
+
   def test_single_feature_single_receiver(self):
     feature = constant_op.constant(5)
     receiver_tensor = array_ops.placeholder(dtypes.string)