From f17392238e8d66db339de43efedc8e6ab016b750 Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Wed, 6 May 2020 13:04:40 -0700 Subject: [PATCH] Update make_tensor_proto to support nested nparray values when dtype is specified. This removes an inconsistency between graph & eager mode. In particular, prior to this CL, the following succeeds in eager mode but fails in graph mode: tf.convert_to_tensor([22, np.array(1)], dtype=tf.int32) PiperOrigin-RevId: 310212356 Change-Id: Ib836171b7ebc6ef5974c364b8ded172503120eba --- tensorflow/python/framework/tensor_util.py | 8 ++++++-- tensorflow/python/framework/tensor_util_test.py | 13 +++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 63365b815aa..50388595c3d 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -260,8 +260,12 @@ def _check_quantized(values): def _generate_isinstance_check(expected_types): def inner(values): - _ = [_check_failed(v) for v in nest.flatten(values) - if not isinstance(v, expected_types)] + for v in nest.flatten(values): + if not (isinstance(v, expected_types) or + (isinstance(v, np.ndarray) and + issubclass(v.dtype.type, expected_types))): + _check_failed(v) + return inner _check_int = _generate_isinstance_check( diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 6df20b54aa0..ad0aec1623d 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -713,6 +713,19 @@ class TensorUtilTest(test.TestCase): self.assertAllEqual( np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]), a) + def testNestedNumpyArrayWithoutDType(self): + t = tensor_util.make_tensor_proto([10.0, 20.0, np.array(30.0)]) + a = tensor_util.MakeNdarray(t) + self.assertEqual(np.float32, a.dtype) + self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) + + def testNestedNumpyArrayWithDType(self): + t = tensor_util.make_tensor_proto([10.0, 20.0, np.array(30.0)], + dtype=dtypes.float32) + a = tensor_util.MakeNdarray(t) + self.assertEqual(np.float32, a.dtype) + self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) + def testUnsupportedDTypes(self): with self.assertRaises(TypeError): tensor_util.make_tensor_proto(np.array([1]), 0)