[NFC] Optimize getting the last stack frame + stack frame conversion

Firstly, use callbacks instead of C++ datastructures to avoid expensive
conversion + rehashing.

Secondly, when returning the last frame, do not needlessly traverse the entire
stack trace.

PiperOrigin-RevId: 346937177
Change-Id: Ib13d0bf0079e2cc33a31bf2d4621b1c5b8e5b5da
This commit is contained in:
George Karpenkov 2020-12-10 22:45:04 -08:00 committed by TensorFlower Gardener
parent 43a080a530
commit 20423e72df
5 changed files with 101 additions and 65 deletions

View File

@ -401,9 +401,7 @@ TEST_F(FunctionLibraryRuntimeTest, InstantiationStackTraceCopying) {
return "DummyStackTrace"; return "DummyStackTrace";
} }
absl::optional<StackFrame> LastUserFrame() const override { StackFrame LastUserFrame() const override { return StackFrame{}; }
return absl::nullopt;
}
}; };
FunctionDef func = test::function::XTimesTwo(); FunctionDef func = test::function::XTimesTwo();

View File

@ -350,8 +350,8 @@ class AbstractStackTrace {
virtual absl::Span<StackFrame const> ToFrames() const = 0; virtual absl::Span<StackFrame const> ToFrames() const = 0;
// Returns the last stack frame from user code, attempting to ignore the // Returns the last stack frame from user code, attempting to ignore the
// framework code. Returns an empty optional if no such stack frame was found. // framework code. Returns an empty frame if no such stack frame was found.
virtual absl::optional<StackFrame> LastUserFrame() const = 0; virtual StackFrame LastUserFrame() const = 0;
virtual std::string ToString(const TracePrintingOptions& opts) const = 0; virtual std::string ToString(const TracePrintingOptions& opts) const = 0;
}; };

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/python/util/stack_trace.h" #include "tensorflow/python/util/stack_trace.h"
#include <limits>
#include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/stringpiece.h"
@ -40,29 +42,38 @@ const char* GetPythonString(PyObject* o) {
namespace tensorflow { namespace tensorflow {
std::vector<StackFrame> StackTrace::ToStackFrames( std::vector<StackFrame> StackTrace::ToStackFrames(
const StackTraceMap& mapper, const StackTraceFilter& filtered) const { const StackTraceMap& mapper, const StackTraceFilter& filtered,
bool reverse_traversal, int limit) const {
DCheckPyGilStateForStackTrace(); DCheckPyGilStateForStackTrace();
std::vector<StackFrame> result; std::vector<StackFrame> result;
result.reserve(code_objs_.size()); result.reserve(code_objs_.size());
for (int i = code_objs_.size() - 1; i >= 0; --i) { if (limit == -1) limit = std::numeric_limits<int>::max();
const std::pair<PyCodeObject*, int>& code_obj = code_objs_[i];
const char* file_name = GetPythonString(code_obj.first->co_filename);
const int line_number =
PyCode_Addr2Line(code_objs_[i].first, code_obj.second);
if (!result.empty() && filtered.count(file_name)) { for (int i = 0; i < code_objs_.size(); i++) {
continue; // Never filter the innermost frame. int idx = reverse_traversal ? i : code_objs_.size() - 1 - i;
const std::pair<PyCodeObject*, int>& code_obj = code_objs_[idx];
const char* file_name = GetPythonString(code_obj.first->co_filename);
const int line_number = PyCode_Addr2Line(code_obj.first, code_obj.second);
if (filtered && filtered(file_name)) {
continue;
} }
auto it = mapper.find(std::make_pair(file_name, line_number)); absl::optional<StackFrame> mapped =
mapper ? mapper(std::make_pair(file_name, line_number)) : absl::nullopt;
if (it != mapper.end()) { if (mapped) {
result.push_back(it->second); result.push_back(*mapped);
} else { } else {
result.emplace_back(StackFrame{file_name, line_number, result.emplace_back(StackFrame{file_name, line_number,
GetPythonString(code_obj.first->co_name)}); GetPythonString(code_obj.first->co_name)});
} }
if (result.size() == limit) {
break;
}
} }
return result; return result;

View File

@ -44,10 +44,10 @@ inline void DCheckPyGilStateForStackTrace() {
// Maps filename/line_no combination into a stack frame. // Maps filename/line_no combination into a stack frame.
using StackTraceMap = using StackTraceMap =
absl::flat_hash_map<std::pair<std::string, int>, StackFrame>; std::function<absl::optional<StackFrame>(std::pair<const char*, int>)>;
// Contains filenames which should be skipped. // Returns "true" on filenames which should be skipped.
using StackTraceFilter = absl::flat_hash_set<std::string>; using StackTraceFilter = std::function<bool(const char*)>;
// A class for capturing Python stack trace. // A class for capturing Python stack trace.
class StackTrace final { class StackTrace final {
@ -95,11 +95,14 @@ class StackTrace final {
// Returns a structured representation of the captured stack trace. // Returns a structured representation of the captured stack trace.
// `mapper` provides a custom mapping for translating stack frames, `filter` // `mapper` provides a custom mapping for translating stack frames, `filter`
// returns `true` for the stack frames which should be omitted, and if // returns `true` for the stack frames which should be omitted.
// `drop_last` is set, the last stack frame is dropped. //
std::vector<StackFrame> ToStackFrames( // `reverse_traversal` changes the traversal order of the stack trace, and
const StackTraceMap& mapper = {}, // `limit` bounds the number of returned frames (after filtering).
const StackTraceFilter& filtered = {}) const; std::vector<StackFrame> ToStackFrames(const StackTraceMap& mapper = {},
const StackTraceFilter& filtered = {},
bool reverse_traversal = false,
int limit = -1) const;
// Python GIL must be acquired beforehand. // Python GIL must be acquired beforehand.
ABSL_ATTRIBUTE_HOT ABSL_ATTRIBUTE_HOT

View File

@ -135,15 +135,9 @@ class StackTraceWrapper : public AbstractStackTrace {
return *stack_frames_cache_; return *stack_frames_cache_;
} }
absl::optional<StackFrame> LastUserFrame() const override { StackFrame LastUserFrame() const override {
GenerateCache(); GenerateLastFrameCache();
for (int i = stack_frames_cache_->size() - 1; i >= 0; i--) { return *last_stack_frame_cache_;
const StackFrame& frame = stack_frames_cache_->at(i);
if (!IsInternalFrame(frame)) {
return frame;
}
}
return absl::nullopt;
} }
std::string ToString(const TracePrintingOptions& opts) const override { std::string ToString(const TracePrintingOptions& opts) const override {
@ -165,7 +159,7 @@ class StackTraceWrapper : public AbstractStackTrace {
std::vector<StackFrame> filtered_frames; std::vector<StackFrame> filtered_frames;
for (const StackFrame& frame : *stack_frames_cache_) { for (const StackFrame& frame : *stack_frames_cache_) {
if (!IsInternalFrame(frame)) { if (!IsInternalFrameForFilename(frame.file_name)) {
filtered_frames.push_back(frame); filtered_frames.push_back(frame);
} }
} }
@ -175,36 +169,46 @@ class StackTraceWrapper : public AbstractStackTrace {
bool IsCacheGenerated() const { return stack_frames_cache_.has_value(); } bool IsCacheGenerated() const { return stack_frames_cache_.has_value(); }
void GenerateCache() const { void GenerateCache() const {
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and
// 2) ToStackFrames and LineContents actually need it.
if (stack_frames_cache_) { if (stack_frames_cache_) {
return; return;
} }
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and
// 2) ToStackFrames and LineContents actually need it.
PyGILState_STATE state = PyGILState_Ensure(); PyGILState_STATE state = PyGILState_Ensure();
absl::flat_hash_map<std::pair<std::string, int>, StackFrame> m;
absl::flat_hash_set<std::string> f;
for (const std::pair<py::handle, py::handle>& p : *source_map_) { stack_frames_cache_ = captured_.ToStackFrames(
const py::tuple& key = py::cast<py::tuple>(p.first); [&](std::pair<const char*, int> p) { return StackTraceMapping(p); },
const py::tuple& value = py::cast<py::tuple>(p.second); [&](const char* f) { return StackTraceFiltering(f); });
m.emplace(std::make_pair(std::string(py::cast<py::str>(key[0])),
py::cast<ssize_t>(key[1])),
StackFrame{std::string(py::cast<py::str>(value[0])),
py::cast<py::int_>(value[1]),
std::string(py::cast<py::str>(value[2]))});
}
for (const py::handle& h : *filtered_filenames_) {
f.emplace(py::cast<py::str>(h));
}
stack_frames_cache_ = captured_.ToStackFrames(m, f);
stack_frames_cache_->pop_back(); // Drop last stack frame. stack_frames_cache_->pop_back(); // Drop last stack frame.
PyGILState_Release(state); PyGILState_Release(state);
} }
void GenerateLastFrameCache() const {
if (last_stack_frame_cache_) {
return;
}
PyGILState_STATE state = PyGILState_Ensure();
auto f = [&](const char* file_name) -> bool {
return StackTraceFiltering(file_name) ||
IsInternalFrameForFilename(file_name);
};
std::vector<StackFrame> last_frame = captured_.ToStackFrames(
[&](std::pair<const char*, int> p) { return StackTraceMapping(p); }, f,
/*reverse_traversal=*/true,
/*limit=*/1);
if (last_frame.empty()) {
last_stack_frame_cache_ = StackFrame{};
} else {
DCHECK(last_frame.size() == 1);
last_stack_frame_cache_ = last_frame[0];
}
PyGILState_Release(state);
}
StackTraceWrapper(StackTraceWrapper&&) = default; StackTraceWrapper(StackTraceWrapper&&) = default;
~StackTraceWrapper() override { ~StackTraceWrapper() override {
PyGILState_STATE state = PyGILState_Ensure(); PyGILState_STATE state = PyGILState_Ensure();
@ -225,19 +229,43 @@ class StackTraceWrapper : public AbstractStackTrace {
}); });
} }
static bool IsInternalFrame(const StackFrame& frame) { static bool IsInternalFrameForFilename(absl::string_view file_name) {
// Use a simple heuristic for now. // Use a simple heuristic for now.
// TODO(cheshire): Build a more sophisticated mechanism, rely on @tf.export. // TODO(cheshire): Build a more sophisticated mechanism, rely on @tf.export.
return absl::StrContains(frame.file_name, "tensorflow/python") && return absl::StrContains(file_name, "tensorflow/python") &&
!absl::StrContains(frame.file_name, "keras") && !absl::StrContains(file_name, "keras") &&
!absl::StrContains(frame.file_name, "test.py"); !absl::StrContains(file_name, "test.py");
}
absl::optional<StackFrame> StackTraceMapping(
std::pair<const char*, int> p) const {
if (source_map_->empty()) {
return absl::nullopt;
}
auto key = py::make_tuple(py::str(p.first), py::int_(p.second));
if (source_map_->contains(key)) {
const py::tuple& value = (*source_map_)[key];
return StackFrame{std::string(py::cast<py::str>(value[0])),
py::cast<py::int_>(value[1]),
std::string(py::cast<py::str>(value[2]))};
}
return absl::nullopt;
}
bool StackTraceFiltering(const char* file_name) const {
return filtered_filenames_->contains(file_name);
} }
mutable absl::optional<std::vector<StackFrame>> stack_frames_cache_;
StackTrace captured_; StackTrace captured_;
// Using optional to force destruction while we hold a GIL. // Using optional to force destruction while we hold a GIL.
absl::optional<py::dict> source_map_; absl::optional<py::dict> source_map_;
absl::optional<py::set> filtered_filenames_; absl::optional<py::set> filtered_filenames_;
mutable absl::optional<std::vector<StackFrame>> stack_frames_cache_;
mutable absl::optional<StackFrame> last_stack_frame_cache_;
}; };
} // namespace } // namespace
@ -336,12 +364,8 @@ PYBIND11_MODULE(_tf_stack, m) {
self.GenerateCache(); self.GenerateCache();
return py::str(self.ToString({})); return py::str(self.ToString({}));
}) })
.def("last_user_frame", [](const StackTraceWrapper& self) { .def("last_user_frame",
if (absl::optional<StackFrame> frame = self.LastUserFrame()) { [](const StackTraceWrapper& self) { return self.LastUserFrame(); });
return *frame;
}
return StackFrame{};
});
m.def( m.def(
"extract_stack_for_node", "extract_stack_for_node",