Move stack more of the trace mappers and filters to C++. Simplify the API surface of extract_stack, and clean up its docstring.

PiperOrigin-RevId: 348459279
Change-Id: I1ae8afde96a7220df5fb6a95ace62f7c940a1e71
This commit is contained in:
Dan Moldovan 2020-12-21 06:38:42 -08:00 committed by TensorFlower Gardener
parent ad8083afac
commit 9b83b647b9
8 changed files with 209 additions and 176 deletions

View File

@ -1591,6 +1591,7 @@ py_library(
deps = [
":platform",
"//tensorflow/python/util",
# TODO(mdan): Remove this once the transitive dependency is fixed.
"//tensorflow/python/util:tf_stack",
],
)

View File

@ -167,31 +167,35 @@ class StackTraceMapper(tf_stack.StackTraceMapper):
"""Remaps generated code to code it originated from."""
def __init__(self, converted_fn):
super().__init__()
self._source_map = converted_fn.ag_source_map
# This may be called repeatedly: once on entry, by the superclass, then by
# each child context manager.
self._cached_map = None
def get_effective_source_map(self):
effective_source_map = self._effective_source_map
if effective_source_map is None:
if self.parent is not None:
parent_map = self.parent.get_effective_source_map()
if self._cached_map is not None:
return self._cached_map
parent_map = self.parent.get_effective_source_map()
effective_source_map = {}
for loc, origin in self._source_map.items():
effective_source_map[(loc.filename, loc.lineno)] = (origin.loc.filename,
origin.loc.lineno,
origin.function_name)
for key, value in parent_map.items():
filename, lineno, _ = value
value_loc = origin_info.LineLocation(filename=filename, lineno=lineno)
if value_loc in self._source_map:
origin = self._source_map[value_loc]
effective_source_map[key] = (origin.loc.filename, origin.loc.lineno,
origin.function_name)
else:
parent_map = {}
effective_source_map[key] = value
effective_source_map = {}
for loc, origin in self._source_map.items():
effective_source_map[(loc.filename, loc.lineno)] = (
origin.loc.filename, origin.loc.lineno, origin.function_name)
for key, value in parent_map.items():
filename, lineno, _ = value
value_loc = origin_info.LineLocation(filename=filename, lineno=lineno)
if value_loc in self._source_map:
origin = self._source_map[value_loc]
effective_source_map[key] = (
origin.loc.filename, origin.loc.lineno, origin.function_name)
else:
effective_source_map[key] = value
self._effective_source_map = effective_source_map
self._cached_map = effective_source_map
return effective_source_map

View File

@ -23,9 +23,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import traceback
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import tf_stack
# Registry mechanism below is based on mapreduce.python.mrpython.Register.
@ -65,8 +66,8 @@ class Registry(object):
logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name)
# stack trace is [this_function, Register(), user_function,...]
# so the user function is #2.
stack = tf_stack.extract_stack(limit=3)
stack_index = min(2, len(stack)-1)
stack = traceback.extract_stack(limit=3)
stack_index = min(2, len(stack) - 1)
if stack_index >= 0:
location_tag = stack[stack_index]
else:

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.util import tf_stack
import inspect
class TraceableObject(object):
@ -51,26 +51,20 @@ class TraceableObject(object):
TraceableObject.HEURISTIC_USED if the offset was larger than the stack,
and TraceableObject.FAILURE if the stack was empty.
"""
# Offset is defined in "Args" as relative to the caller. We are one frame
retcode = self.SUCCESS
frame = inspect.currentframe()
# Offset is defined in "Args" as relative to the caller. We are one frame
# beyond the caller.
local_offset = offset + 1
frame_records = tf_stack.extract_stack(
limit=local_offset + 1)
if not frame_records:
return self.FAILURE
if len(frame_records) > local_offset:
frame = frame_records[len(frame_records) - (local_offset + 1)]
self.filename = frame.filename
self.lineno = frame.lineno
return self.SUCCESS
else:
# If the offset is too large then we use the largest offset possible,
# meaning we use the outermost stack frame at index 0.
frame = frame_records[0]
self.filename = frame.filename
self.lineno = frame.lineno
return self.HEURISTIC_USED
for _ in range(offset + 1):
parent = frame.f_back
if parent is None:
# If the offset is too large then we use the largest offset possible.
retcode = self.HEURISTIC_USED
break
frame = parent
self.filename = frame.f_code.co_filename
self.lineno = frame.f_lineno
return retcode
def copy_metadata(self):
"""Return a TraceableObject like this one, but without the object."""

View File

@ -19,12 +19,12 @@ from __future__ import division
from __future__ import print_function
import importlib
import inspect
import types
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util import tf_stack
from tensorflow.tools.compatibility import all_renames_v2
@ -41,11 +41,12 @@ def _call_location():
# We want to get stack frame 3 frames up from current frame,
# i.e. above __getattr__, _tfmw_add_deprecation_warning,
# and _call_location calls.
stack = tf_stack.extract_stack(limit=4)
if not stack: # should never happen as we're in a function
return 'UNKNOWN'
frame = stack[0]
return '{}:{}'.format(frame.filename, frame.lineno)
frame = inspect.currentframe()
for _ in range(4):
parent = frame.f_back
if parent is None:
break
return '{}:{}'.format(frame.f_code.co_filename, frame.f_lineno)
def contains_deprecation_decorator(decorators):

View File

@ -54,6 +54,30 @@ namespace {
namespace py = pybind11;
using SourceLoc = std::tuple<std::string, int>;
using SourceMap = absl::flat_hash_map<SourceLoc, StackFrame>;
using StringSet = absl::flat_hash_set<std::string>;
// Python wrapper for a SourceMap.
class PyBindSourceMap {
public:
PyBindSourceMap() : source_map_(std::make_shared<SourceMap>()) {}
// Shares ownership with whoever captures traces in the scope of this map.
std::shared_ptr<SourceMap> source_map_;
};
// Python wrapper for a FileSet.
class PyBindFileSet {
public:
PyBindFileSet() : file_set_(std::make_shared<StringSet>()) {}
// Shares ownership with whoever captures traces in the scope of this set.
std::shared_ptr<StringSet> file_set_;
};
// Returns contents of the line corresponding to the given frame.
//
// Precondition: must be holding Python GIL.
@ -98,36 +122,21 @@ std::string StackFrameToString(
class StackTraceWrapper : public AbstractStackTrace {
public:
StackTraceWrapper(StackTrace&& captured, const py::dict& source_map,
const py::set& filtered_filenames)
StackTraceWrapper(StackTrace&& captured,
const std::shared_ptr<SourceMap>& source_map,
const std::shared_ptr<StringSet>& filter)
: captured_(std::move(captured)),
source_map_(source_map),
filtered_filenames_(filtered_filenames) {}
filter_(filter) {}
explicit StackTraceWrapper(absl::Span<StackFrame const> stack_frames)
: stack_frames_cache_(std::vector<StackFrame>(stack_frames.begin(),
stack_frames.end())) {}
static StackTraceWrapper ExtractStack(const py::object& limit,
const py::list& mappers,
const py::list& filters) {
// In Python 3.X ``traceback.extract_stack`` allows ``limit`` to
// either be None or -1.
int casted_limit = limit.is_none() ? -1 : py::cast<ssize_t>(limit);
// Raise limit by one since we are dropping the last frame.
if (casted_limit != -1) casted_limit++;
const py::dict& source_map =
mappers.empty()
? py::dict()
: mappers[mappers.size() - 1].attr("get_effective_source_map")();
const py::set& filtered_filenames =
filters.empty()
? py::set()
: filters[filters.size() - 1].attr("get_filtered_filenames")();
return StackTraceWrapper{StackTrace::Capture(casted_limit), source_map,
filtered_filenames};
static StackTraceWrapper ExtractStack(
const std::shared_ptr<SourceMap>& source_map,
const std::shared_ptr<StringSet>& filter) {
return StackTraceWrapper{StackTrace::Capture(-1), source_map, filter};
}
absl::Span<StackFrame const> ToFrames() const override {
@ -169,6 +178,7 @@ class StackTraceWrapper : public AbstractStackTrace {
bool IsCacheGenerated() const { return stack_frames_cache_.has_value(); }
void GenerateCache() const {
// TODO(mdan): We don't really need random access; this can be removed.
if (stack_frames_cache_) {
return;
}
@ -214,7 +224,7 @@ class StackTraceWrapper : public AbstractStackTrace {
PyGILState_STATE state = PyGILState_Ensure();
captured_.Clear();
source_map_.reset();
filtered_filenames_.reset();
filter_.reset();
PyGILState_Release(state);
}
@ -237,33 +247,23 @@ class StackTraceWrapper : public AbstractStackTrace {
!absl::StrContains(file_name, "test.py");
}
absl::optional<StackFrame> StackTraceMapping(
std::pair<const char*, int> p) const {
if (source_map_->empty()) {
return absl::nullopt;
}
auto key = py::make_tuple(py::str(p.first), py::int_(p.second));
if (source_map_->contains(key)) {
const py::tuple& value = (*source_map_)[key];
return StackFrame{std::string(py::cast<py::str>(value[0])),
py::cast<py::int_>(value[1]),
std::string(py::cast<py::str>(value[2]))};
absl::optional<StackFrame> StackTraceMapping(SourceLoc loc) const {
if (source_map_->contains(loc)) {
return source_map_->at(loc);
}
return absl::nullopt;
}
bool StackTraceFiltering(const char* file_name) const {
return filtered_filenames_->contains(file_name);
return filter_->contains(file_name);
}
StackTrace captured_;
// Using optional to force destruction while we hold a GIL.
absl::optional<py::dict> source_map_;
absl::optional<py::set> filtered_filenames_;
std::shared_ptr<SourceMap> source_map_;
std::shared_ptr<StringSet> filter_;
// Using optional to force destruction while we hold a GIL.
mutable absl::optional<std::vector<StackFrame>> stack_frames_cache_;
mutable absl::optional<StackFrame> last_stack_frame_cache_;
};
@ -271,6 +271,42 @@ class StackTraceWrapper : public AbstractStackTrace {
} // namespace
PYBIND11_MODULE(_tf_stack, m) {
py::class_<PyBindSourceMap>(m, "PyBindSourceMap")
.def(py::init())
.def("update_to",
[](const PyBindSourceMap& self, const py::tuple& source_map) {
self.source_map_->clear();
for (const auto& item : source_map) {
const auto& tuple_item = py::cast<py::tuple>(item);
const auto& key = py::cast<py::tuple>(tuple_item[0]);
std::string&& k_filename = py::cast<std::string>(key[0]);
int k_lineno = py::cast<int>(key[1]);
const auto& value = py::cast<py::tuple>(tuple_item[1]);
std::string&& v_filename = py::cast<std::string>(value[0]);
int v_lineno = py::cast<int>(value[1]);
const auto& function_name_val = value[2];
std::string&& v_function_name =
function_name_val.is_none()
? ""
: py::cast<std::string>(function_name_val);
self.source_map_->emplace(
SourceLoc(k_filename, k_lineno),
StackFrame({v_filename, v_lineno, v_function_name}));
}
});
py::class_<PyBindFileSet>(m, "PyBindFileSet")
.def(py::init())
.def("update_to", [](const PyBindFileSet& self, const py::set& file_set) {
self.file_set_->clear();
for (const auto& item : file_set) {
self.file_set_->insert(py::cast<std::string>(item));
}
});
py::class_<StackFrame>(m, "StackFrame")
.def_property_readonly(
"filename",
@ -369,22 +405,22 @@ PYBIND11_MODULE(_tf_stack, m) {
m.def(
"extract_stack_for_node",
[](const py::object& limit, const py::list& mappers,
const py::list& filters,
[](const PyBindSourceMap& source_map, const PyBindFileSet& file_set,
TF_Operation* op) -> const AbstractStackTrace& {
Node* node = reinterpret_cast<Node*>(op);
DCHECK(!node->GetStackTrace()) << "Should not reset the stack trace";
node->SetStackTrace(std::make_shared<StackTraceWrapper>(
StackTraceWrapper::ExtractStack(limit, mappers, filters)));
node->SetStackTrace(
std::make_shared<StackTraceWrapper>(StackTraceWrapper::ExtractStack(
source_map.source_map_, file_set.file_set_)));
return *node->GetStackTrace();
},
py::return_value_policy::reference);
m.def(
"extract_stack",
[](const py::object& limit, const py::list& mappers,
const py::list& filters) {
return StackTraceWrapper::ExtractStack(limit, mappers, filters);
[](const PyBindSourceMap& source_map, const PyBindFileSet& file_set) {
return StackTraceWrapper::ExtractStack(source_map.source_map_,
file_set.file_set_);
},
py::return_value_policy::move);
}

View File

@ -40,8 +40,10 @@ else:
_get_thread_key = threading.get_ident
_source_mapper_stacks = collections.defaultdict(list)
_source_filter_stacks = collections.defaultdict(list)
# TODO(mdan): Move these to C++ as well.
# Moving to C++ can further avoid extra copies made by get_effective_map.
_source_mapper_stacks = collections.defaultdict(lambda: [SentinelMapper()])
_source_filter_stacks = collections.defaultdict(lambda: [SentinelFilter()])
class StackTraceTransform(object):
@ -51,8 +53,6 @@ class StackTraceTransform(object):
_thread_key = None
def __enter__(self):
self.reset()
# Any given instance is assumed to be used by a single thread, which reduces
# expensive thread local lookups.
if self._thread_key is None:
@ -61,48 +61,71 @@ class StackTraceTransform(object):
assert self._thread_key == _get_thread_key(), 'Shared across threads?'
stack = self._stack_dict[self._thread_key]
if stack:
self.parent = stack[-1]
else:
self.parent = None
self.parent = stack[-1]
stack.append(self)
self.update()
return self
def __exit__(self, unused_type, unused_value, unused_traceback):
top = self._stack_dict[self._thread_key].pop()
assert top is self, 'Concurrent access?'
def reset(self):
pass
def update(self):
raise NotImplementedError('subclasses need to override this')
class StackTraceMapper(StackTraceTransform):
"""Allows remapping traceback information to different source code."""
_stack_dict = _source_mapper_stacks
def reset(self):
self._effective_source_map = None
def __init__(self):
self.internal_map = _tf_stack.PyBindSourceMap()
def update(self):
self.internal_map.update_to(tuple(self.get_effective_source_map().items()))
def get_effective_source_map(self):
"""Returns a map (filename, lineno) -> (filename, lineno, function_name)."""
raise NotImplementedError('subclasses need to override this')
EMPTY_DICT = {}
class SentinelMapper(StackTraceMapper):
def get_effective_source_map(self):
return EMPTY_DICT
class StackTraceFilter(StackTraceTransform):
"""Allows filtering traceback information by removing superfluous frames."""
_stack_dict = _source_filter_stacks
def reset(self):
self._filtered_filenames = None
def __init__(self):
self.internal_set = _tf_stack.PyBindFileSet()
def update(self):
self.internal_set.update_to(set(self.get_filtered_filenames()))
def get_filtered_filenames(self):
raise NotImplementedError('subclasses need to override this')
EMPTY_SET = frozenset()
class SentinelFilter(StackTraceFilter):
def get_filtered_filenames(self):
return EMPTY_SET
class CurrentModuleFilter(StackTraceFilter):
"""Filters stack frames from the module where this is used (best effort)."""
def __init__(self):
super().__init__()
filter_filename = None
outer_f = None
f = inspect.currentframe()
@ -114,6 +137,9 @@ class CurrentModuleFilter(StackTraceFilter):
if outer_f is not None:
filter_filename = inspect.getsourcefile(outer_f)
self._filename = filter_filename
# This may be called repeatedly: once on entry by the superclass, then by
# each child context manager.
self._cached_set = None
finally:
# Avoid reference cycles, see:
# https://docs.python.org/3.7/library/inspect.html#the-interpreter-stack
@ -121,58 +147,52 @@ class CurrentModuleFilter(StackTraceFilter):
del outer_f
def get_filtered_filenames(self):
if self._filtered_filenames is None:
self._filtered_filenames = frozenset((self._filename,))
if self.parent is not None:
self._filtered_filenames |= self.parent.get_filtered_filenames()
return self._filtered_filenames
if self._cached_set is not None:
return self._cached_set
filtered_filenames = frozenset((self._filename,))
if self.parent is not None:
filtered_filenames |= self.parent.get_filtered_filenames()
self._cached_set = filtered_filenames
return filtered_filenames
def extract_stack(limit=-1):
"""A lightweight, extensible re-implementation of traceback.extract_stack.
NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for
each stack frame using linecache, which results in an abundance of stat()
calls. This implementation does not retrieve the code, and any consumer
should apply _convert_stack to the result to obtain a traceback that can
be formatted etc. using traceback methods.
Args:
limit: A limit on the number of frames to return.
def extract_stack():
"""An eager-friendly alternative to traceback.extract_stack.
Returns:
An object wrapping the sequence of StackFrame objects (filename, lineno,
name, line) corresponding to the call stack of the current thread. The
returned object can be indexed as a Python list.
A list-like FrameSummary containing StackFrame-like objects, which are
namedtuple-like objects with the following fields: filename, lineno, name,
line, meant to masquerade as traceback.FrameSummary objects.
"""
# N.B ExtractStack in tf_stack.cc will drop this frame prior to
# traversing the stack.
# TODO(cheshire): Remove this function, use extract_stack_for_node or Python
# traceback module.
thread_key = _get_thread_key()
return _tf_stack.extract_stack(limit, _source_mapper_stacks[thread_key],
_source_filter_stacks[thread_key])
return _tf_stack.extract_stack(
_source_mapper_stacks[thread_key][-1].internal_map,
_source_filter_stacks[thread_key][-1].internal_set)
def extract_stack_for_node(node, limit=-1):
"""Same as extract_stack, but also saves the retrieved stack in `node`.
# TODO(mdan): Revisit these - a single location is almost always sufficient.
def extract_stack_for_node(node):
"""Attaches the current stack trace to `node`.
Args:
node: Pointer to the Node object.
limit: A limit on the number of frames to return.
node: a Node object.
Returns:
An object wrapping the sequence of StackFrame objects (filename, lineno,
name, line) corresponding to the call stack of the current thread. The
returned object can be indexed as a Python list.
A list-like FrameSummary containing StackFrame-like objects, which are
namedtuple-like objects with the following fields: filename, lineno, name,
line, meant to masquerade as traceback.FrameSummary objects.
"""
# N.B ExtractStack in tf_stack.cc will drop this frame prior to
# traversing the stack.
thread_key = _get_thread_key()
return _tf_stack.extract_stack_for_node(limit,
_source_mapper_stacks[thread_key],
_source_filter_stacks[thread_key],
node)
return _tf_stack.extract_stack_for_node(
_source_mapper_stacks[thread_key][-1].internal_map,
_source_filter_stacks[thread_key][-1].internal_set, node)
StackSummary = _tf_stack.StackTraceWrapper

View File

@ -26,31 +26,19 @@ from tensorflow.python.util import tf_stack
class TFStackTest(test.TestCase):
def testLimit(self):
self.assertEmpty(tf_stack.extract_stack(limit=0))
self.assertLen(tf_stack.extract_stack(limit=1), 1)
def testFormatStackSelfConsistency(self):
# Both defined on the same line to produce identical stacks.
stacks = tf_stack.extract_stack(), traceback.extract_stack()
self.assertEqual(
len(tf_stack.extract_stack(limit=-1)),
len(tf_stack.extract_stack()))
def testConsistencyWithTraceback(self):
stack, expected_stack = extract_stack()
for frame, expected in zip(stack, expected_stack):
self.assertEqual(convert_stack_frame(frame), expected)
def testFormatStack(self):
stack, expected_stack = extract_stack()
self.assertEqual(
traceback.format_list(stack),
traceback.format_list(expected_stack))
traceback.format_list(stacks[0]), traceback.format_list(stacks[1]))
def testFrameSummaryEquality(self):
frame0, frame1 = tf_stack.extract_stack(limit=2)
self.assertNotEqual(frame0, frame1)
self.assertEqual(frame0, frame0)
frames1 = tf_stack.extract_stack()
frames2 = tf_stack.extract_stack()
another_frame0, _ = tf_stack.extract_stack(limit=2)
self.assertEqual(frame0, another_frame0)
self.assertNotEqual(frames1[0], frames1[1])
self.assertEqual(frames1[0], frames1[0])
self.assertEqual(frames1[0], frames2[0])
def testFrameSummaryEqualityAndHash(self):
# Both defined on the same line to produce identical stacks.
@ -74,17 +62,5 @@ def extract_stack(limit=None):
return tf_stack.extract_stack(limit), traceback.extract_stack(limit)
def convert_stack_frame(frame):
"""Converts a TF stack frame into Python's."""
# TODO(mihaimaruseac): Remove except case when dropping suport for py2
try:
return traceback.FrameSummary(
frame.filename, frame.lineno, frame.name, line=frame.line)
except AttributeError:
# On Python < 3.5 (i.e., Python2), we don't have traceback.FrameSummary so
# we don't need to match with that class. Instead, just a tuple is enough.
return tuple(frame)
if __name__ == "__main__":
test.main()