From 7df1f4ffc371ef5730035375761b60a0d1006a1f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 4 Jun 2020 13:34:34 -0700 Subject: [PATCH] [XLA:Python] Add support for collecting Python tracebacks. Adds a new xla_client.Traceback API that can collect Python tracebacks cheaply (~2us). This is several orders of magnitude cheaper than using the Python `inspect.stack` API. Add a facility to attach a traceback to every buffer and executable object describing its creation context. To avoid paying a runtime cost when not debugging, tracebacks collection is optional and disabled by default. PiperOrigin-RevId: 314793876 Change-Id: Ie5f509364065739c1da4d1a8a729c4c6f56e2d03 --- tensorflow/compiler/xla/python/BUILD | 26 +++ tensorflow/compiler/xla/python/dlpack.cc | 4 +- tensorflow/compiler/xla/python/py_buffer.cc | 11 +- tensorflow/compiler/xla/python/py_buffer.h | 10 +- .../compiler/xla/python/py_executable.cc | 17 +- .../compiler/xla/python/py_executable.h | 10 +- .../compiler/xla/python/python_ref_manager.cc | 7 + .../compiler/xla/python/python_ref_manager.h | 3 + .../tpu_driver/client/tpu_client_extension.cc | 15 +- .../compiler/xla/python/traceback_manager.cc | 196 ++++++++++++++++++ .../compiler/xla/python/traceback_manager.h | 102 +++++++++ tensorflow/compiler/xla/python/xla.cc | 54 ++++- tensorflow/compiler/xla/python/xla_client.py | 22 ++ .../compiler/xla/python/xla_client_test.py | 61 ++++++ 14 files changed, 518 insertions(+), 20 deletions(-) create mode 100644 tensorflow/compiler/xla/python/traceback_manager.cc create mode 100644 tensorflow/compiler/xla/python/traceback_manager.h diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index a195d0eaaea..0fc96cad121 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -123,6 +123,26 @@ cc_library( ], ) +cc_library( + name = "traceback_manager", + srcs = ["traceback_manager.cc"], + hdrs = ["traceback_manager.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":python_ref_manager", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + "@pybind11", + ], +) + cc_library( name = "bfloat16", srcs = ["bfloat16.cc"], @@ -170,10 +190,12 @@ cc_library( features = ["-use_header_modules"], deps = [ ":python_ref_manager", + ":traceback_manager", ":types", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/pjrt:pjrt_client", + "@com_google_absl//absl/types:optional", ], ) @@ -188,11 +210,13 @@ cc_library( features = ["-use_header_modules"], deps = [ ":py_buffer", + ":traceback_manager", ":types", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/pjrt:pjrt_client", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -208,6 +232,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":py_buffer", + ":traceback_manager", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/pjrt:pjrt_client", @@ -333,6 +358,7 @@ pybind_extension( ":py_executable", ":python_ref_manager", ":outfeed_receiver_py", + ":traceback_manager", ":types", "@com_google_absl//absl/base", "@com_google_absl//absl/hash", diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 044710dcd3e..372935b9893 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/types/span.h" #include "include/dlpack/dlpack.h" // from @dlpack #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/python/traceback_manager.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/stream_executor/cuda/cuda_platform_id.h" @@ -349,7 +350,8 @@ StatusOr> DLPackManagedTensorToBuffer( PyCapsule_SetDestructor(tensor.ptr(), nullptr); auto pjrt_buffer = std::make_unique( shape, shape, std::move(device_buffer), client.get(), device); - return std::make_unique(std::move(client), std::move(pjrt_buffer)); + return std::make_unique(std::move(client), std::move(pjrt_buffer), + TracebackManager::Get()->GetTraceback()); } } // namespace xla diff --git a/tensorflow/compiler/xla/python/py_buffer.cc b/tensorflow/compiler/xla/python/py_buffer.cc index 1bb20d52b90..09e2fd4ff87 100644 --- a/tensorflow/compiler/xla/python/py_buffer.cc +++ b/tensorflow/compiler/xla/python/py_buffer.cc @@ -22,8 +22,11 @@ namespace xla { namespace py = pybind11; PyBuffer::PyBuffer(std::shared_ptr client, - std::unique_ptr buffer) - : client_(std::move(client)), buffer_(std::move(buffer)) {} + std::unique_ptr buffer, + absl::optional traceback) + : client_(std::move(client)), + buffer_(std::move(buffer)), + traceback_(std::move(traceback)) {} ClientAndPtr PyBuffer::device() const { return WrapWithClient(client_, buffer_->device()); @@ -33,10 +36,12 @@ StatusOr> PyBuffer::CopyToDevice( const ClientAndPtr& dst_device) const { CHECK(dst_device.get() != nullptr); GlobalPyRefManager()->CollectGarbage(); + auto traceback = TracebackManager::Get()->GetTraceback(); py::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(std::unique_ptr out, buffer_->CopyToDevice(dst_device.get())); - return std::make_unique(dst_device.client, std::move(out)); + return std::make_unique(dst_device.client, std::move(out), + traceback); } Status PyBuffer::BlockHostUntilReady() { diff --git a/tensorflow/compiler/xla/python/py_buffer.h b/tensorflow/compiler/xla/python/py_buffer.h index 56767b3888f..1803c0b3e77 100644 --- a/tensorflow/compiler/xla/python/py_buffer.h +++ b/tensorflow/compiler/xla/python/py_buffer.h @@ -19,7 +19,9 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/python/traceback_manager.h" #include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -32,7 +34,8 @@ namespace xla { class PyBuffer { public: PyBuffer(std::shared_ptr client, - std::unique_ptr buffer); + std::unique_ptr buffer, + absl::optional traceback); std::shared_ptr client() const { return client_; } PjRtBuffer* buffer() const { return buffer_.get(); } @@ -60,9 +63,14 @@ class PyBuffer { // PEP 3118 Python buffer protocol implementation. static PyBufferProcs* BufferProtocol(); + const absl::optional& traceback() { + return traceback_; + } + private: std::shared_ptr client_; std::unique_ptr buffer_; + absl::optional traceback_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/python/py_executable.cc b/tensorflow/compiler/xla/python/py_executable.cc index bf28970bd92..9c09714cdbe 100644 --- a/tensorflow/compiler/xla/python/py_executable.cc +++ b/tensorflow/compiler/xla/python/py_executable.cc @@ -21,9 +21,13 @@ namespace xla { namespace py = pybind11; -PyExecutable::PyExecutable(std::shared_ptr client, - std::unique_ptr executable) - : client_(std::move(client)), executable_(std::move(executable)) {} +PyExecutable::PyExecutable( + std::shared_ptr client, + std::unique_ptr executable, + absl::optional traceback) + : client_(std::move(client)), + executable_(std::move(executable)), + traceback_(std::move(traceback)) {} std::vector> PyExecutable::LocalDevices() const { std::vector> devices; @@ -36,6 +40,7 @@ std::vector> PyExecutable::LocalDevices() const { StatusOr>> PyExecutable::Execute( absl::Span args) { + auto traceback = TracebackManager::Get()->GetTraceback(); py::gil_scoped_release gil_release; ExecuteOptions options; options.untuple_result = true; @@ -47,7 +52,8 @@ StatusOr>> PyExecutable::Execute( std::vector> outputs; outputs.reserve(output_buffers.size()); for (auto& buffer : output_buffers) { - outputs.push_back(std::make_unique(client_, std::move(buffer))); + outputs.push_back( + std::make_unique(client_, std::move(buffer), traceback)); } return outputs; } @@ -55,6 +61,7 @@ StatusOr>> PyExecutable::Execute( StatusOr>>> PyExecutable::ExecuteOnLocalDevices( absl::Span> args) { + auto traceback = TracebackManager::Get()->GetTraceback(); py::gil_scoped_release gil_release; ExecuteOptions options; options.untuple_result = true; @@ -73,7 +80,7 @@ PyExecutable::ExecuteOnLocalDevices( ++computation) { for (auto& buffer : output_buffers[computation]) { outputs[computation].push_back( - std::make_unique(client_, std::move(buffer))); + std::make_unique(client_, std::move(buffer), traceback)); } } return outputs; diff --git a/tensorflow/compiler/xla/python/py_executable.h b/tensorflow/compiler/xla/python/py_executable.h index 6dbc4bf089b..12b5325b8fb 100644 --- a/tensorflow/compiler/xla/python/py_executable.h +++ b/tensorflow/compiler/xla/python/py_executable.h @@ -20,9 +20,11 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/py_buffer.h" +#include "tensorflow/compiler/xla/python/traceback_manager.h" #include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -35,7 +37,8 @@ namespace xla { class PyExecutable { public: PyExecutable(std::shared_ptr client, - std::unique_ptr executable); + std::unique_ptr executable, + absl::optional traceback); std::shared_ptr client() const { return client_; } @@ -59,9 +62,14 @@ class PyExecutable { StatusOr>> HloModules() const; + const absl::optional& traceback() { + return traceback_; + } + private: std::shared_ptr client_; std::unique_ptr executable_; + absl::optional traceback_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/python/python_ref_manager.cc b/tensorflow/compiler/xla/python/python_ref_manager.cc index cf449801205..a815fa8ee0d 100644 --- a/tensorflow/compiler/xla/python/python_ref_manager.cc +++ b/tensorflow/compiler/xla/python/python_ref_manager.cc @@ -50,6 +50,13 @@ PythonRefManager::ManageReferences(absl::Span objects) { return std::make_shared(this, objects); } +void PythonRefManager::AddGarbage(absl::Span garbage) { + absl::MutexLock lock(&mu_); + for (py::object& o : garbage) { + python_garbage_.push_back(std::move(o)); + } +} + void PythonRefManager::CollectGarbage() { // TODO(phawkins): we should CHECK(PyGILState_Check()); std::deque garbage; diff --git a/tensorflow/compiler/xla/python/python_ref_manager.h b/tensorflow/compiler/xla/python/python_ref_manager.h index 0ad533c695f..d9228118e20 100644 --- a/tensorflow/compiler/xla/python/python_ref_manager.h +++ b/tensorflow/compiler/xla/python/python_ref_manager.h @@ -66,6 +66,9 @@ class PythonRefManager { std::shared_ptr ManageReferences( absl::Span objects); + // Adds garbage objects to the manager. + void AddGarbage(absl::Span garbage); + // Releases the contents of python_garbage_. Requires that the GIL is held. // The client calls this method during API entry points where the GIL is held // to free any garbage that has accumulated. diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index f44d69656e6..9a794b79c5c 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -173,9 +173,13 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def("shape", &PyTpuBuffer::on_host_shape) .def("device", &PyTpuBuffer::device) .def("platform", &PyTpuBuffer::platform_name) - .def("is_deleted", [](const PyTpuBuffer& buffer) { - return buffer.DeviceBuffer() == nullptr; - }); + .def("is_deleted", + [](const PyTpuBuffer& buffer) { + return buffer.DeviceBuffer() == nullptr; + }) + // TODO(phawkins): implement traceback support. + .def_property_readonly("traceback", + [](PyTpuBuffer*) { return py::none(); }); py::class_(m, "TpuExecutable") .def("local_logical_device_ids", @@ -193,7 +197,10 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def("execute", &PyTpuExecutable::Execute, py::call_guard(), py::arg("arguments")) .def("execute_on_local_devices", &PyTpuExecutable::ExecuteOnLocalDevices, - py::call_guard(), py::arg("arguments")); + py::call_guard(), py::arg("arguments")) + // TODO(phawkins): implement traceback support. + .def_property_readonly("traceback", + [](PyTpuExecutable*) { return py::none(); }); py::class_>(m, "TpuDevice") .def_property_readonly("coords", &TpuDevice::coords) diff --git a/tensorflow/compiler/xla/python/traceback_manager.cc b/tensorflow/compiler/xla/python/traceback_manager.cc new file mode 100644 index 00000000000..daf5c693e25 --- /dev/null +++ b/tensorflow/compiler/xla/python/traceback_manager.cc @@ -0,0 +1,196 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/traceback_manager.h" + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/python/python_ref_manager.h" + +namespace xla { +namespace { + +namespace py = pybind11; + +} // namespace + +struct TracebackImpl { + ~TracebackImpl(); + + std::vector frames; + + // Computes a traceback for the current Python thread. Requires the GIL. + bool GetTracebackForCurrentThread(); + bool operator==(const TracebackImpl& other) const; + std::string ToString() const; +}; + +// We want Traceback objects to be safe to destroy without holding the GIL, so +// we defer destruction of the strings. +TracebackImpl::~TracebackImpl() { + std::vector objects; + objects.reserve(2 * frames.size()); + for (TracebackManager::Frame& frame : frames) { + objects.push_back(std::move(frame.file_name)); + objects.push_back(std::move(frame.function_name)); + } + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); +} + +bool TracebackImpl::operator==(const TracebackImpl& other) const { + if (frames.size() != other.frames.size()) { + return false; + } + for (int i = 0; i < frames.size(); ++i) { + // Python strings are compared using pointer equality. This is cheap and + // does not require calling back into the Python interpreter, but may mean + // we miss some opportunities for deduplication of TracebackImpl objects. + // However, we expect that function and file names are drawn from a fixed + // pool of constants. + if (frames[i].file_name.ptr() != other.frames[i].file_name.ptr() || + frames[i].function_name.ptr() != other.frames[i].function_name.ptr() || + frames[i].line_num != other.frames[i].line_num || + frames[i].function_start_line != other.frames[i].function_start_line) { + return false; + } + } + return true; +} + +template +H AbslHashValue(H h, const TracebackImpl& tb) { + for (const TracebackManager::Frame& frame : tb.frames) { + h = H::combine(std::move(h), frame.file_name.ptr(), + frame.function_name.ptr(), frame.line_num); + } + return h; +} +bool TracebackImpl::GetTracebackForCurrentThread() { + PyThreadState* thread_state = PyGILState_GetThisThreadState(); + if (!thread_state) { + return false; + } + frames.reserve(32); + for (PyFrameObject* py_frame = thread_state->frame; py_frame != nullptr; + py_frame = py_frame->f_back) { + frames.resize(frames.size() + 1); + TracebackManager::Frame& frame = frames.back(); + PyCodeObject* code = py_frame->f_code; + if (!code) { + return false; + } + frame.line_num = PyFrame_GetLineNumber(py_frame); + frame.file_name = py::str(code->co_filename); + frame.function_name = py::str(code->co_name); + frame.function_start_line = code->co_firstlineno; + } + return true; +} + +std::string TracebackImpl::ToString() const { + std::vector frame_strs; + frame_strs.reserve(frames.size()); + for (const TracebackManager::Frame& frame : frames) { + frame_strs.push_back(absl::StrFormat("%s:%d (%s)", frame.file_name, + frame.line_num, frame.function_name)); + } + return absl::StrJoin(frame_strs, "\n"); +} + +TracebackManager::Traceback::Traceback( + TracebackManager* manager, std::pair* impl) + : manager_(manager), impl_(impl) { + DCHECK(manager_); + ++impl->second; +} + +TracebackManager::Traceback::~Traceback() { + if (manager_) { + --impl_->second; + if (impl_->second == 0) { + manager_->tracebacks_.erase(impl_->first); + } + } +} + +TracebackManager::Traceback::Traceback(const Traceback& other) + : manager_(other.manager_), impl_(other.impl_) { + if (manager_) { + ++impl_->second; + } +} + +TracebackManager::Traceback::Traceback(Traceback&& other) + : manager_(other.manager_), impl_(other.impl_) { + other.manager_ = nullptr; + other.impl_ = nullptr; +} + +TracebackManager::Traceback& TracebackManager::Traceback::operator=( + const TracebackManager::Traceback& other) { + manager_ = other.manager_; + impl_ = other.impl_; + if (manager_) { + ++impl_->second; + } + return *this; +} + +TracebackManager::Traceback& TracebackManager::Traceback::operator=( + TracebackManager::Traceback&& other) { + std::swap(manager_, other.manager_); + std::swap(impl_, other.impl_); + return *this; +} + +std::string TracebackManager::Traceback::ToString() const { + // We require the GIL because we manipulate Python strings. + CHECK(PyGILState_Check()); + if (!manager_) { + // Don't crash if called on a default-constructed Traceback. + return ""; + } + return impl_->first.ToString(); +} + +const std::vector* +TracebackManager::Traceback::Frames() const { + return &impl_->first.frames; +} + +/*static*/ TracebackManager* TracebackManager::Get() { + static TracebackManager* manager = new TracebackManager; + return manager; +} + +TracebackManager::TracebackManager() = default; +TracebackManager::~TracebackManager() = default; + +absl::optional TracebackManager::GetTraceback() { + if (!enabled_) { + return absl::nullopt; + } + CHECK(PyGILState_Check()); + TracebackImpl impl; + if (!impl.GetTracebackForCurrentThread()) { + return absl::nullopt; + } + auto it = tracebacks_.emplace(impl, 0); + return Traceback(this, &*it.first); +} + +void TracebackManager::SetEnabled(bool enabled) { enabled_ = enabled; } + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/traceback_manager.h b/tensorflow/compiler/xla/python/traceback_manager.h new file mode 100644 index 00000000000..274a22b0d87 --- /dev/null +++ b/tensorflow/compiler/xla/python/traceback_manager.h @@ -0,0 +1,102 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TRACEBACK_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_TRACEBACK_MANAGER_H_ + +#include "absl/container/node_hash_map.h" +#include "absl/types/optional.h" +#include "pybind11/pybind11.h" + +namespace xla { + +struct TracebackImpl; + +template +H AbslHashValue(H h, const TracebackImpl& tb); + +// Traceback manager class that deduplicates traceback objects to save memory. +// It probably does not save time to deduplicate tracebacks, but we expect to +// see many copies of the same tracebacks and hence we deduplicate in an attempt +// to save memory. +class TracebackManager { + public: + static TracebackManager* Get(); + + ~TracebackManager(); + + TracebackManager(const TracebackManager&) = delete; + TracebackManager(TracebackManager&&) = delete; + TracebackManager& operator=(const TracebackManager&) = delete; + TracebackManager& operator=(TracebackManager&&) = delete; + + struct Frame { + pybind11::str file_name; + pybind11::str function_name; + unsigned int line_num; + int function_start_line; + }; + + // RAII class that holds a reference to a traceback. + class Traceback { + public: + Traceback() = default; + ~Traceback(); + + Traceback(const Traceback&); + Traceback(Traceback&&); + Traceback& operator=(const Traceback&); + Traceback& operator=(Traceback&&); + + // Requires the GIL be held. + std::string ToString() const; + + // Returns the stack frame objects, in order from innermost to outermost. + const std::vector* Frames() const; + + private: + friend class TracebackManager; + + Traceback(TracebackManager* manager, + std::pair* impl); + + // nullptr for a default-constructed Traceback, non-null otherwise. + TracebackManager* manager_ = nullptr; + // Points to an entry in tracebacks_. Not owned. + std::pair* impl_ = nullptr; + }; + + // Returns a Traceback for the current thread. Returns nullopt if tracebacks + // aren't enabled, + absl::optional GetTraceback(); + + // Enables or disables traceback collection. + void SetEnabled(bool enabled); + bool enabled() const { return enabled_; } + + private: + TracebackManager(); + + bool enabled_ = false; + + // Deduplicated tracebacks. Map from traceback to reference count. + // The map and its contents are protected by the GIL, which is why we do not + // need an atomic integer for the reference count. + absl::node_hash_map tracebacks_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_TRACEBACK_MANAGER_H_ diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 7fa583eb993..d1f4972aace 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -29,6 +29,7 @@ limitations under the License. #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" +#include "pybind11/stl_bind.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -48,6 +49,7 @@ limitations under the License. #include "tensorflow/compiler/xla/python/py_buffer.h" #include "tensorflow/compiler/xla/python/py_executable.h" #include "tensorflow/compiler/xla/python/python_ref_manager.h" +#include "tensorflow/compiler/xla/python/traceback_manager.h" #include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" @@ -598,13 +600,16 @@ PYBIND11_MODULE(xla_extension, m) { std::shared_ptr py_buffer_ref = GlobalPyRefManager()->ManageReference(std::move(c->array)); + auto traceback = TracebackManager::Get()->GetTraceback(); + py::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN( std::unique_ptr buffer, PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy, std::move(py_buffer_ref), client.get(), device)); - return std::make_unique(std::move(client), std::move(buffer)); + return std::make_unique(std::move(client), std::move(buffer), + traceback); }, py::arg("argument"), py::arg("device") = nullptr, py::arg("force_copy") = false); @@ -612,12 +617,13 @@ PYBIND11_MODULE(xla_extension, m) { "compile", [](std::shared_ptr client, const XlaComputation& computation, CompileOptions options) -> StatusOr> { + auto traceback = TracebackManager::Get()->GetTraceback(); py::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(std::unique_ptr executable, PjRtExecutable::Compile(computation, client.get(), std::move(options))); - return std::make_unique(std::move(client), - std::move(executable)); + return std::make_unique( + std::move(client), std::move(executable), std::move(traceback)); }, py::arg("computation"), py::arg("compile_options") = CompileOptions()); @@ -628,6 +634,40 @@ PYBIND11_MODULE(xla_extension, m) { py::arg("allocator_config") = GpuAllocatorConfig(), py::arg("distributed_client") = nullptr, py::arg("node_id") = 0); + py::class_(m, "Frame") + .def_readonly("file_name", &TracebackManager::Frame::file_name) + .def_readonly("function_name", &TracebackManager::Frame::function_name) + .def_readonly("function_start_line", + &TracebackManager::Frame::function_start_line) + .def_readonly("line_num", &TracebackManager::Frame::line_num) + .def("__repr__", [](const TracebackManager::Frame& frame) { + return absl::StrFormat("%s;%s:%d", frame.function_name, frame.file_name, + frame.line_num); + }); + py::bind_vector>(m, "FrameVector"); + + py::class_ traceback( + m, "Traceback", "Represents a Python stack trace."); + traceback.def_property_static( + "enabled", + [](py::object /* cls */) { return TracebackManager::Get()->enabled(); }, + [](py::object /* cls */, bool enabled) { + return TracebackManager::Get()->SetEnabled(enabled); + }); + traceback.def_static( + "get_traceback", []() { return TracebackManager::Get()->GetTraceback(); }, + R"doc( + Returns a :class:`Traceback` for the current thread. + + If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` object + that describes the Python stack of the calling thread. Stack trace + collection has a small overhead, so it is disabled by default. If traceback + collection is disabled, returns ``None``. + )doc"); + traceback.def_property_readonly("frames", + &TracebackManager::Traceback::Frames); + traceback.def("__str__", &TracebackManager::Traceback::ToString); + py::class_> buffer(m, "Buffer"); // TODO(phawkins): alias for backward compatibility. Remove after JAX no // longer uses this name. @@ -668,7 +708,8 @@ PYBIND11_MODULE(xla_extension, m) { .def("is_deleted", &PyBuffer::is_deleted) .def("unsafe_buffer_pointer", &PyBuffer::UnsafeBufferPointer) .def_property_readonly("__cuda_array_interface__", - &PyBuffer::CudaArrayInterface); + &PyBuffer::CudaArrayInterface) + .def_property_readonly("traceback", &PyBuffer::traceback); // pybind11's implementation of the buffer protocol doesn't allow for correct // error handling. We bypass it and implement the buffer protocol ourselves. @@ -686,7 +727,8 @@ PYBIND11_MODULE(xla_extension, m) { .def("execute", &PyExecutable::Execute, py::arg("arguments")) .def("execute_on_local_devices", &PyExecutable::ExecuteOnLocalDevices, py::arg("arguments")) - .def("hlo_modules", &PyExecutable::HloModules); + .def("hlo_modules", &PyExecutable::HloModules) + .def_property_readonly("traceback", &PyExecutable::traceback); py::class_(m, "DebugOptions") .def("__repr__", &DebugOptions::DebugString) @@ -927,6 +969,8 @@ PYBIND11_MODULE(xla_extension, m) { m.def("get_distributed_runtime_service", &GetDistributedRuntimeService); m.def("get_distributed_runtime_client", &GetDistributedRuntimeClient); + + m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); }); } // NOLINT(readability/fn_size) } // namespace xla diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index de1f8a2591a..f4be78680c5 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -19,7 +19,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import atexit import collections +import contextlib import enum # pylint: disable=g-bad-import-order import inspect import os @@ -406,6 +408,7 @@ def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, XlaBuilder = _xla.XlaBuilder XlaComputation = _xla.XlaComputation FftType = _xla.FftType +Client = _xla.LocalClient Buffer = _xla.Buffer Executable = _xla.Executable @@ -663,3 +666,22 @@ def make_replica_groups(replica_groups): _make_replica_group_proto(group) for group in replica_groups ] return replica_groups_protos + + +Traceback = _xla.Traceback + + +@contextlib.contextmanager +def tracebacks(enabled=True): + """Context manager that enables or disables traceback collection.""" + saved = Traceback.enabled + Traceback.enabled = enabled + try: + yield + finally: + Traceback.enabled = saved + + +# Perform one last garbage collection of deferred Python references. This is +# mostly to keep ASAN happy. +atexit.register(_xla.collect_garbage) diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 000db2cb16b..0fc0bcae954 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -2038,6 +2038,67 @@ def TestFactory(xla_backend, cloud_tpu=False): del server tests.append(ProfilerTest) + + class TracebackTest(absltest.TestCase): + + def setUp(self): + super(TracebackTest, self).setUp() + self.backend = xla_backend() + + def testNoTracebacksIfDisabled(self): + with xla_client.tracebacks(enabled=False): + self.assertEqual(None, xla_client.Traceback.get_traceback()) + buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) + self.assertEqual(None, buffer.traceback) + + b = xla_client.XlaBuilder("computation") + ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) + e = self.backend.compile(b.build()) + self.assertEqual(None, e.traceback) + + def assertIsTracebackContaining(self, tb, function): + self.assertIsInstance(tb, xla_client.Traceback) + self.assertIn(function, str(tb)) + self.assertTrue(any(f.function_name == function for f in tb.frames)) + + def testTracebacks(self): + with xla_client.tracebacks(enabled=True): + tb = xla_client.Traceback.get_traceback() + self.assertIsTracebackContaining(tb, "testTracebacks") + + # Tracebacks are not implemented on the TPU driver extension's variant + # of buffers and executables. + if not isinstance(self.backend, xla_client.Client): + return + + buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) + self.assertIsTracebackContaining(buffer.traceback, "testTracebacks") + + b = xla_client.XlaBuilder("computation") + ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) + e = self.backend.compile(b.build()) + self.assertIsTracebackContaining(e.traceback, "testTracebacks") + + def testNestedFunction(self): + + def AFunction(): + + def AnotherFunction(): + return xla_client.Traceback.get_traceback() + + return AnotherFunction() + + with xla_client.tracebacks(enabled=True): + tb = AFunction() + self.assertIsInstance(tb, xla_client.Traceback) + frames = tb.frames + i = next( + i for (i, f) in enumerate(frames) if f.function_name == "AFunction") + self.assertEqual(frames[i - 1].function_name, "AnotherFunction") + self.assertEqual(frames[i + 1].function_name, "testNestedFunction") + + tests.append(TracebackTest) + return tests