Refactor PJRT.

- Make static methods of PjRtBuffer and PjRtExecutable instance methods on PjRtClient to allow us to extract a set of interfaces out of PJRT.

PiperOrigin-RevId: 338101552
Change-Id: I8c10295948ea73d7d4157760a1cd8991384a01dc
This commit is contained in:
Qiao Zhang 2020-10-20 11:39:31 -07:00 committed by TensorFlower Gardener
parent b737cff5fd
commit f187f93d7b
9 changed files with 182 additions and 192 deletions

View File

@ -54,9 +54,9 @@ TEST(GpuMultiStream, Basics) {
device_assignment(0, 0) = device->id(); device_assignment(0, 0) = device->id();
compile_options.executable_build_options.set_device_assignment( compile_options.executable_build_options.set_device_assignment(
device_assignment); device_assignment);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtExecutable> executable, TF_ASSERT_OK_AND_ASSIGN(
PjRtExecutable::Compile(computation, client.get(), std::unique_ptr<PjRtExecutable> executable,
std::move(compile_options))); client->Compile(computation, std::move(compile_options)));
int64 dummy_size = 1 << 20; int64 dummy_size = 1 << 20;
std::vector<int32> dummy_inputs(dummy_size); std::vector<int32> dummy_inputs(dummy_size);
@ -71,22 +71,22 @@ TEST(GpuMultiStream, Basics) {
// must wait. // must wait.
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto dummy_buffer, auto dummy_buffer,
PjRtBuffer::FromHostBuffer( client->BufferFromHostBuffer(
dummy_inputs.data(), dummy_shape, dummy_inputs.data(), dummy_shape,
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes, PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device)); /*buffer_reference=*/nullptr, device));
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto in_buffer0, auto in_buffer0,
PjRtBuffer::FromHostBuffer( client->BufferFromHostBuffer(
inputs.data(), shape, inputs.data(), shape,
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes, PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device)); /*buffer_reference=*/nullptr, device));
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto in_buffer1, auto in_buffer1,
PjRtBuffer::FromHostBuffer( client->BufferFromHostBuffer(
inputs.data(), shape, inputs.data(), shape,
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes, PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device)); /*buffer_reference=*/nullptr, device));
// The execution may be enqueued before the transfers complete, requiring // The execution may be enqueued before the transfers complete, requiring
// adequate device-side synchronization. // adequate device-side synchronization.
ExecuteOptions options; ExecuteOptions options;

View File

@ -576,24 +576,21 @@ void PjRtBuffer::ScopedHold::AddToInput(
} }
} }
/* static */ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
const void* data, const Shape& shape, const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics, HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtClient* client, std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
PjRtDevice* device) { tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer");
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer"); VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString()
VLOG(2) << "PjRtBuffer::FromHostBuffer: shape: " << shape.ToString()
<< " device: " << device->DebugString(); << " device: " << device->DebugString();
if (shape.IsTuple()) { if (shape.IsTuple()) {
return InvalidArgument("Use FromHostLiteral to transfer a tuple"); return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
} }
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState()); device->GetLocalDeviceState());
int64 size = ShapeUtil::ByteSizeOf(shape); int64 size = ShapeUtil::ByteSizeOf(shape);
TransferManager* transfer_manager = TransferManager* transfer_manager = client()->backend().transfer_manager();
client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(Shape compact_shape, TF_ASSIGN_OR_RETURN(Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(shape)); transfer_manager->ChooseCompactLayoutForShape(shape));
@ -628,10 +625,11 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
}; };
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size); buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
} else { } else {
void* staging_buffer = client->host_memory_allocator()->AllocateRaw( void* staging_buffer = host_memory_allocator()->AllocateRaw(
cpu_function_runtime::kMinAlign, size); cpu_function_runtime::kMinAlign, size);
on_delete_callback = [staging_buffer, client]() { on_delete_callback = [staging_buffer, host_memory_allocator =
client->host_memory_allocator()->DeallocateRaw(staging_buffer); host_memory_allocator()]() {
host_memory_allocator->DeallocateRaw(staging_buffer);
}; };
buffer = se::DeviceMemoryBase(staging_buffer, size); buffer = se::DeviceMemoryBase(staging_buffer, size);
std::memcpy(staging_buffer, data, size); std::memcpy(staging_buffer, data, size);
@ -643,7 +641,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
std::initializer_list<se::DeviceMemoryBase>{buffer}, std::initializer_list<se::DeviceMemoryBase>{buffer},
definition_events, std::move(on_delete_callback)); definition_events, std::move(on_delete_callback));
return absl::make_unique<PjRtBuffer>( return absl::make_unique<PjRtBuffer>(
shape, shape, std::move(device_buffer), client, device); shape, shape, std::move(device_buffer), this, device);
} }
} }
@ -651,21 +649,22 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
std::unique_ptr<PjRtBuffer> py_buffer, std::unique_ptr<PjRtBuffer> py_buffer,
AllocateDestinationBuffer(compact_shape, device, local_device, AllocateDestinationBuffer(compact_shape, device, local_device,
local_device->host_to_device_stream(), local_device->host_to_device_stream(),
/*is_uninitialized_create=*/false, client)); /*is_uninitialized_create=*/false, this));
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold()); PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
CHECK(device_buffer.ok()); CHECK(device_buffer.ok());
// If necessary, allocate a host-side buffer for staging host-to-device // If necessary, allocate a host-side buffer for staging host-to-device
// transfers. On GPU this is a buffer in pinned memory. // transfers. On GPU this is a buffer in pinned memory.
std::shared_ptr<void> staging_buffer; std::shared_ptr<void> staging_buffer;
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall || if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
client->should_stage_host_to_device_transfers()) { should_stage_host_to_device_transfers()) {
void* ptr = client->host_memory_allocator()->AllocateRaw( void* ptr = host_memory_allocator()->AllocateRaw(
tensorflow::Allocator::kAllocatorAlignment, size); tensorflow::Allocator::kAllocatorAlignment, size);
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) { staging_buffer = std::shared_ptr<void>(
client->host_memory_allocator()->DeallocateRaw(ptr); ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) {
}); host_memory_allocator->DeallocateRaw(ptr);
});
} }
// Copy the buffer into a staging buffer before returning control to the // Copy the buffer into a staging buffer before returning control to the
@ -684,14 +683,15 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
// usage holds have gone away. // usage holds have gone away.
// TODO(misard) assess if it would be preferable to introduce a heuristic to // TODO(misard) assess if it would be preferable to introduce a heuristic to
// put the transfer into the calling thread for small literals. // put the transfer into the calling thread for small literals.
auto transfer_h2d = [client, transfer_manager, local_device, data, size, auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
data, size,
movable_device_buffer{device_buffer.ToClosure()}, shape, movable_device_buffer{device_buffer.ToClosure()}, shape,
py_buffer{py_buffer.get()}, compact_shape, py_buffer{py_buffer.get()}, compact_shape,
on_device_shape{py_buffer->on_device_shape()}, on_device_shape{py_buffer->on_device_shape()},
staging_buffer{std::move(staging_buffer)}, staging_buffer{std::move(staging_buffer)},
buffer_reference{std::move(buffer_reference)}, buffer_reference{std::move(buffer_reference)},
host_buffer_semantics]() { host_buffer_semantics]() {
ScopedHold device_buffer(movable_device_buffer); PjRtBuffer::ScopedHold device_buffer(movable_device_buffer);
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
// to report failures from a callback. However, the operations here are // to report failures from a callback. However, the operations here are
// unlikely to fail and not recoverable even if we were to fail: DMAs to // unlikely to fail and not recoverable even if we were to fail: DMAs to
@ -699,7 +699,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
// allocation. // allocation.
ShapedBuffer buffer = device_buffer->AsShapedBuffer( ShapedBuffer buffer = device_buffer->AsShapedBuffer(
compact_shape, on_device_shape, client->client()->platform()); compact_shape, on_device_shape, local_client->platform());
// If applicable on the backend, stage the transfer via host memory // If applicable on the backend, stage the transfer via host memory
// allocated via the host_memory_allocator. On GPU, this is pinned // allocated via the host_memory_allocator. On GPU, this is pinned
// memory. // memory.
@ -736,41 +736,38 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
// already defers its work onto a stream (= thread on CPU). // already defers its work onto a stream (= thread on CPU).
transfer_h2d(); transfer_h2d();
} else { } else {
client->h2d_transfer_pool()->Schedule(transfer_h2d); h2d_transfer_pool()->Schedule(transfer_h2d);
} }
return py_buffer; return py_buffer;
} }
/* static */ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CreateUninitialized( const Shape& shape, PjRtDevice* device) {
const Shape& shape, PjRtClient* client, PjRtDevice* device) { tensorflow::profiler::TraceMe traceme(
tensorflow::profiler::TraceMe traceme("PjRtBuffer::CreateUninitialized"); "PjRtClient::CreateUninitializedBuffer");
VLOG(2) << "PjRtBuffer::CreateUninitialized: shape: " << shape.ToString() VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
<< " device: " << device->DebugString(); << shape.ToString() << " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState()); device->GetLocalDeviceState());
TransferManager* transfer_manager = TransferManager* transfer_manager = client()->backend().transfer_manager();
client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(Shape compact_shape, TF_ASSIGN_OR_RETURN(Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(shape)); transfer_manager->ChooseCompactLayoutForShape(shape));
return AllocateDestinationBuffer(compact_shape, device, local_device, return AllocateDestinationBuffer(compact_shape, device, local_device,
/*copy_stream=*/nullptr, /*copy_stream=*/nullptr,
/*is_uninitialized_create=*/true, client); /*is_uninitialized_create=*/true, this);
} }
/* static */ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral( const LiteralSlice& literal, PjRtDevice* device) {
const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device) { tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral");
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostLiteral"); VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: "
VLOG(2) << "PjRtBuffer::FromHostLiteral: shape: "
<< literal.shape().ToString() << " device: " << device->DebugString(); << literal.shape().ToString() << " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState()); device->GetLocalDeviceState());
TransferManager* transfer_manager = TransferManager* transfer_manager = client()->backend().transfer_manager();
client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
Shape compact_shape, Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(literal.shape())); transfer_manager->ChooseCompactLayoutForShape(literal.shape()));
@ -778,9 +775,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
std::unique_ptr<PjRtBuffer> py_buffer, std::unique_ptr<PjRtBuffer> py_buffer,
AllocateDestinationBuffer(compact_shape, device, local_device, AllocateDestinationBuffer(compact_shape, device, local_device,
local_device->host_to_device_stream(), local_device->host_to_device_stream(),
/*is_uninitialized_create=*/false, client)); /*is_uninitialized_create=*/false, this));
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold()); PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
CHECK(device_buffer.ok()); CHECK(device_buffer.ok());
// The host to device transfer is performed on a thread pool, mostly because // The host to device transfer is performed on a thread pool, mostly because
@ -789,11 +786,11 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
// usage holds have gone away. // usage holds have gone away.
// TODO(misard) assess if it would be preferable to introduce a heuristic to // TODO(misard) assess if it would be preferable to introduce a heuristic to
// put the transfer into the calling thread for small literals. // put the transfer into the calling thread for small literals.
auto transfer_h2d = [client, transfer_manager, local_device, auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
movable_device_buffer{device_buffer.ToClosure()}, movable_device_buffer{device_buffer.ToClosure()},
literal, py_buffer{py_buffer.get()}, compact_shape, literal, py_buffer{py_buffer.get()}, compact_shape,
on_device_shape{py_buffer->on_device_shape()}]() { on_device_shape{py_buffer->on_device_shape()}]() {
ScopedHold device_buffer(movable_device_buffer); PjRtBuffer::ScopedHold device_buffer(movable_device_buffer);
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
// to report failures from a callback. However, the operations here are // to report failures from a callback. However, the operations here are
// unlikely to fail and not recoverable even if we were to fail: DMAs to // unlikely to fail and not recoverable even if we were to fail: DMAs to
@ -802,7 +799,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
se::Stream* h2d_stream = local_device->host_to_device_stream(); se::Stream* h2d_stream = local_device->host_to_device_stream();
ShapedBuffer buffer = device_buffer->AsShapedBuffer( ShapedBuffer buffer = device_buffer->AsShapedBuffer(
compact_shape, on_device_shape, client->client()->platform()); compact_shape, on_device_shape, local_client->platform());
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
h2d_stream, literal, buffer)); h2d_stream, literal, buffer));
@ -817,12 +814,12 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
.IgnoreError(); // Can return error::Unimplemented .IgnoreError(); // Can return error::Unimplemented
QCHECK(h2d_stream->ok()); QCHECK(h2d_stream->ok());
}; };
client->h2d_transfer_pool()->Schedule(transfer_h2d); h2d_transfer_pool()->Schedule(transfer_h2d);
return py_buffer; return py_buffer;
} }
/*static*/ void PjRtBuffer::MakeCrossHostReceiveBuffers( void PjRtClient::MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtClient* client, PjRtDevice* device, absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) { PjRtCrossHostRecvNotifier&& notifier) {
if (shapes.empty()) { if (shapes.empty()) {
notifier(InvalidArgument( notifier(InvalidArgument(
@ -843,7 +840,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or = StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or =
AllocateDestinationBuffer(shape, device, local_device, AllocateDestinationBuffer(shape, device, local_device,
/*copy_stream=*/nullptr, /*copy_stream=*/nullptr,
/*is_uninitialized_create=*/false, client); /*is_uninitialized_create=*/false, this);
if (!buffer_or.ok()) { if (!buffer_or.ok()) {
notifier(buffer_or.status()); notifier(buffer_or.status());
return; return;
@ -851,7 +848,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
buffers.push_back(buffer_or.ConsumeValueOrDie()); buffers.push_back(buffer_or.ConsumeValueOrDie());
} }
client->EnqueueCrossHostReceive(std::move(buffers), std::move(notifier)); EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
} }
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape, PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
@ -1159,7 +1156,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral( StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
const bool discard_cached_copy, absl::optional<xla::Layout> layout) { const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral"); tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral");
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value, TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
CopyToHostAsyncInternal(discard_cached_copy, layout)); CopyToHostAsyncInternal(discard_cached_copy, layout));
if (host_value == nullptr) { if (host_value == nullptr) {
@ -1267,9 +1264,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
// Copying across PjRtClients involves a copy through the host. // Copying across PjRtClients involves a copy through the host.
if (dst_device->client() != client_) { if (dst_device->client() != client_) {
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral()); TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
return FromHostBuffer(literal->untyped_data(), literal->shape(), return dst_device->client()->BufferFromHostBuffer(
HostBufferSemantics::kZeroCopy, nullptr, literal->untyped_data(), literal->shape(),
dst_device->client(), dst_device); PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device);
} }
TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device, TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device,
@ -2061,14 +2058,13 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
} // namespace } // namespace
/*static*/ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtExecutable::Compile( StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
const XlaComputation& computation, PjRtClient* client, const XlaComputation& computation, CompileOptions options) {
CompileOptions options) { tensorflow::profiler::TraceMe traceme("PjRtClient::Compile");
tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile");
ExecutableBuildOptions& build_options = options.executable_build_options; ExecutableBuildOptions& build_options = options.executable_build_options;
if (!build_options.device_allocator()) { if (!build_options.device_allocator()) {
build_options.set_device_allocator(client->allocator()); build_options.set_device_allocator(allocator());
} }
int num_replicas; int num_replicas;
@ -2084,14 +2080,14 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
num_partitions = 1; num_partitions = 1;
} else { } else {
if (!build_options.has_device_assignment()) { if (!build_options.has_device_assignment()) {
VLOG(2) << "PjRtExecutable::Compile using default device_assignment."; VLOG(2) << "PjRtClient::Compile using default device_assignment.";
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment, DeviceAssignment device_assignment,
client->GetDefaultDeviceAssignment(build_options.num_replicas(), GetDefaultDeviceAssignment(build_options.num_replicas(),
build_options.num_partitions())); build_options.num_partitions()));
build_options.set_device_assignment(device_assignment); build_options.set_device_assignment(device_assignment);
} }
VLOG(2) << "PjRtExecutable::Compile device_assignment:\n" VLOG(2) << "PjRtClient::Compile device_assignment:\n"
<< build_options.device_assignment().ToString(); << build_options.device_assignment().ToString();
num_replicas = build_options.device_assignment().replica_count(); num_replicas = build_options.device_assignment().replica_count();
num_partitions = build_options.device_assignment().computation_count(); num_partitions = build_options.device_assignment().computation_count();
@ -2118,7 +2114,8 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
// Assign a default layout based on `sharded_shape` to any array subshapes in // Assign a default layout based on `sharded_shape` to any array subshapes in
// `dst_shape` that are missing layouts. // `dst_shape` that are missing layouts.
auto assign_layouts = [client](const Shape& sharded_shape, Shape* dst_shape) { auto assign_layouts = [local_client = client()](const Shape& sharded_shape,
Shape* dst_shape) {
return ShapeUtil::ForEachMutableSubshapeWithStatus( return ShapeUtil::ForEachMutableSubshapeWithStatus(
dst_shape, [&](Shape* subshape, const ShapeIndex& idx) { dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
if (subshape->IsArray() && !subshape->has_layout()) { if (subshape->IsArray() && !subshape->has_layout()) {
@ -2126,8 +2123,7 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
const Shape& sharded_subshape = const Shape& sharded_subshape =
ShapeUtil::GetSubshape(sharded_shape, idx); ShapeUtil::GetSubshape(sharded_shape, idx);
LayoutUtil::SetToDefaultLayout(subshape); LayoutUtil::SetToDefaultLayout(subshape);
TF_ASSIGN_OR_RETURN(Shape layout, client->client() TF_ASSIGN_OR_RETURN(Shape layout, local_client->backend()
->backend()
.transfer_manager() .transfer_manager()
->ChooseCompactLayoutForShape( ->ChooseCompactLayoutForShape(
sharded_subshape)); sharded_subshape));
@ -2162,8 +2158,8 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
for (int replica = 0; replica < num_replicas; ++replica) { for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) { for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = (*device_assignment)(replica, partition); int device_id = (*device_assignment)(replica, partition);
PjRtDevice* device = LookupDevice(*client, device_id); PjRtDevice* device = LookupDevice(*this, device_id);
if (device->host_id() != client->host_id()) { if (device->host_id() != host_id()) {
VLOG(3) << "Non-local device: " << device_id; VLOG(3) << "Non-local device: " << device_id;
continue; continue;
} }
@ -2185,15 +2181,14 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<LocalExecutable>> local_executables, std::vector<std::unique_ptr<LocalExecutable>> local_executables,
client->client()->Compile(computation, argument_layout_pointers, client()->Compile(computation, argument_layout_pointers, build_options));
build_options));
auto executable = absl::make_unique<PjRtExecutable>( auto executable = absl::make_unique<PjRtExecutable>(
std::move(local_executables), options.parameter_is_tupled_arguments, std::move(local_executables), options.parameter_is_tupled_arguments,
std::move(device_assignment), std::move(local_logical_device_ids), std::move(device_assignment), std::move(local_logical_device_ids),
std::move(local_devices), client); std::move(local_devices), this);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
executable->SetUpDonation(client, options.parameter_is_tupled_arguments)); executable->SetUpDonation(this, options.parameter_is_tupled_arguments));
return executable; return executable;
} }

View File

@ -120,6 +120,24 @@ struct PjRtCrossHostRecvBuffer {
using PjRtCrossHostRecvNotifier = using PjRtCrossHostRecvNotifier =
std::function<void(StatusOr<std::vector<PjRtCrossHostRecvBuffer>>&&)>; std::function<void(StatusOr<std::vector<PjRtCrossHostRecvBuffer>>&&)>;
struct CompileOptions {
// The layouts of the arguments that the computation should expect.
absl::optional<std::vector<Shape>> argument_layouts;
// If true, the supplied computation expects its arguments to be wrapped in a
// tuple and passed as a single parameter.
bool parameter_is_tupled_arguments = false;
// XLA's compilation time options.
ExecutableBuildOptions executable_build_options;
// If true, the executable can be run on any device. May only be true if
// !executable_build_options.has_device_assignment(), so only applies to
// single-device executables. Beware: on GPUs, sometimes an executable
// compiled for one device doesn't run on another.
bool compile_portable_executable = false;
};
class PjRtExecutable; class PjRtExecutable;
// Encapsulates the state of Python session with XLA. // Encapsulates the state of Python session with XLA.
@ -198,6 +216,63 @@ class PjRtClient {
// Returns a backend-specific HLO cost analysis visitor. // Returns a backend-specific HLO cost analysis visitor.
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis(); virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options);
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device);
// Describes the semantics the caller to BufferFromHostBuffer expects from the
// runtime, in a total order from most restrictive to least restrictive.
enum class HostBufferSemantics {
// The runtime may not hold references to `data` after the call to
// `BufferFromHostBuffer` completes. The caller promises that `data` is
// immutable and will not be freed only for the duration of the
// BufferFromHostBuffer call. `buffer_reference` will be freed by the time
// `BufferFromHostBuffer` returns.
kImmutableOnlyDuringCall,
// The runtime may hold onto `data` after the call to `BufferFromHostBuffer`
// returns while the runtime completes a transfer to the device. The caller
// promises not to mutate or free `data` until the transfer completes, at
// which point the runtime will release `buffer_reference`. It is also
// correct to wait on the host (directly or indirectly) for the buffer's
// definition event to complete.
kImmutableUntilTransferCompletes,
// The PjRtBuffer may alias `data` internally and the runtime may use the
// `data` contents as long as the buffer is alive. The caller promises to
// keep `data` alive and not to mutate its contents as long as the buffer is
// alive; to notify the caller that the buffer may be freed, the runtime
// will release its `buffer_reference` when the PjRtBuffer is freed. On
// non-CPU platforms this acts identically to
// kImmutableUntilTransferCompletes.
kZeroCopy,
};
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device);
// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device);
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact
// shapes, with identical layouts, corresponding to the buffers that will be
// sent. When resources for the transfer are available, notifier will be
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
// shape in `shapes`. Each struct contains a buffer that will contain the
// received value, and an opaque string that should be transmitted to the
// sending host and used in a call to CopyToRemoteDevice. None of the recv
// buffers will become ready until *all* of the sends have completed.
virtual void MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier);
protected: protected:
friend class PjRtBuffer; friend class PjRtBuffer;
virtual void EnqueueCrossHostReceive( virtual void EnqueueCrossHostReceive(
@ -385,6 +460,7 @@ class PjRtBuffer {
private: private:
friend class PjRtBuffer; friend class PjRtBuffer;
friend class PjRtClient;
// Helper struct that makes it possible to move a ScopedHold through a // Helper struct that makes it possible to move a ScopedHold through a
// closure. // closure.
@ -423,62 +499,6 @@ class PjRtBuffer {
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_; StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
}; };
// Returns a buffer with uninitialized contents.
static StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitialized(
const Shape& shape, PjRtClient* client, PjRtDevice* device);
// Describes the semantics the caller to FromHostBuffer expects from the
// runtime, in a total order from most restrictive to least restrictive.
enum class HostBufferSemantics {
// The runtime may not hold references to `data` after the call to
// `FromHostBuffer` completes. The caller promises that `data` is immutable
// and will not be freed only for the duration of the FromHostBuffer call.
// `buffer_reference` will be freed by the time `FromHostBuffer` returns.
kImmutableOnlyDuringCall,
// The runtime may hold onto `data` after the call to `FromHostBuffer`
// returns while the runtime completes a transfer to the device. The caller
// promises not to mutate or free `data` until the transfer completes, at
// which point the runtime will release `buffer_reference`. It is also
// correct to wait on the host (directly or indirectly) for the buffer's
// definition event to complete.
kImmutableUntilTransferCompletes,
// The PjRtBuffer may alias `data` internally and the runtime may use the
// `data` contents as long as the buffer is alive.
// The caller promises to keep `data` alive and not to mutate its contents
// as long as the buffer is alive; to notify the caller that the buffer may
// be freed, the runtime will release its `buffer_reference` when the
// PjRtBuffer is freed. On non-CPU platforms this acts identically to
// kImmutableUntilTransferCompletes.
kZeroCopy,
};
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtClient* client,
PjRtDevice* device);
// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope.
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostLiteral(
const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device);
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact
// shapes, with identical layouts, corresponding to the buffers that will be
// sent. When resources for the transfer are available, notifier will be
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
// shape in `shapes`. Each struct contains a buffer that will contain the
// received value, and an opaque string that should be transmitted to the
// sending host and used in a call to CopyToRemoteDevice. None of the recv
// buffers will become ready until *all* of the sends have completed.
static void MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
PjRtClient* client,
PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier);
PjRtBuffer(Shape on_host_shape, Shape on_device_shape, PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
std::shared_ptr<TrackedDeviceBuffer> device_buffer, std::shared_ptr<TrackedDeviceBuffer> device_buffer,
PjRtClient* client, PjRtDevice* device); PjRtClient* client, PjRtDevice* device);
@ -661,24 +681,6 @@ class PjRtBuffer {
Semaphore donation_semaphore_; Semaphore donation_semaphore_;
}; };
struct CompileOptions {
// The layouts of the arguments that the computation should expect.
absl::optional<std::vector<Shape>> argument_layouts;
// If true, the supplied computation expects its arguments to be wrapped in a
// tuple and passed as a single parameter.
bool parameter_is_tupled_arguments = false;
// XLA's compilation time options.
ExecutableBuildOptions executable_build_options;
// If true, the executable can be run on any device. May only be true if
// !executable_build_options.has_device_assignment(), so only applies to
// single-device executables. Beware: on GPUs, sometimes an executable
// compiled for one device doesn't run on another.
bool compile_portable_executable = false;
};
class ExecuteContext { class ExecuteContext {
public: public:
virtual ~ExecuteContext() = default; virtual ~ExecuteContext() = default;
@ -710,10 +712,6 @@ struct ExecuteOptions {
// buffer will be donated when passed to the execution. // buffer will be donated when passed to the execution.
class PjRtExecutable { class PjRtExecutable {
public: public:
static StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, PjRtClient* client,
CompileOptions options);
PjRtExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables, PjRtExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
bool parameter_is_tupled_arguments, bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment, std::shared_ptr<DeviceAssignment> device_assignment,
@ -783,6 +781,7 @@ class PjRtExecutable {
} }
private: private:
friend class PjRtClient;
// Initializes information about which arguments to which executables must be // Initializes information about which arguments to which executables must be
// donated due to aliases that were specified by the computation. // donated due to aliases that were specified by the computation.
Status SetUpDonation(PjRtClient* client, bool tuple_inputs); Status SetUpDonation(PjRtClient* client, bool tuple_inputs);

View File

@ -465,10 +465,10 @@ std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
xla::PjRtDevice* device) { xla::PjRtDevice* device) {
CppType data = py::cast<Pybind11Type>(scalar); CppType data = py::cast<Pybind11Type>(scalar);
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({}); xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({});
return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( return ValueOrThrow(client->BufferFromHostBuffer(
&data, shape, &data, shape,
xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr, xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
client, device)); device));
} }
// Convert a scalar to the associated PjRtBuffer or raises an error if it is // Convert a scalar to the associated PjRtBuffer or raises an error if it is
@ -502,17 +502,17 @@ StatusOr<std::unique_ptr<xla::PjRtBuffer>> ScalarToBuffer(
if (jax_enable_x64) { if (jax_enable_x64) {
xla::complex128 data(result.real, result.imag); xla::complex128 data(result.real, result.imag);
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({}); xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({});
return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( return ValueOrThrow(client->BufferFromHostBuffer(
&data, shape, &data, shape,
xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
nullptr, client, device)); nullptr, device));
} else { } else {
xla::complex64 data(result.real, result.imag); xla::complex64 data(result.real, result.imag);
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({}); xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({});
return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( return ValueOrThrow(client->BufferFromHostBuffer(
&data, shape, &data, shape,
xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
nullptr, client, device)); nullptr, device));
} }
} }
return InvalidArgument( return InvalidArgument(
@ -678,7 +678,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
ValueOrThrow(pyclient.BufferFromPyval( ValueOrThrow(pyclient.BufferFromPyval(
numpy_array, data_device, numpy_array, data_device,
/*force_copy=*/false, /*host_buffer_semantics=*/ /*force_copy=*/false, /*host_buffer_semantics=*/
xla::PjRtBuffer::HostBufferSemantics::kZeroCopy)); xla::PjRtClient::HostBufferSemantics::kZeroCopy));
arg_buffers.push_back(buffer->buffer()); arg_buffers.push_back(buffer->buffer());
ArgSignature sig; ArgSignature sig;

View File

@ -409,10 +409,9 @@ Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
compile_options.executable_build_options.set_device_assignment( compile_options.executable_build_options.set_device_assignment(
device_assignment); device_assignment);
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
std::unique_ptr<PjRtExecutable> executable, devices_[device_idx]->client()->Compile(
PjRtExecutable::Compile(computation, devices_[device_idx]->client(), computation, std::move(compile_options)));
std::move(compile_options)));
ExecuteOptions execute_options; ExecuteOptions execute_options;
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers, TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
executable->Execute({}, execute_options)); executable->Execute({}, execute_options));

View File

@ -40,9 +40,8 @@ Status CompileAndExecute(XlaBuilder* builder, XlaOp root, int device_id,
compile_options.executable_build_options.set_device_assignment( compile_options.executable_build_options.set_device_assignment(
device_assignment); device_assignment);
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
std::unique_ptr<PjRtExecutable> executable, client->Compile(computation, std::move(compile_options)));
PjRtExecutable::Compile(computation, client, std::move(compile_options)));
ExecuteOptions execute_options; ExecuteOptions execute_options;
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers, TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
executable->Execute({}, execute_options)); executable->Execute({}, execute_options));

View File

@ -89,7 +89,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval( StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy, const pybind11::object& argument, PjRtDevice* device, bool force_copy,
PjRtBuffer::HostBufferSemantics host_buffer_semantics) { PjRtClient::HostBufferSemantics host_buffer_semantics) {
if (device == nullptr) { if (device == nullptr) {
TF_RET_CHECK(!pjrt_client_->local_devices().empty()); TF_RET_CHECK(!pjrt_client_->local_devices().empty());
device = pjrt_client_->local_devices().front(); device = pjrt_client_->local_devices().front();
@ -114,10 +114,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
std::unique_ptr<PjRtBuffer> buffer; std::unique_ptr<PjRtBuffer> buffer;
{ {
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(buffer, pjrt_client_->BufferFromHostBuffer(
buffer, PjRtBuffer::FromHostBuffer( c->buf_ptr, c->shape, host_buffer_semantics,
c->buf_ptr, c->shape, host_buffer_semantics, std::move(py_buffer_ref), device));
std::move(py_buffer_ref), pjrt_client_.get(), device));
} }
auto traceback = Traceback::Get(); auto traceback = Traceback::Get();
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer), return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
@ -131,8 +130,7 @@ StatusOr<std::shared_ptr<PyExecutable>> PyClient::Compile(
{ {
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(executable, TF_ASSIGN_OR_RETURN(executable,
PjRtExecutable::Compile(computation, pjrt_client_.get(), pjrt_client_->Compile(computation, std::move(options)));
std::move(options)));
TF_ASSIGN_OR_RETURN(fingerprint, TF_ASSIGN_OR_RETURN(fingerprint,
pjrt_client_->ExecutableFingerprint(*executable)); pjrt_client_->ExecutableFingerprint(*executable));
} }

View File

@ -123,7 +123,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval( StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy, const pybind11::object& argument, PjRtDevice* device, bool force_copy,
PjRtBuffer::HostBufferSemantics host_buffer_semantics); PjRtClient::HostBufferSemantics host_buffer_semantics);
StatusOr<std::shared_ptr<PyExecutable>> Compile( StatusOr<std::shared_ptr<PyExecutable>> Compile(
const XlaComputation& computation, CompileOptions options); const XlaComputation& computation, CompileOptions options);

View File

@ -535,12 +535,12 @@ PYBIND11_MODULE(xla_extension, m) {
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform) .value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
.value("BFC", GpuAllocatorConfig::Kind::kBFC); .value("BFC", GpuAllocatorConfig::Kind::kBFC);
py::enum_<PjRtBuffer::HostBufferSemantics>(m, "HostBufferSemantics") py::enum_<PjRtClient::HostBufferSemantics>(m, "HostBufferSemantics")
.value("IMMUTABLE_ONLY_DURING_CALL", .value("IMMUTABLE_ONLY_DURING_CALL",
PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall) PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall)
.value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES", .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes) PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes)
.value("ZERO_COPY", PjRtBuffer::HostBufferSemantics::kZeroCopy); .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy);
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client"); py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
py_local_client.def_property_readonly("platform", &PyClient::platform_name) py_local_client.def_property_readonly("platform", &PyClient::platform_name)
@ -562,7 +562,7 @@ PYBIND11_MODULE(xla_extension, m) {
.def("buffer_from_pyval", &PyClient::BufferFromPyval, py::arg("argument"), .def("buffer_from_pyval", &PyClient::BufferFromPyval, py::arg("argument"),
py::arg("device") = nullptr, py::arg("force_copy") = false, py::arg("device") = nullptr, py::arg("force_copy") = false,
py::arg("host_buffer_semantics") = py::arg("host_buffer_semantics") =
PjRtBuffer::HostBufferSemantics::kZeroCopy) PjRtClient::HostBufferSemantics::kZeroCopy)
.def("compile", &PyClient::Compile, py::arg("computation"), .def("compile", &PyClient::Compile, py::arg("computation"),
py::arg("compile_options") = CompileOptions()) py::arg("compile_options") = CompileOptions())
.def("heap_profile", &PyClient::HeapProfile); .def("heap_profile", &PyClient::HeapProfile);