[XLA:Python] Implement on-device heap profiling support for JAX.
Adds an `xla_client.heap_profile(client)` API which produces a gzip-compressed profile.proto protocol buffer containing an on-device heap profile, suitable for visualization via the pprof tool (https://github.com/google/pprof). The heap profile includes buffers and executables allocated by JAX. PiperOrigin-RevId: 316089755 Change-Id: I4cf3a17da7d9370b9fd08224d3577b7f860980ef
This commit is contained in:
parent
8bc26ba68e
commit
bd361bfaf9
@ -695,13 +695,6 @@ PjRtBuffer::~PjRtBuffer() {
|
||||
}
|
||||
}
|
||||
|
||||
int64 PjRtBuffer::OnDeviceSizeInBytes() const {
|
||||
return client_->client()
|
||||
->backend()
|
||||
.transfer_manager()
|
||||
->GetByteSizeRequirement(on_device_shape_);
|
||||
}
|
||||
|
||||
void PjRtBuffer::WaitForOutstandingUsageHolds() {
|
||||
auto not_in_usage_hold = [&]() {
|
||||
mu_.AssertHeld();
|
||||
|
@ -409,9 +409,6 @@ class PjRtBuffer {
|
||||
return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0;
|
||||
}
|
||||
|
||||
// Returns the size of the on-device representation of this buffer in bytes.
|
||||
int64 OnDeviceSizeInBytes() const;
|
||||
|
||||
// Returns the buffer's value as an XLA Literal. If the value has previously
|
||||
// been prefetched to the host, then returns the prefetched version, otherwise
|
||||
// copies the buffer to the host. Blocks until the value is ready.
|
||||
|
@ -202,9 +202,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/core/profiler:protos_all_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@pybind11",
|
||||
|
@ -25,29 +25,10 @@ namespace py = pybind11;
|
||||
|
||||
PyBuffer::PyBuffer(std::shared_ptr<PyClient> client,
|
||||
std::unique_ptr<PjRtBuffer> buffer,
|
||||
std::shared_ptr<Traceback> traceback)
|
||||
std::unique_ptr<Traceback> traceback)
|
||||
: client_(std::move(client)),
|
||||
buffer_(std::move(buffer)),
|
||||
traceback_(std::move(traceback)) {
|
||||
next_ = client_->buffers_;
|
||||
client_->buffers_ = this;
|
||||
prev_ = nullptr;
|
||||
if (next_) {
|
||||
next_->prev_ = this;
|
||||
}
|
||||
}
|
||||
|
||||
PyBuffer::~PyBuffer() {
|
||||
if (client_->buffers_ == this) {
|
||||
client_->buffers_ = next_;
|
||||
}
|
||||
if (prev_) {
|
||||
prev_->next_ = next_;
|
||||
}
|
||||
if (next_) {
|
||||
next_->prev_ = prev_;
|
||||
}
|
||||
}
|
||||
traceback_(std::move(traceback)) {}
|
||||
|
||||
ClientAndPtr<Device> PyBuffer::device() const {
|
||||
return WrapWithClient(client_, buffer_->device());
|
||||
|
@ -32,8 +32,7 @@ namespace xla {
|
||||
class PyBuffer {
|
||||
public:
|
||||
PyBuffer(std::shared_ptr<PyClient> client, std::unique_ptr<PjRtBuffer> buffer,
|
||||
std::shared_ptr<Traceback> traceback);
|
||||
~PyBuffer();
|
||||
std::unique_ptr<Traceback> traceback);
|
||||
|
||||
std::shared_ptr<PyClient> client() const { return client_; }
|
||||
PjRtBuffer* buffer() const { return buffer_.get(); }
|
||||
@ -64,16 +63,9 @@ class PyBuffer {
|
||||
Traceback* traceback() { return traceback_.get(); }
|
||||
|
||||
private:
|
||||
friend class PyClient;
|
||||
|
||||
std::shared_ptr<PyClient> client_;
|
||||
std::unique_ptr<PjRtBuffer> buffer_;
|
||||
std::shared_ptr<Traceback> traceback_;
|
||||
|
||||
// Doubly-linked list of all buffers known to the client. Protected by the
|
||||
// GIL.
|
||||
PyBuffer* next_;
|
||||
PyBuffer* prev_;
|
||||
std::unique_ptr<Traceback> traceback_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -15,18 +15,15 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/python/py_client.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||
#include "tensorflow/compiler/xla/python/py_executable.h"
|
||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/traceback.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/core/profiler/profile.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace pprof = tensorflow::tfprof::pprof;
|
||||
|
||||
PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
|
||||
: pjrt_client_(std::move(pjrt_client)) {}
|
||||
@ -130,151 +127,4 @@ StatusOr<std::unique_ptr<PyExecutable>> PyClient::Compile(
|
||||
shared_from_this(), std::move(executable), std::move(traceback));
|
||||
}
|
||||
|
||||
class ProfileBuilder {
|
||||
public:
|
||||
ProfileBuilder();
|
||||
pprof::Profile& profile() { return profile_; }
|
||||
|
||||
// Adds or returns the ID of `s` in the table.
|
||||
int StringId(const std::string& s);
|
||||
|
||||
// Adds or returns the ID of a function.
|
||||
int FunctionId(PyCodeObject* code);
|
||||
|
||||
// Adds or returns the ID of a code location.
|
||||
int LocationId(PyCodeObject* code, int instruction);
|
||||
|
||||
private:
|
||||
pprof::Profile profile_;
|
||||
|
||||
absl::flat_hash_map<std::string, int> strings_;
|
||||
absl::flat_hash_map<PyCodeObject*, int> functions_;
|
||||
absl::flat_hash_map<std::pair<PyCodeObject*, int>, int> locations_;
|
||||
};
|
||||
|
||||
ProfileBuilder::ProfileBuilder() { CHECK_EQ(0, StringId("")); }
|
||||
|
||||
int ProfileBuilder::StringId(const std::string& s) {
|
||||
auto ret = strings_.emplace(s, profile_.string_table_size());
|
||||
if (ret.second) {
|
||||
profile_.add_string_table(s);
|
||||
}
|
||||
return ret.first->second;
|
||||
}
|
||||
|
||||
int ProfileBuilder::FunctionId(PyCodeObject* code) {
|
||||
// +1 because id 0 is reserved.
|
||||
auto ret = functions_.emplace(code, profile_.function_size() + 1);
|
||||
if (ret.second) {
|
||||
auto* function = profile_.add_function();
|
||||
function->set_id(ret.first->second);
|
||||
int name = StringId(py::str(code->co_name));
|
||||
function->set_name(name);
|
||||
function->set_system_name(name);
|
||||
function->set_filename(StringId(py::str(code->co_filename)));
|
||||
function->set_start_line(code->co_firstlineno);
|
||||
}
|
||||
return ret.first->second;
|
||||
}
|
||||
|
||||
int ProfileBuilder::LocationId(PyCodeObject* code, int instruction) {
|
||||
// +1 because id 0 is reserved.
|
||||
auto ret = locations_.emplace(std::make_pair(code, instruction),
|
||||
profile_.location_size() + 1);
|
||||
if (ret.second) {
|
||||
auto* location = profile_.add_location();
|
||||
location->set_id(ret.first->second);
|
||||
auto* line = location->add_line();
|
||||
line->set_function_id(FunctionId(code));
|
||||
line->set_line(PyCode_Addr2Line(code, instruction));
|
||||
}
|
||||
return ret.first->second;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct HeapProfileKey {
|
||||
Traceback* traceback;
|
||||
int64 size;
|
||||
Device* device;
|
||||
bool operator==(const HeapProfileKey& other) const;
|
||||
};
|
||||
|
||||
bool HeapProfileKey::operator==(const HeapProfileKey& other) const {
|
||||
if (size != other.size || device != other.device) {
|
||||
return false;
|
||||
}
|
||||
if ((traceback == nullptr) != (other.traceback == nullptr)) {
|
||||
return false;
|
||||
}
|
||||
if (traceback && traceback->raw_frames() != other.traceback->raw_frames()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
H AbslHashValue(H h, const HeapProfileKey& key) {
|
||||
if (key.traceback) {
|
||||
h = H::combine_contiguous(std::move(h), key.traceback->raw_frames().begin(),
|
||||
key.traceback->raw_frames().size());
|
||||
}
|
||||
h = H::combine(std::move(h), key.size, key.device);
|
||||
return h;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
py::bytes PyClient::HeapProfile() {
|
||||
absl::flat_hash_map<HeapProfileKey, int64> entries;
|
||||
for (PyBuffer* buffer = buffers_; buffer; buffer = buffer->next_) {
|
||||
HeapProfileKey key{buffer->traceback(),
|
||||
buffer->buffer()->OnDeviceSizeInBytes(),
|
||||
buffer->buffer()->device()};
|
||||
++entries[key];
|
||||
}
|
||||
|
||||
for (PyExecutable* executable = executables_; executable;
|
||||
executable = executable->next_) {
|
||||
HeapProfileKey key{executable->traceback(),
|
||||
executable->SizeOfGeneratedCodeInBytes(), nullptr};
|
||||
++entries[key];
|
||||
}
|
||||
|
||||
ProfileBuilder builder;
|
||||
auto* allocations = builder.profile().add_sample_type();
|
||||
allocations->set_type(builder.StringId("allocations"));
|
||||
allocations->set_unit(builder.StringId("count"));
|
||||
auto* space = builder.profile().add_sample_type();
|
||||
space->set_type(builder.StringId("space"));
|
||||
space->set_unit(builder.StringId("bytes"));
|
||||
|
||||
const int kind_string_id = builder.StringId("kind");
|
||||
const int buffer_string_id = builder.StringId("buffer");
|
||||
const int executable_string_id = builder.StringId("executable");
|
||||
const int device_string_id = builder.StringId("device");
|
||||
for (const auto& entry : entries) {
|
||||
auto* sample = builder.profile().add_sample();
|
||||
if (entry.first.traceback) {
|
||||
for (const auto& frame : entry.first.traceback->raw_frames()) {
|
||||
sample->add_location_id(builder.LocationId(frame.first, frame.second));
|
||||
}
|
||||
}
|
||||
sample->add_value(entry.second);
|
||||
sample->add_value(entry.first.size * entry.second);
|
||||
|
||||
auto* kind_label = sample->add_label();
|
||||
kind_label->set_key(kind_string_id);
|
||||
if (entry.first.device) {
|
||||
kind_label->set_str(buffer_string_id);
|
||||
auto* device_label = sample->add_label();
|
||||
device_label->set_key(device_string_id);
|
||||
device_label->set_num(entry.first.device->id());
|
||||
} else {
|
||||
kind_label->set_str(executable_string_id);
|
||||
}
|
||||
}
|
||||
return builder.profile().SerializeAsString();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -125,19 +125,8 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||
StatusOr<std::unique_ptr<PyExecutable>> Compile(
|
||||
const XlaComputation& computation, CompileOptions options);
|
||||
|
||||
pybind11::bytes HeapProfile();
|
||||
|
||||
private:
|
||||
friend class PyBuffer;
|
||||
friend class PyExecutable;
|
||||
|
||||
std::shared_ptr<PjRtClient> pjrt_client_;
|
||||
|
||||
// Pointers to intrusive doubly-linked lists of buffers and executables, used
|
||||
// to iterate over all known objects when heap profiling. The list structure
|
||||
// is protected by the GIL.
|
||||
PyBuffer* buffers_ = nullptr;
|
||||
PyExecutable* executables_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -23,29 +23,10 @@ namespace py = pybind11;
|
||||
|
||||
PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
|
||||
std::unique_ptr<PjRtExecutable> executable,
|
||||
std::shared_ptr<Traceback> traceback)
|
||||
std::unique_ptr<Traceback> traceback)
|
||||
: client_(std::move(client)),
|
||||
executable_(std::move(executable)),
|
||||
traceback_(std::move(traceback)) {
|
||||
next_ = client_->executables_;
|
||||
client_->executables_ = this;
|
||||
prev_ = nullptr;
|
||||
if (next_) {
|
||||
next_->prev_ = this;
|
||||
}
|
||||
}
|
||||
|
||||
PyExecutable::~PyExecutable() {
|
||||
if (client_->executables_ == this) {
|
||||
client_->executables_ = next_;
|
||||
}
|
||||
if (prev_) {
|
||||
prev_->next_ = next_;
|
||||
}
|
||||
if (next_) {
|
||||
next_->prev_ = prev_;
|
||||
}
|
||||
}
|
||||
traceback_(std::move(traceback)) {}
|
||||
|
||||
std::vector<ClientAndPtr<Device>> PyExecutable::LocalDevices() const {
|
||||
std::vector<ClientAndPtr<Device>> devices;
|
||||
@ -70,8 +51,8 @@ StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
|
||||
std::vector<std::unique_ptr<PyBuffer>> outputs;
|
||||
outputs.reserve(output_buffers.size());
|
||||
for (auto& buffer : output_buffers) {
|
||||
outputs.push_back(
|
||||
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
|
||||
outputs.push_back(std::make_unique<PyBuffer>(client_, std::move(buffer),
|
||||
std::move(traceback)));
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
@ -97,8 +78,8 @@ PyExecutable::ExecuteOnLocalDevices(
|
||||
for (int computation = 0; computation < output_buffers.size();
|
||||
++computation) {
|
||||
for (auto& buffer : output_buffers[computation]) {
|
||||
outputs[computation].push_back(
|
||||
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
|
||||
outputs[computation].push_back(std::make_unique<PyBuffer>(
|
||||
client_, std::move(buffer), std::move(traceback)));
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
|
@ -37,8 +37,7 @@ class PyExecutable {
|
||||
public:
|
||||
PyExecutable(std::shared_ptr<PyClient> client,
|
||||
std::unique_ptr<PjRtExecutable> executable,
|
||||
std::shared_ptr<Traceback> traceback);
|
||||
~PyExecutable();
|
||||
std::unique_ptr<Traceback> traceback);
|
||||
|
||||
std::shared_ptr<PyClient> client() const { return client_; }
|
||||
|
||||
@ -65,16 +64,9 @@ class PyExecutable {
|
||||
Traceback* traceback() { return traceback_.get(); }
|
||||
|
||||
private:
|
||||
friend class PyClient;
|
||||
|
||||
std::shared_ptr<PyClient> client_;
|
||||
std::unique_ptr<PjRtExecutable> executable_;
|
||||
std::shared_ptr<Traceback> traceback_;
|
||||
|
||||
// Doubly-linked list of all executables known to the client. Protected by the
|
||||
// GIL.
|
||||
PyExecutable* next_;
|
||||
PyExecutable* prev_;
|
||||
std::unique_ptr<Traceback> traceback_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
bool Traceback::enabled_ = true;
|
||||
bool Traceback::enabled_ = false;
|
||||
|
||||
Traceback::~Traceback() {
|
||||
// We want Traceback objects to be safe to destroy without holding the
|
||||
@ -34,7 +34,7 @@ Traceback::~Traceback() {
|
||||
}
|
||||
|
||||
std::string Traceback::Frame::ToString() const {
|
||||
return absl::StrFormat("%s:%d (%s)", file_name, line_num, function_name);
|
||||
return absl::StrFormat("%s;%s:%d", function_name, file_name, line_num);
|
||||
}
|
||||
|
||||
std::string Traceback::ToString() const {
|
||||
@ -61,12 +61,12 @@ std::vector<Traceback::Frame> Traceback::Frames() const {
|
||||
return frames;
|
||||
}
|
||||
|
||||
std::shared_ptr<Traceback> Traceback::Get() {
|
||||
std::unique_ptr<Traceback> Traceback::Get() {
|
||||
DCHECK(PyGILState_Check());
|
||||
if (!enabled_) {
|
||||
return nullptr;
|
||||
}
|
||||
auto tb = std::make_shared<Traceback>();
|
||||
auto tb = std::make_unique<Traceback>();
|
||||
const PyThreadState* thread_state = PyThreadState_GET();
|
||||
for (PyFrameObject* py_frame = thread_state->frame; py_frame != nullptr;
|
||||
py_frame = py_frame->f_back) {
|
||||
|
@ -29,7 +29,7 @@ namespace xla {
|
||||
class Traceback {
|
||||
public:
|
||||
// Require GIL.
|
||||
static std::shared_ptr<Traceback> Get();
|
||||
static std::unique_ptr<Traceback> Get();
|
||||
|
||||
// Require GIL.
|
||||
static bool enabled() { return enabled_; }
|
||||
@ -57,11 +57,6 @@ class Traceback {
|
||||
};
|
||||
std::vector<Frame> Frames() const;
|
||||
|
||||
const absl::InlinedVector<std::pair<PyCodeObject*, int>, 32>& raw_frames()
|
||||
const {
|
||||
return frames_;
|
||||
}
|
||||
|
||||
private:
|
||||
absl::InlinedVector<std::pair<PyCodeObject*, int>, 32> frames_;
|
||||
|
||||
|
@ -529,8 +529,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"),
|
||||
py::arg("device") = nullptr, py::arg("force_copy") = false)
|
||||
.def("compile", &PyClient::Compile, py::arg("computation"),
|
||||
py::arg("compile_options") = CompileOptions())
|
||||
.def("heap_profile", &PyClient::HeapProfile);
|
||||
py::arg("compile_options") = CompileOptions());
|
||||
|
||||
m.def(
|
||||
"get_cpu_client",
|
||||
@ -571,8 +570,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
frame.line_num);
|
||||
});
|
||||
|
||||
py::class_<Traceback, std::shared_ptr<Traceback>> traceback(
|
||||
m, "Traceback", "Represents a Python stack trace.");
|
||||
py::class_<Traceback> traceback(m, "Traceback",
|
||||
"Represents a Python stack trace.");
|
||||
traceback.def_property_static(
|
||||
"enabled", [](py::object /* cls */) { return Traceback::enabled(); },
|
||||
[](py::object /* cls */, bool enabled) {
|
||||
|
@ -23,7 +23,6 @@ import atexit
|
||||
import collections
|
||||
import contextlib
|
||||
import enum # pylint: disable=g-bad-import-order
|
||||
import gzip
|
||||
import inspect
|
||||
import os
|
||||
from typing import List, Sequence, Tuple, Union
|
||||
@ -683,11 +682,6 @@ def tracebacks(enabled=True):
|
||||
Traceback.enabled = saved
|
||||
|
||||
|
||||
def heap_profile(client: Client) -> str:
|
||||
"""Returns a gzipped pprof protocol buffer containing a heap profile."""
|
||||
return gzip.compress(client.heap_profile())
|
||||
|
||||
|
||||
# Perform one last garbage collection of deferred Python references. This is
|
||||
# mostly to keep ASAN happy.
|
||||
atexit.register(_xla.collect_garbage)
|
||||
|
Loading…
Reference in New Issue
Block a user