Always return the interned version from as_dtype, so that dtypes created dynamically using DType(value)
retain the exact value and type of their original counterparts. This allows more brittle code performing exact type checks to continue to work as expected once #40132 lands.
PiperOrigin-RevId: 317643144 Change-Id: Ia17008a65f9300b28a0f2c7bf18a2213b2f407af
This commit is contained in:
parent
b00a7808a7
commit
3f28ac1b63
@ -607,6 +607,9 @@ assert len(_ANY_TO_TF) == sum(
|
||||
def as_dtype(type_value):
|
||||
"""Converts the given `type_value` to a `DType`.
|
||||
|
||||
Note: `DType` values are interned. When passed a new `DType` object,
|
||||
`as_dtype` always returns the interned value.
|
||||
|
||||
Args:
|
||||
type_value: A value that can be converted to a `tf.DType` object. This may
|
||||
currently be a `tf.DType` object, a [`DataType`
|
||||
@ -620,7 +623,7 @@ def as_dtype(type_value):
|
||||
TypeError: If `type_value` cannot be converted to a `DType`.
|
||||
"""
|
||||
if isinstance(type_value, DType):
|
||||
return type_value
|
||||
return _INTERN_TABLE[type_value.as_datatype_enum]
|
||||
|
||||
if isinstance(type_value, np.dtype):
|
||||
try:
|
||||
|
@ -325,15 +325,19 @@ class TypesTest(test_util.TensorFlowTestCase):
|
||||
for enum in dtypes._TYPE_TO_STRING:
|
||||
dtype = dtypes.DType(enum)
|
||||
ctor, args = dtype.__reduce__()
|
||||
self.assertEquals(ctor, dtypes.as_dtype)
|
||||
self.assertEquals(args, (dtype.name,))
|
||||
self.assertEqual(ctor, dtypes.as_dtype)
|
||||
self.assertEqual(args, (dtype.name,))
|
||||
reconstructed = ctor(*args)
|
||||
self.assertEquals(reconstructed, dtype)
|
||||
self.assertEqual(reconstructed, dtype)
|
||||
|
||||
def testAsDtypeInvalidArgument(self):
|
||||
with self.assertRaises(TypeError):
|
||||
dtypes.as_dtype((dtypes.int32, dtypes.float32))
|
||||
|
||||
def testAsDtypeReturnsInternedVersion(self):
|
||||
dt = dtypes.DType(types_pb2.DT_VARIANT)
|
||||
self.assertIs(dtypes.as_dtype(dt), dtypes.variant)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user