diff --git a/tensorflow/python/util/object_identity.py b/tensorflow/python/util/object_identity.py index 47de08d0bb0..15a0694b1c1 100644 --- a/tensorflow/python/util/object_identity.py +++ b/tensorflow/python/util/object_identity.py @@ -38,9 +38,10 @@ class _ObjectIdentityWrapper(object): return self._wrapped def __eq__(self, other): - if isinstance(other, _ObjectIdentityWrapper): - return self._wrapped is other._wrapped # pylint: disable=protected-access - return False + if not isinstance(other, _ObjectIdentityWrapper): + raise TypeError("Cannot compare wrapped object with unwrapped object") + + return self._wrapped is other._wrapped # pylint: disable=protected-access def __ne__(self, other): return not self.__eq__(other) diff --git a/tensorflow/python/util/object_identity_test.py b/tensorflow/python/util/object_identity_test.py index 5dc8be1a25d..67d26ebdcab 100644 --- a/tensorflow/python/util/object_identity_test.py +++ b/tensorflow/python/util/object_identity_test.py @@ -25,9 +25,32 @@ from tensorflow.python.util import object_identity class ObjectIdentityWrapperTest(test.TestCase): def testWrapperNotEqualToWrapped(self): - o = object() - self.assertNotEqual(o, object_identity._ObjectIdentityWrapper(o)) - self.assertNotEqual(object_identity._ObjectIdentityWrapper(o), o) + class SettableHash(object): + + def __init__(self): + self.hash_value = 8675309 + + def __hash__(self): + return self.hash_value + + o = SettableHash() + wrap1 = object_identity._ObjectIdentityWrapper(o) + wrap2 = object_identity._ObjectIdentityWrapper(o) + + self.assertEqual(wrap1, wrap1) + self.assertEqual(wrap1, wrap2) + self.assertEqual(o, wrap1.unwrapped) + self.assertEqual(o, wrap2.unwrapped) + with self.assertRaises(TypeError): + bool(o == wrap1) + with self.assertRaises(TypeError): + bool(wrap1 != o) + + self.assertNotIn(o, set([wrap1])) + o.hash_value = id(o) + # Since there is now a hash collision we raise an exception + with self.assertRaises(TypeError): + bool(o in set([wrap1])) class ObjectIdentitySetTest(test.TestCase):