[XLA:Python] Add support for collecting Python tracebacks.

Adds a new xla_client.Traceback API that can collect Python tracebacks cheaply (~2us). This is several orders of magnitude cheaper than using the Python `inspect.stack` API.

Add a facility to attach a traceback to every buffer and executable object describing its creation context. To avoid paying a runtime cost when not debugging, tracebacks collection is optional and disabled by default.

PiperOrigin-RevId: 314793876
Change-Id: Ie5f509364065739c1da4d1a8a729c4c6f56e2d03
This commit is contained in:
Peter Hawkins 2020-06-04 13:34:34 -07:00 committed by TensorFlower Gardener
parent 84c796966b
commit 7df1f4ffc3
14 changed files with 518 additions and 20 deletions

View File

@ -123,6 +123,26 @@ cc_library(
],
)
cc_library(
name = "traceback_manager",
srcs = ["traceback_manager.cc"],
hdrs = ["traceback_manager.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":python_ref_manager",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:optional",
"@pybind11",
],
)
cc_library(
name = "bfloat16",
srcs = ["bfloat16.cc"],
@ -170,10 +190,12 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":python_ref_manager",
":traceback_manager",
":types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"@com_google_absl//absl/types:optional",
],
)
@ -188,11 +210,13 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":py_buffer",
":traceback_manager",
":types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
@ -208,6 +232,7 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":py_buffer",
":traceback_manager",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
@ -333,6 +358,7 @@ pybind_extension(
":py_executable",
":python_ref_manager",
":outfeed_receiver_py",
":traceback_manager",
":types",
"@com_google_absl//absl/base",
"@com_google_absl//absl/hash",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "include/dlpack/dlpack.h" // from @dlpack
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/python/traceback_manager.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
@ -349,7 +350,8 @@ StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
PyCapsule_SetDestructor(tensor.ptr(), nullptr);
auto pjrt_buffer = std::make_unique<PjRtBuffer>(
shape, shape, std::move(device_buffer), client.get(), device);
return std::make_unique<PyBuffer>(std::move(client), std::move(pjrt_buffer));
return std::make_unique<PyBuffer>(std::move(client), std::move(pjrt_buffer),
TracebackManager::Get()->GetTraceback());
}
} // namespace xla

View File

@ -22,8 +22,11 @@ namespace xla {
namespace py = pybind11;
PyBuffer::PyBuffer(std::shared_ptr<PjRtClient> client,
std::unique_ptr<PjRtBuffer> buffer)
: client_(std::move(client)), buffer_(std::move(buffer)) {}
std::unique_ptr<PjRtBuffer> buffer,
absl::optional<TracebackManager::Traceback> traceback)
: client_(std::move(client)),
buffer_(std::move(buffer)),
traceback_(std::move(traceback)) {}
ClientAndPtr<Device> PyBuffer::device() const {
return WrapWithClient(client_, buffer_->device());
@ -33,10 +36,12 @@ StatusOr<std::unique_ptr<PyBuffer>> PyBuffer::CopyToDevice(
const ClientAndPtr<Device>& dst_device) const {
CHECK(dst_device.get() != nullptr);
GlobalPyRefManager()->CollectGarbage();
auto traceback = TracebackManager::Get()->GetTraceback();
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> out,
buffer_->CopyToDevice(dst_device.get()));
return std::make_unique<PyBuffer>(dst_device.client, std::move(out));
return std::make_unique<PyBuffer>(dst_device.client, std::move(out),
traceback);
}
Status PyBuffer::BlockHostUntilReady() {

View File

@ -19,7 +19,9 @@ limitations under the License.
#include <memory>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/traceback_manager.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@ -32,7 +34,8 @@ namespace xla {
class PyBuffer {
public:
PyBuffer(std::shared_ptr<PjRtClient> client,
std::unique_ptr<PjRtBuffer> buffer);
std::unique_ptr<PjRtBuffer> buffer,
absl::optional<TracebackManager::Traceback> traceback);
std::shared_ptr<PjRtClient> client() const { return client_; }
PjRtBuffer* buffer() const { return buffer_.get(); }
@ -60,9 +63,14 @@ class PyBuffer {
// PEP 3118 Python buffer protocol implementation.
static PyBufferProcs* BufferProtocol();
const absl::optional<TracebackManager::Traceback>& traceback() {
return traceback_;
}
private:
std::shared_ptr<PjRtClient> client_;
std::unique_ptr<PjRtBuffer> buffer_;
absl::optional<TracebackManager::Traceback> traceback_;
};
} // namespace xla

View File

@ -21,9 +21,13 @@ namespace xla {
namespace py = pybind11;
PyExecutable::PyExecutable(std::shared_ptr<PjRtClient> client,
std::unique_ptr<PjRtExecutable> executable)
: client_(std::move(client)), executable_(std::move(executable)) {}
PyExecutable::PyExecutable(
std::shared_ptr<PjRtClient> client,
std::unique_ptr<PjRtExecutable> executable,
absl::optional<TracebackManager::Traceback> traceback)
: client_(std::move(client)),
executable_(std::move(executable)),
traceback_(std::move(traceback)) {}
std::vector<ClientAndPtr<Device>> PyExecutable::LocalDevices() const {
std::vector<ClientAndPtr<Device>> devices;
@ -36,6 +40,7 @@ std::vector<ClientAndPtr<Device>> PyExecutable::LocalDevices() const {
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
absl::Span<PyBuffer* const> args) {
auto traceback = TracebackManager::Get()->GetTraceback();
py::gil_scoped_release gil_release;
ExecuteOptions options;
options.untuple_result = true;
@ -47,7 +52,8 @@ StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
std::vector<std::unique_ptr<PyBuffer>> outputs;
outputs.reserve(output_buffers.size());
for (auto& buffer : output_buffers) {
outputs.push_back(std::make_unique<PyBuffer>(client_, std::move(buffer)));
outputs.push_back(
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
}
return outputs;
}
@ -55,6 +61,7 @@ StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
StatusOr<std::vector<std::vector<std::unique_ptr<PyBuffer>>>>
PyExecutable::ExecuteOnLocalDevices(
absl::Span<const std::vector<PyBuffer*>> args) {
auto traceback = TracebackManager::Get()->GetTraceback();
py::gil_scoped_release gil_release;
ExecuteOptions options;
options.untuple_result = true;
@ -73,7 +80,7 @@ PyExecutable::ExecuteOnLocalDevices(
++computation) {
for (auto& buffer : output_buffers[computation]) {
outputs[computation].push_back(
std::make_unique<PyBuffer>(client_, std::move(buffer)));
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
}
}
return outputs;

View File

@ -20,9 +20,11 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/py_buffer.h"
#include "tensorflow/compiler/xla/python/traceback_manager.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@ -35,7 +37,8 @@ namespace xla {
class PyExecutable {
public:
PyExecutable(std::shared_ptr<PjRtClient> client,
std::unique_ptr<PjRtExecutable> executable);
std::unique_ptr<PjRtExecutable> executable,
absl::optional<TracebackManager::Traceback> traceback);
std::shared_ptr<PjRtClient> client() const { return client_; }
@ -59,9 +62,14 @@ class PyExecutable {
StatusOr<std::vector<std::shared_ptr<HloModule>>> HloModules() const;
const absl::optional<TracebackManager::Traceback>& traceback() {
return traceback_;
}
private:
std::shared_ptr<PjRtClient> client_;
std::unique_ptr<PjRtExecutable> executable_;
absl::optional<TracebackManager::Traceback> traceback_;
};
} // namespace xla

View File

@ -50,6 +50,13 @@ PythonRefManager::ManageReferences(absl::Span<py::object> objects) {
return std::make_shared<ManagedPyObjects>(this, objects);
}
void PythonRefManager::AddGarbage(absl::Span<py::object> garbage) {
absl::MutexLock lock(&mu_);
for (py::object& o : garbage) {
python_garbage_.push_back(std::move(o));
}
}
void PythonRefManager::CollectGarbage() {
// TODO(phawkins): we should CHECK(PyGILState_Check());
std::deque<pybind11::object> garbage;

View File

@ -66,6 +66,9 @@ class PythonRefManager {
std::shared_ptr<ManagedPyObjects> ManageReferences(
absl::Span<pybind11::object> objects);
// Adds garbage objects to the manager.
void AddGarbage(absl::Span<pybind11::object> garbage);
// Releases the contents of python_garbage_. Requires that the GIL is held.
// The client calls this method during API entry points where the GIL is held
// to free any garbage that has accumulated.

View File

@ -173,9 +173,13 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def("shape", &PyTpuBuffer::on_host_shape)
.def("device", &PyTpuBuffer::device)
.def("platform", &PyTpuBuffer::platform_name)
.def("is_deleted", [](const PyTpuBuffer& buffer) {
return buffer.DeviceBuffer() == nullptr;
});
.def("is_deleted",
[](const PyTpuBuffer& buffer) {
return buffer.DeviceBuffer() == nullptr;
})
// TODO(phawkins): implement traceback support.
.def_property_readonly("traceback",
[](PyTpuBuffer*) { return py::none(); });
py::class_<PyTpuExecutable>(m, "TpuExecutable")
.def("local_logical_device_ids",
@ -193,7 +197,10 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def("execute", &PyTpuExecutable::Execute,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
.def("execute_on_local_devices", &PyTpuExecutable::ExecuteOnLocalDevices,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
// TODO(phawkins): implement traceback support.
.def_property_readonly("traceback",
[](PyTpuExecutable*) { return py::none(); });
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
.def_property_readonly("coords", &TpuDevice::coords)

View File

@ -0,0 +1,196 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/traceback_manager.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
namespace xla {
namespace {
namespace py = pybind11;
} // namespace
struct TracebackImpl {
~TracebackImpl();
std::vector<TracebackManager::Frame> frames;
// Computes a traceback for the current Python thread. Requires the GIL.
bool GetTracebackForCurrentThread();
bool operator==(const TracebackImpl& other) const;
std::string ToString() const;
};
// We want Traceback objects to be safe to destroy without holding the GIL, so
// we defer destruction of the strings.
TracebackImpl::~TracebackImpl() {
std::vector<py::object> objects;
objects.reserve(2 * frames.size());
for (TracebackManager::Frame& frame : frames) {
objects.push_back(std::move(frame.file_name));
objects.push_back(std::move(frame.function_name));
}
GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects));
}
bool TracebackImpl::operator==(const TracebackImpl& other) const {
if (frames.size() != other.frames.size()) {
return false;
}
for (int i = 0; i < frames.size(); ++i) {
// Python strings are compared using pointer equality. This is cheap and
// does not require calling back into the Python interpreter, but may mean
// we miss some opportunities for deduplication of TracebackImpl objects.
// However, we expect that function and file names are drawn from a fixed
// pool of constants.
if (frames[i].file_name.ptr() != other.frames[i].file_name.ptr() ||
frames[i].function_name.ptr() != other.frames[i].function_name.ptr() ||
frames[i].line_num != other.frames[i].line_num ||
frames[i].function_start_line != other.frames[i].function_start_line) {
return false;
}
}
return true;
}
template <typename H>
H AbslHashValue(H h, const TracebackImpl& tb) {
for (const TracebackManager::Frame& frame : tb.frames) {
h = H::combine(std::move(h), frame.file_name.ptr(),
frame.function_name.ptr(), frame.line_num);
}
return h;
}
bool TracebackImpl::GetTracebackForCurrentThread() {
PyThreadState* thread_state = PyGILState_GetThisThreadState();
if (!thread_state) {
return false;
}
frames.reserve(32);
for (PyFrameObject* py_frame = thread_state->frame; py_frame != nullptr;
py_frame = py_frame->f_back) {
frames.resize(frames.size() + 1);
TracebackManager::Frame& frame = frames.back();
PyCodeObject* code = py_frame->f_code;
if (!code) {
return false;
}
frame.line_num = PyFrame_GetLineNumber(py_frame);
frame.file_name = py::str(code->co_filename);
frame.function_name = py::str(code->co_name);
frame.function_start_line = code->co_firstlineno;
}
return true;
}
std::string TracebackImpl::ToString() const {
std::vector<std::string> frame_strs;
frame_strs.reserve(frames.size());
for (const TracebackManager::Frame& frame : frames) {
frame_strs.push_back(absl::StrFormat("%s:%d (%s)", frame.file_name,
frame.line_num, frame.function_name));
}
return absl::StrJoin(frame_strs, "\n");
}
TracebackManager::Traceback::Traceback(
TracebackManager* manager, std::pair<TracebackImpl const, int>* impl)
: manager_(manager), impl_(impl) {
DCHECK(manager_);
++impl->second;
}
TracebackManager::Traceback::~Traceback() {
if (manager_) {
--impl_->second;
if (impl_->second == 0) {
manager_->tracebacks_.erase(impl_->first);
}
}
}
TracebackManager::Traceback::Traceback(const Traceback& other)
: manager_(other.manager_), impl_(other.impl_) {
if (manager_) {
++impl_->second;
}
}
TracebackManager::Traceback::Traceback(Traceback&& other)
: manager_(other.manager_), impl_(other.impl_) {
other.manager_ = nullptr;
other.impl_ = nullptr;
}
TracebackManager::Traceback& TracebackManager::Traceback::operator=(
const TracebackManager::Traceback& other) {
manager_ = other.manager_;
impl_ = other.impl_;
if (manager_) {
++impl_->second;
}
return *this;
}
TracebackManager::Traceback& TracebackManager::Traceback::operator=(
TracebackManager::Traceback&& other) {
std::swap(manager_, other.manager_);
std::swap(impl_, other.impl_);
return *this;
}
std::string TracebackManager::Traceback::ToString() const {
// We require the GIL because we manipulate Python strings.
CHECK(PyGILState_Check());
if (!manager_) {
// Don't crash if called on a default-constructed Traceback.
return "<unknown>";
}
return impl_->first.ToString();
}
const std::vector<TracebackManager::Frame>*
TracebackManager::Traceback::Frames() const {
return &impl_->first.frames;
}
/*static*/ TracebackManager* TracebackManager::Get() {
static TracebackManager* manager = new TracebackManager;
return manager;
}
TracebackManager::TracebackManager() = default;
TracebackManager::~TracebackManager() = default;
absl::optional<TracebackManager::Traceback> TracebackManager::GetTraceback() {
if (!enabled_) {
return absl::nullopt;
}
CHECK(PyGILState_Check());
TracebackImpl impl;
if (!impl.GetTracebackForCurrentThread()) {
return absl::nullopt;
}
auto it = tracebacks_.emplace(impl, 0);
return Traceback(this, &*it.first);
}
void TracebackManager::SetEnabled(bool enabled) { enabled_ = enabled; }
} // namespace xla

View File

@ -0,0 +1,102 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TRACEBACK_MANAGER_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_TRACEBACK_MANAGER_H_
#include "absl/container/node_hash_map.h"
#include "absl/types/optional.h"
#include "pybind11/pybind11.h"
namespace xla {
struct TracebackImpl;
template <typename H>
H AbslHashValue(H h, const TracebackImpl& tb);
// Traceback manager class that deduplicates traceback objects to save memory.
// It probably does not save time to deduplicate tracebacks, but we expect to
// see many copies of the same tracebacks and hence we deduplicate in an attempt
// to save memory.
class TracebackManager {
public:
static TracebackManager* Get();
~TracebackManager();
TracebackManager(const TracebackManager&) = delete;
TracebackManager(TracebackManager&&) = delete;
TracebackManager& operator=(const TracebackManager&) = delete;
TracebackManager& operator=(TracebackManager&&) = delete;
struct Frame {
pybind11::str file_name;
pybind11::str function_name;
unsigned int line_num;
int function_start_line;
};
// RAII class that holds a reference to a traceback.
class Traceback {
public:
Traceback() = default;
~Traceback();
Traceback(const Traceback&);
Traceback(Traceback&&);
Traceback& operator=(const Traceback&);
Traceback& operator=(Traceback&&);
// Requires the GIL be held.
std::string ToString() const;
// Returns the stack frame objects, in order from innermost to outermost.
const std::vector<Frame>* Frames() const;
private:
friend class TracebackManager;
Traceback(TracebackManager* manager,
std::pair<TracebackImpl const, int>* impl);
// nullptr for a default-constructed Traceback, non-null otherwise.
TracebackManager* manager_ = nullptr;
// Points to an entry in tracebacks_. Not owned.
std::pair<TracebackImpl const, int>* impl_ = nullptr;
};
// Returns a Traceback for the current thread. Returns nullopt if tracebacks
// aren't enabled,
absl::optional<Traceback> GetTraceback();
// Enables or disables traceback collection.
void SetEnabled(bool enabled);
bool enabled() const { return enabled_; }
private:
TracebackManager();
bool enabled_ = false;
// Deduplicated tracebacks. Map from traceback to reference count.
// The map and its contents are protected by the GIL, which is why we do not
// need an atomic integer for the reference count.
absl::node_hash_map<TracebackImpl, int> tracebacks_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_TRACEBACK_MANAGER_H_

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl_bind.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@ -48,6 +49,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/py_buffer.h"
#include "tensorflow/compiler/xla/python/py_executable.h"
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
#include "tensorflow/compiler/xla/python/traceback_manager.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
@ -598,13 +600,16 @@ PYBIND11_MODULE(xla_extension, m) {
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
GlobalPyRefManager()->ManageReference(std::move(c->array));
auto traceback = TracebackManager::Get()->GetTraceback();
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> buffer,
PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
std::move(py_buffer_ref), client.get(),
device));
return std::make_unique<PyBuffer>(std::move(client), std::move(buffer));
return std::make_unique<PyBuffer>(std::move(client), std::move(buffer),
traceback);
},
py::arg("argument"), py::arg("device") = nullptr,
py::arg("force_copy") = false);
@ -612,12 +617,13 @@ PYBIND11_MODULE(xla_extension, m) {
"compile",
[](std::shared_ptr<PjRtClient> client, const XlaComputation& computation,
CompileOptions options) -> StatusOr<std::unique_ptr<PyExecutable>> {
auto traceback = TracebackManager::Get()->GetTraceback();
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
PjRtExecutable::Compile(computation, client.get(),
std::move(options)));
return std::make_unique<PyExecutable>(std::move(client),
std::move(executable));
return std::make_unique<PyExecutable>(
std::move(client), std::move(executable), std::move(traceback));
},
py::arg("computation"), py::arg("compile_options") = CompileOptions());
@ -628,6 +634,40 @@ PYBIND11_MODULE(xla_extension, m) {
py::arg("allocator_config") = GpuAllocatorConfig(),
py::arg("distributed_client") = nullptr, py::arg("node_id") = 0);
py::class_<TracebackManager::Frame>(m, "Frame")
.def_readonly("file_name", &TracebackManager::Frame::file_name)
.def_readonly("function_name", &TracebackManager::Frame::function_name)
.def_readonly("function_start_line",
&TracebackManager::Frame::function_start_line)
.def_readonly("line_num", &TracebackManager::Frame::line_num)
.def("__repr__", [](const TracebackManager::Frame& frame) {
return absl::StrFormat("%s;%s:%d", frame.function_name, frame.file_name,
frame.line_num);
});
py::bind_vector<std::vector<TracebackManager::Frame>>(m, "FrameVector");
py::class_<TracebackManager::Traceback> traceback(
m, "Traceback", "Represents a Python stack trace.");
traceback.def_property_static(
"enabled",
[](py::object /* cls */) { return TracebackManager::Get()->enabled(); },
[](py::object /* cls */, bool enabled) {
return TracebackManager::Get()->SetEnabled(enabled);
});
traceback.def_static(
"get_traceback", []() { return TracebackManager::Get()->GetTraceback(); },
R"doc(
Returns a :class:`Traceback` for the current thread.
If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` object
that describes the Python stack of the calling thread. Stack trace
collection has a small overhead, so it is disabled by default. If traceback
collection is disabled, returns ``None``.
)doc");
traceback.def_property_readonly("frames",
&TracebackManager::Traceback::Frames);
traceback.def("__str__", &TracebackManager::Traceback::ToString);
py::class_<PyBuffer, std::unique_ptr<PyBuffer>> buffer(m, "Buffer");
// TODO(phawkins): alias for backward compatibility. Remove after JAX no
// longer uses this name.
@ -668,7 +708,8 @@ PYBIND11_MODULE(xla_extension, m) {
.def("is_deleted", &PyBuffer::is_deleted)
.def("unsafe_buffer_pointer", &PyBuffer::UnsafeBufferPointer)
.def_property_readonly("__cuda_array_interface__",
&PyBuffer::CudaArrayInterface);
&PyBuffer::CudaArrayInterface)
.def_property_readonly("traceback", &PyBuffer::traceback);
// pybind11's implementation of the buffer protocol doesn't allow for correct
// error handling. We bypass it and implement the buffer protocol ourselves.
@ -686,7 +727,8 @@ PYBIND11_MODULE(xla_extension, m) {
.def("execute", &PyExecutable::Execute, py::arg("arguments"))
.def("execute_on_local_devices", &PyExecutable::ExecuteOnLocalDevices,
py::arg("arguments"))
.def("hlo_modules", &PyExecutable::HloModules);
.def("hlo_modules", &PyExecutable::HloModules)
.def_property_readonly("traceback", &PyExecutable::traceback);
py::class_<DebugOptions>(m, "DebugOptions")
.def("__repr__", &DebugOptions::DebugString)
@ -927,6 +969,8 @@ PYBIND11_MODULE(xla_extension, m) {
m.def("get_distributed_runtime_service", &GetDistributedRuntimeService);
m.def("get_distributed_runtime_client", &GetDistributedRuntimeClient);
m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); });
} // NOLINT(readability/fn_size)
} // namespace xla

View File

@ -19,7 +19,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import atexit
import collections
import contextlib
import enum # pylint: disable=g-bad-import-order
import inspect
import os
@ -406,6 +408,7 @@ def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
XlaBuilder = _xla.XlaBuilder
XlaComputation = _xla.XlaComputation
FftType = _xla.FftType
Client = _xla.LocalClient
Buffer = _xla.Buffer
Executable = _xla.Executable
@ -663,3 +666,22 @@ def make_replica_groups(replica_groups):
_make_replica_group_proto(group) for group in replica_groups
]
return replica_groups_protos
Traceback = _xla.Traceback
@contextlib.contextmanager
def tracebacks(enabled=True):
"""Context manager that enables or disables traceback collection."""
saved = Traceback.enabled
Traceback.enabled = enabled
try:
yield
finally:
Traceback.enabled = saved
# Perform one last garbage collection of deferred Python references. This is
# mostly to keep ASAN happy.
atexit.register(_xla.collect_garbage)

View File

@ -2038,6 +2038,67 @@ def TestFactory(xla_backend, cloud_tpu=False):
del server
tests.append(ProfilerTest)
class TracebackTest(absltest.TestCase):
def setUp(self):
super(TracebackTest, self).setUp()
self.backend = xla_backend()
def testNoTracebacksIfDisabled(self):
with xla_client.tracebacks(enabled=False):
self.assertEqual(None, xla_client.Traceback.get_traceback())
buffer = self.backend.buffer_from_pyval(np.array(7, np.int32))
self.assertEqual(None, buffer.traceback)
b = xla_client.XlaBuilder("computation")
ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
e = self.backend.compile(b.build())
self.assertEqual(None, e.traceback)
def assertIsTracebackContaining(self, tb, function):
self.assertIsInstance(tb, xla_client.Traceback)
self.assertIn(function, str(tb))
self.assertTrue(any(f.function_name == function for f in tb.frames))
def testTracebacks(self):
with xla_client.tracebacks(enabled=True):
tb = xla_client.Traceback.get_traceback()
self.assertIsTracebackContaining(tb, "testTracebacks")
# Tracebacks are not implemented on the TPU driver extension's variant
# of buffers and executables.
if not isinstance(self.backend, xla_client.Client):
return
buffer = self.backend.buffer_from_pyval(np.array(7, np.int32))
self.assertIsTracebackContaining(buffer.traceback, "testTracebacks")
b = xla_client.XlaBuilder("computation")
ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
e = self.backend.compile(b.build())
self.assertIsTracebackContaining(e.traceback, "testTracebacks")
def testNestedFunction(self):
def AFunction():
def AnotherFunction():
return xla_client.Traceback.get_traceback()
return AnotherFunction()
with xla_client.tracebacks(enabled=True):
tb = AFunction()
self.assertIsInstance(tb, xla_client.Traceback)
frames = tb.frames
i = next(
i for (i, f) in enumerate(frames) if f.function_name == "AFunction")
self.assertEqual(frames[i - 1].function_name, "AnotherFunction")
self.assertEqual(frames[i + 1].function_name, "testNestedFunction")
tests.append(TracebackTest)
return tests