diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index 135c8fcf41a..cab99a57f66 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -482,6 +482,7 @@ cc_library( ], deps = [ "//tensorflow/core/platform:stack_frame", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], diff --git a/tensorflow/core/util/managed_stack_trace.h b/tensorflow/core/util/managed_stack_trace.h index d7b4fe2f520..170c3d5640e 100644 --- a/tensorflow/core/util/managed_stack_trace.h +++ b/tensorflow/core/util/managed_stack_trace.h @@ -19,20 +19,22 @@ limitations under the License. #include <string> #include <vector> +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "absl/types/optional.h" #include "tensorflow/core/platform/stack_frame.h" namespace tensorflow { +using SourceLoc = std::tuple<std::string, int>; + // Maps filename/line_no combination into a stack frame. -using StackTraceMap = - std::function<absl::optional<StackFrame>(std::pair<const char*, int>)>; +using SourceMap = absl::flat_hash_map<SourceLoc, StackFrame>; // Returns "true" on filenames which should be skipped. using StackTraceFilter = std::function<bool(const char*)>; -using ToStackFramesFunctor = std::vector<StackFrame>(int, const StackTraceMap&, +using ToStackFramesFunctor = std::vector<StackFrame>(int, const SourceMap&, const StackTraceFilter&, bool, int); @@ -54,11 +56,12 @@ class ManagedStackTrace { : id_(id), to_stack_frames_(to_stack_frames) {} // Returns stack trace as a vector of `StackFrame`s. - std::vector<StackFrame> ToStackFrames(const StackTraceMap& mapper = {}, + std::vector<StackFrame> ToStackFrames(const SourceMap& source_map = {}, const StackTraceFilter& filtered = {}, bool reverse_traversal = false, int limit = -1) const { - return to_stack_frames_(id_, mapper, filtered, reverse_traversal, limit); + return to_stack_frames_(id_, source_map, filtered, reverse_traversal, + limit); } private: diff --git a/tensorflow/python/util/stack_trace.cc b/tensorflow/python/util/stack_trace.cc index 8aed6691902..36c340d0913 100644 --- a/tensorflow/python/util/stack_trace.cc +++ b/tensorflow/python/util/stack_trace.cc @@ -42,7 +42,7 @@ const char* GetPythonString(PyObject* o) { namespace tensorflow { std::vector<StackFrame> StackTrace::ToStackFrames( - const StackTraceMap& mapper, const StackTraceFilter& filtered, + const SourceMap& source_map, const StackTraceFilter& filtered, bool reverse_traversal, int limit) const { DCheckPyGilStateForStackTrace(); std::vector<StackFrame> result; @@ -61,11 +61,9 @@ std::vector<StackFrame> StackTrace::ToStackFrames( continue; } - absl::optional<StackFrame> mapped = - mapper ? mapper(std::make_pair(file_name, line_number)) : absl::nullopt; - - if (mapped) { - result.push_back(*mapped); + const auto it = source_map.find(std::make_tuple(file_name, line_number)); + if (it != source_map.end()) { + result.push_back(it->second); } else { result.emplace_back(StackFrame{file_name, line_number, GetPythonString(code_obj.first->co_name)}); diff --git a/tensorflow/python/util/stack_trace.h b/tensorflow/python/util/stack_trace.h index 118b5130ab5..14320d4f2f6 100644 --- a/tensorflow/python/util/stack_trace.h +++ b/tensorflow/python/util/stack_trace.h @@ -88,12 +88,12 @@ class StackTrace final { } // Returns a structured representation of the captured stack trace. - // `mapper` provides a custom mapping for translating stack frames, `filter` - // returns `true` for the stack frames which should be omitted. + // `source_map` provides a custom mapping for translating stack frames, + // `filter` returns `true` for the stack frames which should be omitted. // // `reverse_traversal` changes the traversal order of the stack trace, and // `limit` bounds the number of returned frames (after filtering). - std::vector<StackFrame> ToStackFrames(const StackTraceMap& mapper = {}, + std::vector<StackFrame> ToStackFrames(const SourceMap& source_map = {}, const StackTraceFilter& filtered = {}, bool reverse_traversal = false, int limit = -1) const; @@ -149,11 +149,11 @@ extern StackTraceManager* const stack_trace_manager; // Converts the ManagedStackTrace (identified by ID) to a vector of stack // frames. inline std::vector<StackFrame> ManagedStackTraceToStackFrames( - int id, const StackTraceMap& mapper, const StackTraceFilter& filtered, + int id, const SourceMap& source_map, const StackTraceFilter& filtered, bool reverse_traversal, int limit) { PyGILState_STATE gstate = PyGILState_Ensure(); std::vector<StackFrame> result = stack_trace_manager->Get(id)->ToStackFrames( - mapper, filtered, reverse_traversal, limit); + source_map, filtered, reverse_traversal, limit); PyGILState_Release(gstate); return result; } diff --git a/tensorflow/python/util/tf_stack.cc b/tensorflow/python/util/tf_stack.cc index 54b83794876..58cf973f54b 100644 --- a/tensorflow/python/util/tf_stack.cc +++ b/tensorflow/python/util/tf_stack.cc @@ -54,10 +54,6 @@ namespace { namespace py = pybind11; -using SourceLoc = std::tuple<std::string, int>; - -using SourceMap = absl::flat_hash_map<SourceLoc, StackFrame>; - using StringSet = absl::flat_hash_set<std::string>; // Python wrapper for a SourceMap. @@ -149,8 +145,7 @@ class StackTraceWrapper : public AbstractStackTrace { PyGILState_STATE state = PyGILState_Ensure(); stack_frames_cache_ = captured_.ToStackFrames( - [&](std::pair<const char*, int> p) { return StackTraceMapping(p); }, - [&](const char* f) { return StackTraceFiltering(f); }); + *source_map_, [&](const char* f) { return StackTraceFiltering(f); }); stack_frames_cache_->pop_back(); // Drop last stack frame. PyGILState_Release(state); return *stack_frames_cache_; @@ -163,7 +158,7 @@ class StackTraceWrapper : public AbstractStackTrace { PyGILState_STATE state = PyGILState_Ensure(); std::vector<StackFrame> last_frame = captured_.ToStackFrames( - [&](std::pair<const char*, int> p) { return StackTraceMapping(p); }, + *source_map_, [&](const char* file_name) { return StackTraceFiltering(file_name) || IsInternalFrameForFilename(file_name); @@ -226,14 +221,6 @@ class StackTraceWrapper : public AbstractStackTrace { }); } - absl::optional<StackFrame> StackTraceMapping(SourceLoc loc) const { - if (source_map_->contains(loc)) { - return source_map_->at(loc); - } - - return absl::nullopt; - } - bool StackTraceFiltering(const char* file_name) const { return filter_->contains(file_name); }