Introduce a memory leak detection utility.
PiperOrigin-RevId: 294780979 Change-Id: I27b18224dbb49535beaa7ce81906f5686cebb7ef
This commit is contained in:
parent
f04915d2c8
commit
52281ba252
@ -1915,6 +1915,7 @@ py_library(
|
||||
":errors",
|
||||
":framework_for_generated_wrappers",
|
||||
":gpu_util",
|
||||
":memory_checker",
|
||||
":platform",
|
||||
":platform_test",
|
||||
":pywrap_tf_session",
|
||||
@ -2001,6 +2002,25 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "memory_checker",
|
||||
srcs = [
|
||||
"framework/memory_checker.py",
|
||||
"framework/python_memory_checker.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":_python_memory_checker_helper"],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_python_memory_checker_helper",
|
||||
srcs = ["framework/python_memory_checker_helper.cc"],
|
||||
module_name = "_python_memory_checker_helper",
|
||||
deps = [
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "framework_constant_op_test",
|
||||
size = "small",
|
||||
@ -2522,6 +2542,22 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "framework_memory_checker_test",
|
||||
size = "medium",
|
||||
srcs = ["framework/memory_checker_test.py"],
|
||||
main = "framework/memory_checker_test.py",
|
||||
python_version = "PY3",
|
||||
shard_count = 8,
|
||||
tags = [
|
||||
"no_oss_py2",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "framework_dtypes_test",
|
||||
size = "small",
|
||||
|
141
tensorflow/python/framework/memory_checker.py
Normal file
141
tensorflow/python/framework/memory_checker.py
Normal file
@ -0,0 +1,141 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Memory leak detection utility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework.python_memory_checker import _PythonMemoryChecker
|
||||
from tensorflow.python.profiler.traceme import TraceMe
|
||||
from tensorflow.python.profiler.traceme import traceme_wrapper
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
try:
|
||||
from tensorflow.python.platform.cpp_memory_checker import _CppMemoryChecker # pylint:disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
_CppMemoryChecker = None
|
||||
|
||||
|
||||
def _get_test_name_best_effort():
|
||||
"""If available, return the current test name. Otherwise, `None`."""
|
||||
for stack in tf_inspect.stack():
|
||||
function_name = stack[3]
|
||||
if function_name.startswith('test'):
|
||||
try:
|
||||
class_name = stack[0].f_locals['self'].__class__.__name__
|
||||
return class_name + '.' + function_name
|
||||
except: # pylint:disable=bare-except
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# TODO(kkb): Also create decorator versions for convenience.
|
||||
class MemoryChecker(object):
|
||||
"""Memory leak detection class.
|
||||
|
||||
This is a utility class to detect Python and C++ memory leaks. It's intended
|
||||
for both testing and debugging. Basic usage:
|
||||
|
||||
>>> # MemoryChecker() context manager tracks memory status inside its scope.
|
||||
>>> with MemoryChecker() as memory_checker:
|
||||
>>> tensors = []
|
||||
>>> for _ in range(10):
|
||||
>>> # Simulating `tf.constant(1)` object leak every iteration.
|
||||
>>> tensors.append(tf.constant(1))
|
||||
>>>
|
||||
>>> # Take a memory snapshot for later analysis.
|
||||
>>> memory_checker.record_snapshot()
|
||||
>>>
|
||||
>>> # `report()` generates a html graph file showing allocations over
|
||||
>>> # snapshots per every stack trace.
|
||||
>>> memory_checker.report()
|
||||
>>>
|
||||
>>> # This assertion will detect `tf.constant(1)` object leak.
|
||||
>>> memory_checker.assert_no_leak_if_all_possibly_except_one()
|
||||
|
||||
`record_snapshot()` must be called once every iteration at the same location.
|
||||
This is because the detection algorithm relies on the assumption that if there
|
||||
is a leak, it's happening similarly on every snapshot.
|
||||
"""
|
||||
|
||||
@traceme_wrapper
|
||||
def __enter__(self):
|
||||
self._trace_me = TraceMe('with MemoryChecker():')
|
||||
self._trace_me.__enter__()
|
||||
self._python_memory_checker = _PythonMemoryChecker()
|
||||
if _CppMemoryChecker:
|
||||
self._cpp_memory_checker = _CppMemoryChecker(_get_test_name_best_effort())
|
||||
return self
|
||||
|
||||
@traceme_wrapper
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if _CppMemoryChecker:
|
||||
self._cpp_memory_checker.stop()
|
||||
self._trace_me.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
@traceme_wrapper
|
||||
def record_snapshot(self):
|
||||
"""Take a memory snapshot for later analysis.
|
||||
|
||||
`record_snapshot()` must be called once every iteration at the same
|
||||
location. This is because the detection algorithm relies on the assumption
|
||||
that if there is a leak, it's happening similarly on every snapshot.
|
||||
|
||||
The recommended number of `record_snapshot()` call depends on the testing
|
||||
code complexity and the allcoation pattern.
|
||||
"""
|
||||
self._python_memory_checker.record_snapshot()
|
||||
if _CppMemoryChecker:
|
||||
self._cpp_memory_checker.record_snapshot()
|
||||
|
||||
@traceme_wrapper
|
||||
def report(self):
|
||||
"""Generates a html graph file showing allocations over snapshots.
|
||||
|
||||
It create a temporary directory and put all the output files there.
|
||||
If this is running under Google internal testing infra, it will use the
|
||||
directory provided the infra instead.
|
||||
"""
|
||||
self._python_memory_checker.report()
|
||||
if _CppMemoryChecker:
|
||||
self._cpp_memory_checker.report()
|
||||
|
||||
@traceme_wrapper
|
||||
def assert_no_leak_if_all_possibly_except_one(self):
|
||||
"""Raises an exception if a leak is detected.
|
||||
|
||||
This algorithm classifies a series of allocations as a leak if it's the same
|
||||
type(Python) orit happens at the same stack trace(C++) at every snapshot,
|
||||
but possibly except one snapshot.
|
||||
"""
|
||||
|
||||
self._python_memory_checker.assert_no_leak_if_all_possibly_except_one()
|
||||
if _CppMemoryChecker:
|
||||
self._cpp_memory_checker.assert_no_leak_if_all_possibly_except_one()
|
||||
|
||||
@traceme_wrapper
|
||||
def assert_no_new_python_objects(self, threshold=None):
|
||||
"""Raises an exception if there are new Python objects created.
|
||||
|
||||
It computes the number of new Python objects per type using the first and
|
||||
the last snapshots.
|
||||
|
||||
Args:
|
||||
threshold: A dictionary of [Type name string], [count] pair. It won't
|
||||
raise an exception if the new Python objects are under this threshold.
|
||||
"""
|
||||
self._python_memory_checker.assert_no_new_objects(threshold=threshold)
|
176
tensorflow/python/framework/memory_checker_test.py
Normal file
176
tensorflow/python/framework/memory_checker_test.py
Normal file
@ -0,0 +1,176 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework.memory_checker import MemoryChecker
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MemoryCheckerTest(test.TestCase):
|
||||
|
||||
def testNoLeakEmpty(self):
|
||||
with MemoryChecker() as memory_checker:
|
||||
memory_checker.record_snapshot()
|
||||
memory_checker.record_snapshot()
|
||||
memory_checker.record_snapshot()
|
||||
memory_checker.record_snapshot()
|
||||
|
||||
memory_checker.report()
|
||||
memory_checker.assert_no_leak_if_all_possibly_except_one()
|
||||
|
||||
def testNoLeak1(self):
|
||||
with MemoryChecker() as memory_checker:
|
||||
memory_checker.record_snapshot()
|
||||
x = constant_op.constant(1) # pylint: disable=unused-variable
|
||||
memory_checker.record_snapshot()
|
||||
memory_checker.record_snapshot()
|
||||
memory_checker.record_snapshot()
|
||||
|
||||
memory_checker.report()
|
||||
memory_checker.assert_no_leak_if_all_possibly_except_one()
|
||||
|
||||
def testNoLeak2(self):
|
||||
with MemoryChecker() as memory_checker:
|
||||
tensors = []
|
||||
for i in range(10):
|
||||
if i not in (5, 7):
|
||||
tensors.append(constant_op.constant(1))
|
||||
memory_checker.record_snapshot()
|
||||
|
||||
memory_checker.report()
|
||||
memory_checker.assert_no_leak_if_all_possibly_except_one()
|
||||
|
||||
def testLeak1(self):
|
||||
with MemoryChecker() as memory_checker:
|
||||
memory_checker.record_snapshot()
|
||||
x = constant_op.constant(1) # pylint: disable=unused-variable
|
||||
memory_checker.record_snapshot()
|
||||
y = constant_op.constant(1) # pylint: disable=unused-variable
|
||||
memory_checker.record_snapshot()
|
||||
memory_checker.record_snapshot()
|
||||
|
||||
memory_checker.report()
|
||||
with self.assertRaises(AssertionError):
|
||||
memory_checker.assert_no_leak_if_all_possibly_except_one()
|
||||
|
||||
def testLeak2(self):
|
||||
with MemoryChecker() as memory_checker:
|
||||
tensors = []
|
||||
for _ in range(10):
|
||||
tensors.append(constant_op.constant(1))
|
||||
memory_checker.record_snapshot()
|
||||
|
||||
memory_checker.report()
|
||||
with self.assertRaises(AssertionError):
|
||||
memory_checker.assert_no_leak_if_all_possibly_except_one()
|
||||
|
||||
def testNoNewPythonObjectsEmpty(self):
|
||||
with MemoryChecker() as memory_checker:
|
||||
memory_checker.record_snapshot()
|
||||
memory_checker.record_snapshot()
|
||||
|
||||
# TODO(kkb): `{'builtins.weakref': 1}` is unexpected, locate and fix it.
|
||||
memory_checker.assert_no_new_python_objects(
|
||||
threshold={'builtins.weakref': 1})
|
||||
|
||||
def testNewPythonObjects(self):
|
||||
with MemoryChecker() as memory_checker:
|
||||
memory_checker.record_snapshot()
|
||||
x = constant_op.constant(1) # pylint: disable=unused-variable
|
||||
memory_checker.record_snapshot()
|
||||
|
||||
with self.assertRaisesRegexp(AssertionError, 'New Python objects'):
|
||||
memory_checker.assert_no_new_python_objects()
|
||||
|
||||
def testNewPythonObjectBelowThreshold(self):
|
||||
|
||||
class Foo(object):
|
||||
pass
|
||||
|
||||
with MemoryChecker() as memory_checker:
|
||||
memory_checker.record_snapshot()
|
||||
foo = Foo() # pylint: disable=unused-variable
|
||||
memory_checker.record_snapshot()
|
||||
|
||||
# TODO(kkb): `{'builtins.weakref': 1}` is unexpected, locate and fix it.
|
||||
memory_checker.assert_no_new_python_objects(threshold={
|
||||
'__main__.Foo': 1,
|
||||
'builtins.weakref': 1
|
||||
})
|
||||
memory_checker.assert_no_new_python_objects(threshold={
|
||||
'__main__.Foo': 2,
|
||||
'builtins.weakref': 1
|
||||
})
|
||||
|
||||
def testKerasBasic(self):
|
||||
# TODO(kkb): Fix the the slowness on Forge.
|
||||
self.skipTest('This test is too slow on Forge so disabled for now.')
|
||||
|
||||
x = array_ops.zeros([1, 1])
|
||||
y = constant_op.constant([[3]])
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(1, input_dim=1))
|
||||
model.compile(loss='mean_squared_error')
|
||||
|
||||
with MemoryChecker() as memory_checker:
|
||||
for _ in range(10):
|
||||
model.fit(x, y)
|
||||
model.evaluate(x, y)
|
||||
memory_checker.record_snapshot()
|
||||
|
||||
memory_checker.report()
|
||||
memory_checker.assert_no_leak_if_all_possibly_except_one()
|
||||
|
||||
def testKerasAdvanced(self):
|
||||
# TODO(kkb): Fix the the slowness on Forge.
|
||||
self.skipTest('This test is too slow on Forge so disabled for now.')
|
||||
|
||||
# A real world example taken from the following.
|
||||
# https://github.com/tensorflow/tensorflow/issues/32500
|
||||
# b/142150794
|
||||
|
||||
with MemoryChecker() as memory_checker:
|
||||
rows = 6
|
||||
columns = 7
|
||||
model = keras.Sequential([
|
||||
keras.layers.Flatten(input_shape=[rows * columns, 3]),
|
||||
keras.layers.Dense(7, input_shape=[rows * columns * 3]),
|
||||
])
|
||||
|
||||
model.compile(
|
||||
optimizer=keras.optimizer_v2.gradient_descent.SGD(lr=0.01),
|
||||
loss='mean_squared_error',
|
||||
metrics=['accuracy'])
|
||||
states = [[1] * rows * columns for _ in range(20)]
|
||||
f = array_ops.one_hot(states, dtype='float32', depth=3)
|
||||
|
||||
for _ in range(20):
|
||||
model.predict(f, steps=10)
|
||||
memory_checker.record_snapshot()
|
||||
|
||||
memory_checker.report()
|
||||
memory_checker.assert_no_leak_if_all_possibly_except_one()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
152
tensorflow/python/framework/python_memory_checker.py
Normal file
152
tensorflow/python/framework/python_memory_checker.py
Normal file
@ -0,0 +1,152 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python memory leak detection utility.
|
||||
|
||||
Please don't use this class directly. Instead, use `MemoryChecker` wrapper.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import gc
|
||||
|
||||
from tensorflow.python import _python_memory_checker_helper
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.profiler.traceme import traceme_wrapper
|
||||
|
||||
|
||||
def _get_typename(obj):
|
||||
"""Return human readable pretty type name string."""
|
||||
objtype = type(obj)
|
||||
name = objtype.__name__
|
||||
module = getattr(objtype, '__module__', None)
|
||||
if module:
|
||||
return '{}.{}'.format(module, name)
|
||||
else:
|
||||
return name
|
||||
|
||||
|
||||
def _create_python_object_snapshot():
|
||||
gc.collect()
|
||||
all_objects = gc.get_objects()
|
||||
result = collections.defaultdict(set)
|
||||
for obj in all_objects:
|
||||
result[_get_typename(obj)].add(id(obj))
|
||||
return result
|
||||
|
||||
|
||||
def _snapshot_diff(old_snapshot, new_snapshot, exclude_ids):
|
||||
result = collections.Counter()
|
||||
for new_name, new_ids in new_snapshot.items():
|
||||
old_ids = old_snapshot[new_name]
|
||||
result[new_name] = len(new_ids - exclude_ids) - len(old_ids - exclude_ids)
|
||||
|
||||
# This removes zero or negative value entries.
|
||||
result += collections.Counter()
|
||||
return result
|
||||
|
||||
|
||||
class _PythonMemoryChecker(object):
|
||||
"""Python memory leak detection class."""
|
||||
|
||||
def __init__(self):
|
||||
self._snapshots = []
|
||||
|
||||
@traceme_wrapper
|
||||
def record_snapshot(self):
|
||||
# Function called using `mark_stack_trace_and_call` will have
|
||||
# "_python_memory_checker_helper" string in the C++ stack trace. This will
|
||||
# be used to filter out C++ memory allocations caused by this function,
|
||||
# because we are not interested in detecting memory growth caused by memory
|
||||
# checker itself.
|
||||
_python_memory_checker_helper.mark_stack_trace_and_call(
|
||||
lambda: self._snapshots.append(_create_python_object_snapshot()))
|
||||
|
||||
@traceme_wrapper
|
||||
def report(self):
|
||||
# TODO(kkb): Implement.
|
||||
pass
|
||||
|
||||
@traceme_wrapper
|
||||
def assert_no_leak_if_all_possibly_except_one(self):
|
||||
"""Raises an exception if a leak is detected.
|
||||
|
||||
This algorithm classifies a series of allocations as a leak if it's the same
|
||||
type at every snapshot, but possibly except one snapshot.
|
||||
"""
|
||||
|
||||
snapshot_diffs = []
|
||||
for i in range(0, len(self._snapshots) - 1):
|
||||
snapshot_diffs.append(self._snapshot_diff(i, i + 1))
|
||||
|
||||
allocation_counter = collections.Counter()
|
||||
for diff in snapshot_diffs:
|
||||
for name, count in diff.items():
|
||||
if count > 0:
|
||||
allocation_counter[name] += 1
|
||||
|
||||
leaking_object_names = {
|
||||
name for name, count in allocation_counter.items()
|
||||
if count >= len(snapshot_diffs) - 1
|
||||
}
|
||||
|
||||
if leaking_object_names:
|
||||
object_list_to_print = '\n'.join(
|
||||
[' - ' + name for name in leaking_object_names])
|
||||
raise AssertionError(
|
||||
'These Python objects were allocated every snapshot '
|
||||
'possibly except one.\n\n{}'.format(object_list_to_print))
|
||||
|
||||
@traceme_wrapper
|
||||
def assert_no_new_objects(self, threshold=None):
|
||||
"""Assert no new Python objects."""
|
||||
|
||||
if not threshold:
|
||||
threshold = {}
|
||||
|
||||
count_diff = self._snapshot_diff(0, -1)
|
||||
original_count_diff = copy.deepcopy(count_diff)
|
||||
count_diff.subtract(collections.Counter(threshold))
|
||||
|
||||
if max(count_diff.values() or [0]) > 0:
|
||||
raise AssertionError('New Python objects exceeded the threshold.\n'
|
||||
'Python object threshold:\n'
|
||||
'{}\n\n'
|
||||
'New Python objects:\n{}'.format(
|
||||
threshold, original_count_diff.most_common()))
|
||||
elif min(count_diff.values(), default=0) < 0:
|
||||
logging.warning('New Python objects were less than the threshold.\n'
|
||||
'Python object threshold:\n'
|
||||
'{}\n\n'
|
||||
'New Python objects:\n'
|
||||
'{}'.format(threshold, original_count_diff.most_common()))
|
||||
|
||||
@traceme_wrapper
|
||||
def _snapshot_diff(self, old_index, new_index):
|
||||
return _snapshot_diff(self._snapshots[old_index],
|
||||
self._snapshots[new_index],
|
||||
self._get_internal_object_ids())
|
||||
|
||||
@traceme_wrapper
|
||||
def _get_internal_object_ids(self):
|
||||
ids = set()
|
||||
for snapshot in self._snapshots:
|
||||
ids.add(id(snapshot))
|
||||
for v in snapshot.values():
|
||||
ids.add(id(v))
|
||||
return ids
|
22
tensorflow/python/framework/python_memory_checker_helper.cc
Normal file
22
tensorflow/python/framework/python_memory_checker_helper.cc
Normal file
@ -0,0 +1,22 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");;
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(_python_memory_checker_helper, m) {
|
||||
m.def("mark_stack_trace_and_call", [](py::function func) { func(); });
|
||||
};
|
@ -50,7 +50,10 @@ class TraceMe(object):
|
||||
|
||||
|
||||
def traceme_wrapper(func):
|
||||
name = func.__qualname__
|
||||
name = getattr(func, '__qualname__', None)
|
||||
if not name:
|
||||
name = func.__name__
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
with TraceMe(name):
|
||||
return func(*args, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user