Delete ScopedTFStatus to avoid leaking it for long running trainers(1+day).
PiperOrigin-RevId: 168259652
This commit is contained in:
parent
e15f4cae23
commit
2356c0ff46
@ -459,8 +459,13 @@ def _make_specific_exception(node_def, op, message, error_code):
|
|||||||
def raise_exception_on_not_ok_status():
|
def raise_exception_on_not_ok_status():
|
||||||
status = c_api_util.ScopedTFStatus()
|
status = c_api_util.ScopedTFStatus()
|
||||||
yield status.status
|
yield status.status
|
||||||
if c_api.TF_GetCode(status) != 0:
|
try:
|
||||||
raise _make_specific_exception(
|
if c_api.TF_GetCode(status.status) != 0:
|
||||||
None, None,
|
raise _make_specific_exception(
|
||||||
compat.as_text(c_api.TF_Message(status)),
|
None, None,
|
||||||
c_api.TF_GetCode(status))
|
compat.as_text(c_api.TF_Message(status.status)),
|
||||||
|
c_api.TF_GetCode(status.status))
|
||||||
|
# Delete the underlying status object from memory otherwise it stays alive
|
||||||
|
# as there is a reference to status from this from the traceback due to raise.
|
||||||
|
finally:
|
||||||
|
del status
|
||||||
|
@ -18,16 +18,23 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import gc
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from tensorflow.core.lib.core import error_codes_pb2
|
from tensorflow.core.lib.core import error_codes_pb2
|
||||||
|
from tensorflow.python import pywrap_tensorflow
|
||||||
|
from tensorflow.python.framework import c_api_util
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
|
|
||||||
class ErrorsTest(test.TestCase):
|
class ErrorsTest(test.TestCase):
|
||||||
|
|
||||||
|
def _CountReferences(self, typeof):
|
||||||
|
return len([o for o in gc.get_objects() if isinstance(o, typeof)])
|
||||||
|
|
||||||
def testUniqueClassForEachErrorCode(self):
|
def testUniqueClassForEachErrorCode(self):
|
||||||
for error_code, exc_type in [
|
for error_code, exc_type in [
|
||||||
(errors.CANCELLED, errors_impl.CancelledError),
|
(errors.CANCELLED, errors_impl.CancelledError),
|
||||||
@ -80,6 +87,16 @@ class ErrorsTest(test.TestCase):
|
|||||||
self.assertTrue("Unknown error code: 37" in str(w[0].message))
|
self.assertTrue("Unknown error code: 37" in str(w[0].message))
|
||||||
self.assertTrue(isinstance(exc, errors_impl.OpError))
|
self.assertTrue(isinstance(exc, errors_impl.OpError))
|
||||||
|
|
||||||
|
def testStatusDoesNotLeak(self):
|
||||||
|
try:
|
||||||
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
|
pywrap_tensorflow.DeleteFile(
|
||||||
|
compat.as_bytes("/DOES_NOT_EXIST/"), status)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
gc.collect()
|
||||||
|
self.assertEqual(0, self._CountReferences(c_api_util.ScopedTFStatus))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user