Support numpy scalar types in serialized TypeSpecs.
Currently, code of the form ``` v = tf.get_static_value(some_int_tensor) # type(v) == np.int32 hash(SomeTypeSpec(some_arg=v)) ``` breaks because the checks for Python scalar types don't recognize Numpy scalars. PiperOrigin-RevId: 350458098 Change-Id: I906a4e59b404b65d62cb3aa38e7fbecba9519814
This commit is contained in:
parent
c5ae7c2526
commit
06607a0f2c
@ -351,7 +351,8 @@ class TypeSpec(object):
|
||||
|
||||
def __make_cmp_key(self, value):
|
||||
"""Converts `value` to a hashable key."""
|
||||
if isinstance(value, (int, float, bool, dtypes.DType, TypeSpec)):
|
||||
if isinstance(value,
|
||||
(int, float, bool, np.generic, dtypes.DType, TypeSpec)):
|
||||
return value
|
||||
if isinstance(value, compat.bytes_or_text_types):
|
||||
return value
|
||||
|
@ -151,9 +151,11 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool, "blue")),
|
||||
("NumpyMetadata",
|
||||
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool,
|
||||
np.array([[1, 2], [3, 4]])),
|
||||
(np.int32(1), np.float32(1.),
|
||||
np.array([[1, 2], [3, 4]]))),
|
||||
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool,
|
||||
np.array([[1, 2], [3, 4]]))),
|
||||
(np.int32(1), np.float32(1.),
|
||||
np.array([[1, 2], [3, 4]])))),
|
||||
)
|
||||
def testEquality(self, v1, v2):
|
||||
# pylint: disable=g-generic-assert
|
||||
|
Loading…
Reference in New Issue
Block a user