diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 994a7eea494..9eeae83c68a 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -607,6 +607,9 @@ assert len(_ANY_TO_TF) == sum( def as_dtype(type_value): """Converts the given `type_value` to a `DType`. + Note: `DType` values are interned. When passed a new `DType` object, + `as_dtype` always returns the interned value. + Args: type_value: A value that can be converted to a `tf.DType` object. This may currently be a `tf.DType` object, a [`DataType` @@ -620,7 +623,7 @@ def as_dtype(type_value): TypeError: If `type_value` cannot be converted to a `DType`. """ if isinstance(type_value, DType): - return type_value + return _INTERN_TABLE[type_value.as_datatype_enum] if isinstance(type_value, np.dtype): try: diff --git a/tensorflow/python/framework/dtypes_test.py b/tensorflow/python/framework/dtypes_test.py index 041cc5280cd..1b7e02b6179 100644 --- a/tensorflow/python/framework/dtypes_test.py +++ b/tensorflow/python/framework/dtypes_test.py @@ -325,15 +325,19 @@ class TypesTest(test_util.TensorFlowTestCase): for enum in dtypes._TYPE_TO_STRING: dtype = dtypes.DType(enum) ctor, args = dtype.__reduce__() - self.assertEquals(ctor, dtypes.as_dtype) - self.assertEquals(args, (dtype.name,)) + self.assertEqual(ctor, dtypes.as_dtype) + self.assertEqual(args, (dtype.name,)) reconstructed = ctor(*args) - self.assertEquals(reconstructed, dtype) + self.assertEqual(reconstructed, dtype) def testAsDtypeInvalidArgument(self): with self.assertRaises(TypeError): dtypes.as_dtype((dtypes.int32, dtypes.float32)) + def testAsDtypeReturnsInternedVersion(self): + dt = dtypes.DType(types_pb2.DT_VARIANT) + self.assertIs(dtypes.as_dtype(dt), dtypes.variant) + if __name__ == "__main__": googletest.main()