Fix AbstractStackTrace::GetLastFrame on Windows

+ Minor tf_stack cleanup

PiperOrigin-RevId: 348537476
Change-Id: Id551360886c648ab2580592eb1268e51a300a9c5
This commit is contained in:
George Karpenkov 2020-12-21 15:28:13 -08:00 committed by TensorFlower Gardener
parent c81b0bf3cc
commit 8ddb063d4d
2 changed files with 37 additions and 52 deletions

View File

@ -319,7 +319,6 @@ tf_py_test(
name = "tf_stack_test",
srcs = ["tf_stack_test.py"],
python_version = "PY3",
tags = ["no_windows"], # TODO(b/175726972)
deps = [
":tf_export",
":tf_stack",

View File

@ -140,19 +140,50 @@ class StackTraceWrapper : public AbstractStackTrace {
}
absl::Span<StackFrame const> ToFrames() const override {
GenerateCache();
if (stack_frames_cache_) {
return *stack_frames_cache_;
}
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe,
// and 2) ToStackFrames and LineContents actually need it.
PyGILState_STATE state = PyGILState_Ensure();
stack_frames_cache_ = captured_.ToStackFrames(
[&](std::pair<const char*, int> p) { return StackTraceMapping(p); },
[&](const char* f) { return StackTraceFiltering(f); });
stack_frames_cache_->pop_back(); // Drop last stack frame.
PyGILState_Release(state);
return *stack_frames_cache_;
}
StackFrame LastUserFrame() const override {
GenerateLastFrameCache();
if (last_stack_frame_cache_) {
return *last_stack_frame_cache_;
}
PyGILState_STATE state = PyGILState_Ensure();
std::vector<StackFrame> last_frame = captured_.ToStackFrames(
[&](std::pair<const char*, int> p) { return StackTraceMapping(p); },
[&](const char* file_name) {
return StackTraceFiltering(file_name) ||
IsInternalFrameForFilename(file_name);
},
/*reverse_traversal=*/true,
/*limit=*/1);
if (last_frame.empty()) {
last_stack_frame_cache_ = StackFrame{"", -1, ""};
} else {
DCHECK_EQ(last_frame.size(), 1);
last_stack_frame_cache_ = last_frame[0];
}
PyGILState_Release(state);
return *last_stack_frame_cache_;
}
std::string ToString(const TracePrintingOptions& opts) const override {
GenerateCache();
std::vector<std::string> files_to_find_prefix;
for (const StackFrame& frame : *stack_frames_cache_) {
for (const StackFrame& frame : ToFrames()) {
if (!absl::StrContains(frame.file_name, kFilenameToIgnorePrefix)) {
files_to_find_prefix.push_back(frame.file_name);
}
@ -175,50 +206,6 @@ class StackTraceWrapper : public AbstractStackTrace {
return ToStringHelper(filtered_frames, opts, shared_prefix_size);
}
bool IsCacheGenerated() const { return stack_frames_cache_.has_value(); }
void GenerateCache() const {
// TODO(mdan): We don't really need random access; this can be removed.
if (stack_frames_cache_) {
return;
}
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and
// 2) ToStackFrames and LineContents actually need it.
PyGILState_STATE state = PyGILState_Ensure();
stack_frames_cache_ = captured_.ToStackFrames(
[&](std::pair<const char*, int> p) { return StackTraceMapping(p); },
[&](const char* f) { return StackTraceFiltering(f); });
stack_frames_cache_->pop_back(); // Drop last stack frame.
PyGILState_Release(state);
}
void GenerateLastFrameCache() const {
if (last_stack_frame_cache_) {
return;
}
PyGILState_STATE state = PyGILState_Ensure();
auto f = [&](const char* file_name) -> bool {
return StackTraceFiltering(file_name) ||
IsInternalFrameForFilename(file_name);
};
std::vector<StackFrame> last_frame = captured_.ToStackFrames(
[&](std::pair<const char*, int> p) { return StackTraceMapping(p); }, f,
/*reverse_traversal=*/true,
/*limit=*/1);
if (last_frame.empty()) {
last_stack_frame_cache_ = StackFrame{};
} else {
DCHECK(last_frame.size() == 1);
last_stack_frame_cache_ = last_frame[0];
}
PyGILState_Release(state);
}
StackTraceWrapper(StackTraceWrapper&&) = default;
~StackTraceWrapper() override {
PyGILState_STATE state = PyGILState_Ensure();
@ -242,7 +229,8 @@ class StackTraceWrapper : public AbstractStackTrace {
static bool IsInternalFrameForFilename(absl::string_view file_name) {
// Use a simple heuristic for now.
// TODO(cheshire): Build a more sophisticated mechanism, rely on @tf.export.
return absl::StrContains(file_name, "tensorflow/python") &&
return (absl::StrContains(file_name, "tensorflow/python") ||
absl::StrContains(file_name, "tensorflow\\python")) &&
!absl::StrContains(file_name, "keras") &&
!absl::StrContains(file_name, "test.py");
}
@ -392,12 +380,10 @@ PYBIND11_MODULE(_tf_stack, m) {
})
.def("__hash__",
[](const StackTraceWrapper& self) {
self.GenerateCache();
return py::hash(py::str(self.ToString({})));
})
.def("__repr__",
[](const StackTraceWrapper& self) {
self.GenerateCache();
return py::str(self.ToString({}));
})
.def("last_user_frame",