diff --git a/tensorflow/python/util/object_identity.py b/tensorflow/python/util/object_identity.py
index 47de08d0bb0..15a0694b1c1 100644
--- a/tensorflow/python/util/object_identity.py
+++ b/tensorflow/python/util/object_identity.py
@@ -38,9 +38,10 @@ class _ObjectIdentityWrapper(object):
     return self._wrapped
 
   def __eq__(self, other):
-    if isinstance(other, _ObjectIdentityWrapper):
-      return self._wrapped is other._wrapped  # pylint: disable=protected-access
-    return False
+    if not isinstance(other, _ObjectIdentityWrapper):
+      raise TypeError("Cannot compare wrapped object with unwrapped object")
+
+    return self._wrapped is other._wrapped  # pylint: disable=protected-access
 
   def __ne__(self, other):
     return not self.__eq__(other)
diff --git a/tensorflow/python/util/object_identity_test.py b/tensorflow/python/util/object_identity_test.py
index 5dc8be1a25d..67d26ebdcab 100644
--- a/tensorflow/python/util/object_identity_test.py
+++ b/tensorflow/python/util/object_identity_test.py
@@ -25,9 +25,32 @@ from tensorflow.python.util import object_identity
 class ObjectIdentityWrapperTest(test.TestCase):
 
   def testWrapperNotEqualToWrapped(self):
-    o = object()
-    self.assertNotEqual(o, object_identity._ObjectIdentityWrapper(o))
-    self.assertNotEqual(object_identity._ObjectIdentityWrapper(o), o)
+    class SettableHash(object):
+
+      def __init__(self):
+        self.hash_value = 8675309
+
+      def __hash__(self):
+        return self.hash_value
+
+    o = SettableHash()
+    wrap1 = object_identity._ObjectIdentityWrapper(o)
+    wrap2 = object_identity._ObjectIdentityWrapper(o)
+
+    self.assertEqual(wrap1, wrap1)
+    self.assertEqual(wrap1, wrap2)
+    self.assertEqual(o, wrap1.unwrapped)
+    self.assertEqual(o, wrap2.unwrapped)
+    with self.assertRaises(TypeError):
+      bool(o == wrap1)
+    with self.assertRaises(TypeError):
+      bool(wrap1 != o)
+
+    self.assertNotIn(o, set([wrap1]))
+    o.hash_value = id(o)
+    # Since there is now a hash collision we raise an exception
+    with self.assertRaises(TypeError):
+      bool(o in set([wrap1]))
 
 
 class ObjectIdentitySetTest(test.TestCase):