Delete ScopedTFStatus to avoid leaking it for long running trainers(1+day).

PiperOrigin-RevId: 168259652
This commit is contained in:
A. Unique TensorFlower 2017-09-11 11:33:20 -07:00 committed by TensorFlower Gardener
parent e15f4cae23
commit 2356c0ff46
2 changed files with 27 additions and 5 deletions

View File

@ -459,8 +459,13 @@ def _make_specific_exception(node_def, op, message, error_code):
def raise_exception_on_not_ok_status():
status = c_api_util.ScopedTFStatus()
yield status.status
if c_api.TF_GetCode(status) != 0:
raise _make_specific_exception(
None, None,
compat.as_text(c_api.TF_Message(status)),
c_api.TF_GetCode(status))
try:
if c_api.TF_GetCode(status.status) != 0:
raise _make_specific_exception(
None, None,
compat.as_text(c_api.TF_Message(status.status)),
c_api.TF_GetCode(status.status))
# Delete the underlying status object from memory otherwise it stays alive
# as there is a reference to status from this from the traceback due to raise.
finally:
del status

View File

@ -18,16 +18,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gc
import warnings
from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class ErrorsTest(test.TestCase):
def _CountReferences(self, typeof):
return len([o for o in gc.get_objects() if isinstance(o, typeof)])
def testUniqueClassForEachErrorCode(self):
for error_code, exc_type in [
(errors.CANCELLED, errors_impl.CancelledError),
@ -80,6 +87,16 @@ class ErrorsTest(test.TestCase):
self.assertTrue("Unknown error code: 37" in str(w[0].message))
self.assertTrue(isinstance(exc, errors_impl.OpError))
def testStatusDoesNotLeak(self):
try:
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.DeleteFile(
compat.as_bytes("/DOES_NOT_EXIST/"), status)
except:
pass
gc.collect()
self.assertEqual(0, self._CountReferences(c_api_util.ScopedTFStatus))
if __name__ == "__main__":
test.main()