[NFC] Optimize getting the last stack frame + stack frame conversion
Firstly, use callbacks instead of C++ datastructures to avoid expensive conversion + rehashing. Secondly, when returning the last frame, do not needlessly traverse the entire stack trace. PiperOrigin-RevId: 346937177 Change-Id: Ib13d0bf0079e2cc33a31bf2d4621b1c5b8e5b5da
This commit is contained in:
parent
43a080a530
commit
20423e72df
tensorflow
core
python/util
@ -401,9 +401,7 @@ TEST_F(FunctionLibraryRuntimeTest, InstantiationStackTraceCopying) {
|
||||
return "DummyStackTrace";
|
||||
}
|
||||
|
||||
absl::optional<StackFrame> LastUserFrame() const override {
|
||||
return absl::nullopt;
|
||||
}
|
||||
StackFrame LastUserFrame() const override { return StackFrame{}; }
|
||||
};
|
||||
|
||||
FunctionDef func = test::function::XTimesTwo();
|
||||
|
@ -350,8 +350,8 @@ class AbstractStackTrace {
|
||||
virtual absl::Span<StackFrame const> ToFrames() const = 0;
|
||||
|
||||
// Returns the last stack frame from user code, attempting to ignore the
|
||||
// framework code. Returns an empty optional if no such stack frame was found.
|
||||
virtual absl::optional<StackFrame> LastUserFrame() const = 0;
|
||||
// framework code. Returns an empty frame if no such stack frame was found.
|
||||
virtual StackFrame LastUserFrame() const = 0;
|
||||
virtual std::string ToString(const TracePrintingOptions& opts) const = 0;
|
||||
};
|
||||
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/python/util/stack_trace.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "tensorflow/core/platform/str_util.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
|
||||
@ -40,29 +42,38 @@ const char* GetPythonString(PyObject* o) {
|
||||
namespace tensorflow {
|
||||
|
||||
std::vector<StackFrame> StackTrace::ToStackFrames(
|
||||
const StackTraceMap& mapper, const StackTraceFilter& filtered) const {
|
||||
const StackTraceMap& mapper, const StackTraceFilter& filtered,
|
||||
bool reverse_traversal, int limit) const {
|
||||
DCheckPyGilStateForStackTrace();
|
||||
std::vector<StackFrame> result;
|
||||
result.reserve(code_objs_.size());
|
||||
|
||||
for (int i = code_objs_.size() - 1; i >= 0; --i) {
|
||||
const std::pair<PyCodeObject*, int>& code_obj = code_objs_[i];
|
||||
const char* file_name = GetPythonString(code_obj.first->co_filename);
|
||||
const int line_number =
|
||||
PyCode_Addr2Line(code_objs_[i].first, code_obj.second);
|
||||
if (limit == -1) limit = std::numeric_limits<int>::max();
|
||||
|
||||
if (!result.empty() && filtered.count(file_name)) {
|
||||
continue; // Never filter the innermost frame.
|
||||
for (int i = 0; i < code_objs_.size(); i++) {
|
||||
int idx = reverse_traversal ? i : code_objs_.size() - 1 - i;
|
||||
|
||||
const std::pair<PyCodeObject*, int>& code_obj = code_objs_[idx];
|
||||
const char* file_name = GetPythonString(code_obj.first->co_filename);
|
||||
const int line_number = PyCode_Addr2Line(code_obj.first, code_obj.second);
|
||||
|
||||
if (filtered && filtered(file_name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto it = mapper.find(std::make_pair(file_name, line_number));
|
||||
absl::optional<StackFrame> mapped =
|
||||
mapper ? mapper(std::make_pair(file_name, line_number)) : absl::nullopt;
|
||||
|
||||
if (it != mapper.end()) {
|
||||
result.push_back(it->second);
|
||||
if (mapped) {
|
||||
result.push_back(*mapped);
|
||||
} else {
|
||||
result.emplace_back(StackFrame{file_name, line_number,
|
||||
GetPythonString(code_obj.first->co_name)});
|
||||
}
|
||||
|
||||
if (result.size() == limit) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
|
@ -44,10 +44,10 @@ inline void DCheckPyGilStateForStackTrace() {
|
||||
|
||||
// Maps filename/line_no combination into a stack frame.
|
||||
using StackTraceMap =
|
||||
absl::flat_hash_map<std::pair<std::string, int>, StackFrame>;
|
||||
std::function<absl::optional<StackFrame>(std::pair<const char*, int>)>;
|
||||
|
||||
// Contains filenames which should be skipped.
|
||||
using StackTraceFilter = absl::flat_hash_set<std::string>;
|
||||
// Returns "true" on filenames which should be skipped.
|
||||
using StackTraceFilter = std::function<bool(const char*)>;
|
||||
|
||||
// A class for capturing Python stack trace.
|
||||
class StackTrace final {
|
||||
@ -95,11 +95,14 @@ class StackTrace final {
|
||||
|
||||
// Returns a structured representation of the captured stack trace.
|
||||
// `mapper` provides a custom mapping for translating stack frames, `filter`
|
||||
// returns `true` for the stack frames which should be omitted, and if
|
||||
// `drop_last` is set, the last stack frame is dropped.
|
||||
std::vector<StackFrame> ToStackFrames(
|
||||
const StackTraceMap& mapper = {},
|
||||
const StackTraceFilter& filtered = {}) const;
|
||||
// returns `true` for the stack frames which should be omitted.
|
||||
//
|
||||
// `reverse_traversal` changes the traversal order of the stack trace, and
|
||||
// `limit` bounds the number of returned frames (after filtering).
|
||||
std::vector<StackFrame> ToStackFrames(const StackTraceMap& mapper = {},
|
||||
const StackTraceFilter& filtered = {},
|
||||
bool reverse_traversal = false,
|
||||
int limit = -1) const;
|
||||
|
||||
// Python GIL must be acquired beforehand.
|
||||
ABSL_ATTRIBUTE_HOT
|
||||
|
@ -135,15 +135,9 @@ class StackTraceWrapper : public AbstractStackTrace {
|
||||
return *stack_frames_cache_;
|
||||
}
|
||||
|
||||
absl::optional<StackFrame> LastUserFrame() const override {
|
||||
GenerateCache();
|
||||
for (int i = stack_frames_cache_->size() - 1; i >= 0; i--) {
|
||||
const StackFrame& frame = stack_frames_cache_->at(i);
|
||||
if (!IsInternalFrame(frame)) {
|
||||
return frame;
|
||||
}
|
||||
}
|
||||
return absl::nullopt;
|
||||
StackFrame LastUserFrame() const override {
|
||||
GenerateLastFrameCache();
|
||||
return *last_stack_frame_cache_;
|
||||
}
|
||||
|
||||
std::string ToString(const TracePrintingOptions& opts) const override {
|
||||
@ -165,7 +159,7 @@ class StackTraceWrapper : public AbstractStackTrace {
|
||||
|
||||
std::vector<StackFrame> filtered_frames;
|
||||
for (const StackFrame& frame : *stack_frames_cache_) {
|
||||
if (!IsInternalFrame(frame)) {
|
||||
if (!IsInternalFrameForFilename(frame.file_name)) {
|
||||
filtered_frames.push_back(frame);
|
||||
}
|
||||
}
|
||||
@ -175,36 +169,46 @@ class StackTraceWrapper : public AbstractStackTrace {
|
||||
bool IsCacheGenerated() const { return stack_frames_cache_.has_value(); }
|
||||
|
||||
void GenerateCache() const {
|
||||
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and
|
||||
// 2) ToStackFrames and LineContents actually need it.
|
||||
if (stack_frames_cache_) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and
|
||||
// 2) ToStackFrames and LineContents actually need it.
|
||||
PyGILState_STATE state = PyGILState_Ensure();
|
||||
absl::flat_hash_map<std::pair<std::string, int>, StackFrame> m;
|
||||
absl::flat_hash_set<std::string> f;
|
||||
|
||||
for (const std::pair<py::handle, py::handle>& p : *source_map_) {
|
||||
const py::tuple& key = py::cast<py::tuple>(p.first);
|
||||
const py::tuple& value = py::cast<py::tuple>(p.second);
|
||||
|
||||
m.emplace(std::make_pair(std::string(py::cast<py::str>(key[0])),
|
||||
py::cast<ssize_t>(key[1])),
|
||||
StackFrame{std::string(py::cast<py::str>(value[0])),
|
||||
py::cast<py::int_>(value[1]),
|
||||
std::string(py::cast<py::str>(value[2]))});
|
||||
}
|
||||
|
||||
for (const py::handle& h : *filtered_filenames_) {
|
||||
f.emplace(py::cast<py::str>(h));
|
||||
}
|
||||
|
||||
stack_frames_cache_ = captured_.ToStackFrames(m, f);
|
||||
stack_frames_cache_ = captured_.ToStackFrames(
|
||||
[&](std::pair<const char*, int> p) { return StackTraceMapping(p); },
|
||||
[&](const char* f) { return StackTraceFiltering(f); });
|
||||
stack_frames_cache_->pop_back(); // Drop last stack frame.
|
||||
PyGILState_Release(state);
|
||||
}
|
||||
|
||||
void GenerateLastFrameCache() const {
|
||||
if (last_stack_frame_cache_) {
|
||||
return;
|
||||
}
|
||||
|
||||
PyGILState_STATE state = PyGILState_Ensure();
|
||||
auto f = [&](const char* file_name) -> bool {
|
||||
return StackTraceFiltering(file_name) ||
|
||||
IsInternalFrameForFilename(file_name);
|
||||
};
|
||||
|
||||
std::vector<StackFrame> last_frame = captured_.ToStackFrames(
|
||||
[&](std::pair<const char*, int> p) { return StackTraceMapping(p); }, f,
|
||||
/*reverse_traversal=*/true,
|
||||
/*limit=*/1);
|
||||
|
||||
if (last_frame.empty()) {
|
||||
last_stack_frame_cache_ = StackFrame{};
|
||||
} else {
|
||||
DCHECK(last_frame.size() == 1);
|
||||
last_stack_frame_cache_ = last_frame[0];
|
||||
}
|
||||
PyGILState_Release(state);
|
||||
}
|
||||
|
||||
StackTraceWrapper(StackTraceWrapper&&) = default;
|
||||
~StackTraceWrapper() override {
|
||||
PyGILState_STATE state = PyGILState_Ensure();
|
||||
@ -225,19 +229,43 @@ class StackTraceWrapper : public AbstractStackTrace {
|
||||
});
|
||||
}
|
||||
|
||||
static bool IsInternalFrame(const StackFrame& frame) {
|
||||
static bool IsInternalFrameForFilename(absl::string_view file_name) {
|
||||
// Use a simple heuristic for now.
|
||||
// TODO(cheshire): Build a more sophisticated mechanism, rely on @tf.export.
|
||||
return absl::StrContains(frame.file_name, "tensorflow/python") &&
|
||||
!absl::StrContains(frame.file_name, "keras") &&
|
||||
!absl::StrContains(frame.file_name, "test.py");
|
||||
return absl::StrContains(file_name, "tensorflow/python") &&
|
||||
!absl::StrContains(file_name, "keras") &&
|
||||
!absl::StrContains(file_name, "test.py");
|
||||
}
|
||||
|
||||
absl::optional<StackFrame> StackTraceMapping(
|
||||
std::pair<const char*, int> p) const {
|
||||
if (source_map_->empty()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
auto key = py::make_tuple(py::str(p.first), py::int_(p.second));
|
||||
|
||||
if (source_map_->contains(key)) {
|
||||
const py::tuple& value = (*source_map_)[key];
|
||||
return StackFrame{std::string(py::cast<py::str>(value[0])),
|
||||
py::cast<py::int_>(value[1]),
|
||||
std::string(py::cast<py::str>(value[2]))};
|
||||
}
|
||||
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
bool StackTraceFiltering(const char* file_name) const {
|
||||
return filtered_filenames_->contains(file_name);
|
||||
}
|
||||
|
||||
mutable absl::optional<std::vector<StackFrame>> stack_frames_cache_;
|
||||
StackTrace captured_;
|
||||
// Using optional to force destruction while we hold a GIL.
|
||||
absl::optional<py::dict> source_map_;
|
||||
absl::optional<py::set> filtered_filenames_;
|
||||
|
||||
mutable absl::optional<std::vector<StackFrame>> stack_frames_cache_;
|
||||
mutable absl::optional<StackFrame> last_stack_frame_cache_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@ -336,12 +364,8 @@ PYBIND11_MODULE(_tf_stack, m) {
|
||||
self.GenerateCache();
|
||||
return py::str(self.ToString({}));
|
||||
})
|
||||
.def("last_user_frame", [](const StackTraceWrapper& self) {
|
||||
if (absl::optional<StackFrame> frame = self.LastUserFrame()) {
|
||||
return *frame;
|
||||
}
|
||||
return StackFrame{};
|
||||
});
|
||||
.def("last_user_frame",
|
||||
[](const StackTraceWrapper& self) { return self.LastUserFrame(); });
|
||||
|
||||
m.def(
|
||||
"extract_stack_for_node",
|
||||
|
Loading…
Reference in New Issue
Block a user