Disallow comparing ObjectIdentityWrapper to others

When using the experimental_ref() API in Tensors & Variables. A common
bug I hit was incorrectly comparing a wrapped object with an unwrapped
object instead of first calling deref(). To avoid this we raise an
exception now instead of returning False. This implies that if Tensors
and Variables are kept in the same set or dictionary as other objects,
an exception can be raised if there is a hash collision.

PiperOrigin-RevId: 268837575
(cherry picked from commit 57e8769bc4)
This commit is contained in:
Gaurav Jain 2019-09-12 22:20:10 -07:00
parent c09880bd0f
commit f882f551d6
2 changed files with 30 additions and 6 deletions

View File

@ -38,9 +38,10 @@ class _ObjectIdentityWrapper(object):
return self._wrapped
def __eq__(self, other):
if isinstance(other, _ObjectIdentityWrapper):
return self._wrapped is other._wrapped # pylint: disable=protected-access
return False
if not isinstance(other, _ObjectIdentityWrapper):
raise TypeError("Cannot compare wrapped object with unwrapped object")
return self._wrapped is other._wrapped # pylint: disable=protected-access
def __ne__(self, other):
return not self.__eq__(other)

View File

@ -25,9 +25,32 @@ from tensorflow.python.util import object_identity
class ObjectIdentityWrapperTest(test.TestCase):
def testWrapperNotEqualToWrapped(self):
o = object()
self.assertNotEqual(o, object_identity._ObjectIdentityWrapper(o))
self.assertNotEqual(object_identity._ObjectIdentityWrapper(o), o)
class SettableHash(object):
def __init__(self):
self.hash_value = 8675309
def __hash__(self):
return self.hash_value
o = SettableHash()
wrap1 = object_identity._ObjectIdentityWrapper(o)
wrap2 = object_identity._ObjectIdentityWrapper(o)
self.assertEqual(wrap1, wrap1)
self.assertEqual(wrap1, wrap2)
self.assertEqual(o, wrap1.unwrapped)
self.assertEqual(o, wrap2.unwrapped)
with self.assertRaises(TypeError):
bool(o == wrap1)
with self.assertRaises(TypeError):
bool(wrap1 != o)
self.assertNotIn(o, set([wrap1]))
o.hash_value = id(o)
# Since there is now a hash collision we raise an exception
with self.assertRaises(TypeError):
bool(o in set([wrap1]))
class ObjectIdentitySetTest(test.TestCase):