Prepare //tensorflow/python/eager:function_test for Tensor equality.
PiperOrigin-RevId: 263562894
(cherry picked from commit 913f565e0f
)
This commit is contained in:
parent
beff50ca87
commit
5b27c05c38
@ -2581,10 +2581,17 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
def testDecoratedMethodVariableCleanup(self):
|
def testDecoratedMethodVariableCleanup(self):
|
||||||
m = DefunnedMiniModel()
|
m = DefunnedMiniModel()
|
||||||
m(array_ops.ones([1, 2]))
|
m(array_ops.ones([1, 2]))
|
||||||
weak_variables = weakref.WeakSet(m.variables)
|
variable_refs = list({v.experimental_ref() for v in m.variables})
|
||||||
self.assertLen(weak_variables, 2)
|
self.assertLen(variable_refs, 2)
|
||||||
del m
|
del m
|
||||||
self.assertEqual([], list(weak_variables))
|
|
||||||
|
# Verifying if the variables are only referenced from variable_refs.
|
||||||
|
# We expect the reference counter to be 1, but `sys.getrefcount` reports
|
||||||
|
# one higher reference counter because a temporary is created when we call
|
||||||
|
# sys.getrefcount(). Hence check if the number returned is 2.
|
||||||
|
# https://docs.python.org/3/library/sys.html#sys.getrefcount
|
||||||
|
self.assertEqual(sys.getrefcount(variable_refs[0].deref()), 2)
|
||||||
|
self.assertEqual(sys.getrefcount(variable_refs[1].deref()), 2)
|
||||||
|
|
||||||
def testExecutorType(self):
|
def testExecutorType(self):
|
||||||
@function.defun
|
@function.defun
|
||||||
|
Loading…
Reference in New Issue
Block a user