From 97b6a54c41106fe43cfb8d7124928cde0627353f Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Thu, 12 Sep 2019 22:20:10 -0700 Subject: [PATCH] Disallow comparing ObjectIdentityWrapper to others When using the experimental_ref() API in Tensors & Variables. A common bug I hit was incorrectly comparing a wrapped object with an unwrapped object instead of first calling deref(). To avoid this we raise an exception now instead of returning False. This implies that if Tensors and Variables are kept in the same set or dictionary as other objects, an exception can be raised if there is a hash collision. PiperOrigin-RevId: 268837575 (cherry picked from commit 57e8769bc4ef1c94ddbcfbe4a39afe8f73b433c5) --- tensorflow/python/util/object_identity.py | 7 +++-- .../python/util/object_identity_test.py | 29 +++++++++++++++++-- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/util/object_identity.py b/tensorflow/python/util/object_identity.py index 47de08d0bb0..15a0694b1c1 100644 --- a/tensorflow/python/util/object_identity.py +++ b/tensorflow/python/util/object_identity.py @@ -38,9 +38,10 @@ class _ObjectIdentityWrapper(object): return self._wrapped def __eq__(self, other): - if isinstance(other, _ObjectIdentityWrapper): - return self._wrapped is other._wrapped # pylint: disable=protected-access - return False + if not isinstance(other, _ObjectIdentityWrapper): + raise TypeError("Cannot compare wrapped object with unwrapped object") + + return self._wrapped is other._wrapped # pylint: disable=protected-access def __ne__(self, other): return not self.__eq__(other) diff --git a/tensorflow/python/util/object_identity_test.py b/tensorflow/python/util/object_identity_test.py index 5dc8be1a25d..67d26ebdcab 100644 --- a/tensorflow/python/util/object_identity_test.py +++ b/tensorflow/python/util/object_identity_test.py @@ -25,9 +25,32 @@ from tensorflow.python.util import object_identity class ObjectIdentityWrapperTest(test.TestCase): def testWrapperNotEqualToWrapped(self): - o = object() - self.assertNotEqual(o, object_identity._ObjectIdentityWrapper(o)) - self.assertNotEqual(object_identity._ObjectIdentityWrapper(o), o) + class SettableHash(object): + + def __init__(self): + self.hash_value = 8675309 + + def __hash__(self): + return self.hash_value + + o = SettableHash() + wrap1 = object_identity._ObjectIdentityWrapper(o) + wrap2 = object_identity._ObjectIdentityWrapper(o) + + self.assertEqual(wrap1, wrap1) + self.assertEqual(wrap1, wrap2) + self.assertEqual(o, wrap1.unwrapped) + self.assertEqual(o, wrap2.unwrapped) + with self.assertRaises(TypeError): + bool(o == wrap1) + with self.assertRaises(TypeError): + bool(wrap1 != o) + + self.assertNotIn(o, set([wrap1])) + o.hash_value = id(o) + # Since there is now a hash collision we raise an exception + with self.assertRaises(TypeError): + bool(o in set([wrap1])) class ObjectIdentitySetTest(test.TestCase):