Fix memory leak in python caused by @tf_should_use.

The issue is that python's GC has trouble collecting objects with __del__ methods.

The solution is two pronged:
* Keep track of usage state outside of the class, via a dict mapping
  id(object) => state
* Remove __del__ (this was the source: python's GC couldn't collect wrapped
  objects), and instead use weakref.finalize to emit warnings just as the object
  is being garbage collected.
* Added tests for garbage collection [they were failing before i fixed the issue]

PiperOrigin-RevId: 158042388
This commit is contained in:
Eugene Brevdo 2017-06-05 11:54:49 -07:00 committed by TensorFlower Gardener
parent cc411f9387
commit cf238e1f2f
5 changed files with 94 additions and 39 deletions

View File

@ -17,14 +17,52 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
import itertools
import traceback
import types
import six # pylint: disable=unused-import
from backports import weakref # pylint: disable=g-bad-import-order
from tensorflow.python.platform import tf_logging
from tensorflow.python.util import tf_decorator
class _RefInfoField(
collections.namedtuple(
'_RefInfoField', ('type_', 'repr_', 'creation_stack', 'object_used'))):
pass
# Thread-safe up to int32max/2 thanks to python's GIL; and may be safe even for
# higher values in Python 3.4+. We don't expect to ever count higher than this.
# https://mail.python.org/pipermail/python-list/2005-April/342279.html
_REF_ITER = itertools.count()
# Dictionary mapping id(obj) => _RefInfoField.
_REF_INFO = {}
def _deleted(obj_id, fatal_error):
obj = _REF_INFO[obj_id]
del _REF_INFO[obj_id]
if not obj.object_used:
if fatal_error:
logger = tf_logging.fatal
else:
logger = tf_logging.error
logger(
'==================================\n'
'Object was never used (type %s):\n%s\nIf you want to mark it as '
'used call its "mark_used()" method.\nIt was originally created '
'here:\n%s\n'
'==================================' %
(obj.type_, obj.repr_, obj.creation_stack))
def _add_should_use_warning(x, fatal_error=False):
"""Wraps object x so that if it is never used, a warning is logged.
@ -39,14 +77,14 @@ def _add_should_use_warning(x, fatal_error=False):
"""
if x is None: # special corner case where x is None
return x
has_been_used = getattr(x, '_tf_object_has_been_used', None)
if has_been_used is not None:
x._tf_object_has_been_used = has_been_used # pylint: disable=protected-access
if hasattr(x, '_tf_ref_id'): # this is already a TFShouldUseWarningWrapper
return x
def override_method(method):
def fn(self, *args, **kwargs):
self._tf_object_has_been_used = True # pylint: disable=protected-access
# pylint: disable=protected-access
_REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
object_used=True)
return method(self, *args, **kwargs)
return fn
@ -55,38 +93,36 @@ def _add_should_use_warning(x, fatal_error=False):
def __init__(self, true_self):
self.__dict__ = true_self.__dict__
stack = [x.strip() for x in traceback.format_stack()]
stack = [s.strip() for s in traceback.format_stack()]
# Remove top three stack entries from adding the wrapper
self._tf_object_creation_stack = '\n'.join(stack[:-3])
self._tf_object_has_been_used = False
self.creation_stack = '\n'.join(stack[:-3])
self._tf_ref_id = next(_REF_ITER)
_REF_INFO[self._tf_ref_id] = _RefInfoField(
type_=type(x),
repr_=repr(x),
creation_stack=stack,
object_used=False)
# Create a finalizer for self, which will be called when self is
# garbage collected. Can't add self as the args because the
# loop will break garbage collection. We keep track of
# ourselves via python ids.
weakref.finalize(self, _deleted, self._tf_ref_id, fatal_error)
# Not sure why this pylint warning is being used; this is not an
# old class form.
# pylint: disable=super-on-old-class
def __getattribute__(self, name):
if name != '_tf_object_has_been_used':
self._tf_object_has_been_used = True
if name == '_tf_ref_id':
return super(TFShouldUseWarningWrapper, self).__getattribute__(name)
if self._tf_ref_id in _REF_INFO:
_REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
object_used=True)
return super(TFShouldUseWarningWrapper, self).__getattribute__(name)
def __del__(self):
if not self._tf_object_has_been_used:
if fatal_error:
logger = tf_logging.fatal
else:
logger = tf_logging.error
logger(
'==================================\n'
'Object was never used (type %s):\n%s\nIf you want to mark it as '
'used call its "mark_used()" method.\nIt was originally created '
'here:\n%s\n'
'==================================' %
(type(x), x, self._tf_object_creation_stack))
if hasattr(super(TFShouldUseWarningWrapper, self), '__del__'):
return super(TFShouldUseWarningWrapper, self).__del__()
def mark_used(self, *args, **kwargs):
self._tf_object_has_been_used = True
_REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
object_used=True)
if hasattr(super(TFShouldUseWarningWrapper, self), 'mark_used'):
return super(TFShouldUseWarningWrapper, self).mark_used(*args, **kwargs)
# pylint: enable=super-on-old-class
@ -102,7 +138,8 @@ def _add_should_use_warning(x, fatal_error=False):
wrapped = TFShouldUseWarningWrapper(x)
wrapped.__doc__ = x.__doc__ # functools.wraps fails on some objects.
wrapped._tf_object_has_been_used = False # pylint: disable=protected-access
ref_id = wrapped._tf_ref_id # pylint: disable=protected-access
_REF_INFO[ref_id] = _REF_INFO[ref_id]._replace(object_used=False)
return wrapped

View File

@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import contextlib
import gc
import sys
from tensorflow.python.framework import constant_op
@ -45,7 +46,7 @@ def reroute_error(captured):
class TfShouldUseTest(test.TestCase):
def testAddShouldUseWarningWhenNotUsed(self):
c = constant_op.constant(0, name='blah')
c = constant_op.constant(0, name='blah0')
captured = []
with reroute_error(captured):
def in_this_function():
@ -53,44 +54,52 @@ class TfShouldUseTest(test.TestCase):
del h
in_this_function()
self.assertIn('Object was never used', '\n'.join(captured))
self.assertIn('blah:0', '\n'.join(captured))
self.assertIn('blah0:0', '\n'.join(captured))
self.assertIn('in_this_function', '\n'.join(captured))
gc.collect()
self.assertFalse(gc.garbage)
def _testAddShouldUseWarningWhenUsed(self, fn):
c = constant_op.constant(0, name='blah')
def _testAddShouldUseWarningWhenUsed(self, fn, name):
c = constant_op.constant(0, name=name)
captured = []
with reroute_error(captured):
h = tf_should_use._add_should_use_warning(c)
fn(h)
del h
self.assertNotIn('Object was never used', '\n'.join(captured))
self.assertNotIn('blah:0', '\n'.join(captured))
self.assertNotIn('%s:0' % name, '\n'.join(captured))
def testAddShouldUseWarningWhenUsedWithAdd(self):
def add(h):
_ = h + 1
self._testAddShouldUseWarningWhenUsed(add)
self._testAddShouldUseWarningWhenUsed(add, name='blah_add')
gc.collect()
self.assertFalse(gc.garbage)
def testAddShouldUseWarningWhenUsedWithGetName(self):
def get_name(h):
_ = h.name
self._testAddShouldUseWarningWhenUsed(get_name)
self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name')
gc.collect()
self.assertFalse(gc.garbage)
def testShouldUseResult(self):
@tf_should_use.should_use_result
def return_const(value):
return constant_op.constant(value, name='blah')
return constant_op.constant(value, name='blah2')
captured = []
with reroute_error(captured):
return_const(0.0)
self.assertIn('Object was never used', '\n'.join(captured))
self.assertIn('blah:0', '\n'.join(captured))
self.assertIn('blah2:0', '\n'.join(captured))
self.assertIn('return_const', '\n'.join(captured))
gc.collect()
self.assertFalse(gc.garbage)
def testShouldUseResultWhenNotReallyUsed(self):
@tf_should_use.should_use_result
def return_const(value):
return constant_op.constant(value, name='blah')
return constant_op.constant(value, name='blah3')
captured = []
with reroute_error(captured):
with self.test_session():
@ -100,8 +109,10 @@ class TfShouldUseTest(test.TestCase):
v = constant_op.constant(1.0, name='meh')
v.eval()
self.assertIn('Object was never used', '\n'.join(captured))
self.assertIn('blah:0', '\n'.join(captured))
self.assertIn('blah3:0', '\n'.join(captured))
self.assertIn('return_const', '\n'.join(captured))
gc.collect()
self.assertFalse(gc.garbage)
if __name__ == '__main__':

View File

@ -85,3 +85,6 @@ pip2 install mock
pip2 install portpicker
pip3 install portpicker
pip2 install backports.weakref==1.0rc1
pip3 install backports.weakref==1.0rc1

View File

@ -89,3 +89,6 @@ pip3.5 install wheel==0.29.0
pip3.5 install portpicker
pip3.5 install werkzeug
pip3.5 install backports.weakref==1.0rc1

View File

@ -39,6 +39,7 @@ REQUIRED_PACKAGES = [
'html5lib == 0.9999999', # identical to 1.0b8
'markdown == 2.2.0',
'bleach == 1.5.0',
'backports.weakref == 1.0rc1',
]
project_name = 'tensorflow'