From 8f593c48c84a2d52d2ba8becf2eaef20250325a0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Jan 2019 03:07:27 -0800 Subject: [PATCH] Correctly read tensor values. PiperOrigin-RevId: 229152770 --- tensorflow/python/framework/tensor_util.py | 95 ++++++------------- .../python/framework/tensor_util_test.py | 30 +++--- 2 files changed, 43 insertions(+), 82 deletions(-) diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 51f71616a1b..ca8b067935c 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -598,88 +598,53 @@ def MakeNdarray(tensor): dtype = tensor_dtype.as_numpy_dtype if tensor.tensor_content: - return (np.frombuffer(tensor.tensor_content, dtype=dtype).copy() - .reshape(shape)) - elif tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16: + return (np.frombuffer(tensor.tensor_content, + dtype=dtype).copy().reshape(shape)) + + if tensor_dtype == dtypes.string: + # np.pad throws on these arrays of type np.object. + values = list(tensor.string_val) + padding = num_elements - len(values) + if padding > 0: + last = values[-1] if values else "" + values.extend([last] * padding) + return np.array(values, dtype=dtype).reshape(shape) + + if tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16: # the half_val field of the TensorProto stores the binary representation # of the fp16: we need to reinterpret this as a proper float16 - if len(tensor.half_val) == 1: - tmp = np.array(tensor.half_val[0], dtype=np.uint16) - tmp.dtype = tensor_dtype.as_numpy_dtype - return np.repeat(tmp, num_elements).reshape(shape) - else: - tmp = np.fromiter(tensor.half_val, dtype=np.uint16) - tmp.dtype = tensor_dtype.as_numpy_dtype - return tmp.reshape(shape) + values = np.fromiter(tensor.half_val, dtype=np.uint16) + values.dtype = tensor_dtype.as_numpy_dtype elif tensor_dtype == dtypes.float32: - if len(tensor.float_val) == 1: - return np.repeat( - np.array(tensor.float_val[0], dtype=dtype), - num_elements).reshape(shape) - else: - return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape) + values = np.fromiter(tensor.float_val, dtype=dtype) elif tensor_dtype == dtypes.float64: - if len(tensor.double_val) == 1: - return np.repeat( - np.array(tensor.double_val[0], dtype=dtype), - num_elements).reshape(shape) - else: - return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape) + values = np.fromiter(tensor.double_val, dtype=dtype) elif tensor_dtype in [ dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8, dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16 ]: - if len(tensor.int_val) == 1: - return np.repeat(np.array(tensor.int_val[0], dtype=dtype), - num_elements).reshape(shape) - else: - return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape) + values = np.fromiter(tensor.int_val, dtype=dtype) elif tensor_dtype == dtypes.int64: - if len(tensor.int64_val) == 1: - return np.repeat( - np.array(tensor.int64_val[0], dtype=dtype), - num_elements).reshape(shape) - else: - return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape) - elif tensor_dtype == dtypes.string: - if len(tensor.string_val) == 1: - return np.repeat( - np.array(tensor.string_val[0], dtype=dtype), - num_elements).reshape(shape) - else: - return np.array( - [x for x in tensor.string_val], dtype=dtype).reshape(shape) + values = np.fromiter(tensor.int64_val, dtype=dtype) elif tensor_dtype == dtypes.complex64: it = iter(tensor.scomplex_val) - if len(tensor.scomplex_val) == 2: - return np.repeat( - np.array( - complex(tensor.scomplex_val[0], tensor.scomplex_val[1]), - dtype=dtype), num_elements).reshape(shape) - else: - return np.array( - [complex(x[0], x[1]) for x in zip(it, it)], - dtype=dtype).reshape(shape) + values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype) elif tensor_dtype == dtypes.complex128: it = iter(tensor.dcomplex_val) - if len(tensor.dcomplex_val) == 2: - return np.repeat( - np.array( - complex(tensor.dcomplex_val[0], tensor.dcomplex_val[1]), - dtype=dtype), num_elements).reshape(shape) - else: - return np.array( - [complex(x[0], x[1]) for x in zip(it, it)], - dtype=dtype).reshape(shape) + values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype) elif tensor_dtype == dtypes.bool: - if len(tensor.bool_val) == 1: - return np.repeat(np.array(tensor.bool_val[0], dtype=dtype), - num_elements).reshape(shape) - else: - return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape) + values = np.fromiter(tensor.bool_val, dtype=dtype) else: raise TypeError("Unsupported tensor type: %s" % tensor.dtype) + if values.size == 0: + return np.zeros(shape, dtype) + + if values.size != num_elements: + values = np.pad(values, (0, num_elements - values.size), "edge") + + return values.reshape(shape) + def ShapeEquals(tensor_proto, shape): """Returns True if "tensor_proto" has the given "shape". diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 00337546186..cdacdfaaada 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -336,23 +336,16 @@ class TensorUtilTest(test.TestCase): self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a) def testIntTypesWithImplicitRepeat(self): - for dtype, nptype in [(dtypes.int64, np.int64), - (dtypes.int32, np.int32), - (dtypes.uint8, np.uint8), - (dtypes.uint16, np.uint16), - (dtypes.int16, np.int16), - (dtypes.int8, np.int8)]: + for dtype, nptype in [(dtypes.int64, np.int64), (dtypes.int32, np.int32), + (dtypes.uint8, np.uint8), (dtypes.uint16, np.uint16), + (dtypes.int16, np.int16), (dtypes.int8, np.int8)]: self.assertAllEqual( - np.array( - [[10, 10, 10, 10], - [10, 10, 10, 10], - [10, 10, 10, 10]], - dtype=nptype), + np.array([[10, 11, 12, 12], [12, 12, 12, 12], [12, 12, 12, 12]], + dtype=nptype), tensor_util.MakeNdarray( - tensor_util.make_tensor_proto( - [10], - shape=[3, 4], - dtype=dtype))) + tensor_util.make_tensor_proto([10, 11, 12], + shape=[3, 4], + dtype=dtype))) def testIntMixedWithDimension(self): # Github issue: 11974 @@ -500,9 +493,12 @@ class TensorUtilTest(test.TestCase): self.assertEquals([b"foo"], a) def testStringWithImplicitRepeat(self): - t = tensor_util.make_tensor_proto("f", shape=[3, 4]) + t = tensor_util.make_tensor_proto(["f", "g"], shape=[3, 4]) a = tensor_util.MakeNdarray(t) - self.assertAllEqual(np.array([[b"f"] * 4] * 3, dtype=np.object), a) + self.assertAllEqual( + np.array([[b"f", b"g", b"g", b"g"], [b"g", b"g", b"g", b"g"], + [b"g", b"g", b"g", b"g"]], + dtype=np.object), a) def testStringN(self): t = tensor_util.make_tensor_proto([b"foo", b"bar", b"baz"], shape=[1, 3])