[XLA:Python] Implement on-device heap profiling support for JAX.

Adds an `xla_client.heap_profile(client)` API which produces a gzip-compressed profile.proto protocol buffer containing an on-device heap profile, suitable for visualization via the pprof tool (https://github.com/google/pprof).

The heap profile includes buffers and executables allocated by JAX.

PiperOrigin-RevId: 316089755
Change-Id: I4cf3a17da7d9370b9fd08224d3577b7f860980ef
This commit is contained in:
Peter Hawkins 2020-06-12 05:42:17 -07:00 committed by TensorFlower Gardener
parent 8bc26ba68e
commit bd361bfaf9
13 changed files with 20 additions and 259 deletions

View File

@ -695,13 +695,6 @@ PjRtBuffer::~PjRtBuffer() {
}
}
int64 PjRtBuffer::OnDeviceSizeInBytes() const {
return client_->client()
->backend()
.transfer_manager()
->GetByteSizeRequirement(on_device_shape_);
}
void PjRtBuffer::WaitForOutstandingUsageHolds() {
auto not_in_usage_hold = [&]() {
mu_.AssertHeld();

View File

@ -409,9 +409,6 @@ class PjRtBuffer {
return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0;
}
// Returns the size of the on-device representation of this buffer in bytes.
int64 OnDeviceSizeInBytes() const;
// Returns the buffer's value as an XLA Literal. If the value has previously
// been prefetched to the host, then returns the prefetched version, otherwise
// copies the buffer to the host. Blocks until the value is ready.

View File

@ -202,9 +202,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/core/profiler:protos_all_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@pybind11",

View File

@ -25,29 +25,10 @@ namespace py = pybind11;
PyBuffer::PyBuffer(std::shared_ptr<PyClient> client,
std::unique_ptr<PjRtBuffer> buffer,
std::shared_ptr<Traceback> traceback)
std::unique_ptr<Traceback> traceback)
: client_(std::move(client)),
buffer_(std::move(buffer)),
traceback_(std::move(traceback)) {
next_ = client_->buffers_;
client_->buffers_ = this;
prev_ = nullptr;
if (next_) {
next_->prev_ = this;
}
}
PyBuffer::~PyBuffer() {
if (client_->buffers_ == this) {
client_->buffers_ = next_;
}
if (prev_) {
prev_->next_ = next_;
}
if (next_) {
next_->prev_ = prev_;
}
}
traceback_(std::move(traceback)) {}
ClientAndPtr<Device> PyBuffer::device() const {
return WrapWithClient(client_, buffer_->device());

View File

@ -32,8 +32,7 @@ namespace xla {
class PyBuffer {
public:
PyBuffer(std::shared_ptr<PyClient> client, std::unique_ptr<PjRtBuffer> buffer,
std::shared_ptr<Traceback> traceback);
~PyBuffer();
std::unique_ptr<Traceback> traceback);
std::shared_ptr<PyClient> client() const { return client_; }
PjRtBuffer* buffer() const { return buffer_.get(); }
@ -64,16 +63,9 @@ class PyBuffer {
Traceback* traceback() { return traceback_.get(); }
private:
friend class PyClient;
std::shared_ptr<PyClient> client_;
std::unique_ptr<PjRtBuffer> buffer_;
std::shared_ptr<Traceback> traceback_;
// Doubly-linked list of all buffers known to the client. Protected by the
// GIL.
PyBuffer* next_;
PyBuffer* prev_;
std::unique_ptr<Traceback> traceback_;
};
} // namespace xla

View File

@ -15,18 +15,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/py_client.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/python/py_buffer.h"
#include "tensorflow/compiler/xla/python/py_executable.h"
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
#include "tensorflow/compiler/xla/python/traceback.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/core/profiler/profile.pb.h"
namespace xla {
namespace py = pybind11;
namespace pprof = tensorflow::tfprof::pprof;
PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
: pjrt_client_(std::move(pjrt_client)) {}
@ -130,151 +127,4 @@ StatusOr<std::unique_ptr<PyExecutable>> PyClient::Compile(
shared_from_this(), std::move(executable), std::move(traceback));
}
class ProfileBuilder {
public:
ProfileBuilder();
pprof::Profile& profile() { return profile_; }
// Adds or returns the ID of `s` in the table.
int StringId(const std::string& s);
// Adds or returns the ID of a function.
int FunctionId(PyCodeObject* code);
// Adds or returns the ID of a code location.
int LocationId(PyCodeObject* code, int instruction);
private:
pprof::Profile profile_;
absl::flat_hash_map<std::string, int> strings_;
absl::flat_hash_map<PyCodeObject*, int> functions_;
absl::flat_hash_map<std::pair<PyCodeObject*, int>, int> locations_;
};
ProfileBuilder::ProfileBuilder() { CHECK_EQ(0, StringId("")); }
int ProfileBuilder::StringId(const std::string& s) {
auto ret = strings_.emplace(s, profile_.string_table_size());
if (ret.second) {
profile_.add_string_table(s);
}
return ret.first->second;
}
int ProfileBuilder::FunctionId(PyCodeObject* code) {
// +1 because id 0 is reserved.
auto ret = functions_.emplace(code, profile_.function_size() + 1);
if (ret.second) {
auto* function = profile_.add_function();
function->set_id(ret.first->second);
int name = StringId(py::str(code->co_name));
function->set_name(name);
function->set_system_name(name);
function->set_filename(StringId(py::str(code->co_filename)));
function->set_start_line(code->co_firstlineno);
}
return ret.first->second;
}
int ProfileBuilder::LocationId(PyCodeObject* code, int instruction) {
// +1 because id 0 is reserved.
auto ret = locations_.emplace(std::make_pair(code, instruction),
profile_.location_size() + 1);
if (ret.second) {
auto* location = profile_.add_location();
location->set_id(ret.first->second);
auto* line = location->add_line();
line->set_function_id(FunctionId(code));
line->set_line(PyCode_Addr2Line(code, instruction));
}
return ret.first->second;
}
namespace {
struct HeapProfileKey {
Traceback* traceback;
int64 size;
Device* device;
bool operator==(const HeapProfileKey& other) const;
};
bool HeapProfileKey::operator==(const HeapProfileKey& other) const {
if (size != other.size || device != other.device) {
return false;
}
if ((traceback == nullptr) != (other.traceback == nullptr)) {
return false;
}
if (traceback && traceback->raw_frames() != other.traceback->raw_frames()) {
return false;
}
return true;
}
template <typename H>
H AbslHashValue(H h, const HeapProfileKey& key) {
if (key.traceback) {
h = H::combine_contiguous(std::move(h), key.traceback->raw_frames().begin(),
key.traceback->raw_frames().size());
}
h = H::combine(std::move(h), key.size, key.device);
return h;
}
} // namespace
py::bytes PyClient::HeapProfile() {
absl::flat_hash_map<HeapProfileKey, int64> entries;
for (PyBuffer* buffer = buffers_; buffer; buffer = buffer->next_) {
HeapProfileKey key{buffer->traceback(),
buffer->buffer()->OnDeviceSizeInBytes(),
buffer->buffer()->device()};
++entries[key];
}
for (PyExecutable* executable = executables_; executable;
executable = executable->next_) {
HeapProfileKey key{executable->traceback(),
executable->SizeOfGeneratedCodeInBytes(), nullptr};
++entries[key];
}
ProfileBuilder builder;
auto* allocations = builder.profile().add_sample_type();
allocations->set_type(builder.StringId("allocations"));
allocations->set_unit(builder.StringId("count"));
auto* space = builder.profile().add_sample_type();
space->set_type(builder.StringId("space"));
space->set_unit(builder.StringId("bytes"));
const int kind_string_id = builder.StringId("kind");
const int buffer_string_id = builder.StringId("buffer");
const int executable_string_id = builder.StringId("executable");
const int device_string_id = builder.StringId("device");
for (const auto& entry : entries) {
auto* sample = builder.profile().add_sample();
if (entry.first.traceback) {
for (const auto& frame : entry.first.traceback->raw_frames()) {
sample->add_location_id(builder.LocationId(frame.first, frame.second));
}
}
sample->add_value(entry.second);
sample->add_value(entry.first.size * entry.second);
auto* kind_label = sample->add_label();
kind_label->set_key(kind_string_id);
if (entry.first.device) {
kind_label->set_str(buffer_string_id);
auto* device_label = sample->add_label();
device_label->set_key(device_string_id);
device_label->set_num(entry.first.device->id());
} else {
kind_label->set_str(executable_string_id);
}
}
return builder.profile().SerializeAsString();
}
} // namespace xla

View File

@ -125,19 +125,8 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
StatusOr<std::unique_ptr<PyExecutable>> Compile(
const XlaComputation& computation, CompileOptions options);
pybind11::bytes HeapProfile();
private:
friend class PyBuffer;
friend class PyExecutable;
std::shared_ptr<PjRtClient> pjrt_client_;
// Pointers to intrusive doubly-linked lists of buffers and executables, used
// to iterate over all known objects when heap profiling. The list structure
// is protected by the GIL.
PyBuffer* buffers_ = nullptr;
PyExecutable* executables_ = nullptr;
};
} // namespace xla

View File

@ -23,29 +23,10 @@ namespace py = pybind11;
PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
std::unique_ptr<PjRtExecutable> executable,
std::shared_ptr<Traceback> traceback)
std::unique_ptr<Traceback> traceback)
: client_(std::move(client)),
executable_(std::move(executable)),
traceback_(std::move(traceback)) {
next_ = client_->executables_;
client_->executables_ = this;
prev_ = nullptr;
if (next_) {
next_->prev_ = this;
}
}
PyExecutable::~PyExecutable() {
if (client_->executables_ == this) {
client_->executables_ = next_;
}
if (prev_) {
prev_->next_ = next_;
}
if (next_) {
next_->prev_ = prev_;
}
}
traceback_(std::move(traceback)) {}
std::vector<ClientAndPtr<Device>> PyExecutable::LocalDevices() const {
std::vector<ClientAndPtr<Device>> devices;
@ -70,8 +51,8 @@ StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
std::vector<std::unique_ptr<PyBuffer>> outputs;
outputs.reserve(output_buffers.size());
for (auto& buffer : output_buffers) {
outputs.push_back(
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
outputs.push_back(std::make_unique<PyBuffer>(client_, std::move(buffer),
std::move(traceback)));
}
return outputs;
}
@ -97,8 +78,8 @@ PyExecutable::ExecuteOnLocalDevices(
for (int computation = 0; computation < output_buffers.size();
++computation) {
for (auto& buffer : output_buffers[computation]) {
outputs[computation].push_back(
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
outputs[computation].push_back(std::make_unique<PyBuffer>(
client_, std::move(buffer), std::move(traceback)));
}
}
return outputs;

View File

@ -37,8 +37,7 @@ class PyExecutable {
public:
PyExecutable(std::shared_ptr<PyClient> client,
std::unique_ptr<PjRtExecutable> executable,
std::shared_ptr<Traceback> traceback);
~PyExecutable();
std::unique_ptr<Traceback> traceback);
std::shared_ptr<PyClient> client() const { return client_; }
@ -65,16 +64,9 @@ class PyExecutable {
Traceback* traceback() { return traceback_.get(); }
private:
friend class PyClient;
std::shared_ptr<PyClient> client_;
std::unique_ptr<PjRtExecutable> executable_;
std::shared_ptr<Traceback> traceback_;
// Doubly-linked list of all executables known to the client. Protected by the
// GIL.
PyExecutable* next_;
PyExecutable* prev_;
std::unique_ptr<Traceback> traceback_;
};
} // namespace xla

View File

@ -25,7 +25,7 @@ namespace xla {
namespace py = pybind11;
bool Traceback::enabled_ = true;
bool Traceback::enabled_ = false;
Traceback::~Traceback() {
// We want Traceback objects to be safe to destroy without holding the
@ -34,7 +34,7 @@ Traceback::~Traceback() {
}
std::string Traceback::Frame::ToString() const {
return absl::StrFormat("%s:%d (%s)", file_name, line_num, function_name);
return absl::StrFormat("%s;%s:%d", function_name, file_name, line_num);
}
std::string Traceback::ToString() const {
@ -61,12 +61,12 @@ std::vector<Traceback::Frame> Traceback::Frames() const {
return frames;
}
std::shared_ptr<Traceback> Traceback::Get() {
std::unique_ptr<Traceback> Traceback::Get() {
DCHECK(PyGILState_Check());
if (!enabled_) {
return nullptr;
}
auto tb = std::make_shared<Traceback>();
auto tb = std::make_unique<Traceback>();
const PyThreadState* thread_state = PyThreadState_GET();
for (PyFrameObject* py_frame = thread_state->frame; py_frame != nullptr;
py_frame = py_frame->f_back) {

View File

@ -29,7 +29,7 @@ namespace xla {
class Traceback {
public:
// Require GIL.
static std::shared_ptr<Traceback> Get();
static std::unique_ptr<Traceback> Get();
// Require GIL.
static bool enabled() { return enabled_; }
@ -57,11 +57,6 @@ class Traceback {
};
std::vector<Frame> Frames() const;
const absl::InlinedVector<std::pair<PyCodeObject*, int>, 32>& raw_frames()
const {
return frames_;
}
private:
absl::InlinedVector<std::pair<PyCodeObject*, int>, 32> frames_;

View File

@ -529,8 +529,7 @@ PYBIND11_MODULE(xla_extension, m) {
.def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"),
py::arg("device") = nullptr, py::arg("force_copy") = false)
.def("compile", &PyClient::Compile, py::arg("computation"),
py::arg("compile_options") = CompileOptions())
.def("heap_profile", &PyClient::HeapProfile);
py::arg("compile_options") = CompileOptions());
m.def(
"get_cpu_client",
@ -571,8 +570,8 @@ PYBIND11_MODULE(xla_extension, m) {
frame.line_num);
});
py::class_<Traceback, std::shared_ptr<Traceback>> traceback(
m, "Traceback", "Represents a Python stack trace.");
py::class_<Traceback> traceback(m, "Traceback",
"Represents a Python stack trace.");
traceback.def_property_static(
"enabled", [](py::object /* cls */) { return Traceback::enabled(); },
[](py::object /* cls */, bool enabled) {

View File

@ -23,7 +23,6 @@ import atexit
import collections
import contextlib
import enum # pylint: disable=g-bad-import-order
import gzip
import inspect
import os
from typing import List, Sequence, Tuple, Union
@ -683,11 +682,6 @@ def tracebacks(enabled=True):
Traceback.enabled = saved
def heap_profile(client: Client) -> str:
"""Returns a gzipped pprof protocol buffer containing a heap profile."""
return gzip.compress(client.heap_profile())
# Perform one last garbage collection of deferred Python references. This is
# mostly to keep ASAN happy.
atexit.register(_xla.collect_garbage)