Fix memory leak in python caused by @tf_should_use.
The issue is that python's GC has trouble collecting objects with __del__ methods. The solution is two pronged: * Keep track of usage state outside of the class, via a dict mapping id(object) => state * Remove __del__ (this was the source: python's GC couldn't collect wrapped objects), and instead use weakref.finalize to emit warnings just as the object is being garbage collected. * Added tests for garbage collection [they were failing before i fixed the issue] PiperOrigin-RevId: 158042388
This commit is contained in:
parent
cc411f9387
commit
cf238e1f2f
@ -17,14 +17,52 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import itertools
|
||||
import traceback
|
||||
import types
|
||||
|
||||
import six # pylint: disable=unused-import
|
||||
|
||||
from backports import weakref # pylint: disable=g-bad-import-order
|
||||
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.util import tf_decorator
|
||||
|
||||
|
||||
class _RefInfoField(
|
||||
collections.namedtuple(
|
||||
'_RefInfoField', ('type_', 'repr_', 'creation_stack', 'object_used'))):
|
||||
pass
|
||||
|
||||
|
||||
# Thread-safe up to int32max/2 thanks to python's GIL; and may be safe even for
|
||||
# higher values in Python 3.4+. We don't expect to ever count higher than this.
|
||||
# https://mail.python.org/pipermail/python-list/2005-April/342279.html
|
||||
_REF_ITER = itertools.count()
|
||||
|
||||
# Dictionary mapping id(obj) => _RefInfoField.
|
||||
_REF_INFO = {}
|
||||
|
||||
|
||||
def _deleted(obj_id, fatal_error):
|
||||
obj = _REF_INFO[obj_id]
|
||||
del _REF_INFO[obj_id]
|
||||
if not obj.object_used:
|
||||
if fatal_error:
|
||||
logger = tf_logging.fatal
|
||||
else:
|
||||
logger = tf_logging.error
|
||||
logger(
|
||||
'==================================\n'
|
||||
'Object was never used (type %s):\n%s\nIf you want to mark it as '
|
||||
'used call its "mark_used()" method.\nIt was originally created '
|
||||
'here:\n%s\n'
|
||||
'==================================' %
|
||||
(obj.type_, obj.repr_, obj.creation_stack))
|
||||
|
||||
|
||||
def _add_should_use_warning(x, fatal_error=False):
|
||||
"""Wraps object x so that if it is never used, a warning is logged.
|
||||
|
||||
@ -39,14 +77,14 @@ def _add_should_use_warning(x, fatal_error=False):
|
||||
"""
|
||||
if x is None: # special corner case where x is None
|
||||
return x
|
||||
has_been_used = getattr(x, '_tf_object_has_been_used', None)
|
||||
if has_been_used is not None:
|
||||
x._tf_object_has_been_used = has_been_used # pylint: disable=protected-access
|
||||
if hasattr(x, '_tf_ref_id'): # this is already a TFShouldUseWarningWrapper
|
||||
return x
|
||||
|
||||
def override_method(method):
|
||||
def fn(self, *args, **kwargs):
|
||||
self._tf_object_has_been_used = True # pylint: disable=protected-access
|
||||
# pylint: disable=protected-access
|
||||
_REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
|
||||
object_used=True)
|
||||
return method(self, *args, **kwargs)
|
||||
return fn
|
||||
|
||||
@ -55,38 +93,36 @@ def _add_should_use_warning(x, fatal_error=False):
|
||||
|
||||
def __init__(self, true_self):
|
||||
self.__dict__ = true_self.__dict__
|
||||
stack = [x.strip() for x in traceback.format_stack()]
|
||||
stack = [s.strip() for s in traceback.format_stack()]
|
||||
# Remove top three stack entries from adding the wrapper
|
||||
self._tf_object_creation_stack = '\n'.join(stack[:-3])
|
||||
self._tf_object_has_been_used = False
|
||||
self.creation_stack = '\n'.join(stack[:-3])
|
||||
self._tf_ref_id = next(_REF_ITER)
|
||||
_REF_INFO[self._tf_ref_id] = _RefInfoField(
|
||||
type_=type(x),
|
||||
repr_=repr(x),
|
||||
creation_stack=stack,
|
||||
object_used=False)
|
||||
|
||||
# Create a finalizer for self, which will be called when self is
|
||||
# garbage collected. Can't add self as the args because the
|
||||
# loop will break garbage collection. We keep track of
|
||||
# ourselves via python ids.
|
||||
weakref.finalize(self, _deleted, self._tf_ref_id, fatal_error)
|
||||
|
||||
# Not sure why this pylint warning is being used; this is not an
|
||||
# old class form.
|
||||
# pylint: disable=super-on-old-class
|
||||
def __getattribute__(self, name):
|
||||
if name != '_tf_object_has_been_used':
|
||||
self._tf_object_has_been_used = True
|
||||
if name == '_tf_ref_id':
|
||||
return super(TFShouldUseWarningWrapper, self).__getattribute__(name)
|
||||
if self._tf_ref_id in _REF_INFO:
|
||||
_REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
|
||||
object_used=True)
|
||||
return super(TFShouldUseWarningWrapper, self).__getattribute__(name)
|
||||
|
||||
def __del__(self):
|
||||
if not self._tf_object_has_been_used:
|
||||
if fatal_error:
|
||||
logger = tf_logging.fatal
|
||||
else:
|
||||
logger = tf_logging.error
|
||||
logger(
|
||||
'==================================\n'
|
||||
'Object was never used (type %s):\n%s\nIf you want to mark it as '
|
||||
'used call its "mark_used()" method.\nIt was originally created '
|
||||
'here:\n%s\n'
|
||||
'==================================' %
|
||||
(type(x), x, self._tf_object_creation_stack))
|
||||
|
||||
if hasattr(super(TFShouldUseWarningWrapper, self), '__del__'):
|
||||
return super(TFShouldUseWarningWrapper, self).__del__()
|
||||
|
||||
def mark_used(self, *args, **kwargs):
|
||||
self._tf_object_has_been_used = True
|
||||
_REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
|
||||
object_used=True)
|
||||
if hasattr(super(TFShouldUseWarningWrapper, self), 'mark_used'):
|
||||
return super(TFShouldUseWarningWrapper, self).mark_used(*args, **kwargs)
|
||||
# pylint: enable=super-on-old-class
|
||||
@ -102,7 +138,8 @@ def _add_should_use_warning(x, fatal_error=False):
|
||||
|
||||
wrapped = TFShouldUseWarningWrapper(x)
|
||||
wrapped.__doc__ = x.__doc__ # functools.wraps fails on some objects.
|
||||
wrapped._tf_object_has_been_used = False # pylint: disable=protected-access
|
||||
ref_id = wrapped._tf_ref_id # pylint: disable=protected-access
|
||||
_REF_INFO[ref_id] = _REF_INFO[ref_id]._replace(object_used=False)
|
||||
return wrapped
|
||||
|
||||
|
||||
|
@ -20,6 +20,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import sys
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -45,7 +46,7 @@ def reroute_error(captured):
|
||||
class TfShouldUseTest(test.TestCase):
|
||||
|
||||
def testAddShouldUseWarningWhenNotUsed(self):
|
||||
c = constant_op.constant(0, name='blah')
|
||||
c = constant_op.constant(0, name='blah0')
|
||||
captured = []
|
||||
with reroute_error(captured):
|
||||
def in_this_function():
|
||||
@ -53,44 +54,52 @@ class TfShouldUseTest(test.TestCase):
|
||||
del h
|
||||
in_this_function()
|
||||
self.assertIn('Object was never used', '\n'.join(captured))
|
||||
self.assertIn('blah:0', '\n'.join(captured))
|
||||
self.assertIn('blah0:0', '\n'.join(captured))
|
||||
self.assertIn('in_this_function', '\n'.join(captured))
|
||||
gc.collect()
|
||||
self.assertFalse(gc.garbage)
|
||||
|
||||
def _testAddShouldUseWarningWhenUsed(self, fn):
|
||||
c = constant_op.constant(0, name='blah')
|
||||
def _testAddShouldUseWarningWhenUsed(self, fn, name):
|
||||
c = constant_op.constant(0, name=name)
|
||||
captured = []
|
||||
with reroute_error(captured):
|
||||
h = tf_should_use._add_should_use_warning(c)
|
||||
fn(h)
|
||||
del h
|
||||
self.assertNotIn('Object was never used', '\n'.join(captured))
|
||||
self.assertNotIn('blah:0', '\n'.join(captured))
|
||||
self.assertNotIn('%s:0' % name, '\n'.join(captured))
|
||||
|
||||
def testAddShouldUseWarningWhenUsedWithAdd(self):
|
||||
def add(h):
|
||||
_ = h + 1
|
||||
self._testAddShouldUseWarningWhenUsed(add)
|
||||
self._testAddShouldUseWarningWhenUsed(add, name='blah_add')
|
||||
gc.collect()
|
||||
self.assertFalse(gc.garbage)
|
||||
|
||||
def testAddShouldUseWarningWhenUsedWithGetName(self):
|
||||
def get_name(h):
|
||||
_ = h.name
|
||||
self._testAddShouldUseWarningWhenUsed(get_name)
|
||||
self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name')
|
||||
gc.collect()
|
||||
self.assertFalse(gc.garbage)
|
||||
|
||||
def testShouldUseResult(self):
|
||||
@tf_should_use.should_use_result
|
||||
def return_const(value):
|
||||
return constant_op.constant(value, name='blah')
|
||||
return constant_op.constant(value, name='blah2')
|
||||
captured = []
|
||||
with reroute_error(captured):
|
||||
return_const(0.0)
|
||||
self.assertIn('Object was never used', '\n'.join(captured))
|
||||
self.assertIn('blah:0', '\n'.join(captured))
|
||||
self.assertIn('blah2:0', '\n'.join(captured))
|
||||
self.assertIn('return_const', '\n'.join(captured))
|
||||
gc.collect()
|
||||
self.assertFalse(gc.garbage)
|
||||
|
||||
def testShouldUseResultWhenNotReallyUsed(self):
|
||||
@tf_should_use.should_use_result
|
||||
def return_const(value):
|
||||
return constant_op.constant(value, name='blah')
|
||||
return constant_op.constant(value, name='blah3')
|
||||
captured = []
|
||||
with reroute_error(captured):
|
||||
with self.test_session():
|
||||
@ -100,8 +109,10 @@ class TfShouldUseTest(test.TestCase):
|
||||
v = constant_op.constant(1.0, name='meh')
|
||||
v.eval()
|
||||
self.assertIn('Object was never used', '\n'.join(captured))
|
||||
self.assertIn('blah:0', '\n'.join(captured))
|
||||
self.assertIn('blah3:0', '\n'.join(captured))
|
||||
self.assertIn('return_const', '\n'.join(captured))
|
||||
gc.collect()
|
||||
self.assertFalse(gc.garbage)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -85,3 +85,6 @@ pip2 install mock
|
||||
|
||||
pip2 install portpicker
|
||||
pip3 install portpicker
|
||||
|
||||
pip2 install backports.weakref==1.0rc1
|
||||
pip3 install backports.weakref==1.0rc1
|
||||
|
@ -89,3 +89,6 @@ pip3.5 install wheel==0.29.0
|
||||
pip3.5 install portpicker
|
||||
|
||||
pip3.5 install werkzeug
|
||||
|
||||
pip3.5 install backports.weakref==1.0rc1
|
||||
|
||||
|
@ -39,6 +39,7 @@ REQUIRED_PACKAGES = [
|
||||
'html5lib == 0.9999999', # identical to 1.0b8
|
||||
'markdown == 2.2.0',
|
||||
'bleach == 1.5.0',
|
||||
'backports.weakref == 1.0rc1',
|
||||
]
|
||||
|
||||
project_name = 'tensorflow'
|
||||
|
Loading…
Reference in New Issue
Block a user