Always return the interned version from as_dtype, so that dtypes created dynamically using DType(value) retain the exact value and type of their original counterparts. This allows more brittle code performing exact type checks to continue to work as expected once #40132 lands.

PiperOrigin-RevId: 317643144
Change-Id: Ia17008a65f9300b28a0f2c7bf18a2213b2f407af
This commit is contained in:
Dan Moldovan 2020-06-22 06:37:01 -07:00 committed by TensorFlower Gardener
parent b00a7808a7
commit 3f28ac1b63
2 changed files with 11 additions and 4 deletions

View File

@ -607,6 +607,9 @@ assert len(_ANY_TO_TF) == sum(
def as_dtype(type_value):
"""Converts the given `type_value` to a `DType`.
Note: `DType` values are interned. When passed a new `DType` object,
`as_dtype` always returns the interned value.
Args:
type_value: A value that can be converted to a `tf.DType` object. This may
currently be a `tf.DType` object, a [`DataType`
@ -620,7 +623,7 @@ def as_dtype(type_value):
TypeError: If `type_value` cannot be converted to a `DType`.
"""
if isinstance(type_value, DType):
return type_value
return _INTERN_TABLE[type_value.as_datatype_enum]
if isinstance(type_value, np.dtype):
try:

View File

@ -325,15 +325,19 @@ class TypesTest(test_util.TensorFlowTestCase):
for enum in dtypes._TYPE_TO_STRING:
dtype = dtypes.DType(enum)
ctor, args = dtype.__reduce__()
self.assertEquals(ctor, dtypes.as_dtype)
self.assertEquals(args, (dtype.name,))
self.assertEqual(ctor, dtypes.as_dtype)
self.assertEqual(args, (dtype.name,))
reconstructed = ctor(*args)
self.assertEquals(reconstructed, dtype)
self.assertEqual(reconstructed, dtype)
def testAsDtypeInvalidArgument(self):
with self.assertRaises(TypeError):
dtypes.as_dtype((dtypes.int32, dtypes.float32))
def testAsDtypeReturnsInternedVersion(self):
dt = dtypes.DType(types_pb2.DT_VARIANT)
self.assertIs(dtypes.as_dtype(dt), dtypes.variant)
if __name__ == "__main__":
googletest.main()