Disallow comparing ObjectIdentityWrapper to others
When using the experimental_ref() API in Tensors & Variables. A common bug I hit was incorrectly comparing a wrapped object with an unwrapped object instead of first calling deref(). To avoid this we raise an exception now instead of returning False. This implies that if Tensors and Variables are kept in the same set or dictionary as other objects, an exception can be raised if there is a hash collision. PiperOrigin-RevId: 268837575
This commit is contained in:
parent
37ea74a7c5
commit
57e8769bc4
tensorflow/python/util
@ -40,9 +40,10 @@ class _ObjectIdentityWrapper(object):
|
||||
return self._wrapped
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, _ObjectIdentityWrapper):
|
||||
return self._wrapped is other._wrapped # pylint: disable=protected-access
|
||||
return False
|
||||
if not isinstance(other, _ObjectIdentityWrapper):
|
||||
raise TypeError("Cannot compare wrapped object with unwrapped object")
|
||||
|
||||
return self._wrapped is other._wrapped # pylint: disable=protected-access
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
@ -25,9 +25,32 @@ from tensorflow.python.util import object_identity
|
||||
class ObjectIdentityWrapperTest(test.TestCase):
|
||||
|
||||
def testWrapperNotEqualToWrapped(self):
|
||||
o = object()
|
||||
self.assertNotEqual(o, object_identity._ObjectIdentityWrapper(o))
|
||||
self.assertNotEqual(object_identity._ObjectIdentityWrapper(o), o)
|
||||
class SettableHash(object):
|
||||
|
||||
def __init__(self):
|
||||
self.hash_value = 8675309
|
||||
|
||||
def __hash__(self):
|
||||
return self.hash_value
|
||||
|
||||
o = SettableHash()
|
||||
wrap1 = object_identity._ObjectIdentityWrapper(o)
|
||||
wrap2 = object_identity._ObjectIdentityWrapper(o)
|
||||
|
||||
self.assertEqual(wrap1, wrap1)
|
||||
self.assertEqual(wrap1, wrap2)
|
||||
self.assertEqual(o, wrap1.unwrapped)
|
||||
self.assertEqual(o, wrap2.unwrapped)
|
||||
with self.assertRaises(TypeError):
|
||||
bool(o == wrap1)
|
||||
with self.assertRaises(TypeError):
|
||||
bool(wrap1 != o)
|
||||
|
||||
self.assertNotIn(o, set([wrap1]))
|
||||
o.hash_value = id(o)
|
||||
# Since there is now a hash collision we raise an exception
|
||||
with self.assertRaises(TypeError):
|
||||
bool(o in set([wrap1]))
|
||||
|
||||
|
||||
class ObjectIdentitySetTest(test.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user