Always return the interned version from as_dtype, so that dtypes created dynamically using DType(value)
retain the exact value and type of their original counterparts. This allows more brittle code performing exact type checks to continue to work as expected once #40132 lands.
PiperOrigin-RevId: 317643144 Change-Id: Ia17008a65f9300b28a0f2c7bf18a2213b2f407af
This commit is contained in:
parent
b00a7808a7
commit
3f28ac1b63
@ -607,6 +607,9 @@ assert len(_ANY_TO_TF) == sum(
|
|||||||
def as_dtype(type_value):
|
def as_dtype(type_value):
|
||||||
"""Converts the given `type_value` to a `DType`.
|
"""Converts the given `type_value` to a `DType`.
|
||||||
|
|
||||||
|
Note: `DType` values are interned. When passed a new `DType` object,
|
||||||
|
`as_dtype` always returns the interned value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
type_value: A value that can be converted to a `tf.DType` object. This may
|
type_value: A value that can be converted to a `tf.DType` object. This may
|
||||||
currently be a `tf.DType` object, a [`DataType`
|
currently be a `tf.DType` object, a [`DataType`
|
||||||
@ -620,7 +623,7 @@ def as_dtype(type_value):
|
|||||||
TypeError: If `type_value` cannot be converted to a `DType`.
|
TypeError: If `type_value` cannot be converted to a `DType`.
|
||||||
"""
|
"""
|
||||||
if isinstance(type_value, DType):
|
if isinstance(type_value, DType):
|
||||||
return type_value
|
return _INTERN_TABLE[type_value.as_datatype_enum]
|
||||||
|
|
||||||
if isinstance(type_value, np.dtype):
|
if isinstance(type_value, np.dtype):
|
||||||
try:
|
try:
|
||||||
|
@ -325,15 +325,19 @@ class TypesTest(test_util.TensorFlowTestCase):
|
|||||||
for enum in dtypes._TYPE_TO_STRING:
|
for enum in dtypes._TYPE_TO_STRING:
|
||||||
dtype = dtypes.DType(enum)
|
dtype = dtypes.DType(enum)
|
||||||
ctor, args = dtype.__reduce__()
|
ctor, args = dtype.__reduce__()
|
||||||
self.assertEquals(ctor, dtypes.as_dtype)
|
self.assertEqual(ctor, dtypes.as_dtype)
|
||||||
self.assertEquals(args, (dtype.name,))
|
self.assertEqual(args, (dtype.name,))
|
||||||
reconstructed = ctor(*args)
|
reconstructed = ctor(*args)
|
||||||
self.assertEquals(reconstructed, dtype)
|
self.assertEqual(reconstructed, dtype)
|
||||||
|
|
||||||
def testAsDtypeInvalidArgument(self):
|
def testAsDtypeInvalidArgument(self):
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
dtypes.as_dtype((dtypes.int32, dtypes.float32))
|
dtypes.as_dtype((dtypes.int32, dtypes.float32))
|
||||||
|
|
||||||
|
def testAsDtypeReturnsInternedVersion(self):
|
||||||
|
dt = dtypes.DType(types_pb2.DT_VARIANT)
|
||||||
|
self.assertIs(dtypes.as_dtype(dt), dtypes.variant)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user