Refactor PJRT.
- Make static methods of PjRtBuffer and PjRtExecutable instance methods on PjRtClient to allow us to extract a set of interfaces out of PJRT. PiperOrigin-RevId: 338101552 Change-Id: I8c10295948ea73d7d4157760a1cd8991384a01dc
This commit is contained in:
parent
b737cff5fd
commit
f187f93d7b
tensorflow/compiler/xla
@ -54,9 +54,9 @@ TEST(GpuMultiStream, Basics) {
|
|||||||
device_assignment(0, 0) = device->id();
|
device_assignment(0, 0) = device->id();
|
||||||
compile_options.executable_build_options.set_device_assignment(
|
compile_options.executable_build_options.set_device_assignment(
|
||||||
device_assignment);
|
device_assignment);
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtExecutable> executable,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
PjRtExecutable::Compile(computation, client.get(),
|
std::unique_ptr<PjRtExecutable> executable,
|
||||||
std::move(compile_options)));
|
client->Compile(computation, std::move(compile_options)));
|
||||||
|
|
||||||
int64 dummy_size = 1 << 20;
|
int64 dummy_size = 1 << 20;
|
||||||
std::vector<int32> dummy_inputs(dummy_size);
|
std::vector<int32> dummy_inputs(dummy_size);
|
||||||
@ -71,22 +71,22 @@ TEST(GpuMultiStream, Basics) {
|
|||||||
// must wait.
|
// must wait.
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto dummy_buffer,
|
auto dummy_buffer,
|
||||||
PjRtBuffer::FromHostBuffer(
|
client->BufferFromHostBuffer(
|
||||||
dummy_inputs.data(), dummy_shape,
|
dummy_inputs.data(), dummy_shape,
|
||||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||||
/*buffer_reference=*/nullptr, client.get(), device));
|
/*buffer_reference=*/nullptr, device));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto in_buffer0,
|
auto in_buffer0,
|
||||||
PjRtBuffer::FromHostBuffer(
|
client->BufferFromHostBuffer(
|
||||||
inputs.data(), shape,
|
inputs.data(), shape,
|
||||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||||
/*buffer_reference=*/nullptr, client.get(), device));
|
/*buffer_reference=*/nullptr, device));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto in_buffer1,
|
auto in_buffer1,
|
||||||
PjRtBuffer::FromHostBuffer(
|
client->BufferFromHostBuffer(
|
||||||
inputs.data(), shape,
|
inputs.data(), shape,
|
||||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||||
/*buffer_reference=*/nullptr, client.get(), device));
|
/*buffer_reference=*/nullptr, device));
|
||||||
// The execution may be enqueued before the transfers complete, requiring
|
// The execution may be enqueued before the transfers complete, requiring
|
||||||
// adequate device-side synchronization.
|
// adequate device-side synchronization.
|
||||||
ExecuteOptions options;
|
ExecuteOptions options;
|
||||||
|
@ -576,24 +576,21 @@ void PjRtBuffer::ScopedHold::AddToInput(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
|
||||||
const void* data, const Shape& shape,
|
const void* data, const Shape& shape,
|
||||||
HostBufferSemantics host_buffer_semantics,
|
HostBufferSemantics host_buffer_semantics,
|
||||||
std::shared_ptr<void> buffer_reference, PjRtClient* client,
|
std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
|
||||||
PjRtDevice* device) {
|
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer");
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer");
|
VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString()
|
||||||
VLOG(2) << "PjRtBuffer::FromHostBuffer: shape: " << shape.ToString()
|
|
||||||
<< " device: " << device->DebugString();
|
<< " device: " << device->DebugString();
|
||||||
if (shape.IsTuple()) {
|
if (shape.IsTuple()) {
|
||||||
return InvalidArgument("Use FromHostLiteral to transfer a tuple");
|
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
device->GetLocalDeviceState());
|
device->GetLocalDeviceState());
|
||||||
int64 size = ShapeUtil::ByteSizeOf(shape);
|
int64 size = ShapeUtil::ByteSizeOf(shape);
|
||||||
|
|
||||||
TransferManager* transfer_manager =
|
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||||
client->client()->backend().transfer_manager();
|
|
||||||
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
||||||
transfer_manager->ChooseCompactLayoutForShape(shape));
|
transfer_manager->ChooseCompactLayoutForShape(shape));
|
||||||
|
|
||||||
@ -628,10 +625,11 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
|||||||
};
|
};
|
||||||
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
|
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
|
||||||
} else {
|
} else {
|
||||||
void* staging_buffer = client->host_memory_allocator()->AllocateRaw(
|
void* staging_buffer = host_memory_allocator()->AllocateRaw(
|
||||||
cpu_function_runtime::kMinAlign, size);
|
cpu_function_runtime::kMinAlign, size);
|
||||||
on_delete_callback = [staging_buffer, client]() {
|
on_delete_callback = [staging_buffer, host_memory_allocator =
|
||||||
client->host_memory_allocator()->DeallocateRaw(staging_buffer);
|
host_memory_allocator()]() {
|
||||||
|
host_memory_allocator->DeallocateRaw(staging_buffer);
|
||||||
};
|
};
|
||||||
buffer = se::DeviceMemoryBase(staging_buffer, size);
|
buffer = se::DeviceMemoryBase(staging_buffer, size);
|
||||||
std::memcpy(staging_buffer, data, size);
|
std::memcpy(staging_buffer, data, size);
|
||||||
@ -643,7 +641,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
|||||||
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
||||||
definition_events, std::move(on_delete_callback));
|
definition_events, std::move(on_delete_callback));
|
||||||
return absl::make_unique<PjRtBuffer>(
|
return absl::make_unique<PjRtBuffer>(
|
||||||
shape, shape, std::move(device_buffer), client, device);
|
shape, shape, std::move(device_buffer), this, device);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -651,21 +649,22 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
|||||||
std::unique_ptr<PjRtBuffer> py_buffer,
|
std::unique_ptr<PjRtBuffer> py_buffer,
|
||||||
AllocateDestinationBuffer(compact_shape, device, local_device,
|
AllocateDestinationBuffer(compact_shape, device, local_device,
|
||||||
local_device->host_to_device_stream(),
|
local_device->host_to_device_stream(),
|
||||||
/*is_uninitialized_create=*/false, client));
|
/*is_uninitialized_create=*/false, this));
|
||||||
|
|
||||||
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
|
PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
|
||||||
CHECK(device_buffer.ok());
|
CHECK(device_buffer.ok());
|
||||||
|
|
||||||
// If necessary, allocate a host-side buffer for staging host-to-device
|
// If necessary, allocate a host-side buffer for staging host-to-device
|
||||||
// transfers. On GPU this is a buffer in pinned memory.
|
// transfers. On GPU this is a buffer in pinned memory.
|
||||||
std::shared_ptr<void> staging_buffer;
|
std::shared_ptr<void> staging_buffer;
|
||||||
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
|
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
|
||||||
client->should_stage_host_to_device_transfers()) {
|
should_stage_host_to_device_transfers()) {
|
||||||
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
void* ptr = host_memory_allocator()->AllocateRaw(
|
||||||
tensorflow::Allocator::kAllocatorAlignment, size);
|
tensorflow::Allocator::kAllocatorAlignment, size);
|
||||||
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
|
staging_buffer = std::shared_ptr<void>(
|
||||||
client->host_memory_allocator()->DeallocateRaw(ptr);
|
ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) {
|
||||||
});
|
host_memory_allocator->DeallocateRaw(ptr);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy the buffer into a staging buffer before returning control to the
|
// Copy the buffer into a staging buffer before returning control to the
|
||||||
@ -684,14 +683,15 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
|||||||
// usage holds have gone away.
|
// usage holds have gone away.
|
||||||
// TODO(misard) assess if it would be preferable to introduce a heuristic to
|
// TODO(misard) assess if it would be preferable to introduce a heuristic to
|
||||||
// put the transfer into the calling thread for small literals.
|
// put the transfer into the calling thread for small literals.
|
||||||
auto transfer_h2d = [client, transfer_manager, local_device, data, size,
|
auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
|
||||||
|
data, size,
|
||||||
movable_device_buffer{device_buffer.ToClosure()}, shape,
|
movable_device_buffer{device_buffer.ToClosure()}, shape,
|
||||||
py_buffer{py_buffer.get()}, compact_shape,
|
py_buffer{py_buffer.get()}, compact_shape,
|
||||||
on_device_shape{py_buffer->on_device_shape()},
|
on_device_shape{py_buffer->on_device_shape()},
|
||||||
staging_buffer{std::move(staging_buffer)},
|
staging_buffer{std::move(staging_buffer)},
|
||||||
buffer_reference{std::move(buffer_reference)},
|
buffer_reference{std::move(buffer_reference)},
|
||||||
host_buffer_semantics]() {
|
host_buffer_semantics]() {
|
||||||
ScopedHold device_buffer(movable_device_buffer);
|
PjRtBuffer::ScopedHold device_buffer(movable_device_buffer);
|
||||||
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
|
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
|
||||||
// to report failures from a callback. However, the operations here are
|
// to report failures from a callback. However, the operations here are
|
||||||
// unlikely to fail and not recoverable even if we were to fail: DMAs to
|
// unlikely to fail and not recoverable even if we were to fail: DMAs to
|
||||||
@ -699,7 +699,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
|||||||
// allocation.
|
// allocation.
|
||||||
|
|
||||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
|
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
|
||||||
compact_shape, on_device_shape, client->client()->platform());
|
compact_shape, on_device_shape, local_client->platform());
|
||||||
// If applicable on the backend, stage the transfer via host memory
|
// If applicable on the backend, stage the transfer via host memory
|
||||||
// allocated via the host_memory_allocator. On GPU, this is pinned
|
// allocated via the host_memory_allocator. On GPU, this is pinned
|
||||||
// memory.
|
// memory.
|
||||||
@ -736,41 +736,38 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
|||||||
// already defers its work onto a stream (= thread on CPU).
|
// already defers its work onto a stream (= thread on CPU).
|
||||||
transfer_h2d();
|
transfer_h2d();
|
||||||
} else {
|
} else {
|
||||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||||
}
|
}
|
||||||
return py_buffer;
|
return py_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CreateUninitialized(
|
const Shape& shape, PjRtDevice* device) {
|
||||||
const Shape& shape, PjRtClient* client, PjRtDevice* device) {
|
tensorflow::profiler::TraceMe traceme(
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::CreateUninitialized");
|
"PjRtClient::CreateUninitializedBuffer");
|
||||||
VLOG(2) << "PjRtBuffer::CreateUninitialized: shape: " << shape.ToString()
|
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
|
||||||
<< " device: " << device->DebugString();
|
<< shape.ToString() << " device: " << device->DebugString();
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
device->GetLocalDeviceState());
|
device->GetLocalDeviceState());
|
||||||
|
|
||||||
TransferManager* transfer_manager =
|
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||||
client->client()->backend().transfer_manager();
|
|
||||||
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
||||||
transfer_manager->ChooseCompactLayoutForShape(shape));
|
transfer_manager->ChooseCompactLayoutForShape(shape));
|
||||||
|
|
||||||
return AllocateDestinationBuffer(compact_shape, device, local_device,
|
return AllocateDestinationBuffer(compact_shape, device, local_device,
|
||||||
/*copy_stream=*/nullptr,
|
/*copy_stream=*/nullptr,
|
||||||
/*is_uninitialized_create=*/true, client);
|
/*is_uninitialized_create=*/true, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
const LiteralSlice& literal, PjRtDevice* device) {
|
||||||
const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device) {
|
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral");
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostLiteral");
|
VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: "
|
||||||
VLOG(2) << "PjRtBuffer::FromHostLiteral: shape: "
|
|
||||||
<< literal.shape().ToString() << " device: " << device->DebugString();
|
<< literal.shape().ToString() << " device: " << device->DebugString();
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
device->GetLocalDeviceState());
|
device->GetLocalDeviceState());
|
||||||
|
|
||||||
TransferManager* transfer_manager =
|
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||||
client->client()->backend().transfer_manager();
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
Shape compact_shape,
|
Shape compact_shape,
|
||||||
transfer_manager->ChooseCompactLayoutForShape(literal.shape()));
|
transfer_manager->ChooseCompactLayoutForShape(literal.shape()));
|
||||||
@ -778,9 +775,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
|||||||
std::unique_ptr<PjRtBuffer> py_buffer,
|
std::unique_ptr<PjRtBuffer> py_buffer,
|
||||||
AllocateDestinationBuffer(compact_shape, device, local_device,
|
AllocateDestinationBuffer(compact_shape, device, local_device,
|
||||||
local_device->host_to_device_stream(),
|
local_device->host_to_device_stream(),
|
||||||
/*is_uninitialized_create=*/false, client));
|
/*is_uninitialized_create=*/false, this));
|
||||||
|
|
||||||
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
|
PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
|
||||||
CHECK(device_buffer.ok());
|
CHECK(device_buffer.ok());
|
||||||
|
|
||||||
// The host to device transfer is performed on a thread pool, mostly because
|
// The host to device transfer is performed on a thread pool, mostly because
|
||||||
@ -789,11 +786,11 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
|||||||
// usage holds have gone away.
|
// usage holds have gone away.
|
||||||
// TODO(misard) assess if it would be preferable to introduce a heuristic to
|
// TODO(misard) assess if it would be preferable to introduce a heuristic to
|
||||||
// put the transfer into the calling thread for small literals.
|
// put the transfer into the calling thread for small literals.
|
||||||
auto transfer_h2d = [client, transfer_manager, local_device,
|
auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
|
||||||
movable_device_buffer{device_buffer.ToClosure()},
|
movable_device_buffer{device_buffer.ToClosure()},
|
||||||
literal, py_buffer{py_buffer.get()}, compact_shape,
|
literal, py_buffer{py_buffer.get()}, compact_shape,
|
||||||
on_device_shape{py_buffer->on_device_shape()}]() {
|
on_device_shape{py_buffer->on_device_shape()}]() {
|
||||||
ScopedHold device_buffer(movable_device_buffer);
|
PjRtBuffer::ScopedHold device_buffer(movable_device_buffer);
|
||||||
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
|
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
|
||||||
// to report failures from a callback. However, the operations here are
|
// to report failures from a callback. However, the operations here are
|
||||||
// unlikely to fail and not recoverable even if we were to fail: DMAs to
|
// unlikely to fail and not recoverable even if we were to fail: DMAs to
|
||||||
@ -802,7 +799,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
|||||||
|
|
||||||
se::Stream* h2d_stream = local_device->host_to_device_stream();
|
se::Stream* h2d_stream = local_device->host_to_device_stream();
|
||||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
|
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
|
||||||
compact_shape, on_device_shape, client->client()->platform());
|
compact_shape, on_device_shape, local_client->platform());
|
||||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||||
h2d_stream, literal, buffer));
|
h2d_stream, literal, buffer));
|
||||||
|
|
||||||
@ -817,12 +814,12 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
|||||||
.IgnoreError(); // Can return error::Unimplemented
|
.IgnoreError(); // Can return error::Unimplemented
|
||||||
QCHECK(h2d_stream->ok());
|
QCHECK(h2d_stream->ok());
|
||||||
};
|
};
|
||||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||||
return py_buffer;
|
return py_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ void PjRtBuffer::MakeCrossHostReceiveBuffers(
|
void PjRtClient::MakeCrossHostReceiveBuffers(
|
||||||
absl::Span<const Shape> shapes, PjRtClient* client, PjRtDevice* device,
|
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||||
PjRtCrossHostRecvNotifier&& notifier) {
|
PjRtCrossHostRecvNotifier&& notifier) {
|
||||||
if (shapes.empty()) {
|
if (shapes.empty()) {
|
||||||
notifier(InvalidArgument(
|
notifier(InvalidArgument(
|
||||||
@ -843,7 +840,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
|||||||
StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or =
|
StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or =
|
||||||
AllocateDestinationBuffer(shape, device, local_device,
|
AllocateDestinationBuffer(shape, device, local_device,
|
||||||
/*copy_stream=*/nullptr,
|
/*copy_stream=*/nullptr,
|
||||||
/*is_uninitialized_create=*/false, client);
|
/*is_uninitialized_create=*/false, this);
|
||||||
if (!buffer_or.ok()) {
|
if (!buffer_or.ok()) {
|
||||||
notifier(buffer_or.status());
|
notifier(buffer_or.status());
|
||||||
return;
|
return;
|
||||||
@ -851,7 +848,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
|||||||
buffers.push_back(buffer_or.ConsumeValueOrDie());
|
buffers.push_back(buffer_or.ConsumeValueOrDie());
|
||||||
}
|
}
|
||||||
|
|
||||||
client->EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
|
EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
|
||||||
}
|
}
|
||||||
|
|
||||||
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||||
@ -1159,7 +1156,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
|
|||||||
|
|
||||||
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
|
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
|
||||||
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
|
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral");
|
tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral");
|
||||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
|
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
|
||||||
CopyToHostAsyncInternal(discard_cached_copy, layout));
|
CopyToHostAsyncInternal(discard_cached_copy, layout));
|
||||||
if (host_value == nullptr) {
|
if (host_value == nullptr) {
|
||||||
@ -1267,9 +1264,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
|||||||
// Copying across PjRtClients involves a copy through the host.
|
// Copying across PjRtClients involves a copy through the host.
|
||||||
if (dst_device->client() != client_) {
|
if (dst_device->client() != client_) {
|
||||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
|
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
|
||||||
return FromHostBuffer(literal->untyped_data(), literal->shape(),
|
return dst_device->client()->BufferFromHostBuffer(
|
||||||
HostBufferSemantics::kZeroCopy, nullptr,
|
literal->untyped_data(), literal->shape(),
|
||||||
dst_device->client(), dst_device);
|
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device,
|
||||||
@ -2061,14 +2058,13 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/*static*/ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtExecutable::Compile(
|
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
||||||
const XlaComputation& computation, PjRtClient* client,
|
const XlaComputation& computation, CompileOptions options) {
|
||||||
CompileOptions options) {
|
tensorflow::profiler::TraceMe traceme("PjRtClient::Compile");
|
||||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile");
|
|
||||||
|
|
||||||
ExecutableBuildOptions& build_options = options.executable_build_options;
|
ExecutableBuildOptions& build_options = options.executable_build_options;
|
||||||
if (!build_options.device_allocator()) {
|
if (!build_options.device_allocator()) {
|
||||||
build_options.set_device_allocator(client->allocator());
|
build_options.set_device_allocator(allocator());
|
||||||
}
|
}
|
||||||
|
|
||||||
int num_replicas;
|
int num_replicas;
|
||||||
@ -2084,14 +2080,14 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
|||||||
num_partitions = 1;
|
num_partitions = 1;
|
||||||
} else {
|
} else {
|
||||||
if (!build_options.has_device_assignment()) {
|
if (!build_options.has_device_assignment()) {
|
||||||
VLOG(2) << "PjRtExecutable::Compile using default device_assignment.";
|
VLOG(2) << "PjRtClient::Compile using default device_assignment.";
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
DeviceAssignment device_assignment,
|
DeviceAssignment device_assignment,
|
||||||
client->GetDefaultDeviceAssignment(build_options.num_replicas(),
|
GetDefaultDeviceAssignment(build_options.num_replicas(),
|
||||||
build_options.num_partitions()));
|
build_options.num_partitions()));
|
||||||
build_options.set_device_assignment(device_assignment);
|
build_options.set_device_assignment(device_assignment);
|
||||||
}
|
}
|
||||||
VLOG(2) << "PjRtExecutable::Compile device_assignment:\n"
|
VLOG(2) << "PjRtClient::Compile device_assignment:\n"
|
||||||
<< build_options.device_assignment().ToString();
|
<< build_options.device_assignment().ToString();
|
||||||
num_replicas = build_options.device_assignment().replica_count();
|
num_replicas = build_options.device_assignment().replica_count();
|
||||||
num_partitions = build_options.device_assignment().computation_count();
|
num_partitions = build_options.device_assignment().computation_count();
|
||||||
@ -2118,7 +2114,8 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
|||||||
|
|
||||||
// Assign a default layout based on `sharded_shape` to any array subshapes in
|
// Assign a default layout based on `sharded_shape` to any array subshapes in
|
||||||
// `dst_shape` that are missing layouts.
|
// `dst_shape` that are missing layouts.
|
||||||
auto assign_layouts = [client](const Shape& sharded_shape, Shape* dst_shape) {
|
auto assign_layouts = [local_client = client()](const Shape& sharded_shape,
|
||||||
|
Shape* dst_shape) {
|
||||||
return ShapeUtil::ForEachMutableSubshapeWithStatus(
|
return ShapeUtil::ForEachMutableSubshapeWithStatus(
|
||||||
dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
|
dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
|
||||||
if (subshape->IsArray() && !subshape->has_layout()) {
|
if (subshape->IsArray() && !subshape->has_layout()) {
|
||||||
@ -2126,8 +2123,7 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
|||||||
const Shape& sharded_subshape =
|
const Shape& sharded_subshape =
|
||||||
ShapeUtil::GetSubshape(sharded_shape, idx);
|
ShapeUtil::GetSubshape(sharded_shape, idx);
|
||||||
LayoutUtil::SetToDefaultLayout(subshape);
|
LayoutUtil::SetToDefaultLayout(subshape);
|
||||||
TF_ASSIGN_OR_RETURN(Shape layout, client->client()
|
TF_ASSIGN_OR_RETURN(Shape layout, local_client->backend()
|
||||||
->backend()
|
|
||||||
.transfer_manager()
|
.transfer_manager()
|
||||||
->ChooseCompactLayoutForShape(
|
->ChooseCompactLayoutForShape(
|
||||||
sharded_subshape));
|
sharded_subshape));
|
||||||
@ -2162,8 +2158,8 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
|||||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||||
for (int partition = 0; partition < num_partitions; ++partition) {
|
for (int partition = 0; partition < num_partitions; ++partition) {
|
||||||
int device_id = (*device_assignment)(replica, partition);
|
int device_id = (*device_assignment)(replica, partition);
|
||||||
PjRtDevice* device = LookupDevice(*client, device_id);
|
PjRtDevice* device = LookupDevice(*this, device_id);
|
||||||
if (device->host_id() != client->host_id()) {
|
if (device->host_id() != host_id()) {
|
||||||
VLOG(3) << "Non-local device: " << device_id;
|
VLOG(3) << "Non-local device: " << device_id;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -2185,15 +2181,14 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
|||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::vector<std::unique_ptr<LocalExecutable>> local_executables,
|
std::vector<std::unique_ptr<LocalExecutable>> local_executables,
|
||||||
client->client()->Compile(computation, argument_layout_pointers,
|
client()->Compile(computation, argument_layout_pointers, build_options));
|
||||||
build_options));
|
|
||||||
|
|
||||||
auto executable = absl::make_unique<PjRtExecutable>(
|
auto executable = absl::make_unique<PjRtExecutable>(
|
||||||
std::move(local_executables), options.parameter_is_tupled_arguments,
|
std::move(local_executables), options.parameter_is_tupled_arguments,
|
||||||
std::move(device_assignment), std::move(local_logical_device_ids),
|
std::move(device_assignment), std::move(local_logical_device_ids),
|
||||||
std::move(local_devices), client);
|
std::move(local_devices), this);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
executable->SetUpDonation(client, options.parameter_is_tupled_arguments));
|
executable->SetUpDonation(this, options.parameter_is_tupled_arguments));
|
||||||
return executable;
|
return executable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,6 +120,24 @@ struct PjRtCrossHostRecvBuffer {
|
|||||||
using PjRtCrossHostRecvNotifier =
|
using PjRtCrossHostRecvNotifier =
|
||||||
std::function<void(StatusOr<std::vector<PjRtCrossHostRecvBuffer>>&&)>;
|
std::function<void(StatusOr<std::vector<PjRtCrossHostRecvBuffer>>&&)>;
|
||||||
|
|
||||||
|
struct CompileOptions {
|
||||||
|
// The layouts of the arguments that the computation should expect.
|
||||||
|
absl::optional<std::vector<Shape>> argument_layouts;
|
||||||
|
|
||||||
|
// If true, the supplied computation expects its arguments to be wrapped in a
|
||||||
|
// tuple and passed as a single parameter.
|
||||||
|
bool parameter_is_tupled_arguments = false;
|
||||||
|
|
||||||
|
// XLA's compilation time options.
|
||||||
|
ExecutableBuildOptions executable_build_options;
|
||||||
|
|
||||||
|
// If true, the executable can be run on any device. May only be true if
|
||||||
|
// !executable_build_options.has_device_assignment(), so only applies to
|
||||||
|
// single-device executables. Beware: on GPUs, sometimes an executable
|
||||||
|
// compiled for one device doesn't run on another.
|
||||||
|
bool compile_portable_executable = false;
|
||||||
|
};
|
||||||
|
|
||||||
class PjRtExecutable;
|
class PjRtExecutable;
|
||||||
|
|
||||||
// Encapsulates the state of Python session with XLA.
|
// Encapsulates the state of Python session with XLA.
|
||||||
@ -198,6 +216,63 @@ class PjRtClient {
|
|||||||
// Returns a backend-specific HLO cost analysis visitor.
|
// Returns a backend-specific HLO cost analysis visitor.
|
||||||
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
|
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
|
||||||
|
|
||||||
|
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
||||||
|
const XlaComputation& computation, CompileOptions options);
|
||||||
|
|
||||||
|
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||||
|
const Shape& shape, PjRtDevice* device);
|
||||||
|
|
||||||
|
// Describes the semantics the caller to BufferFromHostBuffer expects from the
|
||||||
|
// runtime, in a total order from most restrictive to least restrictive.
|
||||||
|
enum class HostBufferSemantics {
|
||||||
|
// The runtime may not hold references to `data` after the call to
|
||||||
|
// `BufferFromHostBuffer` completes. The caller promises that `data` is
|
||||||
|
// immutable and will not be freed only for the duration of the
|
||||||
|
// BufferFromHostBuffer call. `buffer_reference` will be freed by the time
|
||||||
|
// `BufferFromHostBuffer` returns.
|
||||||
|
kImmutableOnlyDuringCall,
|
||||||
|
|
||||||
|
// The runtime may hold onto `data` after the call to `BufferFromHostBuffer`
|
||||||
|
// returns while the runtime completes a transfer to the device. The caller
|
||||||
|
// promises not to mutate or free `data` until the transfer completes, at
|
||||||
|
// which point the runtime will release `buffer_reference`. It is also
|
||||||
|
// correct to wait on the host (directly or indirectly) for the buffer's
|
||||||
|
// definition event to complete.
|
||||||
|
kImmutableUntilTransferCompletes,
|
||||||
|
|
||||||
|
// The PjRtBuffer may alias `data` internally and the runtime may use the
|
||||||
|
// `data` contents as long as the buffer is alive. The caller promises to
|
||||||
|
// keep `data` alive and not to mutate its contents as long as the buffer is
|
||||||
|
// alive; to notify the caller that the buffer may be freed, the runtime
|
||||||
|
// will release its `buffer_reference` when the PjRtBuffer is freed. On
|
||||||
|
// non-CPU platforms this acts identically to
|
||||||
|
// kImmutableUntilTransferCompletes.
|
||||||
|
kZeroCopy,
|
||||||
|
};
|
||||||
|
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
||||||
|
const void* data, const Shape& shape,
|
||||||
|
HostBufferSemantics host_buffer_semantics,
|
||||||
|
std::shared_ptr<void> buffer_reference, PjRtDevice* device);
|
||||||
|
|
||||||
|
// Note that literal must remain in scope until the transfer has completed, so
|
||||||
|
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
||||||
|
// the return value before letting literal go out of scope.
|
||||||
|
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
||||||
|
const LiteralSlice& literal, PjRtDevice* device);
|
||||||
|
|
||||||
|
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
||||||
|
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
||||||
|
// shapes, with identical layouts, corresponding to the buffers that will be
|
||||||
|
// sent. When resources for the transfer are available, notifier will be
|
||||||
|
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
|
||||||
|
// shape in `shapes`. Each struct contains a buffer that will contain the
|
||||||
|
// received value, and an opaque string that should be transmitted to the
|
||||||
|
// sending host and used in a call to CopyToRemoteDevice. None of the recv
|
||||||
|
// buffers will become ready until *all* of the sends have completed.
|
||||||
|
virtual void MakeCrossHostReceiveBuffers(
|
||||||
|
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||||
|
PjRtCrossHostRecvNotifier&& notifier);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
friend class PjRtBuffer;
|
friend class PjRtBuffer;
|
||||||
virtual void EnqueueCrossHostReceive(
|
virtual void EnqueueCrossHostReceive(
|
||||||
@ -385,6 +460,7 @@ class PjRtBuffer {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
friend class PjRtBuffer;
|
friend class PjRtBuffer;
|
||||||
|
friend class PjRtClient;
|
||||||
|
|
||||||
// Helper struct that makes it possible to move a ScopedHold through a
|
// Helper struct that makes it possible to move a ScopedHold through a
|
||||||
// closure.
|
// closure.
|
||||||
@ -423,62 +499,6 @@ class PjRtBuffer {
|
|||||||
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
|
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns a buffer with uninitialized contents.
|
|
||||||
static StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitialized(
|
|
||||||
const Shape& shape, PjRtClient* client, PjRtDevice* device);
|
|
||||||
|
|
||||||
// Describes the semantics the caller to FromHostBuffer expects from the
|
|
||||||
// runtime, in a total order from most restrictive to least restrictive.
|
|
||||||
enum class HostBufferSemantics {
|
|
||||||
// The runtime may not hold references to `data` after the call to
|
|
||||||
// `FromHostBuffer` completes. The caller promises that `data` is immutable
|
|
||||||
// and will not be freed only for the duration of the FromHostBuffer call.
|
|
||||||
// `buffer_reference` will be freed by the time `FromHostBuffer` returns.
|
|
||||||
kImmutableOnlyDuringCall,
|
|
||||||
|
|
||||||
// The runtime may hold onto `data` after the call to `FromHostBuffer`
|
|
||||||
// returns while the runtime completes a transfer to the device. The caller
|
|
||||||
// promises not to mutate or free `data` until the transfer completes, at
|
|
||||||
// which point the runtime will release `buffer_reference`. It is also
|
|
||||||
// correct to wait on the host (directly or indirectly) for the buffer's
|
|
||||||
// definition event to complete.
|
|
||||||
kImmutableUntilTransferCompletes,
|
|
||||||
|
|
||||||
// The PjRtBuffer may alias `data` internally and the runtime may use the
|
|
||||||
// `data` contents as long as the buffer is alive.
|
|
||||||
// The caller promises to keep `data` alive and not to mutate its contents
|
|
||||||
// as long as the buffer is alive; to notify the caller that the buffer may
|
|
||||||
// be freed, the runtime will release its `buffer_reference` when the
|
|
||||||
// PjRtBuffer is freed. On non-CPU platforms this acts identically to
|
|
||||||
// kImmutableUntilTransferCompletes.
|
|
||||||
kZeroCopy,
|
|
||||||
};
|
|
||||||
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostBuffer(
|
|
||||||
const void* data, const Shape& shape,
|
|
||||||
HostBufferSemantics host_buffer_semantics,
|
|
||||||
std::shared_ptr<void> buffer_reference, PjRtClient* client,
|
|
||||||
PjRtDevice* device);
|
|
||||||
|
|
||||||
// Note that literal must remain in scope until the transfer has completed, so
|
|
||||||
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
|
||||||
// the return value before letting literal go out of scope.
|
|
||||||
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostLiteral(
|
|
||||||
const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device);
|
|
||||||
|
|
||||||
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
|
||||||
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
|
||||||
// shapes, with identical layouts, corresponding to the buffers that will be
|
|
||||||
// sent. When resources for the transfer are available, notifier will be
|
|
||||||
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
|
|
||||||
// shape in `shapes`. Each struct contains a buffer that will contain the
|
|
||||||
// received value, and an opaque string that should be transmitted to the
|
|
||||||
// sending host and used in a call to CopyToRemoteDevice. None of the recv
|
|
||||||
// buffers will become ready until *all* of the sends have completed.
|
|
||||||
static void MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
|
|
||||||
PjRtClient* client,
|
|
||||||
PjRtDevice* device,
|
|
||||||
PjRtCrossHostRecvNotifier&& notifier);
|
|
||||||
|
|
||||||
PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
|
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
|
||||||
PjRtClient* client, PjRtDevice* device);
|
PjRtClient* client, PjRtDevice* device);
|
||||||
@ -661,24 +681,6 @@ class PjRtBuffer {
|
|||||||
Semaphore donation_semaphore_;
|
Semaphore donation_semaphore_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CompileOptions {
|
|
||||||
// The layouts of the arguments that the computation should expect.
|
|
||||||
absl::optional<std::vector<Shape>> argument_layouts;
|
|
||||||
|
|
||||||
// If true, the supplied computation expects its arguments to be wrapped in a
|
|
||||||
// tuple and passed as a single parameter.
|
|
||||||
bool parameter_is_tupled_arguments = false;
|
|
||||||
|
|
||||||
// XLA's compilation time options.
|
|
||||||
ExecutableBuildOptions executable_build_options;
|
|
||||||
|
|
||||||
// If true, the executable can be run on any device. May only be true if
|
|
||||||
// !executable_build_options.has_device_assignment(), so only applies to
|
|
||||||
// single-device executables. Beware: on GPUs, sometimes an executable
|
|
||||||
// compiled for one device doesn't run on another.
|
|
||||||
bool compile_portable_executable = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
class ExecuteContext {
|
class ExecuteContext {
|
||||||
public:
|
public:
|
||||||
virtual ~ExecuteContext() = default;
|
virtual ~ExecuteContext() = default;
|
||||||
@ -710,10 +712,6 @@ struct ExecuteOptions {
|
|||||||
// buffer will be donated when passed to the execution.
|
// buffer will be donated when passed to the execution.
|
||||||
class PjRtExecutable {
|
class PjRtExecutable {
|
||||||
public:
|
public:
|
||||||
static StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
|
||||||
const XlaComputation& computation, PjRtClient* client,
|
|
||||||
CompileOptions options);
|
|
||||||
|
|
||||||
PjRtExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
|
PjRtExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
|
||||||
bool parameter_is_tupled_arguments,
|
bool parameter_is_tupled_arguments,
|
||||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||||
@ -783,6 +781,7 @@ class PjRtExecutable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class PjRtClient;
|
||||||
// Initializes information about which arguments to which executables must be
|
// Initializes information about which arguments to which executables must be
|
||||||
// donated due to aliases that were specified by the computation.
|
// donated due to aliases that were specified by the computation.
|
||||||
Status SetUpDonation(PjRtClient* client, bool tuple_inputs);
|
Status SetUpDonation(PjRtClient* client, bool tuple_inputs);
|
||||||
|
@ -465,10 +465,10 @@ std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
|
|||||||
xla::PjRtDevice* device) {
|
xla::PjRtDevice* device) {
|
||||||
CppType data = py::cast<Pybind11Type>(scalar);
|
CppType data = py::cast<Pybind11Type>(scalar);
|
||||||
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({});
|
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({});
|
||||||
return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer(
|
return ValueOrThrow(client->BufferFromHostBuffer(
|
||||||
&data, shape,
|
&data, shape,
|
||||||
xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
|
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
|
||||||
client, device));
|
device));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert a scalar to the associated PjRtBuffer or raises an error if it is
|
// Convert a scalar to the associated PjRtBuffer or raises an error if it is
|
||||||
@ -502,17 +502,17 @@ StatusOr<std::unique_ptr<xla::PjRtBuffer>> ScalarToBuffer(
|
|||||||
if (jax_enable_x64) {
|
if (jax_enable_x64) {
|
||||||
xla::complex128 data(result.real, result.imag);
|
xla::complex128 data(result.real, result.imag);
|
||||||
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({});
|
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({});
|
||||||
return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer(
|
return ValueOrThrow(client->BufferFromHostBuffer(
|
||||||
&data, shape,
|
&data, shape,
|
||||||
xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall,
|
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
|
||||||
nullptr, client, device));
|
nullptr, device));
|
||||||
} else {
|
} else {
|
||||||
xla::complex64 data(result.real, result.imag);
|
xla::complex64 data(result.real, result.imag);
|
||||||
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({});
|
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({});
|
||||||
return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer(
|
return ValueOrThrow(client->BufferFromHostBuffer(
|
||||||
&data, shape,
|
&data, shape,
|
||||||
xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall,
|
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
|
||||||
nullptr, client, device));
|
nullptr, device));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
@ -678,7 +678,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
|||||||
ValueOrThrow(pyclient.BufferFromPyval(
|
ValueOrThrow(pyclient.BufferFromPyval(
|
||||||
numpy_array, data_device,
|
numpy_array, data_device,
|
||||||
/*force_copy=*/false, /*host_buffer_semantics=*/
|
/*force_copy=*/false, /*host_buffer_semantics=*/
|
||||||
xla::PjRtBuffer::HostBufferSemantics::kZeroCopy));
|
xla::PjRtClient::HostBufferSemantics::kZeroCopy));
|
||||||
arg_buffers.push_back(buffer->buffer());
|
arg_buffers.push_back(buffer->buffer());
|
||||||
|
|
||||||
ArgSignature sig;
|
ArgSignature sig;
|
||||||
|
@ -409,10 +409,9 @@ Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
|
|||||||
compile_options.executable_build_options.set_device_assignment(
|
compile_options.executable_build_options.set_device_assignment(
|
||||||
device_assignment);
|
device_assignment);
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
|
||||||
std::unique_ptr<PjRtExecutable> executable,
|
devices_[device_idx]->client()->Compile(
|
||||||
PjRtExecutable::Compile(computation, devices_[device_idx]->client(),
|
computation, std::move(compile_options)));
|
||||||
std::move(compile_options)));
|
|
||||||
ExecuteOptions execute_options;
|
ExecuteOptions execute_options;
|
||||||
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
|
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
|
||||||
executable->Execute({}, execute_options));
|
executable->Execute({}, execute_options));
|
||||||
|
@ -40,9 +40,8 @@ Status CompileAndExecute(XlaBuilder* builder, XlaOp root, int device_id,
|
|||||||
compile_options.executable_build_options.set_device_assignment(
|
compile_options.executable_build_options.set_device_assignment(
|
||||||
device_assignment);
|
device_assignment);
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
|
||||||
std::unique_ptr<PjRtExecutable> executable,
|
client->Compile(computation, std::move(compile_options)));
|
||||||
PjRtExecutable::Compile(computation, client, std::move(compile_options)));
|
|
||||||
ExecuteOptions execute_options;
|
ExecuteOptions execute_options;
|
||||||
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
|
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
|
||||||
executable->Execute({}, execute_options));
|
executable->Execute({}, execute_options));
|
||||||
|
@ -89,7 +89,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
|
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
|
||||||
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
||||||
PjRtBuffer::HostBufferSemantics host_buffer_semantics) {
|
PjRtClient::HostBufferSemantics host_buffer_semantics) {
|
||||||
if (device == nullptr) {
|
if (device == nullptr) {
|
||||||
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
|
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
|
||||||
device = pjrt_client_->local_devices().front();
|
device = pjrt_client_->local_devices().front();
|
||||||
@ -114,10 +114,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
|
|||||||
std::unique_ptr<PjRtBuffer> buffer;
|
std::unique_ptr<PjRtBuffer> buffer;
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(buffer, pjrt_client_->BufferFromHostBuffer(
|
||||||
buffer, PjRtBuffer::FromHostBuffer(
|
c->buf_ptr, c->shape, host_buffer_semantics,
|
||||||
c->buf_ptr, c->shape, host_buffer_semantics,
|
std::move(py_buffer_ref), device));
|
||||||
std::move(py_buffer_ref), pjrt_client_.get(), device));
|
|
||||||
}
|
}
|
||||||
auto traceback = Traceback::Get();
|
auto traceback = Traceback::Get();
|
||||||
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
|
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
|
||||||
@ -131,8 +130,7 @@ StatusOr<std::shared_ptr<PyExecutable>> PyClient::Compile(
|
|||||||
{
|
{
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
TF_ASSIGN_OR_RETURN(executable,
|
TF_ASSIGN_OR_RETURN(executable,
|
||||||
PjRtExecutable::Compile(computation, pjrt_client_.get(),
|
pjrt_client_->Compile(computation, std::move(options)));
|
||||||
std::move(options)));
|
|
||||||
TF_ASSIGN_OR_RETURN(fingerprint,
|
TF_ASSIGN_OR_RETURN(fingerprint,
|
||||||
pjrt_client_->ExecutableFingerprint(*executable));
|
pjrt_client_->ExecutableFingerprint(*executable));
|
||||||
}
|
}
|
||||||
|
@ -123,7 +123,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
|
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
|
||||||
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
||||||
PjRtBuffer::HostBufferSemantics host_buffer_semantics);
|
PjRtClient::HostBufferSemantics host_buffer_semantics);
|
||||||
|
|
||||||
StatusOr<std::shared_ptr<PyExecutable>> Compile(
|
StatusOr<std::shared_ptr<PyExecutable>> Compile(
|
||||||
const XlaComputation& computation, CompileOptions options);
|
const XlaComputation& computation, CompileOptions options);
|
||||||
|
@ -535,12 +535,12 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
|
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
|
||||||
.value("BFC", GpuAllocatorConfig::Kind::kBFC);
|
.value("BFC", GpuAllocatorConfig::Kind::kBFC);
|
||||||
|
|
||||||
py::enum_<PjRtBuffer::HostBufferSemantics>(m, "HostBufferSemantics")
|
py::enum_<PjRtClient::HostBufferSemantics>(m, "HostBufferSemantics")
|
||||||
.value("IMMUTABLE_ONLY_DURING_CALL",
|
.value("IMMUTABLE_ONLY_DURING_CALL",
|
||||||
PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall)
|
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall)
|
||||||
.value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
|
.value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
|
||||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes)
|
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes)
|
||||||
.value("ZERO_COPY", PjRtBuffer::HostBufferSemantics::kZeroCopy);
|
.value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy);
|
||||||
|
|
||||||
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
||||||
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
||||||
@ -562,7 +562,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.def("buffer_from_pyval", &PyClient::BufferFromPyval, py::arg("argument"),
|
.def("buffer_from_pyval", &PyClient::BufferFromPyval, py::arg("argument"),
|
||||||
py::arg("device") = nullptr, py::arg("force_copy") = false,
|
py::arg("device") = nullptr, py::arg("force_copy") = false,
|
||||||
py::arg("host_buffer_semantics") =
|
py::arg("host_buffer_semantics") =
|
||||||
PjRtBuffer::HostBufferSemantics::kZeroCopy)
|
PjRtClient::HostBufferSemantics::kZeroCopy)
|
||||||
.def("compile", &PyClient::Compile, py::arg("computation"),
|
.def("compile", &PyClient::Compile, py::arg("computation"),
|
||||||
py::arg("compile_options") = CompileOptions())
|
py::arg("compile_options") = CompileOptions())
|
||||||
.def("heap_profile", &PyClient::HeapProfile);
|
.def("heap_profile", &PyClient::HeapProfile);
|
||||||
|
Loading…
Reference in New Issue
Block a user