diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py index 88df3351e66..05c99856d27 100644 --- a/tensorflow/python/util/tf_should_use.py +++ b/tensorflow/python/util/tf_should_use.py @@ -17,14 +17,52 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import functools +import itertools import traceback import types +import six # pylint: disable=unused-import + +from backports import weakref # pylint: disable=g-bad-import-order + from tensorflow.python.platform import tf_logging from tensorflow.python.util import tf_decorator +class _RefInfoField( + collections.namedtuple( + '_RefInfoField', ('type_', 'repr_', 'creation_stack', 'object_used'))): + pass + + +# Thread-safe up to int32max/2 thanks to python's GIL; and may be safe even for +# higher values in Python 3.4+. We don't expect to ever count higher than this. +# https://mail.python.org/pipermail/python-list/2005-April/342279.html +_REF_ITER = itertools.count() + +# Dictionary mapping id(obj) => _RefInfoField. +_REF_INFO = {} + + +def _deleted(obj_id, fatal_error): + obj = _REF_INFO[obj_id] + del _REF_INFO[obj_id] + if not obj.object_used: + if fatal_error: + logger = tf_logging.fatal + else: + logger = tf_logging.error + logger( + '==================================\n' + 'Object was never used (type %s):\n%s\nIf you want to mark it as ' + 'used call its "mark_used()" method.\nIt was originally created ' + 'here:\n%s\n' + '==================================' % + (obj.type_, obj.repr_, obj.creation_stack)) + + def _add_should_use_warning(x, fatal_error=False): """Wraps object x so that if it is never used, a warning is logged. @@ -39,14 +77,14 @@ def _add_should_use_warning(x, fatal_error=False): """ if x is None: # special corner case where x is None return x - has_been_used = getattr(x, '_tf_object_has_been_used', None) - if has_been_used is not None: - x._tf_object_has_been_used = has_been_used # pylint: disable=protected-access + if hasattr(x, '_tf_ref_id'): # this is already a TFShouldUseWarningWrapper return x def override_method(method): def fn(self, *args, **kwargs): - self._tf_object_has_been_used = True # pylint: disable=protected-access + # pylint: disable=protected-access + _REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace( + object_used=True) return method(self, *args, **kwargs) return fn @@ -55,38 +93,36 @@ def _add_should_use_warning(x, fatal_error=False): def __init__(self, true_self): self.__dict__ = true_self.__dict__ - stack = [x.strip() for x in traceback.format_stack()] + stack = [s.strip() for s in traceback.format_stack()] # Remove top three stack entries from adding the wrapper - self._tf_object_creation_stack = '\n'.join(stack[:-3]) - self._tf_object_has_been_used = False + self.creation_stack = '\n'.join(stack[:-3]) + self._tf_ref_id = next(_REF_ITER) + _REF_INFO[self._tf_ref_id] = _RefInfoField( + type_=type(x), + repr_=repr(x), + creation_stack=stack, + object_used=False) + + # Create a finalizer for self, which will be called when self is + # garbage collected. Can't add self as the args because the + # loop will break garbage collection. We keep track of + # ourselves via python ids. + weakref.finalize(self, _deleted, self._tf_ref_id, fatal_error) # Not sure why this pylint warning is being used; this is not an # old class form. # pylint: disable=super-on-old-class def __getattribute__(self, name): - if name != '_tf_object_has_been_used': - self._tf_object_has_been_used = True + if name == '_tf_ref_id': + return super(TFShouldUseWarningWrapper, self).__getattribute__(name) + if self._tf_ref_id in _REF_INFO: + _REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace( + object_used=True) return super(TFShouldUseWarningWrapper, self).__getattribute__(name) - def __del__(self): - if not self._tf_object_has_been_used: - if fatal_error: - logger = tf_logging.fatal - else: - logger = tf_logging.error - logger( - '==================================\n' - 'Object was never used (type %s):\n%s\nIf you want to mark it as ' - 'used call its "mark_used()" method.\nIt was originally created ' - 'here:\n%s\n' - '==================================' % - (type(x), x, self._tf_object_creation_stack)) - - if hasattr(super(TFShouldUseWarningWrapper, self), '__del__'): - return super(TFShouldUseWarningWrapper, self).__del__() - def mark_used(self, *args, **kwargs): - self._tf_object_has_been_used = True + _REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace( + object_used=True) if hasattr(super(TFShouldUseWarningWrapper, self), 'mark_used'): return super(TFShouldUseWarningWrapper, self).mark_used(*args, **kwargs) # pylint: enable=super-on-old-class @@ -102,7 +138,8 @@ def _add_should_use_warning(x, fatal_error=False): wrapped = TFShouldUseWarningWrapper(x) wrapped.__doc__ = x.__doc__ # functools.wraps fails on some objects. - wrapped._tf_object_has_been_used = False # pylint: disable=protected-access + ref_id = wrapped._tf_ref_id # pylint: disable=protected-access + _REF_INFO[ref_id] = _REF_INFO[ref_id]._replace(object_used=False) return wrapped diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py index 71d48e3dde3..c8268744004 100644 --- a/tensorflow/python/util/tf_should_use_test.py +++ b/tensorflow/python/util/tf_should_use_test.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import contextlib +import gc import sys from tensorflow.python.framework import constant_op @@ -45,7 +46,7 @@ def reroute_error(captured): class TfShouldUseTest(test.TestCase): def testAddShouldUseWarningWhenNotUsed(self): - c = constant_op.constant(0, name='blah') + c = constant_op.constant(0, name='blah0') captured = [] with reroute_error(captured): def in_this_function(): @@ -53,44 +54,52 @@ class TfShouldUseTest(test.TestCase): del h in_this_function() self.assertIn('Object was never used', '\n'.join(captured)) - self.assertIn('blah:0', '\n'.join(captured)) + self.assertIn('blah0:0', '\n'.join(captured)) self.assertIn('in_this_function', '\n'.join(captured)) + gc.collect() + self.assertFalse(gc.garbage) - def _testAddShouldUseWarningWhenUsed(self, fn): - c = constant_op.constant(0, name='blah') + def _testAddShouldUseWarningWhenUsed(self, fn, name): + c = constant_op.constant(0, name=name) captured = [] with reroute_error(captured): h = tf_should_use._add_should_use_warning(c) fn(h) del h self.assertNotIn('Object was never used', '\n'.join(captured)) - self.assertNotIn('blah:0', '\n'.join(captured)) + self.assertNotIn('%s:0' % name, '\n'.join(captured)) def testAddShouldUseWarningWhenUsedWithAdd(self): def add(h): _ = h + 1 - self._testAddShouldUseWarningWhenUsed(add) + self._testAddShouldUseWarningWhenUsed(add, name='blah_add') + gc.collect() + self.assertFalse(gc.garbage) def testAddShouldUseWarningWhenUsedWithGetName(self): def get_name(h): _ = h.name - self._testAddShouldUseWarningWhenUsed(get_name) + self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name') + gc.collect() + self.assertFalse(gc.garbage) def testShouldUseResult(self): @tf_should_use.should_use_result def return_const(value): - return constant_op.constant(value, name='blah') + return constant_op.constant(value, name='blah2') captured = [] with reroute_error(captured): return_const(0.0) self.assertIn('Object was never used', '\n'.join(captured)) - self.assertIn('blah:0', '\n'.join(captured)) + self.assertIn('blah2:0', '\n'.join(captured)) self.assertIn('return_const', '\n'.join(captured)) + gc.collect() + self.assertFalse(gc.garbage) def testShouldUseResultWhenNotReallyUsed(self): @tf_should_use.should_use_result def return_const(value): - return constant_op.constant(value, name='blah') + return constant_op.constant(value, name='blah3') captured = [] with reroute_error(captured): with self.test_session(): @@ -100,8 +109,10 @@ class TfShouldUseTest(test.TestCase): v = constant_op.constant(1.0, name='meh') v.eval() self.assertIn('Object was never used', '\n'.join(captured)) - self.assertIn('blah:0', '\n'.join(captured)) + self.assertIn('blah3:0', '\n'.join(captured)) self.assertIn('return_const', '\n'.join(captured)) + gc.collect() + self.assertFalse(gc.garbage) if __name__ == '__main__': diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index b8f9fc84539..8768852dc7e 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -85,3 +85,6 @@ pip2 install mock pip2 install portpicker pip3 install portpicker + +pip2 install backports.weakref==1.0rc1 +pip3 install backports.weakref==1.0rc1 diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh index e7e2d256cd9..edfc4e3a98f 100755 --- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh @@ -89,3 +89,6 @@ pip3.5 install wheel==0.29.0 pip3.5 install portpicker pip3.5 install werkzeug + +pip3.5 install backports.weakref==1.0rc1 + diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index ae6516db891..0ce6d729069 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -39,6 +39,7 @@ REQUIRED_PACKAGES = [ 'html5lib == 0.9999999', # identical to 1.0b8 'markdown == 2.2.0', 'bleach == 1.5.0', + 'backports.weakref == 1.0rc1', ] project_name = 'tensorflow'