Support uint32
and uint64
in tf.get_static_value()
from within a tf.function()
Attempting to do so without this change raises an exception: ``` File ".../tensorflow/python/ops/check_ops.py", line 672, in assert_equal data, summarize, message, name) File ".../tensorflow/python/ops/check_ops.py", line 369, in _binary_assert x_static = tensor_util.constant_value(x) File ".../tensorflow/python/framework/tensor_util.py", line 876, in constant_value ret = _ConstantValue(tensor, partial) File ".../tensorflow/python/framework/tensor_util.py", line 681, in _ConstantValue return MakeNdarray(tensor.op.get_attr("value")) File ".../tensorflow/python/framework/tensor_util.py", line 641, in MakeNdarray raise TypeError("Unsupported tensor type: %s" % tensor.dtype) TypeError: Unsupported tensor type: 23 ``` PiperOrigin-RevId: 356437863 Change-Id: I51d9adf8727bcf01ac3a862a608fa93e2cb6a88d
This commit is contained in:
parent
f47acb1257
commit
aae0fc015b
tensorflow/python/framework
@ -628,6 +628,10 @@ def MakeNdarray(tensor):
|
|||||||
values = np.fromiter(tensor.int_val, dtype=dtype)
|
values = np.fromiter(tensor.int_val, dtype=dtype)
|
||||||
elif tensor_dtype == dtypes.int64:
|
elif tensor_dtype == dtypes.int64:
|
||||||
values = np.fromiter(tensor.int64_val, dtype=dtype)
|
values = np.fromiter(tensor.int64_val, dtype=dtype)
|
||||||
|
elif tensor_dtype == dtypes.uint32:
|
||||||
|
values = np.fromiter(tensor.uint32_val, dtype=dtype)
|
||||||
|
elif tensor_dtype == dtypes.uint64:
|
||||||
|
values = np.fromiter(tensor.uint64_val, dtype=dtype)
|
||||||
elif tensor_dtype == dtypes.complex64:
|
elif tensor_dtype == dtypes.complex64:
|
||||||
it = iter(tensor.scomplex_val)
|
it = iter(tensor.scomplex_val)
|
||||||
values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
|
values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
import contextlib
|
import contextlib
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -41,7 +42,7 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class TensorUtilTest(test.TestCase):
|
class TensorUtilTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def testFloat(self):
|
def testFloat(self):
|
||||||
value = 10.0
|
value = 10.0
|
||||||
@ -318,12 +319,13 @@ class TensorUtilTest(test.TestCase):
|
|||||||
self.assertEqual(np.int32, a.dtype)
|
self.assertEqual(np.int32, a.dtype)
|
||||||
self.assertAllClose(np.array([[10, 20], [30, 40]], dtype=np.int32), a)
|
self.assertAllClose(np.array([[10, 20], [30, 40]], dtype=np.int32), a)
|
||||||
|
|
||||||
def testIntTypes(self):
|
@parameterized.named_parameters(
|
||||||
for dtype, nptype in [(dtypes.int32, np.int32),
|
("_int8", dtypes.int8, np.int8), ("_int16", dtypes.int16, np.int16),
|
||||||
(dtypes.uint8, np.uint8),
|
("_int32", dtypes.int32, np.int32), ("_int64", dtypes.int64, np.int64),
|
||||||
(dtypes.uint16, np.uint16),
|
("_uint8", dtypes.uint8, np.uint8), ("_uint16", dtypes.uint16, np.uint16),
|
||||||
(dtypes.int16, np.int16),
|
("_uint32", dtypes.uint32, np.uint32),
|
||||||
(dtypes.int8, np.int8)]:
|
("_uint64", dtypes.uint64, np.uint64))
|
||||||
|
def testIntTypes(self, dtype, nptype):
|
||||||
# Test with array.
|
# Test with array.
|
||||||
t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtype)
|
t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtype)
|
||||||
self.assertEqual(dtype, t.dtype)
|
self.assertEqual(dtype, t.dtype)
|
||||||
@ -339,10 +341,13 @@ class TensorUtilTest(test.TestCase):
|
|||||||
self.assertEqual(nptype, a.dtype)
|
self.assertEqual(nptype, a.dtype)
|
||||||
self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
|
self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
|
||||||
|
|
||||||
def testIntTypesWithImplicitRepeat(self):
|
@parameterized.named_parameters(
|
||||||
for dtype, nptype in [(dtypes.int64, np.int64), (dtypes.int32, np.int32),
|
("_int8", dtypes.int8, np.int8), ("_int16", dtypes.int16, np.int16),
|
||||||
(dtypes.uint8, np.uint8), (dtypes.uint16, np.uint16),
|
("_int32", dtypes.int32, np.int32), ("_int64", dtypes.int64, np.int64),
|
||||||
(dtypes.int16, np.int16), (dtypes.int8, np.int8)]:
|
("_uint8", dtypes.uint8, np.uint8), ("_uint16", dtypes.uint16, np.uint16),
|
||||||
|
("_uint32", dtypes.uint32, np.uint32),
|
||||||
|
("_uint64", dtypes.uint64, np.uint64))
|
||||||
|
def testIntTypesWithImplicitRepeat(self, dtype, nptype):
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
np.array([[10, 11, 12, 12], [12, 12, 12, 12], [12, 12, 12, 12]],
|
np.array([[10, 11, 12, 12], [12, 12, 12, 12], [12, 12, 12, 12]],
|
||||||
dtype=nptype),
|
dtype=nptype),
|
||||||
@ -362,53 +367,73 @@ class TensorUtilTest(test.TestCase):
|
|||||||
self.assertEqual(nptype, a.dtype)
|
self.assertEqual(nptype, a.dtype)
|
||||||
self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
|
self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
|
||||||
|
|
||||||
def testLong(self):
|
@parameterized.named_parameters(
|
||||||
t = tensor_util.make_tensor_proto(10, dtype=dtypes.int64)
|
("_int64", dtypes.int64, np.int64, "DT_INT64", "int64_val"),
|
||||||
self.assertProtoEquals("""
|
("_uint64", dtypes.uint64, np.uint64, "DT_UINT64", "uint64_val"))
|
||||||
dtype: DT_INT64
|
def testLong(self, dtype, nptype, proto_dtype, proto_value_name):
|
||||||
|
t = tensor_util.make_tensor_proto(10, dtype=dtype)
|
||||||
|
self.assertProtoEquals(
|
||||||
|
"""
|
||||||
|
dtype: %s
|
||||||
tensor_shape {}
|
tensor_shape {}
|
||||||
int64_val: 10
|
%s: 10
|
||||||
""", t)
|
""" % (proto_dtype, proto_value_name), t)
|
||||||
a = tensor_util.MakeNdarray(t)
|
a = tensor_util.MakeNdarray(t)
|
||||||
self.assertEqual(np.int64, a.dtype)
|
self.assertEqual(nptype, a.dtype)
|
||||||
self.assertAllClose(np.array(10, dtype=np.int64), a)
|
self.assertAllClose(np.array(10, dtype=nptype), a)
|
||||||
|
|
||||||
def testLongN(self):
|
@parameterized.named_parameters(
|
||||||
t = tensor_util.make_tensor_proto(
|
("_int64", dtypes.int64, np.int64, "DT_INT64"),
|
||||||
[10, 20, 30], shape=[1, 3], dtype=dtypes.int64)
|
("_uint64", dtypes.uint64, np.uint64, "DT_UINT64"))
|
||||||
|
def testLongN(self, dtype, nptype, proto_dtype):
|
||||||
|
t = tensor_util.make_tensor_proto([10, 20, 30], shape=[1, 3], dtype=dtype)
|
||||||
if sys.byteorder == "big":
|
if sys.byteorder == "big":
|
||||||
self.assertProtoEquals(r"""
|
# pylint: disable=line-too-long
|
||||||
dtype: DT_INT64
|
self.assertProtoEquals(
|
||||||
|
r"""
|
||||||
|
dtype: %s
|
||||||
tensor_shape { dim { size: 1 } dim { size: 3 } }
|
tensor_shape { dim { size: 1 } dim { size: 3 } }
|
||||||
tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
|
tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
|
||||||
""", t)
|
""" % proto_dtype, t)
|
||||||
|
# pylint: enable=line-too-long
|
||||||
else:
|
else:
|
||||||
self.assertProtoEquals(r"""
|
# pylint: disable=line-too-long
|
||||||
dtype: DT_INT64
|
self.assertProtoEquals(
|
||||||
|
r"""
|
||||||
|
dtype: %s
|
||||||
tensor_shape { dim { size: 1 } dim { size: 3 } }
|
tensor_shape { dim { size: 1 } dim { size: 3 } }
|
||||||
tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
|
tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
|
||||||
""", t)
|
""" % proto_dtype, t)
|
||||||
|
# pylint: enable=line-too-long
|
||||||
a = tensor_util.MakeNdarray(t)
|
a = tensor_util.MakeNdarray(t)
|
||||||
self.assertEqual(np.int64, a.dtype)
|
self.assertEqual(nptype, a.dtype)
|
||||||
self.assertAllClose(np.array([[10, 20, 30]], dtype=np.int64), a)
|
self.assertAllClose(np.array([[10, 20, 30]], dtype=nptype), a)
|
||||||
|
|
||||||
def testLongNpArray(self):
|
@parameterized.named_parameters(("_int64", np.int64, "DT_INT64"),
|
||||||
t = tensor_util.make_tensor_proto(np.array([10, 20, 30]))
|
("_uint64", np.uint64, "DT_UINT64"))
|
||||||
|
def testLongNpArray(self, nptype, proto_dtype):
|
||||||
|
t = tensor_util.make_tensor_proto(np.array([10, 20, 30], dtype=nptype))
|
||||||
if sys.byteorder == "big":
|
if sys.byteorder == "big":
|
||||||
self.assertProtoEquals(r"""
|
# pylint: disable=line-too-long
|
||||||
dtype: DT_INT64
|
self.assertProtoEquals(
|
||||||
|
r"""
|
||||||
|
dtype: %s
|
||||||
tensor_shape { dim { size: 3 } }
|
tensor_shape { dim { size: 3 } }
|
||||||
tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
|
tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
|
||||||
""", t)
|
""" % proto_dtype, t)
|
||||||
|
# pylint: enable=line-too-long
|
||||||
else:
|
else:
|
||||||
self.assertProtoEquals(r"""
|
# pylint: disable=line-too-long
|
||||||
dtype: DT_INT64
|
self.assertProtoEquals(
|
||||||
|
r"""
|
||||||
|
dtype: %s
|
||||||
tensor_shape { dim { size: 3 } }
|
tensor_shape { dim { size: 3 } }
|
||||||
tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
|
tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
|
||||||
""", t)
|
""" % proto_dtype, t)
|
||||||
|
# pylint: enable=line-too-long
|
||||||
a = tensor_util.MakeNdarray(t)
|
a = tensor_util.MakeNdarray(t)
|
||||||
self.assertEqual(np.int64, a.dtype)
|
self.assertEqual(nptype, a.dtype)
|
||||||
self.assertAllClose(np.array([10, 20, 30], dtype=np.int64), a)
|
self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
|
||||||
|
|
||||||
def testQuantizedTypes(self):
|
def testQuantizedTypes(self):
|
||||||
# Test with array.
|
# Test with array.
|
||||||
|
Loading…
Reference in New Issue
Block a user