[XLA:Python] Add support for collecting Python tracebacks.
Adds a new xla_client.Traceback API that can collect Python tracebacks cheaply (~2us). This is several orders of magnitude cheaper than using the Python `inspect.stack` API. Add a facility to attach a traceback to every buffer and executable object describing its creation context. To avoid paying a runtime cost when not debugging, tracebacks collection is optional and disabled by default. PiperOrigin-RevId: 314793876 Change-Id: Ie5f509364065739c1da4d1a8a729c4c6f56e2d03
This commit is contained in:
parent
84c796966b
commit
7df1f4ffc3
@ -123,6 +123,26 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "traceback_manager",
|
||||
srcs = ["traceback_manager.cc"],
|
||||
hdrs = ["traceback_manager.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":python_ref_manager",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bfloat16",
|
||||
srcs = ["bfloat16.cc"],
|
||||
@ -170,10 +190,12 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":python_ref_manager",
|
||||
":traceback_manager",
|
||||
":types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
@ -188,11 +210,13 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":py_buffer",
|
||||
":traceback_manager",
|
||||
":types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
@ -208,6 +232,7 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":py_buffer",
|
||||
":traceback_manager",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
@ -333,6 +358,7 @@ pybind_extension(
|
||||
":py_executable",
|
||||
":python_ref_manager",
|
||||
":outfeed_receiver_py",
|
||||
":traceback_manager",
|
||||
":types",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/hash",
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "include/dlpack/dlpack.h" // from @dlpack
|
||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
|
||||
@ -349,7 +350,8 @@ StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
|
||||
PyCapsule_SetDestructor(tensor.ptr(), nullptr);
|
||||
auto pjrt_buffer = std::make_unique<PjRtBuffer>(
|
||||
shape, shape, std::move(device_buffer), client.get(), device);
|
||||
return std::make_unique<PyBuffer>(std::move(client), std::move(pjrt_buffer));
|
||||
return std::make_unique<PyBuffer>(std::move(client), std::move(pjrt_buffer),
|
||||
TracebackManager::Get()->GetTraceback());
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -22,8 +22,11 @@ namespace xla {
|
||||
namespace py = pybind11;
|
||||
|
||||
PyBuffer::PyBuffer(std::shared_ptr<PjRtClient> client,
|
||||
std::unique_ptr<PjRtBuffer> buffer)
|
||||
: client_(std::move(client)), buffer_(std::move(buffer)) {}
|
||||
std::unique_ptr<PjRtBuffer> buffer,
|
||||
absl::optional<TracebackManager::Traceback> traceback)
|
||||
: client_(std::move(client)),
|
||||
buffer_(std::move(buffer)),
|
||||
traceback_(std::move(traceback)) {}
|
||||
|
||||
ClientAndPtr<Device> PyBuffer::device() const {
|
||||
return WrapWithClient(client_, buffer_->device());
|
||||
@ -33,10 +36,12 @@ StatusOr<std::unique_ptr<PyBuffer>> PyBuffer::CopyToDevice(
|
||||
const ClientAndPtr<Device>& dst_device) const {
|
||||
CHECK(dst_device.get() != nullptr);
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
auto traceback = TracebackManager::Get()->GetTraceback();
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> out,
|
||||
buffer_->CopyToDevice(dst_device.get()));
|
||||
return std::make_unique<PyBuffer>(dst_device.client, std::move(out));
|
||||
return std::make_unique<PyBuffer>(dst_device.client, std::move(out),
|
||||
traceback);
|
||||
}
|
||||
|
||||
Status PyBuffer::BlockHostUntilReady() {
|
||||
|
@ -19,7 +19,9 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -32,7 +34,8 @@ namespace xla {
|
||||
class PyBuffer {
|
||||
public:
|
||||
PyBuffer(std::shared_ptr<PjRtClient> client,
|
||||
std::unique_ptr<PjRtBuffer> buffer);
|
||||
std::unique_ptr<PjRtBuffer> buffer,
|
||||
absl::optional<TracebackManager::Traceback> traceback);
|
||||
|
||||
std::shared_ptr<PjRtClient> client() const { return client_; }
|
||||
PjRtBuffer* buffer() const { return buffer_.get(); }
|
||||
@ -60,9 +63,14 @@ class PyBuffer {
|
||||
// PEP 3118 Python buffer protocol implementation.
|
||||
static PyBufferProcs* BufferProtocol();
|
||||
|
||||
const absl::optional<TracebackManager::Traceback>& traceback() {
|
||||
return traceback_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<PjRtClient> client_;
|
||||
std::unique_ptr<PjRtBuffer> buffer_;
|
||||
absl::optional<TracebackManager::Traceback> traceback_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -21,9 +21,13 @@ namespace xla {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PyExecutable::PyExecutable(std::shared_ptr<PjRtClient> client,
|
||||
std::unique_ptr<PjRtExecutable> executable)
|
||||
: client_(std::move(client)), executable_(std::move(executable)) {}
|
||||
PyExecutable::PyExecutable(
|
||||
std::shared_ptr<PjRtClient> client,
|
||||
std::unique_ptr<PjRtExecutable> executable,
|
||||
absl::optional<TracebackManager::Traceback> traceback)
|
||||
: client_(std::move(client)),
|
||||
executable_(std::move(executable)),
|
||||
traceback_(std::move(traceback)) {}
|
||||
|
||||
std::vector<ClientAndPtr<Device>> PyExecutable::LocalDevices() const {
|
||||
std::vector<ClientAndPtr<Device>> devices;
|
||||
@ -36,6 +40,7 @@ std::vector<ClientAndPtr<Device>> PyExecutable::LocalDevices() const {
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
|
||||
absl::Span<PyBuffer* const> args) {
|
||||
auto traceback = TracebackManager::Get()->GetTraceback();
|
||||
py::gil_scoped_release gil_release;
|
||||
ExecuteOptions options;
|
||||
options.untuple_result = true;
|
||||
@ -47,7 +52,8 @@ StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
|
||||
std::vector<std::unique_ptr<PyBuffer>> outputs;
|
||||
outputs.reserve(output_buffers.size());
|
||||
for (auto& buffer : output_buffers) {
|
||||
outputs.push_back(std::make_unique<PyBuffer>(client_, std::move(buffer)));
|
||||
outputs.push_back(
|
||||
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
@ -55,6 +61,7 @@ StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
|
||||
StatusOr<std::vector<std::vector<std::unique_ptr<PyBuffer>>>>
|
||||
PyExecutable::ExecuteOnLocalDevices(
|
||||
absl::Span<const std::vector<PyBuffer*>> args) {
|
||||
auto traceback = TracebackManager::Get()->GetTraceback();
|
||||
py::gil_scoped_release gil_release;
|
||||
ExecuteOptions options;
|
||||
options.untuple_result = true;
|
||||
@ -73,7 +80,7 @@ PyExecutable::ExecuteOnLocalDevices(
|
||||
++computation) {
|
||||
for (auto& buffer : output_buffers[computation]) {
|
||||
outputs[computation].push_back(
|
||||
std::make_unique<PyBuffer>(client_, std::move(buffer)));
|
||||
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
|
@ -20,9 +20,11 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -35,7 +37,8 @@ namespace xla {
|
||||
class PyExecutable {
|
||||
public:
|
||||
PyExecutable(std::shared_ptr<PjRtClient> client,
|
||||
std::unique_ptr<PjRtExecutable> executable);
|
||||
std::unique_ptr<PjRtExecutable> executable,
|
||||
absl::optional<TracebackManager::Traceback> traceback);
|
||||
|
||||
std::shared_ptr<PjRtClient> client() const { return client_; }
|
||||
|
||||
@ -59,9 +62,14 @@ class PyExecutable {
|
||||
|
||||
StatusOr<std::vector<std::shared_ptr<HloModule>>> HloModules() const;
|
||||
|
||||
const absl::optional<TracebackManager::Traceback>& traceback() {
|
||||
return traceback_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<PjRtClient> client_;
|
||||
std::unique_ptr<PjRtExecutable> executable_;
|
||||
absl::optional<TracebackManager::Traceback> traceback_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -50,6 +50,13 @@ PythonRefManager::ManageReferences(absl::Span<py::object> objects) {
|
||||
return std::make_shared<ManagedPyObjects>(this, objects);
|
||||
}
|
||||
|
||||
void PythonRefManager::AddGarbage(absl::Span<py::object> garbage) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
for (py::object& o : garbage) {
|
||||
python_garbage_.push_back(std::move(o));
|
||||
}
|
||||
}
|
||||
|
||||
void PythonRefManager::CollectGarbage() {
|
||||
// TODO(phawkins): we should CHECK(PyGILState_Check());
|
||||
std::deque<pybind11::object> garbage;
|
||||
|
@ -66,6 +66,9 @@ class PythonRefManager {
|
||||
std::shared_ptr<ManagedPyObjects> ManageReferences(
|
||||
absl::Span<pybind11::object> objects);
|
||||
|
||||
// Adds garbage objects to the manager.
|
||||
void AddGarbage(absl::Span<pybind11::object> garbage);
|
||||
|
||||
// Releases the contents of python_garbage_. Requires that the GIL is held.
|
||||
// The client calls this method during API entry points where the GIL is held
|
||||
// to free any garbage that has accumulated.
|
||||
|
@ -173,9 +173,13 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
.def("shape", &PyTpuBuffer::on_host_shape)
|
||||
.def("device", &PyTpuBuffer::device)
|
||||
.def("platform", &PyTpuBuffer::platform_name)
|
||||
.def("is_deleted", [](const PyTpuBuffer& buffer) {
|
||||
return buffer.DeviceBuffer() == nullptr;
|
||||
});
|
||||
.def("is_deleted",
|
||||
[](const PyTpuBuffer& buffer) {
|
||||
return buffer.DeviceBuffer() == nullptr;
|
||||
})
|
||||
// TODO(phawkins): implement traceback support.
|
||||
.def_property_readonly("traceback",
|
||||
[](PyTpuBuffer*) { return py::none(); });
|
||||
|
||||
py::class_<PyTpuExecutable>(m, "TpuExecutable")
|
||||
.def("local_logical_device_ids",
|
||||
@ -193,7 +197,10 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
.def("execute", &PyTpuExecutable::Execute,
|
||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
|
||||
.def("execute_on_local_devices", &PyTpuExecutable::ExecuteOnLocalDevices,
|
||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
|
||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
|
||||
// TODO(phawkins): implement traceback support.
|
||||
.def_property_readonly("traceback",
|
||||
[](PyTpuExecutable*) { return py::none(); });
|
||||
|
||||
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
||||
.def_property_readonly("coords", &TpuDevice::coords)
|
||||
|
196
tensorflow/compiler/xla/python/traceback_manager.cc
Normal file
196
tensorflow/compiler/xla/python/traceback_manager.cc
Normal file
@ -0,0 +1,196 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
} // namespace
|
||||
|
||||
struct TracebackImpl {
|
||||
~TracebackImpl();
|
||||
|
||||
std::vector<TracebackManager::Frame> frames;
|
||||
|
||||
// Computes a traceback for the current Python thread. Requires the GIL.
|
||||
bool GetTracebackForCurrentThread();
|
||||
bool operator==(const TracebackImpl& other) const;
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
// We want Traceback objects to be safe to destroy without holding the GIL, so
|
||||
// we defer destruction of the strings.
|
||||
TracebackImpl::~TracebackImpl() {
|
||||
std::vector<py::object> objects;
|
||||
objects.reserve(2 * frames.size());
|
||||
for (TracebackManager::Frame& frame : frames) {
|
||||
objects.push_back(std::move(frame.file_name));
|
||||
objects.push_back(std::move(frame.function_name));
|
||||
}
|
||||
GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects));
|
||||
}
|
||||
|
||||
bool TracebackImpl::operator==(const TracebackImpl& other) const {
|
||||
if (frames.size() != other.frames.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < frames.size(); ++i) {
|
||||
// Python strings are compared using pointer equality. This is cheap and
|
||||
// does not require calling back into the Python interpreter, but may mean
|
||||
// we miss some opportunities for deduplication of TracebackImpl objects.
|
||||
// However, we expect that function and file names are drawn from a fixed
|
||||
// pool of constants.
|
||||
if (frames[i].file_name.ptr() != other.frames[i].file_name.ptr() ||
|
||||
frames[i].function_name.ptr() != other.frames[i].function_name.ptr() ||
|
||||
frames[i].line_num != other.frames[i].line_num ||
|
||||
frames[i].function_start_line != other.frames[i].function_start_line) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
H AbslHashValue(H h, const TracebackImpl& tb) {
|
||||
for (const TracebackManager::Frame& frame : tb.frames) {
|
||||
h = H::combine(std::move(h), frame.file_name.ptr(),
|
||||
frame.function_name.ptr(), frame.line_num);
|
||||
}
|
||||
return h;
|
||||
}
|
||||
bool TracebackImpl::GetTracebackForCurrentThread() {
|
||||
PyThreadState* thread_state = PyGILState_GetThisThreadState();
|
||||
if (!thread_state) {
|
||||
return false;
|
||||
}
|
||||
frames.reserve(32);
|
||||
for (PyFrameObject* py_frame = thread_state->frame; py_frame != nullptr;
|
||||
py_frame = py_frame->f_back) {
|
||||
frames.resize(frames.size() + 1);
|
||||
TracebackManager::Frame& frame = frames.back();
|
||||
PyCodeObject* code = py_frame->f_code;
|
||||
if (!code) {
|
||||
return false;
|
||||
}
|
||||
frame.line_num = PyFrame_GetLineNumber(py_frame);
|
||||
frame.file_name = py::str(code->co_filename);
|
||||
frame.function_name = py::str(code->co_name);
|
||||
frame.function_start_line = code->co_firstlineno;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string TracebackImpl::ToString() const {
|
||||
std::vector<std::string> frame_strs;
|
||||
frame_strs.reserve(frames.size());
|
||||
for (const TracebackManager::Frame& frame : frames) {
|
||||
frame_strs.push_back(absl::StrFormat("%s:%d (%s)", frame.file_name,
|
||||
frame.line_num, frame.function_name));
|
||||
}
|
||||
return absl::StrJoin(frame_strs, "\n");
|
||||
}
|
||||
|
||||
TracebackManager::Traceback::Traceback(
|
||||
TracebackManager* manager, std::pair<TracebackImpl const, int>* impl)
|
||||
: manager_(manager), impl_(impl) {
|
||||
DCHECK(manager_);
|
||||
++impl->second;
|
||||
}
|
||||
|
||||
TracebackManager::Traceback::~Traceback() {
|
||||
if (manager_) {
|
||||
--impl_->second;
|
||||
if (impl_->second == 0) {
|
||||
manager_->tracebacks_.erase(impl_->first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TracebackManager::Traceback::Traceback(const Traceback& other)
|
||||
: manager_(other.manager_), impl_(other.impl_) {
|
||||
if (manager_) {
|
||||
++impl_->second;
|
||||
}
|
||||
}
|
||||
|
||||
TracebackManager::Traceback::Traceback(Traceback&& other)
|
||||
: manager_(other.manager_), impl_(other.impl_) {
|
||||
other.manager_ = nullptr;
|
||||
other.impl_ = nullptr;
|
||||
}
|
||||
|
||||
TracebackManager::Traceback& TracebackManager::Traceback::operator=(
|
||||
const TracebackManager::Traceback& other) {
|
||||
manager_ = other.manager_;
|
||||
impl_ = other.impl_;
|
||||
if (manager_) {
|
||||
++impl_->second;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
TracebackManager::Traceback& TracebackManager::Traceback::operator=(
|
||||
TracebackManager::Traceback&& other) {
|
||||
std::swap(manager_, other.manager_);
|
||||
std::swap(impl_, other.impl_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string TracebackManager::Traceback::ToString() const {
|
||||
// We require the GIL because we manipulate Python strings.
|
||||
CHECK(PyGILState_Check());
|
||||
if (!manager_) {
|
||||
// Don't crash if called on a default-constructed Traceback.
|
||||
return "<unknown>";
|
||||
}
|
||||
return impl_->first.ToString();
|
||||
}
|
||||
|
||||
const std::vector<TracebackManager::Frame>*
|
||||
TracebackManager::Traceback::Frames() const {
|
||||
return &impl_->first.frames;
|
||||
}
|
||||
|
||||
/*static*/ TracebackManager* TracebackManager::Get() {
|
||||
static TracebackManager* manager = new TracebackManager;
|
||||
return manager;
|
||||
}
|
||||
|
||||
TracebackManager::TracebackManager() = default;
|
||||
TracebackManager::~TracebackManager() = default;
|
||||
|
||||
absl::optional<TracebackManager::Traceback> TracebackManager::GetTraceback() {
|
||||
if (!enabled_) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
CHECK(PyGILState_Check());
|
||||
TracebackImpl impl;
|
||||
if (!impl.GetTracebackForCurrentThread()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
auto it = tracebacks_.emplace(impl, 0);
|
||||
return Traceback(this, &*it.first);
|
||||
}
|
||||
|
||||
void TracebackManager::SetEnabled(bool enabled) { enabled_ = enabled; }
|
||||
|
||||
} // namespace xla
|
102
tensorflow/compiler/xla/python/traceback_manager.h
Normal file
102
tensorflow/compiler/xla/python/traceback_manager.h
Normal file
@ -0,0 +1,102 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TRACEBACK_MANAGER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_TRACEBACK_MANAGER_H_
|
||||
|
||||
#include "absl/container/node_hash_map.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
struct TracebackImpl;
|
||||
|
||||
template <typename H>
|
||||
H AbslHashValue(H h, const TracebackImpl& tb);
|
||||
|
||||
// Traceback manager class that deduplicates traceback objects to save memory.
|
||||
// It probably does not save time to deduplicate tracebacks, but we expect to
|
||||
// see many copies of the same tracebacks and hence we deduplicate in an attempt
|
||||
// to save memory.
|
||||
class TracebackManager {
|
||||
public:
|
||||
static TracebackManager* Get();
|
||||
|
||||
~TracebackManager();
|
||||
|
||||
TracebackManager(const TracebackManager&) = delete;
|
||||
TracebackManager(TracebackManager&&) = delete;
|
||||
TracebackManager& operator=(const TracebackManager&) = delete;
|
||||
TracebackManager& operator=(TracebackManager&&) = delete;
|
||||
|
||||
struct Frame {
|
||||
pybind11::str file_name;
|
||||
pybind11::str function_name;
|
||||
unsigned int line_num;
|
||||
int function_start_line;
|
||||
};
|
||||
|
||||
// RAII class that holds a reference to a traceback.
|
||||
class Traceback {
|
||||
public:
|
||||
Traceback() = default;
|
||||
~Traceback();
|
||||
|
||||
Traceback(const Traceback&);
|
||||
Traceback(Traceback&&);
|
||||
Traceback& operator=(const Traceback&);
|
||||
Traceback& operator=(Traceback&&);
|
||||
|
||||
// Requires the GIL be held.
|
||||
std::string ToString() const;
|
||||
|
||||
// Returns the stack frame objects, in order from innermost to outermost.
|
||||
const std::vector<Frame>* Frames() const;
|
||||
|
||||
private:
|
||||
friend class TracebackManager;
|
||||
|
||||
Traceback(TracebackManager* manager,
|
||||
std::pair<TracebackImpl const, int>* impl);
|
||||
|
||||
// nullptr for a default-constructed Traceback, non-null otherwise.
|
||||
TracebackManager* manager_ = nullptr;
|
||||
// Points to an entry in tracebacks_. Not owned.
|
||||
std::pair<TracebackImpl const, int>* impl_ = nullptr;
|
||||
};
|
||||
|
||||
// Returns a Traceback for the current thread. Returns nullopt if tracebacks
|
||||
// aren't enabled,
|
||||
absl::optional<Traceback> GetTraceback();
|
||||
|
||||
// Enables or disables traceback collection.
|
||||
void SetEnabled(bool enabled);
|
||||
bool enabled() const { return enabled_; }
|
||||
|
||||
private:
|
||||
TracebackManager();
|
||||
|
||||
bool enabled_ = false;
|
||||
|
||||
// Deduplicated tracebacks. Map from traceback to reference count.
|
||||
// The map and its contents are protected by the GIL, which is why we do not
|
||||
// need an atomic integer for the reference count.
|
||||
absl::node_hash_map<TracebackImpl, int> tracebacks_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_TRACEBACK_MANAGER_H_
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "pybind11/numpy.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
@ -48,6 +49,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||
#include "tensorflow/compiler/xla/python/py_executable.h"
|
||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/traceback_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
@ -598,13 +600,16 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
|
||||
GlobalPyRefManager()->ManageReference(std::move(c->array));
|
||||
|
||||
auto traceback = TracebackManager::Get()->GetTraceback();
|
||||
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PjRtBuffer> buffer,
|
||||
PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
|
||||
std::move(py_buffer_ref), client.get(),
|
||||
device));
|
||||
return std::make_unique<PyBuffer>(std::move(client), std::move(buffer));
|
||||
return std::make_unique<PyBuffer>(std::move(client), std::move(buffer),
|
||||
traceback);
|
||||
},
|
||||
py::arg("argument"), py::arg("device") = nullptr,
|
||||
py::arg("force_copy") = false);
|
||||
@ -612,12 +617,13 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
"compile",
|
||||
[](std::shared_ptr<PjRtClient> client, const XlaComputation& computation,
|
||||
CompileOptions options) -> StatusOr<std::unique_ptr<PyExecutable>> {
|
||||
auto traceback = TracebackManager::Get()->GetTraceback();
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
|
||||
PjRtExecutable::Compile(computation, client.get(),
|
||||
std::move(options)));
|
||||
return std::make_unique<PyExecutable>(std::move(client),
|
||||
std::move(executable));
|
||||
return std::make_unique<PyExecutable>(
|
||||
std::move(client), std::move(executable), std::move(traceback));
|
||||
},
|
||||
py::arg("computation"), py::arg("compile_options") = CompileOptions());
|
||||
|
||||
@ -628,6 +634,40 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
py::arg("allocator_config") = GpuAllocatorConfig(),
|
||||
py::arg("distributed_client") = nullptr, py::arg("node_id") = 0);
|
||||
|
||||
py::class_<TracebackManager::Frame>(m, "Frame")
|
||||
.def_readonly("file_name", &TracebackManager::Frame::file_name)
|
||||
.def_readonly("function_name", &TracebackManager::Frame::function_name)
|
||||
.def_readonly("function_start_line",
|
||||
&TracebackManager::Frame::function_start_line)
|
||||
.def_readonly("line_num", &TracebackManager::Frame::line_num)
|
||||
.def("__repr__", [](const TracebackManager::Frame& frame) {
|
||||
return absl::StrFormat("%s;%s:%d", frame.function_name, frame.file_name,
|
||||
frame.line_num);
|
||||
});
|
||||
py::bind_vector<std::vector<TracebackManager::Frame>>(m, "FrameVector");
|
||||
|
||||
py::class_<TracebackManager::Traceback> traceback(
|
||||
m, "Traceback", "Represents a Python stack trace.");
|
||||
traceback.def_property_static(
|
||||
"enabled",
|
||||
[](py::object /* cls */) { return TracebackManager::Get()->enabled(); },
|
||||
[](py::object /* cls */, bool enabled) {
|
||||
return TracebackManager::Get()->SetEnabled(enabled);
|
||||
});
|
||||
traceback.def_static(
|
||||
"get_traceback", []() { return TracebackManager::Get()->GetTraceback(); },
|
||||
R"doc(
|
||||
Returns a :class:`Traceback` for the current thread.
|
||||
|
||||
If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` object
|
||||
that describes the Python stack of the calling thread. Stack trace
|
||||
collection has a small overhead, so it is disabled by default. If traceback
|
||||
collection is disabled, returns ``None``.
|
||||
)doc");
|
||||
traceback.def_property_readonly("frames",
|
||||
&TracebackManager::Traceback::Frames);
|
||||
traceback.def("__str__", &TracebackManager::Traceback::ToString);
|
||||
|
||||
py::class_<PyBuffer, std::unique_ptr<PyBuffer>> buffer(m, "Buffer");
|
||||
// TODO(phawkins): alias for backward compatibility. Remove after JAX no
|
||||
// longer uses this name.
|
||||
@ -668,7 +708,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def("is_deleted", &PyBuffer::is_deleted)
|
||||
.def("unsafe_buffer_pointer", &PyBuffer::UnsafeBufferPointer)
|
||||
.def_property_readonly("__cuda_array_interface__",
|
||||
&PyBuffer::CudaArrayInterface);
|
||||
&PyBuffer::CudaArrayInterface)
|
||||
.def_property_readonly("traceback", &PyBuffer::traceback);
|
||||
|
||||
// pybind11's implementation of the buffer protocol doesn't allow for correct
|
||||
// error handling. We bypass it and implement the buffer protocol ourselves.
|
||||
@ -686,7 +727,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def("execute", &PyExecutable::Execute, py::arg("arguments"))
|
||||
.def("execute_on_local_devices", &PyExecutable::ExecuteOnLocalDevices,
|
||||
py::arg("arguments"))
|
||||
.def("hlo_modules", &PyExecutable::HloModules);
|
||||
.def("hlo_modules", &PyExecutable::HloModules)
|
||||
.def_property_readonly("traceback", &PyExecutable::traceback);
|
||||
|
||||
py::class_<DebugOptions>(m, "DebugOptions")
|
||||
.def("__repr__", &DebugOptions::DebugString)
|
||||
@ -927,6 +969,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
|
||||
m.def("get_distributed_runtime_service", &GetDistributedRuntimeService);
|
||||
m.def("get_distributed_runtime_client", &GetDistributedRuntimeClient);
|
||||
|
||||
m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); });
|
||||
} // NOLINT(readability/fn_size)
|
||||
|
||||
} // namespace xla
|
||||
|
@ -19,7 +19,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import atexit
|
||||
import collections
|
||||
import contextlib
|
||||
import enum # pylint: disable=g-bad-import-order
|
||||
import inspect
|
||||
import os
|
||||
@ -406,6 +408,7 @@ def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
|
||||
XlaBuilder = _xla.XlaBuilder
|
||||
XlaComputation = _xla.XlaComputation
|
||||
FftType = _xla.FftType
|
||||
Client = _xla.LocalClient
|
||||
Buffer = _xla.Buffer
|
||||
Executable = _xla.Executable
|
||||
|
||||
@ -663,3 +666,22 @@ def make_replica_groups(replica_groups):
|
||||
_make_replica_group_proto(group) for group in replica_groups
|
||||
]
|
||||
return replica_groups_protos
|
||||
|
||||
|
||||
Traceback = _xla.Traceback
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tracebacks(enabled=True):
|
||||
"""Context manager that enables or disables traceback collection."""
|
||||
saved = Traceback.enabled
|
||||
Traceback.enabled = enabled
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
Traceback.enabled = saved
|
||||
|
||||
|
||||
# Perform one last garbage collection of deferred Python references. This is
|
||||
# mostly to keep ASAN happy.
|
||||
atexit.register(_xla.collect_garbage)
|
||||
|
@ -2038,6 +2038,67 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
del server
|
||||
|
||||
tests.append(ProfilerTest)
|
||||
|
||||
class TracebackTest(absltest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TracebackTest, self).setUp()
|
||||
self.backend = xla_backend()
|
||||
|
||||
def testNoTracebacksIfDisabled(self):
|
||||
with xla_client.tracebacks(enabled=False):
|
||||
self.assertEqual(None, xla_client.Traceback.get_traceback())
|
||||
buffer = self.backend.buffer_from_pyval(np.array(7, np.int32))
|
||||
self.assertEqual(None, buffer.traceback)
|
||||
|
||||
b = xla_client.XlaBuilder("computation")
|
||||
ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
|
||||
e = self.backend.compile(b.build())
|
||||
self.assertEqual(None, e.traceback)
|
||||
|
||||
def assertIsTracebackContaining(self, tb, function):
|
||||
self.assertIsInstance(tb, xla_client.Traceback)
|
||||
self.assertIn(function, str(tb))
|
||||
self.assertTrue(any(f.function_name == function for f in tb.frames))
|
||||
|
||||
def testTracebacks(self):
|
||||
with xla_client.tracebacks(enabled=True):
|
||||
tb = xla_client.Traceback.get_traceback()
|
||||
self.assertIsTracebackContaining(tb, "testTracebacks")
|
||||
|
||||
# Tracebacks are not implemented on the TPU driver extension's variant
|
||||
# of buffers and executables.
|
||||
if not isinstance(self.backend, xla_client.Client):
|
||||
return
|
||||
|
||||
buffer = self.backend.buffer_from_pyval(np.array(7, np.int32))
|
||||
self.assertIsTracebackContaining(buffer.traceback, "testTracebacks")
|
||||
|
||||
b = xla_client.XlaBuilder("computation")
|
||||
ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
|
||||
e = self.backend.compile(b.build())
|
||||
self.assertIsTracebackContaining(e.traceback, "testTracebacks")
|
||||
|
||||
def testNestedFunction(self):
|
||||
|
||||
def AFunction():
|
||||
|
||||
def AnotherFunction():
|
||||
return xla_client.Traceback.get_traceback()
|
||||
|
||||
return AnotherFunction()
|
||||
|
||||
with xla_client.tracebacks(enabled=True):
|
||||
tb = AFunction()
|
||||
self.assertIsInstance(tb, xla_client.Traceback)
|
||||
frames = tb.frames
|
||||
i = next(
|
||||
i for (i, f) in enumerate(frames) if f.function_name == "AFunction")
|
||||
self.assertEqual(frames[i - 1].function_name, "AnotherFunction")
|
||||
self.assertEqual(frames[i + 1].function_name, "testNestedFunction")
|
||||
|
||||
tests.append(TracebackTest)
|
||||
|
||||
return tests
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user