Merge pull request #32526 from jaingaurav/cherry-1.15-2
[r1.15-CherryPick]:Disallow comparing ObjectIdentityWrapper to others
This commit is contained in:
commit
ea930781c3
tensorflow/python
@ -444,7 +444,7 @@ class ExponentialMovingAverage(object):
|
|||||||
"Variable", "VariableV2", "VarHandleOp"
|
"Variable", "VariableV2", "VarHandleOp"
|
||||||
]))
|
]))
|
||||||
if self._zero_debias:
|
if self._zero_debias:
|
||||||
zero_debias_true.add(avg)
|
zero_debias_true.add(avg.experimental_ref())
|
||||||
self._averages[var.experimental_ref()] = avg
|
self._averages[var.experimental_ref()] = avg
|
||||||
|
|
||||||
with ops.name_scope(self.name) as scope:
|
with ops.name_scope(self.name) as scope:
|
||||||
|
@ -38,9 +38,10 @@ class _ObjectIdentityWrapper(object):
|
|||||||
return self._wrapped
|
return self._wrapped
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, _ObjectIdentityWrapper):
|
if not isinstance(other, _ObjectIdentityWrapper):
|
||||||
return self._wrapped is other._wrapped # pylint: disable=protected-access
|
raise TypeError("Cannot compare wrapped object with unwrapped object")
|
||||||
return False
|
|
||||||
|
return self._wrapped is other._wrapped # pylint: disable=protected-access
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
@ -25,9 +25,32 @@ from tensorflow.python.util import object_identity
|
|||||||
class ObjectIdentityWrapperTest(test.TestCase):
|
class ObjectIdentityWrapperTest(test.TestCase):
|
||||||
|
|
||||||
def testWrapperNotEqualToWrapped(self):
|
def testWrapperNotEqualToWrapped(self):
|
||||||
o = object()
|
class SettableHash(object):
|
||||||
self.assertNotEqual(o, object_identity._ObjectIdentityWrapper(o))
|
|
||||||
self.assertNotEqual(object_identity._ObjectIdentityWrapper(o), o)
|
def __init__(self):
|
||||||
|
self.hash_value = 8675309
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return self.hash_value
|
||||||
|
|
||||||
|
o = SettableHash()
|
||||||
|
wrap1 = object_identity._ObjectIdentityWrapper(o)
|
||||||
|
wrap2 = object_identity._ObjectIdentityWrapper(o)
|
||||||
|
|
||||||
|
self.assertEqual(wrap1, wrap1)
|
||||||
|
self.assertEqual(wrap1, wrap2)
|
||||||
|
self.assertEqual(o, wrap1.unwrapped)
|
||||||
|
self.assertEqual(o, wrap2.unwrapped)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
bool(o == wrap1)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
bool(wrap1 != o)
|
||||||
|
|
||||||
|
self.assertNotIn(o, set([wrap1]))
|
||||||
|
o.hash_value = id(o)
|
||||||
|
# Since there is now a hash collision we raise an exception
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
bool(o in set([wrap1]))
|
||||||
|
|
||||||
|
|
||||||
class ObjectIdentitySetTest(test.TestCase):
|
class ObjectIdentitySetTest(test.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user