[XLA:Python] Implement on-device heap profiling support for JAX.
Adds an `xla_client.heap_profile(client)` API which produces a gzip-compressed profile.proto protocol buffer containing an on-device heap profile, suitable for visualization via the pprof tool (https://github.com/google/pprof). The heap profile includes buffers and executables allocated by JAX. PiperOrigin-RevId: 316138726 Change-Id: I1b263f8033c6fc07466a362ff3d6f65a55334def
This commit is contained in:
parent
ad4363bdcf
commit
76fa5e8a4a
@ -695,6 +695,13 @@ PjRtBuffer::~PjRtBuffer() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 PjRtBuffer::OnDeviceSizeInBytes() const {
|
||||||
|
return client_->client()
|
||||||
|
->backend()
|
||||||
|
.transfer_manager()
|
||||||
|
->GetByteSizeRequirement(on_device_shape_);
|
||||||
|
}
|
||||||
|
|
||||||
void PjRtBuffer::WaitForOutstandingUsageHolds() {
|
void PjRtBuffer::WaitForOutstandingUsageHolds() {
|
||||||
auto not_in_usage_hold = [&]() {
|
auto not_in_usage_hold = [&]() {
|
||||||
mu_.AssertHeld();
|
mu_.AssertHeld();
|
||||||
|
@ -409,6 +409,9 @@ class PjRtBuffer {
|
|||||||
return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0;
|
return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the size of the on-device representation of this buffer in bytes.
|
||||||
|
int64 OnDeviceSizeInBytes() const;
|
||||||
|
|
||||||
// Returns the buffer's value as an XLA Literal. If the value has previously
|
// Returns the buffer's value as an XLA Literal. If the value has previously
|
||||||
// been prefetched to the host, then returns the prefetched version, otherwise
|
// been prefetched to the host, then returns the prefetched version, otherwise
|
||||||
// copies the buffer to the host. Blocks until the value is ready.
|
// copies the buffer to the host. Blocks until the value is ready.
|
||||||
|
@ -202,7 +202,9 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||||
|
"//tensorflow/core/profiler:protos_all_cc",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
"@pybind11",
|
"@pybind11",
|
||||||
|
@ -25,10 +25,31 @@ namespace py = pybind11;
|
|||||||
|
|
||||||
PyBuffer::PyBuffer(std::shared_ptr<PyClient> client,
|
PyBuffer::PyBuffer(std::shared_ptr<PyClient> client,
|
||||||
std::unique_ptr<PjRtBuffer> buffer,
|
std::unique_ptr<PjRtBuffer> buffer,
|
||||||
std::unique_ptr<Traceback> traceback)
|
std::shared_ptr<Traceback> traceback)
|
||||||
: client_(std::move(client)),
|
: client_(std::move(client)),
|
||||||
buffer_(std::move(buffer)),
|
buffer_(std::move(buffer)),
|
||||||
traceback_(std::move(traceback)) {}
|
traceback_(std::move(traceback)) {
|
||||||
|
CHECK(PyGILState_Check());
|
||||||
|
next_ = client_->buffers_;
|
||||||
|
client_->buffers_ = this;
|
||||||
|
prev_ = nullptr;
|
||||||
|
if (next_) {
|
||||||
|
next_->prev_ = this;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PyBuffer::~PyBuffer() {
|
||||||
|
CHECK(PyGILState_Check());
|
||||||
|
if (client_->buffers_ == this) {
|
||||||
|
client_->buffers_ = next_;
|
||||||
|
}
|
||||||
|
if (prev_) {
|
||||||
|
prev_->next_ = next_;
|
||||||
|
}
|
||||||
|
if (next_) {
|
||||||
|
next_->prev_ = prev_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ClientAndPtr<Device> PyBuffer::device() const {
|
ClientAndPtr<Device> PyBuffer::device() const {
|
||||||
return WrapWithClient(client_, buffer_->device());
|
return WrapWithClient(client_, buffer_->device());
|
||||||
@ -38,10 +59,12 @@ StatusOr<std::unique_ptr<PyBuffer>> PyBuffer::CopyToDevice(
|
|||||||
const ClientAndPtr<Device>& dst_device) const {
|
const ClientAndPtr<Device>& dst_device) const {
|
||||||
CHECK(dst_device.get() != nullptr);
|
CHECK(dst_device.get() != nullptr);
|
||||||
GlobalPyRefManager()->CollectGarbage();
|
GlobalPyRefManager()->CollectGarbage();
|
||||||
|
std::unique_ptr<PjRtBuffer> out;
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
TF_ASSIGN_OR_RETURN(out, buffer_->CopyToDevice(dst_device.get()));
|
||||||
|
}
|
||||||
auto traceback = Traceback::Get();
|
auto traceback = Traceback::Get();
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> out,
|
|
||||||
buffer_->CopyToDevice(dst_device.get()));
|
|
||||||
return std::make_unique<PyBuffer>(dst_device.client, std::move(out),
|
return std::make_unique<PyBuffer>(dst_device.client, std::move(out),
|
||||||
std::move(traceback));
|
std::move(traceback));
|
||||||
}
|
}
|
||||||
|
@ -32,7 +32,8 @@ namespace xla {
|
|||||||
class PyBuffer {
|
class PyBuffer {
|
||||||
public:
|
public:
|
||||||
PyBuffer(std::shared_ptr<PyClient> client, std::unique_ptr<PjRtBuffer> buffer,
|
PyBuffer(std::shared_ptr<PyClient> client, std::unique_ptr<PjRtBuffer> buffer,
|
||||||
std::unique_ptr<Traceback> traceback);
|
std::shared_ptr<Traceback> traceback);
|
||||||
|
~PyBuffer();
|
||||||
|
|
||||||
std::shared_ptr<PyClient> client() const { return client_; }
|
std::shared_ptr<PyClient> client() const { return client_; }
|
||||||
PjRtBuffer* buffer() const { return buffer_.get(); }
|
PjRtBuffer* buffer() const { return buffer_.get(); }
|
||||||
@ -63,9 +64,16 @@ class PyBuffer {
|
|||||||
Traceback* traceback() { return traceback_.get(); }
|
Traceback* traceback() { return traceback_.get(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class PyClient;
|
||||||
|
|
||||||
std::shared_ptr<PyClient> client_;
|
std::shared_ptr<PyClient> client_;
|
||||||
std::unique_ptr<PjRtBuffer> buffer_;
|
std::unique_ptr<PjRtBuffer> buffer_;
|
||||||
std::unique_ptr<Traceback> traceback_;
|
std::shared_ptr<Traceback> traceback_;
|
||||||
|
|
||||||
|
// Doubly-linked list of all buffers known to the client. Protected by the
|
||||||
|
// GIL.
|
||||||
|
PyBuffer* next_;
|
||||||
|
PyBuffer* prev_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -15,15 +15,18 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/python/py_client.h"
|
#include "tensorflow/compiler/xla/python/py_client.h"
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/python/py_executable.h"
|
#include "tensorflow/compiler/xla/python/py_executable.h"
|
||||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||||
#include "tensorflow/compiler/xla/python/traceback.h"
|
#include "tensorflow/compiler/xla/python/traceback.h"
|
||||||
#include "tensorflow/compiler/xla/python/types.h"
|
#include "tensorflow/compiler/xla/python/types.h"
|
||||||
|
#include "tensorflow/core/profiler/profile.pb.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
namespace pprof = tensorflow::tfprof::pprof;
|
||||||
|
|
||||||
PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
|
PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
|
||||||
: pjrt_client_(std::move(pjrt_client)) {}
|
: pjrt_client_(std::move(pjrt_client)) {}
|
||||||
@ -104,27 +107,179 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
|
|||||||
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
|
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
|
||||||
GlobalPyRefManager()->ManageReference(std::move(c->array));
|
GlobalPyRefManager()->ManageReference(std::move(c->array));
|
||||||
|
|
||||||
|
std::unique_ptr<PjRtBuffer> buffer;
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
buffer, PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
|
||||||
|
std::move(py_buffer_ref),
|
||||||
|
pjrt_client_.get(), device));
|
||||||
|
}
|
||||||
auto traceback = Traceback::Get();
|
auto traceback = Traceback::Get();
|
||||||
|
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
std::unique_ptr<PjRtBuffer> buffer,
|
|
||||||
PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
|
|
||||||
std::move(py_buffer_ref), pjrt_client_.get(),
|
|
||||||
device));
|
|
||||||
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
|
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
|
||||||
std::move(traceback));
|
std::move(traceback));
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PyExecutable>> PyClient::Compile(
|
StatusOr<std::unique_ptr<PyExecutable>> PyClient::Compile(
|
||||||
const XlaComputation& computation, CompileOptions options) {
|
const XlaComputation& computation, CompileOptions options) {
|
||||||
|
std::unique_ptr<PjRtExecutable> executable;
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
TF_ASSIGN_OR_RETURN(executable,
|
||||||
|
PjRtExecutable::Compile(computation, pjrt_client_.get(),
|
||||||
|
std::move(options)));
|
||||||
|
}
|
||||||
auto traceback = Traceback::Get();
|
auto traceback = Traceback::Get();
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
|
|
||||||
PjRtExecutable::Compile(computation, pjrt_client_.get(),
|
|
||||||
std::move(options)));
|
|
||||||
return std::make_unique<PyExecutable>(
|
return std::make_unique<PyExecutable>(
|
||||||
shared_from_this(), std::move(executable), std::move(traceback));
|
shared_from_this(), std::move(executable), std::move(traceback));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ProfileBuilder {
|
||||||
|
public:
|
||||||
|
ProfileBuilder();
|
||||||
|
pprof::Profile& profile() { return profile_; }
|
||||||
|
|
||||||
|
// Adds or returns the ID of `s` in the table.
|
||||||
|
int StringId(const std::string& s);
|
||||||
|
|
||||||
|
// Adds or returns the ID of a function.
|
||||||
|
int FunctionId(PyCodeObject* code);
|
||||||
|
|
||||||
|
// Adds or returns the ID of a code location.
|
||||||
|
int LocationId(PyCodeObject* code, int instruction);
|
||||||
|
|
||||||
|
private:
|
||||||
|
pprof::Profile profile_;
|
||||||
|
|
||||||
|
absl::flat_hash_map<std::string, int> strings_;
|
||||||
|
absl::flat_hash_map<PyCodeObject*, int> functions_;
|
||||||
|
absl::flat_hash_map<std::pair<PyCodeObject*, int>, int> locations_;
|
||||||
|
};
|
||||||
|
|
||||||
|
ProfileBuilder::ProfileBuilder() { CHECK_EQ(0, StringId("")); }
|
||||||
|
|
||||||
|
int ProfileBuilder::StringId(const std::string& s) {
|
||||||
|
auto ret = strings_.emplace(s, profile_.string_table_size());
|
||||||
|
if (ret.second) {
|
||||||
|
profile_.add_string_table(s);
|
||||||
|
}
|
||||||
|
return ret.first->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ProfileBuilder::FunctionId(PyCodeObject* code) {
|
||||||
|
// +1 because id 0 is reserved.
|
||||||
|
auto ret = functions_.emplace(code, profile_.function_size() + 1);
|
||||||
|
if (ret.second) {
|
||||||
|
auto* function = profile_.add_function();
|
||||||
|
function->set_id(ret.first->second);
|
||||||
|
int name = StringId(py::str(code->co_name));
|
||||||
|
function->set_name(name);
|
||||||
|
function->set_system_name(name);
|
||||||
|
function->set_filename(StringId(py::str(code->co_filename)));
|
||||||
|
function->set_start_line(code->co_firstlineno);
|
||||||
|
}
|
||||||
|
return ret.first->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ProfileBuilder::LocationId(PyCodeObject* code, int instruction) {
|
||||||
|
// +1 because id 0 is reserved.
|
||||||
|
auto ret = locations_.emplace(std::make_pair(code, instruction),
|
||||||
|
profile_.location_size() + 1);
|
||||||
|
if (ret.second) {
|
||||||
|
auto* location = profile_.add_location();
|
||||||
|
location->set_id(ret.first->second);
|
||||||
|
auto* line = location->add_line();
|
||||||
|
line->set_function_id(FunctionId(code));
|
||||||
|
line->set_line(PyCode_Addr2Line(code, instruction));
|
||||||
|
}
|
||||||
|
return ret.first->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct HeapProfileKey {
|
||||||
|
Traceback* traceback;
|
||||||
|
int64 size;
|
||||||
|
Device* device;
|
||||||
|
bool operator==(const HeapProfileKey& other) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool HeapProfileKey::operator==(const HeapProfileKey& other) const {
|
||||||
|
if (size != other.size || device != other.device) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if ((traceback == nullptr) != (other.traceback == nullptr)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (traceback && traceback->raw_frames() != other.traceback->raw_frames()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename H>
|
||||||
|
H AbslHashValue(H h, const HeapProfileKey& key) {
|
||||||
|
if (key.traceback) {
|
||||||
|
h = H::combine_contiguous(std::move(h), key.traceback->raw_frames().begin(),
|
||||||
|
key.traceback->raw_frames().size());
|
||||||
|
}
|
||||||
|
h = H::combine(std::move(h), key.size, key.device);
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
py::bytes PyClient::HeapProfile() {
|
||||||
|
CHECK(PyGILState_Check());
|
||||||
|
absl::flat_hash_map<HeapProfileKey, int64> entries;
|
||||||
|
for (PyBuffer* buffer = buffers_; buffer; buffer = buffer->next_) {
|
||||||
|
HeapProfileKey key{buffer->traceback(),
|
||||||
|
buffer->buffer()->OnDeviceSizeInBytes(),
|
||||||
|
buffer->buffer()->device()};
|
||||||
|
++entries[key];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (PyExecutable* executable = executables_; executable;
|
||||||
|
executable = executable->next_) {
|
||||||
|
HeapProfileKey key{executable->traceback(),
|
||||||
|
executable->SizeOfGeneratedCodeInBytes(), nullptr};
|
||||||
|
++entries[key];
|
||||||
|
}
|
||||||
|
|
||||||
|
ProfileBuilder builder;
|
||||||
|
auto* allocations = builder.profile().add_sample_type();
|
||||||
|
allocations->set_type(builder.StringId("allocations"));
|
||||||
|
allocations->set_unit(builder.StringId("count"));
|
||||||
|
auto* space = builder.profile().add_sample_type();
|
||||||
|
space->set_type(builder.StringId("space"));
|
||||||
|
space->set_unit(builder.StringId("bytes"));
|
||||||
|
|
||||||
|
const int kind_string_id = builder.StringId("kind");
|
||||||
|
const int buffer_string_id = builder.StringId("buffer");
|
||||||
|
const int executable_string_id = builder.StringId("executable");
|
||||||
|
const int device_string_id = builder.StringId("device");
|
||||||
|
for (const auto& entry : entries) {
|
||||||
|
auto* sample = builder.profile().add_sample();
|
||||||
|
if (entry.first.traceback) {
|
||||||
|
for (const auto& frame : entry.first.traceback->raw_frames()) {
|
||||||
|
sample->add_location_id(builder.LocationId(frame.first, frame.second));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sample->add_value(entry.second);
|
||||||
|
sample->add_value(entry.first.size * entry.second);
|
||||||
|
|
||||||
|
auto* kind_label = sample->add_label();
|
||||||
|
kind_label->set_key(kind_string_id);
|
||||||
|
if (entry.first.device) {
|
||||||
|
kind_label->set_str(buffer_string_id);
|
||||||
|
auto* device_label = sample->add_label();
|
||||||
|
device_label->set_key(device_string_id);
|
||||||
|
device_label->set_num(entry.first.device->id());
|
||||||
|
} else {
|
||||||
|
kind_label->set_str(executable_string_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return builder.profile().SerializeAsString();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -125,8 +125,19 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
|||||||
StatusOr<std::unique_ptr<PyExecutable>> Compile(
|
StatusOr<std::unique_ptr<PyExecutable>> Compile(
|
||||||
const XlaComputation& computation, CompileOptions options);
|
const XlaComputation& computation, CompileOptions options);
|
||||||
|
|
||||||
|
pybind11::bytes HeapProfile();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class PyBuffer;
|
||||||
|
friend class PyExecutable;
|
||||||
|
|
||||||
std::shared_ptr<PjRtClient> pjrt_client_;
|
std::shared_ptr<PjRtClient> pjrt_client_;
|
||||||
|
|
||||||
|
// Pointers to intrusive doubly-linked lists of buffers and executables, used
|
||||||
|
// to iterate over all known objects when heap profiling. The list structure
|
||||||
|
// is protected by the GIL.
|
||||||
|
PyBuffer* buffers_ = nullptr;
|
||||||
|
PyExecutable* executables_ = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -23,10 +23,31 @@ namespace py = pybind11;
|
|||||||
|
|
||||||
PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
|
PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
|
||||||
std::unique_ptr<PjRtExecutable> executable,
|
std::unique_ptr<PjRtExecutable> executable,
|
||||||
std::unique_ptr<Traceback> traceback)
|
std::shared_ptr<Traceback> traceback)
|
||||||
: client_(std::move(client)),
|
: client_(std::move(client)),
|
||||||
executable_(std::move(executable)),
|
executable_(std::move(executable)),
|
||||||
traceback_(std::move(traceback)) {}
|
traceback_(std::move(traceback)) {
|
||||||
|
CHECK(PyGILState_Check());
|
||||||
|
next_ = client_->executables_;
|
||||||
|
client_->executables_ = this;
|
||||||
|
prev_ = nullptr;
|
||||||
|
if (next_) {
|
||||||
|
next_->prev_ = this;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PyExecutable::~PyExecutable() {
|
||||||
|
CHECK(PyGILState_Check());
|
||||||
|
if (client_->executables_ == this) {
|
||||||
|
client_->executables_ = next_;
|
||||||
|
}
|
||||||
|
if (prev_) {
|
||||||
|
prev_->next_ = next_;
|
||||||
|
}
|
||||||
|
if (next_) {
|
||||||
|
next_->prev_ = prev_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<ClientAndPtr<Device>> PyExecutable::LocalDevices() const {
|
std::vector<ClientAndPtr<Device>> PyExecutable::LocalDevices() const {
|
||||||
std::vector<ClientAndPtr<Device>> devices;
|
std::vector<ClientAndPtr<Device>> devices;
|
||||||
@ -39,20 +60,23 @@ std::vector<ClientAndPtr<Device>> PyExecutable::LocalDevices() const {
|
|||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
|
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
|
||||||
absl::Span<PyBuffer* const> args) {
|
absl::Span<PyBuffer* const> args) {
|
||||||
|
std::vector<std::unique_ptr<PjRtBuffer>> output_buffers;
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
ExecuteOptions options;
|
||||||
|
options.untuple_result = true;
|
||||||
|
std::vector<PjRtBuffer*> arg_buffers(args.size());
|
||||||
|
absl::c_transform(args, arg_buffers.begin(),
|
||||||
|
[](PyBuffer* buf) { return buf->buffer(); });
|
||||||
|
TF_ASSIGN_OR_RETURN(output_buffers,
|
||||||
|
executable_->Execute(arg_buffers, options));
|
||||||
|
}
|
||||||
auto traceback = Traceback::Get();
|
auto traceback = Traceback::Get();
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
ExecuteOptions options;
|
|
||||||
options.untuple_result = true;
|
|
||||||
std::vector<PjRtBuffer*> arg_buffers(args.size());
|
|
||||||
absl::c_transform(args, arg_buffers.begin(),
|
|
||||||
[](PyBuffer* buf) { return buf->buffer(); });
|
|
||||||
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
|
|
||||||
executable_->Execute(arg_buffers, options));
|
|
||||||
std::vector<std::unique_ptr<PyBuffer>> outputs;
|
std::vector<std::unique_ptr<PyBuffer>> outputs;
|
||||||
outputs.reserve(output_buffers.size());
|
outputs.reserve(output_buffers.size());
|
||||||
for (auto& buffer : output_buffers) {
|
for (auto& buffer : output_buffers) {
|
||||||
outputs.push_back(std::make_unique<PyBuffer>(client_, std::move(buffer),
|
outputs.push_back(
|
||||||
std::move(traceback)));
|
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
|
||||||
}
|
}
|
||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
@ -60,26 +84,28 @@ StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
|
|||||||
StatusOr<std::vector<std::vector<std::unique_ptr<PyBuffer>>>>
|
StatusOr<std::vector<std::vector<std::unique_ptr<PyBuffer>>>>
|
||||||
PyExecutable::ExecuteOnLocalDevices(
|
PyExecutable::ExecuteOnLocalDevices(
|
||||||
absl::Span<const std::vector<PyBuffer*>> args) {
|
absl::Span<const std::vector<PyBuffer*>> args) {
|
||||||
auto traceback = Traceback::Get();
|
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
|
||||||
py::gil_scoped_release gil_release;
|
{
|
||||||
ExecuteOptions options;
|
py::gil_scoped_release gil_release;
|
||||||
options.untuple_result = true;
|
ExecuteOptions options;
|
||||||
std::vector<std::vector<PjRtBuffer*>> arg_buffers(args.size());
|
options.untuple_result = true;
|
||||||
for (int computation = 0; computation < args.size(); ++computation) {
|
std::vector<std::vector<PjRtBuffer*>> arg_buffers(args.size());
|
||||||
arg_buffers[computation].resize(args[computation].size());
|
for (int computation = 0; computation < args.size(); ++computation) {
|
||||||
absl::c_transform(args[computation], arg_buffers[computation].begin(),
|
arg_buffers[computation].resize(args[computation].size());
|
||||||
[](PyBuffer* buf) { return buf->buffer(); });
|
absl::c_transform(args[computation], arg_buffers[computation].begin(),
|
||||||
|
[](PyBuffer* buf) { return buf->buffer(); });
|
||||||
|
}
|
||||||
|
TF_ASSIGN_OR_RETURN(output_buffers, executable_->ExecuteOnLocalDevices(
|
||||||
|
arg_buffers, options));
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(
|
auto traceback = Traceback::Get();
|
||||||
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers,
|
|
||||||
executable_->ExecuteOnLocalDevices(arg_buffers, options));
|
|
||||||
std::vector<std::vector<std::unique_ptr<PyBuffer>>> outputs;
|
std::vector<std::vector<std::unique_ptr<PyBuffer>>> outputs;
|
||||||
outputs.resize(output_buffers.size());
|
outputs.resize(output_buffers.size());
|
||||||
for (int computation = 0; computation < output_buffers.size();
|
for (int computation = 0; computation < output_buffers.size();
|
||||||
++computation) {
|
++computation) {
|
||||||
for (auto& buffer : output_buffers[computation]) {
|
for (auto& buffer : output_buffers[computation]) {
|
||||||
outputs[computation].push_back(std::make_unique<PyBuffer>(
|
outputs[computation].push_back(
|
||||||
client_, std::move(buffer), std::move(traceback)));
|
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return outputs;
|
return outputs;
|
||||||
|
@ -37,7 +37,8 @@ class PyExecutable {
|
|||||||
public:
|
public:
|
||||||
PyExecutable(std::shared_ptr<PyClient> client,
|
PyExecutable(std::shared_ptr<PyClient> client,
|
||||||
std::unique_ptr<PjRtExecutable> executable,
|
std::unique_ptr<PjRtExecutable> executable,
|
||||||
std::unique_ptr<Traceback> traceback);
|
std::shared_ptr<Traceback> traceback);
|
||||||
|
~PyExecutable();
|
||||||
|
|
||||||
std::shared_ptr<PyClient> client() const { return client_; }
|
std::shared_ptr<PyClient> client() const { return client_; }
|
||||||
|
|
||||||
@ -64,9 +65,16 @@ class PyExecutable {
|
|||||||
Traceback* traceback() { return traceback_.get(); }
|
Traceback* traceback() { return traceback_.get(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class PyClient;
|
||||||
|
|
||||||
std::shared_ptr<PyClient> client_;
|
std::shared_ptr<PyClient> client_;
|
||||||
std::unique_ptr<PjRtExecutable> executable_;
|
std::unique_ptr<PjRtExecutable> executable_;
|
||||||
std::unique_ptr<Traceback> traceback_;
|
std::shared_ptr<Traceback> traceback_;
|
||||||
|
|
||||||
|
// Doubly-linked list of all executables known to the client. Protected by the
|
||||||
|
// GIL.
|
||||||
|
PyExecutable* next_;
|
||||||
|
PyExecutable* prev_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -25,7 +25,7 @@ namespace xla {
|
|||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
bool Traceback::enabled_ = false;
|
bool Traceback::enabled_ = true;
|
||||||
|
|
||||||
Traceback::~Traceback() {
|
Traceback::~Traceback() {
|
||||||
// We want Traceback objects to be safe to destroy without holding the
|
// We want Traceback objects to be safe to destroy without holding the
|
||||||
@ -34,7 +34,7 @@ Traceback::~Traceback() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string Traceback::Frame::ToString() const {
|
std::string Traceback::Frame::ToString() const {
|
||||||
return absl::StrFormat("%s;%s:%d", function_name, file_name, line_num);
|
return absl::StrFormat("%s:%d (%s)", file_name, line_num, function_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Traceback::ToString() const {
|
std::string Traceback::ToString() const {
|
||||||
@ -61,12 +61,12 @@ std::vector<Traceback::Frame> Traceback::Frames() const {
|
|||||||
return frames;
|
return frames;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Traceback> Traceback::Get() {
|
std::shared_ptr<Traceback> Traceback::Get() {
|
||||||
DCHECK(PyGILState_Check());
|
DCHECK(PyGILState_Check());
|
||||||
if (!enabled_) {
|
if (!enabled_) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto tb = std::make_unique<Traceback>();
|
auto tb = std::make_shared<Traceback>();
|
||||||
const PyThreadState* thread_state = PyThreadState_GET();
|
const PyThreadState* thread_state = PyThreadState_GET();
|
||||||
for (PyFrameObject* py_frame = thread_state->frame; py_frame != nullptr;
|
for (PyFrameObject* py_frame = thread_state->frame; py_frame != nullptr;
|
||||||
py_frame = py_frame->f_back) {
|
py_frame = py_frame->f_back) {
|
||||||
|
@ -29,7 +29,7 @@ namespace xla {
|
|||||||
class Traceback {
|
class Traceback {
|
||||||
public:
|
public:
|
||||||
// Require GIL.
|
// Require GIL.
|
||||||
static std::unique_ptr<Traceback> Get();
|
static std::shared_ptr<Traceback> Get();
|
||||||
|
|
||||||
// Require GIL.
|
// Require GIL.
|
||||||
static bool enabled() { return enabled_; }
|
static bool enabled() { return enabled_; }
|
||||||
@ -57,6 +57,11 @@ class Traceback {
|
|||||||
};
|
};
|
||||||
std::vector<Frame> Frames() const;
|
std::vector<Frame> Frames() const;
|
||||||
|
|
||||||
|
const absl::InlinedVector<std::pair<PyCodeObject*, int>, 32>& raw_frames()
|
||||||
|
const {
|
||||||
|
return frames_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::InlinedVector<std::pair<PyCodeObject*, int>, 32> frames_;
|
absl::InlinedVector<std::pair<PyCodeObject*, int>, 32> frames_;
|
||||||
|
|
||||||
|
@ -529,7 +529,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"),
|
.def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"),
|
||||||
py::arg("device") = nullptr, py::arg("force_copy") = false)
|
py::arg("device") = nullptr, py::arg("force_copy") = false)
|
||||||
.def("compile", &PyClient::Compile, py::arg("computation"),
|
.def("compile", &PyClient::Compile, py::arg("computation"),
|
||||||
py::arg("compile_options") = CompileOptions());
|
py::arg("compile_options") = CompileOptions())
|
||||||
|
.def("heap_profile", &PyClient::HeapProfile);
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"get_cpu_client",
|
"get_cpu_client",
|
||||||
@ -570,8 +571,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
frame.line_num);
|
frame.line_num);
|
||||||
});
|
});
|
||||||
|
|
||||||
py::class_<Traceback> traceback(m, "Traceback",
|
py::class_<Traceback, std::shared_ptr<Traceback>> traceback(
|
||||||
"Represents a Python stack trace.");
|
m, "Traceback", "Represents a Python stack trace.");
|
||||||
traceback.def_property_static(
|
traceback.def_property_static(
|
||||||
"enabled", [](py::object /* cls */) { return Traceback::enabled(); },
|
"enabled", [](py::object /* cls */) { return Traceback::enabled(); },
|
||||||
[](py::object /* cls */, bool enabled) {
|
[](py::object /* cls */, bool enabled) {
|
||||||
|
@ -23,6 +23,7 @@ import atexit
|
|||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
import enum # pylint: disable=g-bad-import-order
|
import enum # pylint: disable=g-bad-import-order
|
||||||
|
import gzip
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
from typing import List, Sequence, Tuple, Union
|
from typing import List, Sequence, Tuple, Union
|
||||||
@ -682,6 +683,11 @@ def tracebacks(enabled=True):
|
|||||||
Traceback.enabled = saved
|
Traceback.enabled = saved
|
||||||
|
|
||||||
|
|
||||||
|
def heap_profile(client: Client) -> str:
|
||||||
|
"""Returns a gzipped pprof protocol buffer containing a heap profile."""
|
||||||
|
return gzip.compress(client.heap_profile())
|
||||||
|
|
||||||
|
|
||||||
# Perform one last garbage collection of deferred Python references. This is
|
# Perform one last garbage collection of deferred Python references. This is
|
||||||
# mostly to keep ASAN happy.
|
# mostly to keep ASAN happy.
|
||||||
atexit.register(_xla.collect_garbage)
|
atexit.register(_xla.collect_garbage)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user