diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index dc6076294e3..1663136507a 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -178,7 +178,7 @@ def _GetDenseDimensions(list_of_lists): def _FlattenToStrings(nested_strings): - if isinstance(nested_strings, list): + if isinstance(nested_strings, (list, tuple)): for inner in nested_strings: for flattened_string in _FlattenToStrings(inner): yield flattened_string diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 0a73abde15a..47a1335374f 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -376,6 +376,34 @@ class TensorUtilTest(tf.test.TestCase): self.assertEquals(np.object, a.dtype) self.assertAllEqual(np.array([[b"a", b"ab"], [b"abc", b"abcd"]]), a) + def testStringTuple(self): + t = tensor_util.make_tensor_proto((b"a", b"ab", b"abc", b"abcd")) + self.assertProtoEquals(""" + dtype: DT_STRING + tensor_shape { dim { size: 4 } } + string_val: "a" + string_val: "ab" + string_val: "abc" + string_val: "abcd" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.object, a.dtype) + self.assertAllEqual(np.array((b"a", b"ab", b"abc", b"abcd")), a) + + def testStringNestedTuple(self): + t = tensor_util.make_tensor_proto(((b"a", b"ab"), (b"abc", b"abcd"))) + self.assertProtoEquals(""" + dtype: DT_STRING + tensor_shape { dim { size: 2 } dim { size: 2 } } + string_val: "a" + string_val: "ab" + string_val: "abc" + string_val: "abcd" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.object, a.dtype) + self.assertAllEqual(np.array(((b"a", b"ab"), (b"abc", b"abcd"))), a) + def testComplex64(self): t = tensor_util.make_tensor_proto((1+2j), dtype=tf.complex64) self.assertProtoEquals("""