[NFC] Optimize getting the last stack frame + stack frame conversion

Firstly, use callbacks instead of C++ datastructures to avoid expensive
conversion + rehashing.

Secondly, when returning the last frame, do not needlessly traverse the entire
stack trace.

PiperOrigin-RevId: 346937177
Change-Id: Ib13d0bf0079e2cc33a31bf2d4621b1c5b8e5b5da
This commit is contained in:
George Karpenkov 2020-12-10 22:45:04 -08:00 committed by TensorFlower Gardener
parent 43a080a530
commit 20423e72df
5 changed files with 101 additions and 65 deletions
tensorflow

View File

@ -401,9 +401,7 @@ TEST_F(FunctionLibraryRuntimeTest, InstantiationStackTraceCopying) {
return "DummyStackTrace";
}
absl::optional<StackFrame> LastUserFrame() const override {
return absl::nullopt;
}
StackFrame LastUserFrame() const override { return StackFrame{}; }
};
FunctionDef func = test::function::XTimesTwo();

View File

@ -350,8 +350,8 @@ class AbstractStackTrace {
virtual absl::Span<StackFrame const> ToFrames() const = 0;
// Returns the last stack frame from user code, attempting to ignore the
// framework code. Returns an empty optional if no such stack frame was found.
virtual absl::optional<StackFrame> LastUserFrame() const = 0;
// framework code. Returns an empty frame if no such stack frame was found.
virtual StackFrame LastUserFrame() const = 0;
virtual std::string ToString(const TracePrintingOptions& opts) const = 0;
};

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/python/util/stack_trace.h"
#include <limits>
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/platform/stringpiece.h"
@ -40,29 +42,38 @@ const char* GetPythonString(PyObject* o) {
namespace tensorflow {
std::vector<StackFrame> StackTrace::ToStackFrames(
const StackTraceMap& mapper, const StackTraceFilter& filtered) const {
const StackTraceMap& mapper, const StackTraceFilter& filtered,
bool reverse_traversal, int limit) const {
DCheckPyGilStateForStackTrace();
std::vector<StackFrame> result;
result.reserve(code_objs_.size());
for (int i = code_objs_.size() - 1; i >= 0; --i) {
const std::pair<PyCodeObject*, int>& code_obj = code_objs_[i];
const char* file_name = GetPythonString(code_obj.first->co_filename);
const int line_number =
PyCode_Addr2Line(code_objs_[i].first, code_obj.second);
if (limit == -1) limit = std::numeric_limits<int>::max();
if (!result.empty() && filtered.count(file_name)) {
continue; // Never filter the innermost frame.
for (int i = 0; i < code_objs_.size(); i++) {
int idx = reverse_traversal ? i : code_objs_.size() - 1 - i;
const std::pair<PyCodeObject*, int>& code_obj = code_objs_[idx];
const char* file_name = GetPythonString(code_obj.first->co_filename);
const int line_number = PyCode_Addr2Line(code_obj.first, code_obj.second);
if (filtered && filtered(file_name)) {
continue;
}
auto it = mapper.find(std::make_pair(file_name, line_number));
absl::optional<StackFrame> mapped =
mapper ? mapper(std::make_pair(file_name, line_number)) : absl::nullopt;
if (it != mapper.end()) {
result.push_back(it->second);
if (mapped) {
result.push_back(*mapped);
} else {
result.emplace_back(StackFrame{file_name, line_number,
GetPythonString(code_obj.first->co_name)});
}
if (result.size() == limit) {
break;
}
}
return result;

View File

@ -44,10 +44,10 @@ inline void DCheckPyGilStateForStackTrace() {
// Maps filename/line_no combination into a stack frame.
using StackTraceMap =
absl::flat_hash_map<std::pair<std::string, int>, StackFrame>;
std::function<absl::optional<StackFrame>(std::pair<const char*, int>)>;
// Contains filenames which should be skipped.
using StackTraceFilter = absl::flat_hash_set<std::string>;
// Returns "true" on filenames which should be skipped.
using StackTraceFilter = std::function<bool(const char*)>;
// A class for capturing Python stack trace.
class StackTrace final {
@ -95,11 +95,14 @@ class StackTrace final {
// Returns a structured representation of the captured stack trace.
// `mapper` provides a custom mapping for translating stack frames, `filter`
// returns `true` for the stack frames which should be omitted, and if
// `drop_last` is set, the last stack frame is dropped.
std::vector<StackFrame> ToStackFrames(
const StackTraceMap& mapper = {},
const StackTraceFilter& filtered = {}) const;
// returns `true` for the stack frames which should be omitted.
//
// `reverse_traversal` changes the traversal order of the stack trace, and
// `limit` bounds the number of returned frames (after filtering).
std::vector<StackFrame> ToStackFrames(const StackTraceMap& mapper = {},
const StackTraceFilter& filtered = {},
bool reverse_traversal = false,
int limit = -1) const;
// Python GIL must be acquired beforehand.
ABSL_ATTRIBUTE_HOT

View File

@ -135,15 +135,9 @@ class StackTraceWrapper : public AbstractStackTrace {
return *stack_frames_cache_;
}
absl::optional<StackFrame> LastUserFrame() const override {
GenerateCache();
for (int i = stack_frames_cache_->size() - 1; i >= 0; i--) {
const StackFrame& frame = stack_frames_cache_->at(i);
if (!IsInternalFrame(frame)) {
return frame;
}
}
return absl::nullopt;
StackFrame LastUserFrame() const override {
GenerateLastFrameCache();
return *last_stack_frame_cache_;
}
std::string ToString(const TracePrintingOptions& opts) const override {
@ -165,7 +159,7 @@ class StackTraceWrapper : public AbstractStackTrace {
std::vector<StackFrame> filtered_frames;
for (const StackFrame& frame : *stack_frames_cache_) {
if (!IsInternalFrame(frame)) {
if (!IsInternalFrameForFilename(frame.file_name)) {
filtered_frames.push_back(frame);
}
}
@ -175,36 +169,46 @@ class StackTraceWrapper : public AbstractStackTrace {
bool IsCacheGenerated() const { return stack_frames_cache_.has_value(); }
void GenerateCache() const {
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and
// 2) ToStackFrames and LineContents actually need it.
if (stack_frames_cache_) {
return;
}
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and
// 2) ToStackFrames and LineContents actually need it.
PyGILState_STATE state = PyGILState_Ensure();
absl::flat_hash_map<std::pair<std::string, int>, StackFrame> m;
absl::flat_hash_set<std::string> f;
for (const std::pair<py::handle, py::handle>& p : *source_map_) {
const py::tuple& key = py::cast<py::tuple>(p.first);
const py::tuple& value = py::cast<py::tuple>(p.second);
m.emplace(std::make_pair(std::string(py::cast<py::str>(key[0])),
py::cast<ssize_t>(key[1])),
StackFrame{std::string(py::cast<py::str>(value[0])),
py::cast<py::int_>(value[1]),
std::string(py::cast<py::str>(value[2]))});
}
for (const py::handle& h : *filtered_filenames_) {
f.emplace(py::cast<py::str>(h));
}
stack_frames_cache_ = captured_.ToStackFrames(m, f);
stack_frames_cache_ = captured_.ToStackFrames(
[&](std::pair<const char*, int> p) { return StackTraceMapping(p); },
[&](const char* f) { return StackTraceFiltering(f); });
stack_frames_cache_->pop_back(); // Drop last stack frame.
PyGILState_Release(state);
}
void GenerateLastFrameCache() const {
if (last_stack_frame_cache_) {
return;
}
PyGILState_STATE state = PyGILState_Ensure();
auto f = [&](const char* file_name) -> bool {
return StackTraceFiltering(file_name) ||
IsInternalFrameForFilename(file_name);
};
std::vector<StackFrame> last_frame = captured_.ToStackFrames(
[&](std::pair<const char*, int> p) { return StackTraceMapping(p); }, f,
/*reverse_traversal=*/true,
/*limit=*/1);
if (last_frame.empty()) {
last_stack_frame_cache_ = StackFrame{};
} else {
DCHECK(last_frame.size() == 1);
last_stack_frame_cache_ = last_frame[0];
}
PyGILState_Release(state);
}
StackTraceWrapper(StackTraceWrapper&&) = default;
~StackTraceWrapper() override {
PyGILState_STATE state = PyGILState_Ensure();
@ -225,19 +229,43 @@ class StackTraceWrapper : public AbstractStackTrace {
});
}
static bool IsInternalFrame(const StackFrame& frame) {
static bool IsInternalFrameForFilename(absl::string_view file_name) {
// Use a simple heuristic for now.
// TODO(cheshire): Build a more sophisticated mechanism, rely on @tf.export.
return absl::StrContains(frame.file_name, "tensorflow/python") &&
!absl::StrContains(frame.file_name, "keras") &&
!absl::StrContains(frame.file_name, "test.py");
return absl::StrContains(file_name, "tensorflow/python") &&
!absl::StrContains(file_name, "keras") &&
!absl::StrContains(file_name, "test.py");
}
absl::optional<StackFrame> StackTraceMapping(
std::pair<const char*, int> p) const {
if (source_map_->empty()) {
return absl::nullopt;
}
auto key = py::make_tuple(py::str(p.first), py::int_(p.second));
if (source_map_->contains(key)) {
const py::tuple& value = (*source_map_)[key];
return StackFrame{std::string(py::cast<py::str>(value[0])),
py::cast<py::int_>(value[1]),
std::string(py::cast<py::str>(value[2]))};
}
return absl::nullopt;
}
bool StackTraceFiltering(const char* file_name) const {
return filtered_filenames_->contains(file_name);
}
mutable absl::optional<std::vector<StackFrame>> stack_frames_cache_;
StackTrace captured_;
// Using optional to force destruction while we hold a GIL.
absl::optional<py::dict> source_map_;
absl::optional<py::set> filtered_filenames_;
mutable absl::optional<std::vector<StackFrame>> stack_frames_cache_;
mutable absl::optional<StackFrame> last_stack_frame_cache_;
};
} // namespace
@ -336,12 +364,8 @@ PYBIND11_MODULE(_tf_stack, m) {
self.GenerateCache();
return py::str(self.ToString({}));
})
.def("last_user_frame", [](const StackTraceWrapper& self) {
if (absl::optional<StackFrame> frame = self.LastUserFrame()) {
return *frame;
}
return StackFrame{};
});
.def("last_user_frame",
[](const StackTraceWrapper& self) { return self.LastUserFrame(); });
m.def(
"extract_stack_for_node",