From 2356c0ff4630284f6168c3edbe43ee6a0a77d200 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 11 Sep 2017 11:33:20 -0700 Subject: [PATCH] Delete ScopedTFStatus to avoid leaking it for long running trainers(1+day). PiperOrigin-RevId: 168259652 --- tensorflow/python/framework/errors_impl.py | 15 ++++++++++----- tensorflow/python/framework/errors_test.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py index 9b1f0a0cfb2..fa956c3d292 100644 --- a/tensorflow/python/framework/errors_impl.py +++ b/tensorflow/python/framework/errors_impl.py @@ -459,8 +459,13 @@ def _make_specific_exception(node_def, op, message, error_code): def raise_exception_on_not_ok_status(): status = c_api_util.ScopedTFStatus() yield status.status - if c_api.TF_GetCode(status) != 0: - raise _make_specific_exception( - None, None, - compat.as_text(c_api.TF_Message(status)), - c_api.TF_GetCode(status)) + try: + if c_api.TF_GetCode(status.status) != 0: + raise _make_specific_exception( + None, None, + compat.as_text(c_api.TF_Message(status.status)), + c_api.TF_GetCode(status.status)) + # Delete the underlying status object from memory otherwise it stays alive + # as there is a reference to status from this from the traceback due to raise. + finally: + del status diff --git a/tensorflow/python/framework/errors_test.py b/tensorflow/python/framework/errors_test.py index 4b96b49b35e..8ed4492d7b6 100644 --- a/tensorflow/python/framework/errors_test.py +++ b/tensorflow/python/framework/errors_test.py @@ -18,16 +18,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gc import warnings from tensorflow.core.lib.core import error_codes_pb2 +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.framework import c_api_util from tensorflow.python.framework import errors from tensorflow.python.framework import errors_impl from tensorflow.python.platform import test +from tensorflow.python.util import compat class ErrorsTest(test.TestCase): + def _CountReferences(self, typeof): + return len([o for o in gc.get_objects() if isinstance(o, typeof)]) + def testUniqueClassForEachErrorCode(self): for error_code, exc_type in [ (errors.CANCELLED, errors_impl.CancelledError), @@ -80,6 +87,16 @@ class ErrorsTest(test.TestCase): self.assertTrue("Unknown error code: 37" in str(w[0].message)) self.assertTrue(isinstance(exc, errors_impl.OpError)) + def testStatusDoesNotLeak(self): + try: + with errors.raise_exception_on_not_ok_status() as status: + pywrap_tensorflow.DeleteFile( + compat.as_bytes("/DOES_NOT_EXIST/"), status) + except: + pass + gc.collect() + self.assertEqual(0, self._CountReferences(c_api_util.ScopedTFStatus)) + if __name__ == "__main__": test.main()