From f11e226a347371d00f64e08f9ea0a0e5c2e82b28 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Tue, 9 Feb 2021 10:07:32 -0800 Subject: [PATCH] [NFC][Stack Traces] Do not build a functor where we can use the translation map directly PiperOrigin-RevId: 356528885 Change-Id: I4482fd75dd10fee71b72b0537cbf96a4762dde92 --- tensorflow/core/util/BUILD | 1 - tensorflow/core/util/managed_stack_trace.h | 13 +++++-------- tensorflow/python/util/stack_trace.cc | 10 ++++++---- tensorflow/python/util/stack_trace.h | 10 +++++----- tensorflow/python/util/tf_stack.cc | 17 +++++++++++++++-- 5 files changed, 31 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index cab99a57f66..135c8fcf41a 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -482,7 +482,6 @@ cc_library( ], deps = [ "//tensorflow/core/platform:stack_frame", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], diff --git a/tensorflow/core/util/managed_stack_trace.h b/tensorflow/core/util/managed_stack_trace.h index 170c3d5640e..d7b4fe2f520 100644 --- a/tensorflow/core/util/managed_stack_trace.h +++ b/tensorflow/core/util/managed_stack_trace.h @@ -19,22 +19,20 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "absl/types/optional.h" #include "tensorflow/core/platform/stack_frame.h" namespace tensorflow { -using SourceLoc = std::tuple; - // Maps filename/line_no combination into a stack frame. -using SourceMap = absl::flat_hash_map; +using StackTraceMap = + std::function(std::pair)>; // Returns "true" on filenames which should be skipped. using StackTraceFilter = std::function; -using ToStackFramesFunctor = std::vector(int, const SourceMap&, +using ToStackFramesFunctor = std::vector(int, const StackTraceMap&, const StackTraceFilter&, bool, int); @@ -56,12 +54,11 @@ class ManagedStackTrace { : id_(id), to_stack_frames_(to_stack_frames) {} // Returns stack trace as a vector of `StackFrame`s. - std::vector ToStackFrames(const SourceMap& source_map = {}, + std::vector ToStackFrames(const StackTraceMap& mapper = {}, const StackTraceFilter& filtered = {}, bool reverse_traversal = false, int limit = -1) const { - return to_stack_frames_(id_, source_map, filtered, reverse_traversal, - limit); + return to_stack_frames_(id_, mapper, filtered, reverse_traversal, limit); } private: diff --git a/tensorflow/python/util/stack_trace.cc b/tensorflow/python/util/stack_trace.cc index 36c340d0913..8aed6691902 100644 --- a/tensorflow/python/util/stack_trace.cc +++ b/tensorflow/python/util/stack_trace.cc @@ -42,7 +42,7 @@ const char* GetPythonString(PyObject* o) { namespace tensorflow { std::vector StackTrace::ToStackFrames( - const SourceMap& source_map, const StackTraceFilter& filtered, + const StackTraceMap& mapper, const StackTraceFilter& filtered, bool reverse_traversal, int limit) const { DCheckPyGilStateForStackTrace(); std::vector result; @@ -61,9 +61,11 @@ std::vector StackTrace::ToStackFrames( continue; } - const auto it = source_map.find(std::make_tuple(file_name, line_number)); - if (it != source_map.end()) { - result.push_back(it->second); + absl::optional mapped = + mapper ? mapper(std::make_pair(file_name, line_number)) : absl::nullopt; + + if (mapped) { + result.push_back(*mapped); } else { result.emplace_back(StackFrame{file_name, line_number, GetPythonString(code_obj.first->co_name)}); diff --git a/tensorflow/python/util/stack_trace.h b/tensorflow/python/util/stack_trace.h index 14320d4f2f6..118b5130ab5 100644 --- a/tensorflow/python/util/stack_trace.h +++ b/tensorflow/python/util/stack_trace.h @@ -88,12 +88,12 @@ class StackTrace final { } // Returns a structured representation of the captured stack trace. - // `source_map` provides a custom mapping for translating stack frames, - // `filter` returns `true` for the stack frames which should be omitted. + // `mapper` provides a custom mapping for translating stack frames, `filter` + // returns `true` for the stack frames which should be omitted. // // `reverse_traversal` changes the traversal order of the stack trace, and // `limit` bounds the number of returned frames (after filtering). - std::vector ToStackFrames(const SourceMap& source_map = {}, + std::vector ToStackFrames(const StackTraceMap& mapper = {}, const StackTraceFilter& filtered = {}, bool reverse_traversal = false, int limit = -1) const; @@ -149,11 +149,11 @@ extern StackTraceManager* const stack_trace_manager; // Converts the ManagedStackTrace (identified by ID) to a vector of stack // frames. inline std::vector ManagedStackTraceToStackFrames( - int id, const SourceMap& source_map, const StackTraceFilter& filtered, + int id, const StackTraceMap& mapper, const StackTraceFilter& filtered, bool reverse_traversal, int limit) { PyGILState_STATE gstate = PyGILState_Ensure(); std::vector result = stack_trace_manager->Get(id)->ToStackFrames( - source_map, filtered, reverse_traversal, limit); + mapper, filtered, reverse_traversal, limit); PyGILState_Release(gstate); return result; } diff --git a/tensorflow/python/util/tf_stack.cc b/tensorflow/python/util/tf_stack.cc index 58cf973f54b..54b83794876 100644 --- a/tensorflow/python/util/tf_stack.cc +++ b/tensorflow/python/util/tf_stack.cc @@ -54,6 +54,10 @@ namespace { namespace py = pybind11; +using SourceLoc = std::tuple; + +using SourceMap = absl::flat_hash_map; + using StringSet = absl::flat_hash_set; // Python wrapper for a SourceMap. @@ -145,7 +149,8 @@ class StackTraceWrapper : public AbstractStackTrace { PyGILState_STATE state = PyGILState_Ensure(); stack_frames_cache_ = captured_.ToStackFrames( - *source_map_, [&](const char* f) { return StackTraceFiltering(f); }); + [&](std::pair p) { return StackTraceMapping(p); }, + [&](const char* f) { return StackTraceFiltering(f); }); stack_frames_cache_->pop_back(); // Drop last stack frame. PyGILState_Release(state); return *stack_frames_cache_; @@ -158,7 +163,7 @@ class StackTraceWrapper : public AbstractStackTrace { PyGILState_STATE state = PyGILState_Ensure(); std::vector last_frame = captured_.ToStackFrames( - *source_map_, + [&](std::pair p) { return StackTraceMapping(p); }, [&](const char* file_name) { return StackTraceFiltering(file_name) || IsInternalFrameForFilename(file_name); @@ -221,6 +226,14 @@ class StackTraceWrapper : public AbstractStackTrace { }); } + absl::optional StackTraceMapping(SourceLoc loc) const { + if (source_map_->contains(loc)) { + return source_map_->at(loc); + } + + return absl::nullopt; + } + bool StackTraceFiltering(const char* file_name) const { return filter_->contains(file_name); }