Rollback of rollback of "Move the ownership of Python stack traces to Graph object, make them accessible from C++ API"

Move the ownership of Python stack traces to Graph object, make them accessible from C++ API

Expose stack printing options, implement common prefix filtering.

PiperOrigin-RevId: 345579757
Change-Id: I88673891e893b1f71a5b039e44f0bc30f190c18a
This commit is contained in:
George Karpenkov 2020-12-03 18:33:32 -08:00 committed by TensorFlower Gardener
parent 7cde7c6939
commit 1943e58d29
15 changed files with 310 additions and 121 deletions

View File

@ -155,7 +155,7 @@ TF_Function* TF_GraphToFunctionWithControlOutputs(
int ncontrol_outputs, const TF_Operation* const* control_outputs,
const char* const* control_output_names, const TF_FunctionOptions* opts,
const char* description, TF_Status* status) {
tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu));
tensorflow::mutex_lock l(fn_body->mu);
// Process inputs.
std::vector<tensorflow::OutputTensor> input_tensors;
@ -213,6 +213,7 @@ TF_Function* TF_GraphToFunctionWithControlOutputs(
TF_DeleteFunction(tf_function);
return nullptr;
}
tf_function->graph_with_debug_info = &fn_body->graph;
return tf_function;
}

View File

@ -70,7 +70,7 @@ struct TF_Library {
struct TF_Graph {
TF_Graph();
tensorflow::mutex mu;
mutable tensorflow::mutex mu;
tensorflow::Graph graph TF_GUARDED_BY(mu);
// Runs shape inference.
@ -157,6 +157,9 @@ struct TF_DeviceList {
struct TF_Function {
tensorflow::FunctionDef fdef;
// Graph with nodes with debug stack traces.
const tensorflow::Graph* graph_with_debug_info = nullptr;
};
struct TF_ApiDefMap {

View File

@ -749,7 +749,8 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function->fdef);
status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithDebugInfo(
function->fdef, function->graph_with_debug_info);
}
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor.h"
@ -110,6 +111,12 @@ class ImmediateExecutionContext : public AbstractContext {
// already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Same as `AddFunctionDef`, and additionally saves a pointer to the Graph
// which has nodes containing stack traces for the nodes in `fdef`. Assumes
// `graph` is alive while the function is alive.
virtual Status AddFunctionDefWithDebugInfo(const FunctionDef& fdef,
const Graph* graph) = 0;
// Find and return a added function by its name.
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;

View File

@ -705,6 +705,12 @@ Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
return Status::OK();
}
Status EagerContext::AddFunctionDefWithDebugInfo(
const FunctionDef& fdef, const Graph* graph_with_debug_info) {
return AddFunctionDef(fdef, FunctionDefLibrary(),
/* add_to_local_only=*/false, graph_with_debug_info);
}
Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
return AddFunctionDef(fdef, FunctionDefLibrary(),
/* add_to_local_only=*/false);
@ -712,7 +718,8 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
const FunctionDefLibrary& library,
const bool add_to_local_only) {
const bool add_to_local_only,
const Graph* graph_with_debug_info) {
bool is_first_ref = false;
{
mutex_lock l(cache_mu_);
@ -746,7 +753,8 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
is_first_ref = registered_function->RefCountIsOne();
}
if (is_first_ref) {
TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef));
TF_RETURN_IF_ERROR(
func_lib_def_.AddFunctionDef(fdef, graph_with_debug_info));
TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
if (!add_to_local_only) {
return MaybeRegisterFunctionRemotely(fdef);

View File

@ -233,13 +233,18 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
// Add the given `fdef` to the local FunctionLibraryDefinition. And add an
// entry to the KernelAndDevice cache for it if it's not exist.
Status AddFunctionDef(const FunctionDef& fdef) override;
Status AddFunctionDefWithDebugInfo(
const FunctionDef& fdef, const Graph* graph_with_debug_info) override;
// `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add
// it to the local FunctionLibraryDefinition as well, but no need to add it
// to the KernelAndDevice cache since they won't be executed as
// KernelAndDevices.
Status AddFunctionDef(const FunctionDef& fdef,
const FunctionDefLibrary& library,
const bool add_to_local_only = false);
const bool add_to_local_only = false,
const Graph* graph_with_debug_info = nullptr);
const FunctionDef* GetFunctionDef(const string& function_name);

View File

@ -1173,12 +1173,14 @@ Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
}
FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
FunctionDefAndOpRegistration(const FunctionDef& fdef_in,
const Graph* graph_with_debug_info)
: fdef(fdef_in),
// Exact shape inference for functions is handled by ShapeRefiner.
// Here we pass a dummy shape inference function for legacy code paths.
op_registration_data(fdef.signature(), shape_inference::UnknownShape,
true /* is_function */) {}
true /* is_function */),
graph_with_debug_info(graph_with_debug_info) {}
FunctionLibraryDefinition::FunctionLibraryDefinition(
const FunctionLibraryDefinition& other)
@ -1230,14 +1232,15 @@ FunctionLibraryDefinition::FindHelper(const string& func) const {
}
}
Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
Status FunctionLibraryDefinition::AddFunctionDef(
const FunctionDef& fdef, const Graph* graph_with_debug_info) {
mutex_lock l(mu_);
bool added;
return AddFunctionDefHelper(fdef, &added);
return AddFunctionDefHelper(fdef, graph_with_debug_info, &added);
}
Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
bool* added) {
Status FunctionLibraryDefinition::AddFunctionDefHelper(
const FunctionDef& fdef, const Graph* graph_with_debug_info, bool* added) {
*added = false;
std::shared_ptr<FunctionDefAndOpRegistration>& entry =
function_defs_[fdef.signature().name()];
@ -1257,7 +1260,8 @@ Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
"Cannot add function '", fdef.signature().name(),
"' because an op with the same name already exists.");
}
entry = std::make_shared<FunctionDefAndOpRegistration>(fdef);
entry = std::make_shared<FunctionDefAndOpRegistration>(fdef,
graph_with_debug_info);
*added = true;
return Status::OK();
}
@ -1399,7 +1403,7 @@ Status FunctionLibraryDefinition::AddLibrary(
Status s;
bool added;
for (const FunctionDef& fdef : lib_def.function()) {
s = AddFunctionDefHelper(fdef, &added);
s = AddFunctionDefHelper(fdef, /*graph_with_debug_info=*/nullptr, &added);
if (!s.ok()) {
Remove(funcs, funcs_with_grads);
return s;
@ -1426,7 +1430,8 @@ Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
mutex_lock l(mu_);
bool added;
TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, &added));
TF_RETURN_IF_ERROR(
AddFunctionDefHelper(fdef, /*graph_with_debug_info=*/nullptr, &added));
return Status::OK();
}

View File

@ -330,6 +330,24 @@ class FunctionCallFrame : public CallFrameInterface {
TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame);
};
// Language agnostic stack traces.
class AbstractStackTrace {
public:
struct TracePrintingOptions {
// Show inline the contents of each stack line.
bool show_line_contents = false;
// Drop the common largest prefix of all filenames in stack frames.
bool filter_common_prefix = false;
};
virtual ~AbstractStackTrace() {}
// The returned span is alive as long as the AbstractStackTrace is alive.
virtual absl::Span<StackFrame const> ToFrames() const = 0;
virtual std::string ToString(const TracePrintingOptions& opts) const = 0;
};
// Helper to maintain a map between function names in a given
// FunctionDefLibrary and function definitions.
//
@ -375,7 +393,12 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
// If 'fdef' is successfully added to the library, it will be accessible
// from 'LookUp' and included in the proto returned by 'ToProto'.
// This operation is atomic.
Status AddFunctionDef(const FunctionDef& fdef) TF_LOCKS_EXCLUDED(mu_);
//
// Associates `graph` with a function `func_name`. Lifetime assumption:
// `graph` has to outlive all instantiated graphs.
Status AddFunctionDef(const FunctionDef& fdef,
const Graph* graph_with_debug_info = nullptr)
TF_LOCKS_EXCLUDED(mu_);
// Adds gradient definition 'grad' to this function library.
// This is a no-op if 'grad' already exists in this function library.
@ -484,14 +507,25 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
const FunctionLibraryDefinition& other)
TF_LOCKS_EXCLUDED(mu_);
// Returns graph with debug stack traces for the given function, or `nullptr`
// if none found.
const Graph* GetGraphWithDebugInfo(const std::string& func_name) const {
tf_shared_lock l(mu_);
std::shared_ptr<FunctionDefAndOpRegistration> entry = FindHelper(func_name);
return entry ? entry->graph_with_debug_info : nullptr;
}
private:
// Shape inference for functions is handled separately by ShapeRefiner.
struct FunctionDefAndOpRegistration {
explicit FunctionDefAndOpRegistration(const FunctionDef& fdef_in);
explicit FunctionDefAndOpRegistration(
const FunctionDef& fdef_in,
const Graph* graph_with_debug_info = nullptr);
const FunctionDef fdef;
const OpRegistrationData op_registration_data;
const Graph* graph_with_debug_info;
};
std::shared_ptr<FunctionDefAndOpRegistration> FindHelper(
@ -504,7 +538,8 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
// Same as AddFunctionDef/AddGradientDef except these methods set
// `added` to true if the `fdef`/`grad` were actually added to this.
Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added)
Status AddFunctionDefHelper(const FunctionDef& fdef,
const Graph* graph_with_debug_info, bool* added)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status AddGradientDefHelper(const GradientDef& grad, bool* added)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);

View File

@ -240,10 +240,25 @@ class Node {
std::shared_ptr<NodeProperties> properties() const { return props_; }
// Sets the stack trace for the node. Assumes that getting and setting the
// stack trace for a given node will not race.
void SetStackTrace(const std::shared_ptr<AbstractStackTrace>& stack_trace) {
stack_trace_ = stack_trace;
}
// Get the stack trace for when the node was instantiated.
const std::shared_ptr<AbstractStackTrace>& GetStackTrace() const {
return stack_trace_;
}
private:
friend class Graph;
Node();
// Stack trace for the user code for node instantiation. Can be shared across
// multiple nodes (e.g. when inlining).
std::shared_ptr<AbstractStackTrace> stack_trace_;
// Releases memory from props_, in addition to restoring *this to its
// uninitialized state.
void Clear();

View File

@ -17,9 +17,6 @@ load("//tensorflow:tensorflow.bzl", "tf_monitoring_python_deps")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "pybind_extension")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "pywrap_tensorflow_macro")
@ -5780,17 +5777,28 @@ py_strict_library(
],
)
pybind_extension(
tf_python_pybind_extension(
name = "_tf_stack",
srcs = ["util/tf_stack.cc"],
hdrs = [
"//tensorflow/c:headers",
"//tensorflow/c/eager:headers",
],
# TODO(b/138203821): change to "util._tf_stack" once the bug is fixed.
module_name = "_tf_stack",
deps = [
":stack_trace",
"//tensorflow/c:pywrap_required_hdrs",
"//tensorflow/core/common_runtime:core_cpu_headers_lib",
"//tensorflow/core/framework:pywrap_required_hdrs",
"//tensorflow/core/platform:path",
"//third_party/python_runtime:headers", # buildcleaner: keep
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@pybind11",
],
@ -5817,6 +5825,8 @@ cc_library(
"//tensorflow/core/util:abstract_stack_trace",
"//third_party/python_runtime:headers", # buildcleaner: keep
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/types:optional",
],

View File

@ -1991,7 +1991,6 @@ class Operation(object):
# pylint: disable=protected-access
self._original_op = original_op
self._traceback = tf_stack.extract_stack()
# List of _UserDevSpecs holding code location of device context manager
# invocations and the users original argument to them.
@ -2019,6 +2018,9 @@ class Operation(object):
self._c_op = _create_c_op(self._graph, node_def, inputs,
control_input_ops, op_def)
name = compat.as_str(node_def.name)
self._traceback = tf_stack.extract_stack_for_node(self._c_op)
# pylint: enable=protected-access
self._is_stateful = op_def.is_stateful

View File

@ -40,7 +40,8 @@ const char* GetPythonString(PyObject* o) {
namespace tensorflow {
std::vector<StackFrame> StackTrace::ToStackFrames(
const StackTraceMapper& mapper, const StackTraceFilter& filtered) const {
const StackTraceMap& mapper, const StackTraceFilter& filtered) const {
DCheckPyGilStateForStackTrace();
std::vector<StackFrame> result;
result.reserve(code_objs_.size());
@ -49,13 +50,14 @@ std::vector<StackFrame> StackTrace::ToStackFrames(
const int line_number =
PyCode_Addr2Line(code_objs_[i], last_instructions_[i]);
if (!result.empty() && filtered && filtered(file_name)) {
if (!result.empty() && filtered.count(file_name)) {
continue; // Never filter the innermost frame.
}
if (absl::optional<StackFrame> mapped =
mapper ? mapper(file_name, line_number) : absl::nullopt) {
result.push_back(*mapped);
auto it = mapper.find(std::make_pair(file_name, line_number));
if (it != mapper.end()) {
result.push_back(it->second);
} else {
result.emplace_back(StackFrame{file_name, line_number,
GetPythonString(code_objs_[i]->co_name)});

View File

@ -26,6 +26,8 @@ limitations under the License.
#include "absl/base/attributes.h"
#include "absl/base/optimization.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/types/optional.h"
#include "tensorflow/core/util/managed_stack_trace.h"
@ -41,11 +43,11 @@ inline void DCheckPyGilStateForStackTrace() {
}
// Maps filename/line_no combination into a stack frame.
using StackTraceMapper =
std::function<absl::optional<StackFrame>(std::string, int)>;
using StackTraceMap =
absl::flat_hash_map<std::pair<std::string, int>, StackFrame>;
// Returns "true" for filenames which should be skipped.
using StackTraceFilter = std::function<bool(std::string)>;
// Contains filenames which should be skipped.
using StackTraceFilter = absl::flat_hash_set<std::string>;
// A class for capturing Python stack trace.
class StackTrace final {
@ -92,10 +94,8 @@ class StackTrace final {
ABSL_ATTRIBUTE_HOT
StackTrace& operator=(StackTrace&& other) {
Clear();
code_objs_ = other.code_objs_;
last_instructions_ = other.last_instructions_;
other.code_objs_ = {};
std::swap(code_objs_, other.code_objs_);
std::swap(last_instructions_, other.last_instructions_);
return *this;
}
@ -104,20 +104,22 @@ class StackTrace final {
// returns `true` for the stack frames which should be omitted, and if
// `drop_last` is set, the last stack frame is dropped.
std::vector<StackFrame> ToStackFrames(
const StackTraceMapper& mapper = {},
const StackTraceMap& mapper = {},
const StackTraceFilter& filtered = {}) const;
private:
absl::InlinedVector<PyCodeObject*, kStackTraceInitialSize> code_objs_;
absl::InlinedVector<int, kStackTraceInitialSize> last_instructions_;
// Python GIL must be acquired beforehand.
ABSL_ATTRIBUTE_HOT
void Clear() {
DCheckPyGilStateForStackTrace();
if (!code_objs_.empty()) DCheckPyGilStateForStackTrace();
for (PyCodeObject* obj : code_objs_) Py_DECREF(obj);
code_objs_.clear();
last_instructions_.clear();
}
private:
absl::InlinedVector<PyCodeObject*, kStackTraceInitialSize> code_objs_;
absl::InlinedVector<int, kStackTraceInitialSize> last_instructions_;
StackTrace(const StackTrace&) = delete;
StackTrace& operator=(const StackTrace&) = delete;
};

View File

@ -13,6 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// We extract stack traces in Python using the logic in tf_stack.cc, which
// stores a list of PyCodeObject*. Such stack trace extraction is really fast.
//
// We store the retrieved stack trace within the Node object directly. Then
// whenever the graph is instantiated/copies, we copy the stack trace with it.
// Since the graph instantiation goes through the protobuf roundtrip, we store
// the original Graph with stack traces attached in FunctionLibraryDefinition.
#include <Python.h>
#include <frameobject.h>
@ -20,12 +28,18 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/hash/hash.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/python/util/stack_trace.h"
struct StackFrame; // Forward declaration.
@ -40,76 +54,155 @@ namespace {
namespace py = pybind11;
py::object LineContents(const StackFrame& frame) {
// Returns contents of the line corresponding to the given frame.
//
// Precondition: must be holding Python GIL.
py::str LineContents(const StackFrame& frame) {
DCheckPyGilStateForStackTrace();
static const auto* linecache =
new py::module(py::module::import("linecache"));
const auto& checkcache = linecache->attr("checkcache");
const auto& getline = linecache->attr("getline");
checkcache(py::str(frame.file_name));
const auto& code = py::cast<py::str>(
return py::cast<py::str>(
getline(py::str(frame.file_name), py::int_(frame.line_number))
.attr("strip")());
ssize_t size = 0;
#if PY_MAJOR_VERSION == 3
if (PyUnicode_AsUTF8AndSize(code.ptr(), &size) == nullptr) {
throw py::error_already_set();
}
// Ignores the frames containing this substring for common prefix calculation.
static const char* kFilenameToIgnorePrefix = "<embedded";
// Converts the given stack frame to string, according to options defined in
// `opts`.
std::string StackFrameToString(
const StackFrame& frame,
const AbstractStackTrace::TracePrintingOptions& opts,
int shared_prefix_size = 0) {
std::string out = absl::StrFormat(
"File \"%s\", line %d, in %s",
absl::StrContains(frame.file_name, kFilenameToIgnorePrefix)
? frame.file_name
: frame.file_name.substr(shared_prefix_size),
frame.line_number, frame.function_name);
if (opts.show_line_contents) {
PyGILState_STATE state = PyGILState_Ensure();
std::string line_contents = std::string(LineContents(frame));
PyGILState_Release(state);
if (!line_contents.empty()) {
absl::StrAppend(&out, "\n ", line_contents);
}
#else
size = PyString_Size(code.ptr());
#endif
return size > 0 ? static_cast<py::object>(code) : py::none();
}
return out;
}
std::string StackFrameToString(const StackFrame& frame) {
return py::str("<FrameSummary file {}, line {} in {}>")
.format(py::str(frame.file_name), py::int_(frame.line_number),
py::str(frame.function_name));
}
class StackTraceWrapper {
class StackTraceWrapper : public AbstractStackTrace {
public:
StackTraceWrapper(StackTrace&& captured,
const StackTraceMapper& stack_trace_mapper,
const StackTraceFilter& stack_trace_filter)
StackTraceWrapper(StackTrace&& captured, const py::dict& source_map,
const py::set& filtered_filenames)
: captured_(std::move(captured)),
stack_trace_mapper_(stack_trace_mapper),
stack_trace_filter_(stack_trace_filter) {}
source_map_(source_map),
filtered_filenames_(filtered_filenames) {}
explicit StackTraceWrapper(absl::Span<StackFrame const> stack_frames)
: stack_frames_cache_(std::vector<StackFrame>(stack_frames.begin(),
stack_frames.end())) {}
absl::Span<StackFrame const> ToFrames() const {
static StackTraceWrapper ExtractStack(const py::object& limit,
const py::list& mappers,
const py::list& filters) {
// In Python 3.X ``traceback.extract_stack`` allows ``limit`` to
// either be None or -1.
int casted_limit = limit.is_none() ? -1 : py::cast<ssize_t>(limit);
// Raise limit by one since we are dropping the last frame.
if (casted_limit != -1) casted_limit++;
const py::dict& source_map =
mappers.empty()
? py::dict()
: mappers[mappers.size() - 1].attr("get_effective_source_map")();
const py::set& filtered_filenames =
filters.empty()
? py::set()
: filters[filters.size() - 1].attr("get_filtered_filenames")();
return StackTraceWrapper{StackTrace::Capture(casted_limit), source_map,
filtered_filenames};
}
absl::Span<StackFrame const> ToFrames() const override {
GenerateCache();
return *stack_frames_cache_;
}
std::string ToString() const {
std::string ToString(const TracePrintingOptions& opts) const override {
GenerateCache();
return absl::StrJoin(*stack_frames_cache_, "\n",
[&](std::string* out, const StackFrame& frame) {
absl::StrAppend(out, StackFrameToString(frame));
});
std::vector<std::string> files_to_find_prefix;
for (const StackFrame& frame : *stack_frames_cache_) {
if (!absl::StrContains(frame.file_name, kFilenameToIgnorePrefix)) {
files_to_find_prefix.push_back(frame.file_name);
}
}
int shared_prefix_size =
opts.filter_common_prefix
? io::CommonPathPrefix(files_to_find_prefix).size()
: 0;
return absl::StrJoin(
*stack_frames_cache_, "\n",
[&](std::string* out, const StackFrame& frame) {
absl::StrAppend(out,
StackFrameToString(frame, opts, shared_prefix_size));
});
}
bool IsCacheGenerated() const { return stack_frames_cache_.has_value(); }
void GenerateCache() const {
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and
// 2) ToStackFrames and LineContents actually need it.
PyGILState_STATE state = PyGILState_Ensure();
if (stack_frames_cache_) {
return;
}
stack_frames_cache_ =
captured_.ToStackFrames(stack_trace_mapper_, stack_trace_filter_);
absl::flat_hash_map<std::pair<std::string, int>, StackFrame> m;
absl::flat_hash_set<std::string> f;
for (const std::pair<py::handle, py::handle>& p : *source_map_) {
const py::tuple& key = py::cast<py::tuple>(p.first);
const py::tuple& value = py::cast<py::tuple>(p.second);
m.emplace(std::make_pair(std::string(py::cast<py::str>(key[0])),
py::cast<ssize_t>(key[1])),
StackFrame{std::string(py::cast<py::str>(value[0])),
py::cast<py::int_>(value[1]),
std::string(py::cast<py::str>(value[2]))});
}
for (const py::handle& h : *filtered_filenames_) {
f.emplace(py::cast<py::str>(h));
}
stack_frames_cache_ = captured_.ToStackFrames(m, f);
stack_frames_cache_->pop_back(); // Drop last stack frame.
PyGILState_Release(state);
}
StackTraceWrapper(StackTraceWrapper&&) = default;
~StackTraceWrapper() override {
PyGILState_STATE state = PyGILState_Ensure();
captured_.Clear();
source_map_.reset();
filtered_filenames_.reset();
PyGILState_Release(state);
}
private:
mutable absl::optional<std::vector<StackFrame>> stack_frames_cache_;
StackTrace captured_;
// TODO(cheshire): store those as C++ datastructures instead.
StackTraceMapper stack_trace_mapper_;
StackTraceFilter stack_trace_filter_;
// Using optional to force destruction while we hold a GIL.
absl::optional<py::dict> source_map_;
absl::optional<py::set> filtered_filenames_;
};
} // namespace
@ -126,8 +219,7 @@ PYBIND11_MODULE(_tf_stack, m) {
"name",
[](const StackFrame& self) { return py::str(self.function_name); })
.def_property_readonly(
"line",
[](const StackFrame& self) { return py::str(LineContents(self)); })
"line", [](const StackFrame& self) { return LineContents(self); })
// For compatibility with the traceback module.
.def("__eq__", &StackFrame::operator==)
@ -153,7 +245,7 @@ PYBIND11_MODULE(_tf_stack, m) {
);
})
.def("__repr__",
[](const StackFrame& self) { return StackFrameToString(self); })
[](const StackFrame& self) { return StackFrameToString(self, {}); })
.def("__len__", [](const StackFrame&) { return 4; });
py::class_<StackTraceWrapper>(m, "StackTraceWrapper", py::module_local(true))
@ -201,55 +293,32 @@ PYBIND11_MODULE(_tf_stack, m) {
})
.def("__hash__",
[](const StackTraceWrapper& self) {
return py::hash(py::str(self.ToString()));
self.GenerateCache();
return py::hash(py::str(self.ToString({})));
})
.def("__repr__", [](const StackTraceWrapper& self) {
self.GenerateCache();
return py::str(self.ToString());
return py::str(self.ToString({}));
});
m.def(
"extract_stack_for_node",
[](const py::object& limit, const py::list& mappers,
const py::list& filters,
TF_Operation* op) -> const AbstractStackTrace& {
Node* node = reinterpret_cast<Node*>(op);
DCHECK(!node->GetStackTrace()) << "Should not reset the stack trace";
node->SetStackTrace(std::make_shared<StackTraceWrapper>(
StackTraceWrapper::ExtractStack(limit, mappers, filters)));
return *node->GetStackTrace();
},
py::return_value_policy::reference);
m.def(
"extract_stack",
[](const py::object& limit, const py::list& mappers,
const py::list& filters) {
// In Python 3.X ``traceback.extract_stack`` allows ``limit`` to
// either be None or -1.
int casted_limit = limit.is_none() ? -1 : py::cast<ssize_t>(limit);
// Raise limit by one since we are dropping the last frame.
if (casted_limit != -1) casted_limit++;
const py::dict& source_map = mappers.empty()
? py::dict()
: mappers[mappers.size() - 1].attr(
"get_effective_source_map")();
const py::set& filtered_filenames =
filters.empty()
? py::set()
: filters[filters.size() - 1].attr("get_filtered_filenames")();
auto mapper = [=](std::string filename,
int line_no) -> absl::optional<StackFrame> {
if (source_map.empty()) {
return absl::nullopt;
}
const auto& key =
py::make_tuple(py::str(filename), py::int_(line_no));
if (source_map.contains(key)) {
const py::tuple& mapped = source_map[key];
return StackFrame{std::string(py::cast<py::str>(mapped[0])),
py::cast<py::int_>(mapped[1]),
std::string(py::cast<py::str>(mapped[2]))};
}
return absl::nullopt;
};
auto filter = [=](std::string filename) -> bool {
return filtered_filenames.contains(py::str(filename));
};
return StackTraceWrapper{StackTrace::Capture(casted_limit), mapper,
filter};
return StackTraceWrapper::ExtractStack(limit, mappers, filters);
},
py::return_value_policy::move);
}

View File

@ -141,15 +141,39 @@ def extract_stack(limit=-1):
limit: A limit on the number of frames to return.
Returns:
A sequence of StackFrame objects (filename, lineno, name, line)
corresponding to the call stack of the current thread.
An object wrapping the sequence of StackFrame objects (filename, lineno,
name, line) corresponding to the call stack of the current thread. The
returned object can be indexed as a Python list.
"""
# N.B ExtractStack in tf_stack.cc will drop this frame prior to
# traversing the stack.
# TODO(cheshire): Remove this function, use extract_stack_for_node or Python
# traceback module.
thread_key = _get_thread_key()
return _tf_stack.extract_stack(limit, _source_mapper_stacks[thread_key],
_source_filter_stacks[thread_key])
def extract_stack_for_node(node, limit=-1):
"""Same as extract_stack, but also saves the retrieved stack in `node`.
Args:
node: Pointer to the Node object.
limit: A limit on the number of frames to return.
Returns:
An object wrapping the sequence of StackFrame objects (filename, lineno,
name, line) corresponding to the call stack of the current thread. The
returned object can be indexed as a Python list.
"""
# N.B ExtractStack in tf_stack.cc will drop this frame prior to
# traversing the stack.
thread_key = _get_thread_key()
return _tf_stack.extract_stack_for_node(limit,
_source_mapper_stacks[thread_key],
_source_filter_stacks[thread_key],
node)
StackSummary = _tf_stack.StackTraceWrapper
FrameSummary = _tf_stack.StackFrame