Delete ScopedTFStatus to avoid leaking it for long running trainers(1+day).
PiperOrigin-RevId: 168259652
This commit is contained in:
parent
e15f4cae23
commit
2356c0ff46
@ -459,8 +459,13 @@ def _make_specific_exception(node_def, op, message, error_code):
|
||||
def raise_exception_on_not_ok_status():
|
||||
status = c_api_util.ScopedTFStatus()
|
||||
yield status.status
|
||||
if c_api.TF_GetCode(status) != 0:
|
||||
raise _make_specific_exception(
|
||||
None, None,
|
||||
compat.as_text(c_api.TF_Message(status)),
|
||||
c_api.TF_GetCode(status))
|
||||
try:
|
||||
if c_api.TF_GetCode(status.status) != 0:
|
||||
raise _make_specific_exception(
|
||||
None, None,
|
||||
compat.as_text(c_api.TF_Message(status.status)),
|
||||
c_api.TF_GetCode(status.status))
|
||||
# Delete the underlying status object from memory otherwise it stays alive
|
||||
# as there is a reference to status from this from the traceback due to raise.
|
||||
finally:
|
||||
del status
|
||||
|
@ -18,16 +18,23 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gc
|
||||
import warnings
|
||||
|
||||
from tensorflow.core.lib.core import error_codes_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class ErrorsTest(test.TestCase):
|
||||
|
||||
def _CountReferences(self, typeof):
|
||||
return len([o for o in gc.get_objects() if isinstance(o, typeof)])
|
||||
|
||||
def testUniqueClassForEachErrorCode(self):
|
||||
for error_code, exc_type in [
|
||||
(errors.CANCELLED, errors_impl.CancelledError),
|
||||
@ -80,6 +87,16 @@ class ErrorsTest(test.TestCase):
|
||||
self.assertTrue("Unknown error code: 37" in str(w[0].message))
|
||||
self.assertTrue(isinstance(exc, errors_impl.OpError))
|
||||
|
||||
def testStatusDoesNotLeak(self):
|
||||
try:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.DeleteFile(
|
||||
compat.as_bytes("/DOES_NOT_EXIST/"), status)
|
||||
except:
|
||||
pass
|
||||
gc.collect()
|
||||
self.assertEqual(0, self._CountReferences(c_api_util.ScopedTFStatus))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user