diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index f56f8ad0a4b..0f728f1ebc3 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -308,6 +308,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/util:abstract_stack_trace", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/c/eager/immediate_execution_operation.h b/tensorflow/c/eager/immediate_execution_operation.h index 31a75c5b8c7..8e22fb2d8b5 100644 --- a/tensorflow/c/eager/immediate_execution_operation.h +++ b/tensorflow/c/eager/immediate_execution_operation.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/util/abstract_stack_trace.h" struct TFE_Op; @@ -44,6 +45,12 @@ class ImmediateExecutionOperation : public AbstractOperation { // Experimental virtual Status SetUseXla(bool enable) = 0; + // Set stack trace to be used for potential async error reporting. + virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0; + + // Returns the stack trace set by `SetStackTrace` if exists. + virtual absl::optional GetStackTrace() = 0; + // For LLVM style RTTI. static bool classof(const AbstractOperation* ptr) { return ptr->getKind() == kEager || ptr->getKind() == kTfrt; diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 003a4e5996f..b4c905f220e 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -163,6 +163,7 @@ tf_cuda_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", "//tensorflow/core/platform:platform_port", + "//tensorflow/core/util:abstract_stack_trace", ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 47610629479..6dbc342c1bd 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -306,6 +306,7 @@ Status EagerOperation::Reset( } attrs_.Reset(op); use_xla_ = false; + stack_trace_.reset(); is_function_ = is_function; cancellation_manager_ = nullptr; executor_ = executor ? executor : &ctx_.Executor(); diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index dad578ba9f0..9fc35a18a7f 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/util/abstract_stack_trace.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -120,6 +121,14 @@ class EagerOperation : public ImmediateExecutionOperation { Status SetUseXla(bool enable) override; + void SetStackTrace(AbstractStackTrace stack_trace) override { + stack_trace_ = stack_trace; + } + + absl::optional GetStackTrace() override { + return stack_trace_; + } + Status Reset(const char* op, const char* device_name, bool remote, EagerExecutor* executor, const absl::optional @@ -218,6 +227,7 @@ class EagerOperation : public ImmediateExecutionOperation { VariantDevice device_; bool use_xla_ = false; + absl::optional stack_trace_; bool is_function_; // Conceptually const, but can't be because of Reset bool colocation_exempt_; CancellationManager* cancellation_manager_ = nullptr; // Not owned. diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 71d781e5d3d..fec31da703e 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -634,7 +634,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, auto node = absl::make_unique( &ctx, op->Inputs(), op->remote_func_params(), std::move(kernel), graph_collector, op->GetCancellationManager(), - absl::Span(retvals, num_outputs)); + absl::Span(retvals, num_outputs), op->GetStackTrace()); // Release the inputs from the eager operation since the AsyncExecuteNode // would have taken ownership. This allows the inputs to be forwarded if // possible. diff --git a/tensorflow/core/common_runtime/eager/execute_node.h b/tensorflow/core/common_runtime/eager/execute_node.h index 7924471066e..6d11ecbf7a4 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.h +++ b/tensorflow/core/common_runtime/eager/execute_node.h @@ -150,14 +150,16 @@ class AsyncExecuteNode : public EagerNode { core::RefCountPtr kernel, GraphCollector* graph_collector, CancellationManager* cancellation_manager, - absl::Span retvals) + absl::Span retvals, + absl::optional stack_trace) : EagerNode(), ctx_(ctx), inputs_(inputs), remote_func_params_(remote_func_params), kernel_(std::move(kernel)), graph_collector_(graph_collector), - cancellation_manager_(cancellation_manager) { + cancellation_manager_(cancellation_manager), + stack_trace_(stack_trace) { // Copy the output handles, since the container for them might get // destroyed. for (auto handle : retvals) { @@ -194,10 +196,14 @@ class AsyncExecuteNode : public EagerNode { } ++i; } - const Status status = EagerKernelExecute( + Status status = EagerKernelExecute( ctx_, inputs_, remote_func_params_, kernel_, graph_collector_, cancellation_manager_, absl::MakeSpan(retvals_)); if (!status.ok()) { + if (stack_trace_.has_value()) { + status = Status(status.code(), status.error_message(), + stack_trace_->ToStackFrames()); + } Abort(status); return status; } @@ -227,6 +233,7 @@ class AsyncExecuteNode : public EagerNode { core::RefCountPtr kernel_; GraphCollector* graph_collector_; CancellationManager* const cancellation_manager_; + absl::optional stack_trace_; absl::InlinedVector retvals_; }; diff --git a/tensorflow/core/platform/errors.h b/tensorflow/core/platform/errors.h index 3f1ff477655..55af45a4c24 100644 --- a/tensorflow/core/platform/errors.h +++ b/tensorflow/core/platform/errors.h @@ -62,9 +62,11 @@ inline const strings::AlphaNum& PrepareForStrCat(const strings::AlphaNum& a) { // to be several layers of additional context. template void AppendToMessage(::tensorflow::Status* status, Args... args) { + std::vector stack_trace = status->stack_trace(); *status = ::tensorflow::Status( status->code(), - ::tensorflow::strings::StrCat(status->error_message(), "\n\t", args...)); + ::tensorflow::strings::StrCat(status->error_message(), "\n\t", args...), + std::move(stack_trace)); } // For propagating errors when calling a function. diff --git a/tensorflow/core/platform/status.cc b/tensorflow/core/platform/status.cc index c85527f27ad..04f74d024ca 100644 --- a/tensorflow/core/platform/status.cc +++ b/tensorflow/core/platform/status.cc @@ -89,11 +89,13 @@ class StatusLogSink : public TFLogSink { } // namespace -Status::Status(tensorflow::error::Code code, StringPiece msg) { +Status::Status(tensorflow::error::Code code, tensorflow::StringPiece msg, + std::vector&& stack_trace) { assert(code != tensorflow::error::OK); state_ = std::unique_ptr(new State); state_->code = code; state_->msg = string(msg); + state_->stack_trace = std::move(stack_trace); VLOG(5) << "Generated non-OK status: \"" << *this << "\". " << CurrentStackTrace(); } @@ -117,6 +119,11 @@ const string& Status::empty_string() { return *empty; } +const std::vector& Status::empty_stack_trace() { + static std::vector* empty = new std::vector(); + return *empty; +} + string error_name(error::Code code) { switch (code) { case tensorflow::error::OK: diff --git a/tensorflow/core/platform/status.h b/tensorflow/core/platform/status.h index 5ee93a179db..fc570caf6b1 100644 --- a/tensorflow/core/platform/status.h +++ b/tensorflow/core/platform/status.h @@ -29,6 +29,13 @@ limitations under the License. namespace tensorflow { +// A struct representing a frame in a stack trace. +struct StackFrame { + std::string file_name; + int line_number; + std::string function_name; +}; + #if defined(__clang__) // Only clang supports warn_unused_result as a type annotation. class TF_MUST_USE_RESULT Status; @@ -43,7 +50,15 @@ class Status { /// \brief Create a status with the specified error code and msg as a /// human-readable string containing more detailed information. - Status(tensorflow::error::Code code, tensorflow::StringPiece msg); + Status(tensorflow::error::Code code, tensorflow::StringPiece msg) + : Status(code, msg, {}) {} + + /// \brief Create a status with the specified error code, msg, and stack trace + /// as a human-readable string containing more detailed information. +#ifndef SWIG + Status(tensorflow::error::Code code, tensorflow::StringPiece msg, + std::vector&& stack_trace); +#endif /// Copy the specified status. Status(const Status& s); @@ -66,6 +81,10 @@ class Status { return ok() ? empty_string() : state_->msg; } + const std::vector& stack_trace() const { + return ok() ? empty_stack_trace() : state_->stack_trace; + } + bool operator==(const Status& x) const; bool operator!=(const Status& x) const; @@ -91,9 +110,11 @@ class Status { private: static const std::string& empty_string(); + static const std::vector& empty_stack_trace(); struct State { tensorflow::error::Code code; std::string msg; + std::vector stack_trace; }; // OK status has a `NULL` state_. Otherwise, `state_` points to // a `State` structure containing the error code and message(s) diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index 962beb55e05..78757bed13e 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -61,6 +61,7 @@ filegroup( filegroup( name = "mobile_srcs_only_runtime", srcs = [ + "abstract_stack_trace.h", "batch_util.cc", "batch_util.h", "bcast.cc", @@ -313,6 +314,7 @@ filegroup( filegroup( name = "framework_srcs", srcs = [ + "abstract_stack_trace.h", "activation_mode.h", "batch_util.h", "bcast.h", @@ -437,6 +439,22 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "abstract_stack_trace", + hdrs = ["abstract_stack_trace.h"], + visibility = [ + "//tensorflow/c/eager:__pkg__", + "//tensorflow/core:__pkg__", + "//tensorflow/core/common_runtime/eager:__pkg__", + "//tensorflow/core/platform:__pkg__", + "//tensorflow/python:__pkg__", + "//tensorflow/python/eager:__pkg__", + ], + deps = [ + "//tensorflow/core/platform:status", + ], +) + tf_cuda_library( name = "gpu_cuda_alias", hdrs = ["gpu_cuda_alias.h"], diff --git a/tensorflow/core/util/abstract_stack_trace.h b/tensorflow/core/util/abstract_stack_trace.h new file mode 100644 index 00000000000..442adc6f380 --- /dev/null +++ b/tensorflow/core/util/abstract_stack_trace.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_ABSTRACT_STACK_TRACE_H_ +#define TENSORFLOW_CORE_UTIL_ABSTRACT_STACK_TRACE_H_ + +#include + +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Language agnostic stack trace class. It only saves an id, and language +// clients are responsible for managing the actual stack trace objects. +class AbstractStackTrace { + public: + AbstractStackTrace(int id, std::vector (*to_stack_frames)(int)) + : id_(id), to_stack_frames_(to_stack_frames) {} + + // Returns stack trace as a vector of `StackFrame`s. + std::vector ToStackFrames() const { + return to_stack_frames_(id_); + } + + private: + int id_; + std::vector (*to_stack_frames_)(int); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_ABSTRACT_STACK_TRACE_H_ diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 7f40b0dac95..09d22aa203d 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5664,9 +5664,11 @@ cc_library( hdrs = ["util/stack_trace.h"], deps = [ ":py_util", + "//tensorflow/core/platform:str_util", + "//tensorflow/core/platform:stringpiece", + "//tensorflow/core/util:abstract_stack_trace", "//third_party/python_runtime:headers", # buildcleaner: keep "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 65c99b8c6e5..a96d2322b88 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -56,13 +56,16 @@ cc_library( "//tensorflow/core/platform:logging", "//tensorflow/core/platform:types", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/util:abstract_stack_trace", "//tensorflow/python:cpp_python_util", "//tensorflow/python:ndarray_tensor", "//tensorflow/python:ndarray_tensor_bridge", "//tensorflow/python:numpy_lib", "//tensorflow/python:py_exception_registry", "//tensorflow/python:py_seq_tensor", + "//tensorflow/python:py_util", "//tensorflow/python:safe_ptr", + "//tensorflow/python:stack_trace", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index b996d0dd0c4..a859f4edf01 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -30,6 +30,7 @@ from tensorflow.python.eager import test from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util @@ -480,6 +481,24 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertIs(weak_x(), None) self.assertIs(weak_y(), None) + def testAsyncExceptionStackTrace(self): + config.set_synchronous_execution(False) + + def exception_originated_from_here(): + # Invalid shapes for matmul. + return math_ops.matmul([[1]], [[2], [3]]) + + # In sync mode, an exception would have been raised here but since this is + # in async, the exception will be raised next. + x = exception_originated_from_here() + + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + 'in exception_originated_from_here'): + x.numpy() + + context.async_clear_error() + config.set_synchronous_execution(True) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index a4c06f8e72f..dcaaafeda5c 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -41,10 +41,13 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/util/abstract_stack_trace.h" #include "tensorflow/python/eager/pywrap_gradient_exclusions.h" #include "tensorflow/python/eager/pywrap_tensor.h" #include "tensorflow/python/eager/pywrap_tfe.h" +#include "tensorflow/python/lib/core/py_util.h" #include "tensorflow/python/lib/core/safe_ptr.h" +#include "tensorflow/python/util/stack_trace.h" #include "tensorflow/python/util/util.h" using tensorflow::string; @@ -854,10 +857,14 @@ void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name, TF_Status* out_status) { tensorflow::profiler::TraceMe activity( "TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo); + TFE_Op* op = GetOp(ctx, op_name, device_name, out_status); + auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); }); if (!out_status->status.ok()) return; + tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace()); + for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) { TFE_OpAddInput(op, inputs->at(i), out_status); } @@ -970,14 +977,54 @@ void RaiseFallbackException(const char* message) { .data()); } +// Format and return `status`' error message with the attached stack trace if +// available. `status` must have an error. +std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) { + tensorflow::DCheckPyGilState(); + DCHECK(!status.ok()); + + if (status.stack_trace().empty()) return status.error_message(); + + const std::vector& stack_trace = status.stack_trace(); + + PyObject* linecache = PyImport_ImportModule("linecache"); + PyObject* getline = + PyObject_GetAttr(linecache, PyUnicode_FromString("getline")); + DCHECK(getline); + + std::ostringstream result; + result << "Exception originated from\n\n"; + + for (const tensorflow::StackFrame& stack_frame : stack_trace) { + PyObject* line_str_obj = PyObject_CallFunction( + getline, const_cast("si"), stack_frame.file_name.c_str(), + stack_frame.line_number); + tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj); + tensorflow::str_util::RemoveWhitespaceContext(&line_str); + result << " File \"" << stack_frame.file_name << "\", line " + << stack_frame.line_number << ", in " << stack_frame.function_name + << '\n'; + + if (!line_str.empty()) result << " " << line_str << '\n'; + Py_XDECREF(line_str_obj); + } + + Py_DecRef(getline); + Py_DecRef(linecache); + + result << '\n' << status.error_message(); + return result.str(); +} + int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) { if (status->status.ok()) return 0; const char* msg = TF_Message(status); if (exception == nullptr) { tensorflow::mutex_lock l(exception_class_mutex); if (exception_class != nullptr) { - tensorflow::Safe_PyObjectPtr val( - Py_BuildValue("si", msg, TF_GetCode(status))); + tensorflow::Safe_PyObjectPtr val(Py_BuildValue( + "si", FormatErrorStatusStackTrace(status->status).c_str(), + TF_GetCode(status))); if (PyErr_Occurred()) { // NOTE: This hides the actual error (i.e. the reason `status` was not // TF_OK), but there is nothing we can do at this point since we can't @@ -1003,7 +1050,8 @@ int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status, if (exception == nullptr) { tensorflow::mutex_lock l(exception_class_mutex); if (exception_class != nullptr) { - tensorflow::Safe_PyObjectPtr val(Py_BuildValue("si", msg, status.code())); + tensorflow::Safe_PyObjectPtr val(Py_BuildValue( + "si", FormatErrorStatusStackTrace(status).c_str(), status.code())); PyErr_SetObject(exception_class, val.get()); return -1; } else { @@ -3527,6 +3575,8 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) { } TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status); + tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace()); + auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] { ReturnStatus(status); ReturnOp(ctx, op); @@ -3746,11 +3796,14 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) { if (!status->status.ok()) { // Augment the status with the op_name for easier debugging similar to // TFE_Py_Execute. - TF_SetStatus(status, TF_GetCode(status), - tensorflow::strings::StrCat( - TF_Message(status), - " [Op:", TFE_GetPythonString(op_exec_info.op_name), "]") - .c_str()); + std::vector stack_trace = + status->status.stack_trace(); + status->status = tensorflow::Status( + status->status.code(), + tensorflow::strings::StrCat( + TF_Message(status), + " [Op:", TFE_GetPythonString(op_exec_info.op_name), "]"), + std::move(stack_trace)); MaybeRaiseExceptionFromTFStatus(status, nullptr); return nullptr; diff --git a/tensorflow/python/util/stack_trace.cc b/tensorflow/python/util/stack_trace.cc index cf574f6f292..04b427fd67b 100644 --- a/tensorflow/python/util/stack_trace.cc +++ b/tensorflow/python/util/stack_trace.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/python/util/stack_trace.h" +#include "tensorflow/core/platform/str_util.h" +#include "tensorflow/core/platform/stringpiece.h" + namespace { // Returns C string from a Python string object. Handles Python2/3 strings. @@ -31,22 +34,33 @@ const char* GetPythonString(PyObject* o) { return PyBytes_AsString(o); #endif } + } // namespace namespace tensorflow { -std::string StackTrace::ToString() const { - DCheckPyGilState(); - std::ostringstream result; +std::vector StackTrace::ToStackFrames() const { + std::vector result; + result.reserve(size_); + for (int i = size_ - 1; i >= 0; --i) { - result << " File \"" << PyUnicode_AsUTF8(code_objs_[i]->co_filename) - << "\", line " - << PyCode_Addr2Line(code_objs_[i], last_instructions_[i]) << ", in " - << GetPythonString(code_objs_[i]->co_name) - << "\n \n"; - // TODO(kkb): Add source code line. See tf_stack.cc's - // FrameSummary::line() function. + const char* file_name = GetPythonString(code_objs_[i]->co_filename); + const int line_number = + PyCode_Addr2Line(code_objs_[i], last_instructions_[i]); + result.emplace_back(StackFrame{file_name, line_number, + GetPythonString(code_objs_[i]->co_name)}); } - return result.str(); + + return result; } + +StackTrace* StackTraceManager::Get(int id) { + DCheckPyGilState(); + if (next_id_ - id > kStackTraceCircularBufferSize) return nullptr; + + return &stack_traces_[id & (kStackTraceCircularBufferSize - 1)]; +} + +StackTraceManager* const stack_trace_manager = new StackTraceManager(); + } // namespace tensorflow diff --git a/tensorflow/python/util/stack_trace.h b/tensorflow/python/util/stack_trace.h index 0b9a737bf7e..732d40c92d2 100644 --- a/tensorflow/python/util/stack_trace.h +++ b/tensorflow/python/util/stack_trace.h @@ -25,6 +25,8 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/base/optimization.h" +#include "absl/types/optional.h" +#include "tensorflow/core/util/abstract_stack_trace.h" #include "tensorflow/python/lib/core/py_util.h" namespace tensorflow { @@ -82,10 +84,8 @@ class StackTrace final { return *this; } - // Returns string representation of the captured stack trace. - std::string ToString() const; - - // TODO(kkb): Implement structured stack trace object getter. + // Returns a structured representation of the captured stack trace. + std::vector ToStackFrames() const; private: std::array code_objs_; @@ -103,6 +103,53 @@ class StackTrace final { StackTrace& operator=(const StackTrace&) = delete; }; +// A class that manages Python stack traces in a circular buffer. Users can +// insert stack trace entries and retrive them by ids. +class StackTraceManager { + public: + static constexpr int kStackTraceCircularBufferSize = 1024; + + // Captures the current Python stack trace and returns an id. + // Python GIL must be acquired beforehand. + ABSL_MUST_USE_RESULT + ABSL_ATTRIBUTE_HOT + int Capture() { + DCheckPyGilState(); + const int id = next_id_++; + const int index = id & (kStackTraceCircularBufferSize - 1); + stack_traces_[index] = StackTrace::Capture(); + return id; + } + + // Retrieve captured Python stack trace by id. Returns `nullptr` if the + // requested stack trace is evicted from the circular buffer. + // Python GIL must be acquired beforehand. + ABSL_MUST_USE_RESULT + StackTrace* Get(int id); + + private: + int next_id_ = 0; + std::array stack_traces_; +}; + +// Singleton StackTraceManager. +extern StackTraceManager* const stack_trace_manager; + +// Returns Python stack trace object that can be converted to string. +// Note that the actual stack trace is kept in a circular buffer for string +// conversion could fail if it's evicted before. +// Python GIL must be acquired beforehand. +inline AbstractStackTrace GetStackTrace() { + DCheckPyGilState(); + return AbstractStackTrace(stack_trace_manager->Capture(), [](int id) { + PyGILState_STATE gstate = PyGILState_Ensure(); + std::vector result = + stack_trace_manager->Get(id)->ToStackFrames(); + PyGILState_Release(gstate); + return result; + }); +} + } // namespace tensorflow #endif // TENSORFLOW_PYTHON_UTIL_STACK_TRACE_H_