Support numpy scalar types in serialized TypeSpecs.

Currently, code of the form

```
v = tf.get_static_value(some_int_tensor)
# type(v) == np.int32
hash(SomeTypeSpec(some_arg=v))
```

breaks because the checks for Python scalar types don't recognize Numpy scalars.

PiperOrigin-RevId: 350458098
Change-Id: I906a4e59b404b65d62cb3aa38e7fbecba9519814
This commit is contained in:
Dave Moore 2021-01-06 17:21:32 -08:00 committed by TensorFlower Gardener
parent c5ae7c2526
commit 06607a0f2c
2 changed files with 6 additions and 3 deletions

View File

@ -351,7 +351,8 @@ class TypeSpec(object):
def __make_cmp_key(self, value):
"""Converts `value` to a hashable key."""
if isinstance(value, (int, float, bool, dtypes.DType, TypeSpec)):
if isinstance(value,
(int, float, bool, np.generic, dtypes.DType, TypeSpec)):
return value
if isinstance(value, compat.bytes_or_text_types):
return value

View File

@ -151,9 +151,11 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool, "blue")),
("NumpyMetadata",
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool,
np.array([[1, 2], [3, 4]])),
(np.int32(1), np.float32(1.),
np.array([[1, 2], [3, 4]]))),
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool,
np.array([[1, 2], [3, 4]]))),
(np.int32(1), np.float32(1.),
np.array([[1, 2], [3, 4]])))),
)
def testEquality(self, v1, v2):
# pylint: disable=g-generic-assert