[PJRT] Change PjRtClient::BufferFromHostBuffer to take an on_done_with_host_buffer callback rather than a std::shared_ptr<void> that must be kept alive.
This API is clearer (several users were passing dummy shared_ptr<> values) and probably a little bit faster (for example, when the value to be kept alive is not already a std::shared_ptr<>). PiperOrigin-RevId: 354089254 Change-Id: Ib812efa3e4952562d1eddbbb339cdcdd0d2f2844
This commit is contained in:
parent
7d5328c5df
commit
a75ad4bc12
@ -74,19 +74,19 @@ TEST(GpuMultiStream, Basics) {
|
||||
client->BufferFromHostBuffer(
|
||||
dummy_inputs.data(), dummy_shape,
|
||||
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, device));
|
||||
/*on_done_with_host_buffer=*/nullptr, device));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto in_buffer0,
|
||||
client->BufferFromHostBuffer(
|
||||
inputs.data(), shape,
|
||||
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, device));
|
||||
/*on_done_with_host_buffer=*/nullptr, device));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto in_buffer1,
|
||||
client->BufferFromHostBuffer(
|
||||
inputs.data(), shape,
|
||||
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, device));
|
||||
/*on_done_with_host_buffer=*/nullptr, device));
|
||||
// The execution may be enqueued before the transfers complete, requiring
|
||||
// adequate device-side synchronization.
|
||||
ExecuteOptions options;
|
||||
|
@ -107,6 +107,7 @@ class LocalDeviceState {
|
||||
se::Stream* host_to_device_stream() const {
|
||||
return host_to_device_stream_.get();
|
||||
}
|
||||
se::Stream* callback_stream() const { return callback_stream_.get(); }
|
||||
|
||||
// Returns a device to host stream. Allocates streams in a round-robin fashion
|
||||
// amongst the available streams.
|
||||
|
@ -199,14 +199,14 @@ class PjRtClient {
|
||||
// The runtime may not hold references to `data` after the call to
|
||||
// `BufferFromHostBuffer` completes. The caller promises that `data` is
|
||||
// immutable and will not be freed only for the duration of the
|
||||
// BufferFromHostBuffer call. `buffer_reference` will be freed by the time
|
||||
// `BufferFromHostBuffer` returns.
|
||||
// BufferFromHostBuffer call. `on_done_with_host_buffer` will be called
|
||||
// before `BufferFromHostBuffer` returns.
|
||||
kImmutableOnlyDuringCall,
|
||||
|
||||
// The runtime may hold onto `data` after the call to `BufferFromHostBuffer`
|
||||
// returns while the runtime completes a transfer to the device. The caller
|
||||
// promises not to mutate or free `data` until the transfer completes, at
|
||||
// which point the runtime will release `buffer_reference`. It is also
|
||||
// which point the runtime will call `on_done_with_host_buffer`. It is also
|
||||
// correct to wait on the host (directly or indirectly) for the buffer's
|
||||
// definition event to complete.
|
||||
kImmutableUntilTransferCompletes,
|
||||
@ -215,15 +215,17 @@ class PjRtClient {
|
||||
// `data` contents as long as the buffer is alive. The caller promises to
|
||||
// keep `data` alive and not to mutate its contents as long as the buffer is
|
||||
// alive; to notify the caller that the buffer may be freed, the runtime
|
||||
// will release its `buffer_reference` when the PjRtBuffer is freed. On
|
||||
// will call `on_done_with_host_buffer` when the PjRtBuffer is freed. On
|
||||
// non-CPU platforms this acts identically to
|
||||
// kImmutableUntilTransferCompletes.
|
||||
kZeroCopy,
|
||||
};
|
||||
// on_done_with_host_buffer is optional and may be null.
|
||||
// on_done_with_host_buffer will be called iff an OK status is returned.
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device) = 0;
|
||||
std::function<void()> on_done_with_host_buffer, PjRtDevice* device) = 0;
|
||||
|
||||
// Note that literal must remain in scope until the transfer has completed, so
|
||||
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
||||
|
@ -604,7 +604,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||
PjRtStreamExecutorClient::BufferFromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
|
||||
std::function<void()> on_done_with_host_buffer, PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme(
|
||||
"PjRtStreamExecutorClient::BufferFromHostBuffer");
|
||||
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
|
||||
@ -647,19 +647,20 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
|
||||
// further copies. At the time of writing we require a 16-byte alignment
|
||||
// because XLA may generate code which requires it.
|
||||
if (can_use_zero_copy) {
|
||||
on_delete_callback = [buffer_reference{std::move(buffer_reference)}]() {
|
||||
// Frees buffer_reference.
|
||||
};
|
||||
on_delete_callback = std::move(on_done_with_host_buffer);
|
||||
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
|
||||
} else {
|
||||
void* staging_buffer = host_memory_allocator()->AllocateRaw(
|
||||
cpu_function_runtime::kMinAlign, size);
|
||||
buffer = se::DeviceMemoryBase(staging_buffer, size);
|
||||
std::memcpy(staging_buffer, data, size);
|
||||
if (on_done_with_host_buffer) {
|
||||
on_done_with_host_buffer();
|
||||
}
|
||||
on_delete_callback = [staging_buffer, host_memory_allocator =
|
||||
host_memory_allocator()]() {
|
||||
host_memory_allocator->DeallocateRaw(staging_buffer);
|
||||
};
|
||||
buffer = se::DeviceMemoryBase(staging_buffer, size);
|
||||
std::memcpy(staging_buffer, data, size);
|
||||
}
|
||||
absl::Span<const std::shared_ptr<BufferSequencingEvent>>
|
||||
definition_events;
|
||||
@ -702,7 +703,10 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
|
||||
// thread.
|
||||
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
|
||||
std::memcpy(staging_buffer.get(), data, size);
|
||||
buffer_reference.reset();
|
||||
if (on_done_with_host_buffer) {
|
||||
on_done_with_host_buffer();
|
||||
on_done_with_host_buffer = nullptr;
|
||||
}
|
||||
data = nullptr;
|
||||
}
|
||||
|
||||
@ -718,7 +722,8 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
|
||||
py_buffer{py_buffer.get()}, compact_shape,
|
||||
on_device_shape{py_buffer->on_device_shape()},
|
||||
staging_buffer{std::move(staging_buffer)},
|
||||
buffer_reference{std::move(buffer_reference)},
|
||||
on_done_with_host_buffer{
|
||||
std::move(on_done_with_host_buffer)},
|
||||
host_buffer_semantics]() {
|
||||
PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
|
||||
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
|
||||
@ -756,9 +761,16 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
|
||||
local_device, std::move(device_buffer), event,
|
||||
local_device->host_to_device_stream()));
|
||||
|
||||
local_device->ThenRelease(
|
||||
local_device->host_to_device_stream(),
|
||||
std::make_pair(std::move(buffer_reference), std::move(staging_buffer)));
|
||||
local_device->callback_stream()->ThenWaitFor(
|
||||
local_device->host_to_device_stream());
|
||||
local_device->ThenExecuteOnCallbackThread(
|
||||
local_device->callback_stream(),
|
||||
[staging_buffer{std::move(staging_buffer)},
|
||||
on_done_with_host_buffer{std::move(on_done_with_host_buffer)}]() {
|
||||
if (on_done_with_host_buffer) {
|
||||
on_done_with_host_buffer();
|
||||
}
|
||||
});
|
||||
};
|
||||
if (is_cpu_platform) {
|
||||
// Using the thread_pool would be a double thread hop; the code
|
||||
@ -1306,7 +1318,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtStreamExecutorBuffer::CopyToDevice(
|
||||
return dst_device->client()->BufferFromHostBuffer(
|
||||
literal_pointer->untyped_data(), literal_pointer->shape(),
|
||||
PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy,
|
||||
std::move(literal), dst_device);
|
||||
[literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
|
@ -187,7 +187,8 @@ class PjRtStreamExecutorClient : public PjRtClient {
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
|
||||
std::function<void()> on_done_with_host_buffer,
|
||||
PjRtDevice* device) override;
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
||||
const LiteralSlice& literal, PjRtDevice* device) override;
|
||||
|
@ -123,15 +123,23 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
|
||||
py::cast<std::string>(py::repr(argument)));
|
||||
}
|
||||
|
||||
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
|
||||
GlobalPyRefManager()->ManageReference(std::move(c->array));
|
||||
std::function<void()> on_done_with_host_buffer;
|
||||
if (host_buffer_semantics !=
|
||||
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) {
|
||||
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
|
||||
GlobalPyRefManager()->ManageReference(std::move(c->array));
|
||||
on_done_with_host_buffer =
|
||||
[py_buffer_ref{
|
||||
std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ };
|
||||
}
|
||||
|
||||
std::unique_ptr<PjRtBuffer> buffer;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(buffer, pjrt_client_->BufferFromHostBuffer(
|
||||
c->buf_ptr, c->shape, host_buffer_semantics,
|
||||
std::move(py_buffer_ref), device));
|
||||
TF_ASSIGN_OR_RETURN(buffer,
|
||||
pjrt_client_->BufferFromHostBuffer(
|
||||
c->buf_ptr, c->shape, host_buffer_semantics,
|
||||
std::move(on_done_with_host_buffer), device));
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user