Support numpy scalar types in serialized TypeSpecs.
Currently, code of the form ``` v = tf.get_static_value(some_int_tensor) # type(v) == np.int32 hash(SomeTypeSpec(some_arg=v)) ``` breaks because the checks for Python scalar types don't recognize Numpy scalars. PiperOrigin-RevId: 350458098 Change-Id: I906a4e59b404b65d62cb3aa38e7fbecba9519814
This commit is contained in:
parent
c5ae7c2526
commit
06607a0f2c
tensorflow/python/framework
@ -351,7 +351,8 @@ class TypeSpec(object):
|
|||||||
|
|
||||||
def __make_cmp_key(self, value):
|
def __make_cmp_key(self, value):
|
||||||
"""Converts `value` to a hashable key."""
|
"""Converts `value` to a hashable key."""
|
||||||
if isinstance(value, (int, float, bool, dtypes.DType, TypeSpec)):
|
if isinstance(value,
|
||||||
|
(int, float, bool, np.generic, dtypes.DType, TypeSpec)):
|
||||||
return value
|
return value
|
||||||
if isinstance(value, compat.bytes_or_text_types):
|
if isinstance(value, compat.bytes_or_text_types):
|
||||||
return value
|
return value
|
||||||
|
@ -151,9 +151,11 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool, "blue")),
|
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool, "blue")),
|
||||||
("NumpyMetadata",
|
("NumpyMetadata",
|
||||||
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool,
|
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool,
|
||||||
np.array([[1, 2], [3, 4]])),
|
(np.int32(1), np.float32(1.),
|
||||||
|
np.array([[1, 2], [3, 4]]))),
|
||||||
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool,
|
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool,
|
||||||
np.array([[1, 2], [3, 4]]))),
|
(np.int32(1), np.float32(1.),
|
||||||
|
np.array([[1, 2], [3, 4]])))),
|
||||||
)
|
)
|
||||||
def testEquality(self, v1, v2):
|
def testEquality(self, v1, v2):
|
||||||
# pylint: disable=g-generic-assert
|
# pylint: disable=g-generic-assert
|
||||||
|
Loading…
Reference in New Issue
Block a user