Correctly read tensor values.
PiperOrigin-RevId: 229152770
This commit is contained in:
parent
74a6cca5d8
commit
8f593c48c8
@ -598,88 +598,53 @@ def MakeNdarray(tensor):
|
||||
dtype = tensor_dtype.as_numpy_dtype
|
||||
|
||||
if tensor.tensor_content:
|
||||
return (np.frombuffer(tensor.tensor_content, dtype=dtype).copy()
|
||||
.reshape(shape))
|
||||
elif tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16:
|
||||
return (np.frombuffer(tensor.tensor_content,
|
||||
dtype=dtype).copy().reshape(shape))
|
||||
|
||||
if tensor_dtype == dtypes.string:
|
||||
# np.pad throws on these arrays of type np.object.
|
||||
values = list(tensor.string_val)
|
||||
padding = num_elements - len(values)
|
||||
if padding > 0:
|
||||
last = values[-1] if values else ""
|
||||
values.extend([last] * padding)
|
||||
return np.array(values, dtype=dtype).reshape(shape)
|
||||
|
||||
if tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16:
|
||||
# the half_val field of the TensorProto stores the binary representation
|
||||
# of the fp16: we need to reinterpret this as a proper float16
|
||||
if len(tensor.half_val) == 1:
|
||||
tmp = np.array(tensor.half_val[0], dtype=np.uint16)
|
||||
tmp.dtype = tensor_dtype.as_numpy_dtype
|
||||
return np.repeat(tmp, num_elements).reshape(shape)
|
||||
else:
|
||||
tmp = np.fromiter(tensor.half_val, dtype=np.uint16)
|
||||
tmp.dtype = tensor_dtype.as_numpy_dtype
|
||||
return tmp.reshape(shape)
|
||||
values = np.fromiter(tensor.half_val, dtype=np.uint16)
|
||||
values.dtype = tensor_dtype.as_numpy_dtype
|
||||
elif tensor_dtype == dtypes.float32:
|
||||
if len(tensor.float_val) == 1:
|
||||
return np.repeat(
|
||||
np.array(tensor.float_val[0], dtype=dtype),
|
||||
num_elements).reshape(shape)
|
||||
else:
|
||||
return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape)
|
||||
values = np.fromiter(tensor.float_val, dtype=dtype)
|
||||
elif tensor_dtype == dtypes.float64:
|
||||
if len(tensor.double_val) == 1:
|
||||
return np.repeat(
|
||||
np.array(tensor.double_val[0], dtype=dtype),
|
||||
num_elements).reshape(shape)
|
||||
else:
|
||||
return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape)
|
||||
values = np.fromiter(tensor.double_val, dtype=dtype)
|
||||
elif tensor_dtype in [
|
||||
dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8,
|
||||
dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16
|
||||
]:
|
||||
if len(tensor.int_val) == 1:
|
||||
return np.repeat(np.array(tensor.int_val[0], dtype=dtype),
|
||||
num_elements).reshape(shape)
|
||||
else:
|
||||
return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape)
|
||||
values = np.fromiter(tensor.int_val, dtype=dtype)
|
||||
elif tensor_dtype == dtypes.int64:
|
||||
if len(tensor.int64_val) == 1:
|
||||
return np.repeat(
|
||||
np.array(tensor.int64_val[0], dtype=dtype),
|
||||
num_elements).reshape(shape)
|
||||
else:
|
||||
return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape)
|
||||
elif tensor_dtype == dtypes.string:
|
||||
if len(tensor.string_val) == 1:
|
||||
return np.repeat(
|
||||
np.array(tensor.string_val[0], dtype=dtype),
|
||||
num_elements).reshape(shape)
|
||||
else:
|
||||
return np.array(
|
||||
[x for x in tensor.string_val], dtype=dtype).reshape(shape)
|
||||
values = np.fromiter(tensor.int64_val, dtype=dtype)
|
||||
elif tensor_dtype == dtypes.complex64:
|
||||
it = iter(tensor.scomplex_val)
|
||||
if len(tensor.scomplex_val) == 2:
|
||||
return np.repeat(
|
||||
np.array(
|
||||
complex(tensor.scomplex_val[0], tensor.scomplex_val[1]),
|
||||
dtype=dtype), num_elements).reshape(shape)
|
||||
else:
|
||||
return np.array(
|
||||
[complex(x[0], x[1]) for x in zip(it, it)],
|
||||
dtype=dtype).reshape(shape)
|
||||
values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
|
||||
elif tensor_dtype == dtypes.complex128:
|
||||
it = iter(tensor.dcomplex_val)
|
||||
if len(tensor.dcomplex_val) == 2:
|
||||
return np.repeat(
|
||||
np.array(
|
||||
complex(tensor.dcomplex_val[0], tensor.dcomplex_val[1]),
|
||||
dtype=dtype), num_elements).reshape(shape)
|
||||
else:
|
||||
return np.array(
|
||||
[complex(x[0], x[1]) for x in zip(it, it)],
|
||||
dtype=dtype).reshape(shape)
|
||||
values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
|
||||
elif tensor_dtype == dtypes.bool:
|
||||
if len(tensor.bool_val) == 1:
|
||||
return np.repeat(np.array(tensor.bool_val[0], dtype=dtype),
|
||||
num_elements).reshape(shape)
|
||||
else:
|
||||
return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape)
|
||||
values = np.fromiter(tensor.bool_val, dtype=dtype)
|
||||
else:
|
||||
raise TypeError("Unsupported tensor type: %s" % tensor.dtype)
|
||||
|
||||
if values.size == 0:
|
||||
return np.zeros(shape, dtype)
|
||||
|
||||
if values.size != num_elements:
|
||||
values = np.pad(values, (0, num_elements - values.size), "edge")
|
||||
|
||||
return values.reshape(shape)
|
||||
|
||||
|
||||
def ShapeEquals(tensor_proto, shape):
|
||||
"""Returns True if "tensor_proto" has the given "shape".
|
||||
|
@ -336,23 +336,16 @@ class TensorUtilTest(test.TestCase):
|
||||
self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
|
||||
|
||||
def testIntTypesWithImplicitRepeat(self):
|
||||
for dtype, nptype in [(dtypes.int64, np.int64),
|
||||
(dtypes.int32, np.int32),
|
||||
(dtypes.uint8, np.uint8),
|
||||
(dtypes.uint16, np.uint16),
|
||||
(dtypes.int16, np.int16),
|
||||
(dtypes.int8, np.int8)]:
|
||||
for dtype, nptype in [(dtypes.int64, np.int64), (dtypes.int32, np.int32),
|
||||
(dtypes.uint8, np.uint8), (dtypes.uint16, np.uint16),
|
||||
(dtypes.int16, np.int16), (dtypes.int8, np.int8)]:
|
||||
self.assertAllEqual(
|
||||
np.array(
|
||||
[[10, 10, 10, 10],
|
||||
[10, 10, 10, 10],
|
||||
[10, 10, 10, 10]],
|
||||
dtype=nptype),
|
||||
np.array([[10, 11, 12, 12], [12, 12, 12, 12], [12, 12, 12, 12]],
|
||||
dtype=nptype),
|
||||
tensor_util.MakeNdarray(
|
||||
tensor_util.make_tensor_proto(
|
||||
[10],
|
||||
shape=[3, 4],
|
||||
dtype=dtype)))
|
||||
tensor_util.make_tensor_proto([10, 11, 12],
|
||||
shape=[3, 4],
|
||||
dtype=dtype)))
|
||||
|
||||
def testIntMixedWithDimension(self):
|
||||
# Github issue: 11974
|
||||
@ -500,9 +493,12 @@ class TensorUtilTest(test.TestCase):
|
||||
self.assertEquals([b"foo"], a)
|
||||
|
||||
def testStringWithImplicitRepeat(self):
|
||||
t = tensor_util.make_tensor_proto("f", shape=[3, 4])
|
||||
t = tensor_util.make_tensor_proto(["f", "g"], shape=[3, 4])
|
||||
a = tensor_util.MakeNdarray(t)
|
||||
self.assertAllEqual(np.array([[b"f"] * 4] * 3, dtype=np.object), a)
|
||||
self.assertAllEqual(
|
||||
np.array([[b"f", b"g", b"g", b"g"], [b"g", b"g", b"g", b"g"],
|
||||
[b"g", b"g", b"g", b"g"]],
|
||||
dtype=np.object), a)
|
||||
|
||||
def testStringN(self):
|
||||
t = tensor_util.make_tensor_proto([b"foo", b"bar", b"baz"], shape=[1, 3])
|
||||
|
Loading…
x
Reference in New Issue
Block a user