Fix memory leak in python caused by @tf_should_use.
The issue is that python's GC has trouble collecting objects with __del__ methods. The solution is two pronged: * Keep track of usage state outside of the class, via a dict mapping id(object) => state * Remove __del__ (this was the source: python's GC couldn't collect wrapped objects), and instead use weakref.finalize to emit warnings just as the object is being garbage collected. * Added tests for garbage collection [they were failing before i fixed the issue] PiperOrigin-RevId: 158042388
This commit is contained in:
parent
cc411f9387
commit
cf238e1f2f
@ -17,14 +17,52 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
import functools
|
import functools
|
||||||
|
import itertools
|
||||||
import traceback
|
import traceback
|
||||||
import types
|
import types
|
||||||
|
|
||||||
|
import six # pylint: disable=unused-import
|
||||||
|
|
||||||
|
from backports import weakref # pylint: disable=g-bad-import-order
|
||||||
|
|
||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
|
|
||||||
|
|
||||||
|
class _RefInfoField(
|
||||||
|
collections.namedtuple(
|
||||||
|
'_RefInfoField', ('type_', 'repr_', 'creation_stack', 'object_used'))):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Thread-safe up to int32max/2 thanks to python's GIL; and may be safe even for
|
||||||
|
# higher values in Python 3.4+. We don't expect to ever count higher than this.
|
||||||
|
# https://mail.python.org/pipermail/python-list/2005-April/342279.html
|
||||||
|
_REF_ITER = itertools.count()
|
||||||
|
|
||||||
|
# Dictionary mapping id(obj) => _RefInfoField.
|
||||||
|
_REF_INFO = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _deleted(obj_id, fatal_error):
|
||||||
|
obj = _REF_INFO[obj_id]
|
||||||
|
del _REF_INFO[obj_id]
|
||||||
|
if not obj.object_used:
|
||||||
|
if fatal_error:
|
||||||
|
logger = tf_logging.fatal
|
||||||
|
else:
|
||||||
|
logger = tf_logging.error
|
||||||
|
logger(
|
||||||
|
'==================================\n'
|
||||||
|
'Object was never used (type %s):\n%s\nIf you want to mark it as '
|
||||||
|
'used call its "mark_used()" method.\nIt was originally created '
|
||||||
|
'here:\n%s\n'
|
||||||
|
'==================================' %
|
||||||
|
(obj.type_, obj.repr_, obj.creation_stack))
|
||||||
|
|
||||||
|
|
||||||
def _add_should_use_warning(x, fatal_error=False):
|
def _add_should_use_warning(x, fatal_error=False):
|
||||||
"""Wraps object x so that if it is never used, a warning is logged.
|
"""Wraps object x so that if it is never used, a warning is logged.
|
||||||
|
|
||||||
@ -39,14 +77,14 @@ def _add_should_use_warning(x, fatal_error=False):
|
|||||||
"""
|
"""
|
||||||
if x is None: # special corner case where x is None
|
if x is None: # special corner case where x is None
|
||||||
return x
|
return x
|
||||||
has_been_used = getattr(x, '_tf_object_has_been_used', None)
|
if hasattr(x, '_tf_ref_id'): # this is already a TFShouldUseWarningWrapper
|
||||||
if has_been_used is not None:
|
|
||||||
x._tf_object_has_been_used = has_been_used # pylint: disable=protected-access
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def override_method(method):
|
def override_method(method):
|
||||||
def fn(self, *args, **kwargs):
|
def fn(self, *args, **kwargs):
|
||||||
self._tf_object_has_been_used = True # pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
_REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
|
||||||
|
object_used=True)
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
@ -55,38 +93,36 @@ def _add_should_use_warning(x, fatal_error=False):
|
|||||||
|
|
||||||
def __init__(self, true_self):
|
def __init__(self, true_self):
|
||||||
self.__dict__ = true_self.__dict__
|
self.__dict__ = true_self.__dict__
|
||||||
stack = [x.strip() for x in traceback.format_stack()]
|
stack = [s.strip() for s in traceback.format_stack()]
|
||||||
# Remove top three stack entries from adding the wrapper
|
# Remove top three stack entries from adding the wrapper
|
||||||
self._tf_object_creation_stack = '\n'.join(stack[:-3])
|
self.creation_stack = '\n'.join(stack[:-3])
|
||||||
self._tf_object_has_been_used = False
|
self._tf_ref_id = next(_REF_ITER)
|
||||||
|
_REF_INFO[self._tf_ref_id] = _RefInfoField(
|
||||||
|
type_=type(x),
|
||||||
|
repr_=repr(x),
|
||||||
|
creation_stack=stack,
|
||||||
|
object_used=False)
|
||||||
|
|
||||||
|
# Create a finalizer for self, which will be called when self is
|
||||||
|
# garbage collected. Can't add self as the args because the
|
||||||
|
# loop will break garbage collection. We keep track of
|
||||||
|
# ourselves via python ids.
|
||||||
|
weakref.finalize(self, _deleted, self._tf_ref_id, fatal_error)
|
||||||
|
|
||||||
# Not sure why this pylint warning is being used; this is not an
|
# Not sure why this pylint warning is being used; this is not an
|
||||||
# old class form.
|
# old class form.
|
||||||
# pylint: disable=super-on-old-class
|
# pylint: disable=super-on-old-class
|
||||||
def __getattribute__(self, name):
|
def __getattribute__(self, name):
|
||||||
if name != '_tf_object_has_been_used':
|
if name == '_tf_ref_id':
|
||||||
self._tf_object_has_been_used = True
|
return super(TFShouldUseWarningWrapper, self).__getattribute__(name)
|
||||||
|
if self._tf_ref_id in _REF_INFO:
|
||||||
|
_REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
|
||||||
|
object_used=True)
|
||||||
return super(TFShouldUseWarningWrapper, self).__getattribute__(name)
|
return super(TFShouldUseWarningWrapper, self).__getattribute__(name)
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if not self._tf_object_has_been_used:
|
|
||||||
if fatal_error:
|
|
||||||
logger = tf_logging.fatal
|
|
||||||
else:
|
|
||||||
logger = tf_logging.error
|
|
||||||
logger(
|
|
||||||
'==================================\n'
|
|
||||||
'Object was never used (type %s):\n%s\nIf you want to mark it as '
|
|
||||||
'used call its "mark_used()" method.\nIt was originally created '
|
|
||||||
'here:\n%s\n'
|
|
||||||
'==================================' %
|
|
||||||
(type(x), x, self._tf_object_creation_stack))
|
|
||||||
|
|
||||||
if hasattr(super(TFShouldUseWarningWrapper, self), '__del__'):
|
|
||||||
return super(TFShouldUseWarningWrapper, self).__del__()
|
|
||||||
|
|
||||||
def mark_used(self, *args, **kwargs):
|
def mark_used(self, *args, **kwargs):
|
||||||
self._tf_object_has_been_used = True
|
_REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
|
||||||
|
object_used=True)
|
||||||
if hasattr(super(TFShouldUseWarningWrapper, self), 'mark_used'):
|
if hasattr(super(TFShouldUseWarningWrapper, self), 'mark_used'):
|
||||||
return super(TFShouldUseWarningWrapper, self).mark_used(*args, **kwargs)
|
return super(TFShouldUseWarningWrapper, self).mark_used(*args, **kwargs)
|
||||||
# pylint: enable=super-on-old-class
|
# pylint: enable=super-on-old-class
|
||||||
@ -102,7 +138,8 @@ def _add_should_use_warning(x, fatal_error=False):
|
|||||||
|
|
||||||
wrapped = TFShouldUseWarningWrapper(x)
|
wrapped = TFShouldUseWarningWrapper(x)
|
||||||
wrapped.__doc__ = x.__doc__ # functools.wraps fails on some objects.
|
wrapped.__doc__ = x.__doc__ # functools.wraps fails on some objects.
|
||||||
wrapped._tf_object_has_been_used = False # pylint: disable=protected-access
|
ref_id = wrapped._tf_ref_id # pylint: disable=protected-access
|
||||||
|
_REF_INFO[ref_id] = _REF_INFO[ref_id]._replace(object_used=False)
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import gc
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -45,7 +46,7 @@ def reroute_error(captured):
|
|||||||
class TfShouldUseTest(test.TestCase):
|
class TfShouldUseTest(test.TestCase):
|
||||||
|
|
||||||
def testAddShouldUseWarningWhenNotUsed(self):
|
def testAddShouldUseWarningWhenNotUsed(self):
|
||||||
c = constant_op.constant(0, name='blah')
|
c = constant_op.constant(0, name='blah0')
|
||||||
captured = []
|
captured = []
|
||||||
with reroute_error(captured):
|
with reroute_error(captured):
|
||||||
def in_this_function():
|
def in_this_function():
|
||||||
@ -53,44 +54,52 @@ class TfShouldUseTest(test.TestCase):
|
|||||||
del h
|
del h
|
||||||
in_this_function()
|
in_this_function()
|
||||||
self.assertIn('Object was never used', '\n'.join(captured))
|
self.assertIn('Object was never used', '\n'.join(captured))
|
||||||
self.assertIn('blah:0', '\n'.join(captured))
|
self.assertIn('blah0:0', '\n'.join(captured))
|
||||||
self.assertIn('in_this_function', '\n'.join(captured))
|
self.assertIn('in_this_function', '\n'.join(captured))
|
||||||
|
gc.collect()
|
||||||
|
self.assertFalse(gc.garbage)
|
||||||
|
|
||||||
def _testAddShouldUseWarningWhenUsed(self, fn):
|
def _testAddShouldUseWarningWhenUsed(self, fn, name):
|
||||||
c = constant_op.constant(0, name='blah')
|
c = constant_op.constant(0, name=name)
|
||||||
captured = []
|
captured = []
|
||||||
with reroute_error(captured):
|
with reroute_error(captured):
|
||||||
h = tf_should_use._add_should_use_warning(c)
|
h = tf_should_use._add_should_use_warning(c)
|
||||||
fn(h)
|
fn(h)
|
||||||
del h
|
del h
|
||||||
self.assertNotIn('Object was never used', '\n'.join(captured))
|
self.assertNotIn('Object was never used', '\n'.join(captured))
|
||||||
self.assertNotIn('blah:0', '\n'.join(captured))
|
self.assertNotIn('%s:0' % name, '\n'.join(captured))
|
||||||
|
|
||||||
def testAddShouldUseWarningWhenUsedWithAdd(self):
|
def testAddShouldUseWarningWhenUsedWithAdd(self):
|
||||||
def add(h):
|
def add(h):
|
||||||
_ = h + 1
|
_ = h + 1
|
||||||
self._testAddShouldUseWarningWhenUsed(add)
|
self._testAddShouldUseWarningWhenUsed(add, name='blah_add')
|
||||||
|
gc.collect()
|
||||||
|
self.assertFalse(gc.garbage)
|
||||||
|
|
||||||
def testAddShouldUseWarningWhenUsedWithGetName(self):
|
def testAddShouldUseWarningWhenUsedWithGetName(self):
|
||||||
def get_name(h):
|
def get_name(h):
|
||||||
_ = h.name
|
_ = h.name
|
||||||
self._testAddShouldUseWarningWhenUsed(get_name)
|
self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name')
|
||||||
|
gc.collect()
|
||||||
|
self.assertFalse(gc.garbage)
|
||||||
|
|
||||||
def testShouldUseResult(self):
|
def testShouldUseResult(self):
|
||||||
@tf_should_use.should_use_result
|
@tf_should_use.should_use_result
|
||||||
def return_const(value):
|
def return_const(value):
|
||||||
return constant_op.constant(value, name='blah')
|
return constant_op.constant(value, name='blah2')
|
||||||
captured = []
|
captured = []
|
||||||
with reroute_error(captured):
|
with reroute_error(captured):
|
||||||
return_const(0.0)
|
return_const(0.0)
|
||||||
self.assertIn('Object was never used', '\n'.join(captured))
|
self.assertIn('Object was never used', '\n'.join(captured))
|
||||||
self.assertIn('blah:0', '\n'.join(captured))
|
self.assertIn('blah2:0', '\n'.join(captured))
|
||||||
self.assertIn('return_const', '\n'.join(captured))
|
self.assertIn('return_const', '\n'.join(captured))
|
||||||
|
gc.collect()
|
||||||
|
self.assertFalse(gc.garbage)
|
||||||
|
|
||||||
def testShouldUseResultWhenNotReallyUsed(self):
|
def testShouldUseResultWhenNotReallyUsed(self):
|
||||||
@tf_should_use.should_use_result
|
@tf_should_use.should_use_result
|
||||||
def return_const(value):
|
def return_const(value):
|
||||||
return constant_op.constant(value, name='blah')
|
return constant_op.constant(value, name='blah3')
|
||||||
captured = []
|
captured = []
|
||||||
with reroute_error(captured):
|
with reroute_error(captured):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
@ -100,8 +109,10 @@ class TfShouldUseTest(test.TestCase):
|
|||||||
v = constant_op.constant(1.0, name='meh')
|
v = constant_op.constant(1.0, name='meh')
|
||||||
v.eval()
|
v.eval()
|
||||||
self.assertIn('Object was never used', '\n'.join(captured))
|
self.assertIn('Object was never used', '\n'.join(captured))
|
||||||
self.assertIn('blah:0', '\n'.join(captured))
|
self.assertIn('blah3:0', '\n'.join(captured))
|
||||||
self.assertIn('return_const', '\n'.join(captured))
|
self.assertIn('return_const', '\n'.join(captured))
|
||||||
|
gc.collect()
|
||||||
|
self.assertFalse(gc.garbage)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -85,3 +85,6 @@ pip2 install mock
|
|||||||
|
|
||||||
pip2 install portpicker
|
pip2 install portpicker
|
||||||
pip3 install portpicker
|
pip3 install portpicker
|
||||||
|
|
||||||
|
pip2 install backports.weakref==1.0rc1
|
||||||
|
pip3 install backports.weakref==1.0rc1
|
||||||
|
@ -89,3 +89,6 @@ pip3.5 install wheel==0.29.0
|
|||||||
pip3.5 install portpicker
|
pip3.5 install portpicker
|
||||||
|
|
||||||
pip3.5 install werkzeug
|
pip3.5 install werkzeug
|
||||||
|
|
||||||
|
pip3.5 install backports.weakref==1.0rc1
|
||||||
|
|
||||||
|
@ -39,6 +39,7 @@ REQUIRED_PACKAGES = [
|
|||||||
'html5lib == 0.9999999', # identical to 1.0b8
|
'html5lib == 0.9999999', # identical to 1.0b8
|
||||||
'markdown == 2.2.0',
|
'markdown == 2.2.0',
|
||||||
'bleach == 1.5.0',
|
'bleach == 1.5.0',
|
||||||
|
'backports.weakref == 1.0rc1',
|
||||||
]
|
]
|
||||||
|
|
||||||
project_name = 'tensorflow'
|
project_name = 'tensorflow'
|
||||||
|
Loading…
Reference in New Issue
Block a user