Correctly read tensor values.

PiperOrigin-RevId: 229152770
This commit is contained in:
A. Unique TensorFlower 2019-01-14 03:07:27 -08:00 committed by TensorFlower Gardener
parent 74a6cca5d8
commit 8f593c48c8
2 changed files with 43 additions and 82 deletions

View File

@ -598,88 +598,53 @@ def MakeNdarray(tensor):
dtype = tensor_dtype.as_numpy_dtype
if tensor.tensor_content:
return (np.frombuffer(tensor.tensor_content, dtype=dtype).copy()
.reshape(shape))
elif tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16:
return (np.frombuffer(tensor.tensor_content,
dtype=dtype).copy().reshape(shape))
if tensor_dtype == dtypes.string:
# np.pad throws on these arrays of type np.object.
values = list(tensor.string_val)
padding = num_elements - len(values)
if padding > 0:
last = values[-1] if values else ""
values.extend([last] * padding)
return np.array(values, dtype=dtype).reshape(shape)
if tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16:
# the half_val field of the TensorProto stores the binary representation
# of the fp16: we need to reinterpret this as a proper float16
if len(tensor.half_val) == 1:
tmp = np.array(tensor.half_val[0], dtype=np.uint16)
tmp.dtype = tensor_dtype.as_numpy_dtype
return np.repeat(tmp, num_elements).reshape(shape)
else:
tmp = np.fromiter(tensor.half_val, dtype=np.uint16)
tmp.dtype = tensor_dtype.as_numpy_dtype
return tmp.reshape(shape)
values = np.fromiter(tensor.half_val, dtype=np.uint16)
values.dtype = tensor_dtype.as_numpy_dtype
elif tensor_dtype == dtypes.float32:
if len(tensor.float_val) == 1:
return np.repeat(
np.array(tensor.float_val[0], dtype=dtype),
num_elements).reshape(shape)
else:
return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape)
values = np.fromiter(tensor.float_val, dtype=dtype)
elif tensor_dtype == dtypes.float64:
if len(tensor.double_val) == 1:
return np.repeat(
np.array(tensor.double_val[0], dtype=dtype),
num_elements).reshape(shape)
else:
return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape)
values = np.fromiter(tensor.double_val, dtype=dtype)
elif tensor_dtype in [
dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8,
dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16
]:
if len(tensor.int_val) == 1:
return np.repeat(np.array(tensor.int_val[0], dtype=dtype),
num_elements).reshape(shape)
else:
return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape)
values = np.fromiter(tensor.int_val, dtype=dtype)
elif tensor_dtype == dtypes.int64:
if len(tensor.int64_val) == 1:
return np.repeat(
np.array(tensor.int64_val[0], dtype=dtype),
num_elements).reshape(shape)
else:
return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape)
elif tensor_dtype == dtypes.string:
if len(tensor.string_val) == 1:
return np.repeat(
np.array(tensor.string_val[0], dtype=dtype),
num_elements).reshape(shape)
else:
return np.array(
[x for x in tensor.string_val], dtype=dtype).reshape(shape)
values = np.fromiter(tensor.int64_val, dtype=dtype)
elif tensor_dtype == dtypes.complex64:
it = iter(tensor.scomplex_val)
if len(tensor.scomplex_val) == 2:
return np.repeat(
np.array(
complex(tensor.scomplex_val[0], tensor.scomplex_val[1]),
dtype=dtype), num_elements).reshape(shape)
else:
return np.array(
[complex(x[0], x[1]) for x in zip(it, it)],
dtype=dtype).reshape(shape)
values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
elif tensor_dtype == dtypes.complex128:
it = iter(tensor.dcomplex_val)
if len(tensor.dcomplex_val) == 2:
return np.repeat(
np.array(
complex(tensor.dcomplex_val[0], tensor.dcomplex_val[1]),
dtype=dtype), num_elements).reshape(shape)
else:
return np.array(
[complex(x[0], x[1]) for x in zip(it, it)],
dtype=dtype).reshape(shape)
values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
elif tensor_dtype == dtypes.bool:
if len(tensor.bool_val) == 1:
return np.repeat(np.array(tensor.bool_val[0], dtype=dtype),
num_elements).reshape(shape)
else:
return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape)
values = np.fromiter(tensor.bool_val, dtype=dtype)
else:
raise TypeError("Unsupported tensor type: %s" % tensor.dtype)
if values.size == 0:
return np.zeros(shape, dtype)
if values.size != num_elements:
values = np.pad(values, (0, num_elements - values.size), "edge")
return values.reshape(shape)
def ShapeEquals(tensor_proto, shape):
"""Returns True if "tensor_proto" has the given "shape".

View File

@ -336,23 +336,16 @@ class TensorUtilTest(test.TestCase):
self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
def testIntTypesWithImplicitRepeat(self):
for dtype, nptype in [(dtypes.int64, np.int64),
(dtypes.int32, np.int32),
(dtypes.uint8, np.uint8),
(dtypes.uint16, np.uint16),
(dtypes.int16, np.int16),
(dtypes.int8, np.int8)]:
for dtype, nptype in [(dtypes.int64, np.int64), (dtypes.int32, np.int32),
(dtypes.uint8, np.uint8), (dtypes.uint16, np.uint16),
(dtypes.int16, np.int16), (dtypes.int8, np.int8)]:
self.assertAllEqual(
np.array(
[[10, 10, 10, 10],
[10, 10, 10, 10],
[10, 10, 10, 10]],
dtype=nptype),
np.array([[10, 11, 12, 12], [12, 12, 12, 12], [12, 12, 12, 12]],
dtype=nptype),
tensor_util.MakeNdarray(
tensor_util.make_tensor_proto(
[10],
shape=[3, 4],
dtype=dtype)))
tensor_util.make_tensor_proto([10, 11, 12],
shape=[3, 4],
dtype=dtype)))
def testIntMixedWithDimension(self):
# Github issue: 11974
@ -500,9 +493,12 @@ class TensorUtilTest(test.TestCase):
self.assertEquals([b"foo"], a)
def testStringWithImplicitRepeat(self):
t = tensor_util.make_tensor_proto("f", shape=[3, 4])
t = tensor_util.make_tensor_proto(["f", "g"], shape=[3, 4])
a = tensor_util.MakeNdarray(t)
self.assertAllEqual(np.array([[b"f"] * 4] * 3, dtype=np.object), a)
self.assertAllEqual(
np.array([[b"f", b"g", b"g", b"g"], [b"g", b"g", b"g", b"g"],
[b"g", b"g", b"g", b"g"]],
dtype=np.object), a)
def testStringN(self):
t = tensor_util.make_tensor_proto([b"foo", b"bar", b"baz"], shape=[1, 3])