[PJRT] Change PjRtClient::BufferFromHostBuffer to take an on_done_with_host_buffer callback rather than a std::shared_ptr<void> that must be kept alive.

This API is clearer (several users were passing dummy shared_ptr<> values) and probably a little bit faster (for example, when the value to be kept alive is not already a std::shared_ptr<>).

PiperOrigin-RevId: 354089254
Change-Id: Ib812efa3e4952562d1eddbbb339cdcdd0d2f2844
This commit is contained in:
Peter Hawkins 2021-01-27 07:42:47 -08:00 committed by TensorFlower Gardener
parent 7d5328c5df
commit a75ad4bc12
6 changed files with 50 additions and 26 deletions

View File

@ -74,19 +74,19 @@ TEST(GpuMultiStream, Basics) {
client->BufferFromHostBuffer( client->BufferFromHostBuffer(
dummy_inputs.data(), dummy_shape, dummy_inputs.data(), dummy_shape,
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes, PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, device)); /*on_done_with_host_buffer=*/nullptr, device));
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto in_buffer0, auto in_buffer0,
client->BufferFromHostBuffer( client->BufferFromHostBuffer(
inputs.data(), shape, inputs.data(), shape,
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes, PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, device)); /*on_done_with_host_buffer=*/nullptr, device));
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto in_buffer1, auto in_buffer1,
client->BufferFromHostBuffer( client->BufferFromHostBuffer(
inputs.data(), shape, inputs.data(), shape,
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes, PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, device)); /*on_done_with_host_buffer=*/nullptr, device));
// The execution may be enqueued before the transfers complete, requiring // The execution may be enqueued before the transfers complete, requiring
// adequate device-side synchronization. // adequate device-side synchronization.
ExecuteOptions options; ExecuteOptions options;

View File

@ -107,6 +107,7 @@ class LocalDeviceState {
se::Stream* host_to_device_stream() const { se::Stream* host_to_device_stream() const {
return host_to_device_stream_.get(); return host_to_device_stream_.get();
} }
se::Stream* callback_stream() const { return callback_stream_.get(); }
// Returns a device to host stream. Allocates streams in a round-robin fashion // Returns a device to host stream. Allocates streams in a round-robin fashion
// amongst the available streams. // amongst the available streams.

View File

@ -199,14 +199,14 @@ class PjRtClient {
// The runtime may not hold references to `data` after the call to // The runtime may not hold references to `data` after the call to
// `BufferFromHostBuffer` completes. The caller promises that `data` is // `BufferFromHostBuffer` completes. The caller promises that `data` is
// immutable and will not be freed only for the duration of the // immutable and will not be freed only for the duration of the
// BufferFromHostBuffer call. `buffer_reference` will be freed by the time // BufferFromHostBuffer call. `on_done_with_host_buffer` will be called
// `BufferFromHostBuffer` returns. // before `BufferFromHostBuffer` returns.
kImmutableOnlyDuringCall, kImmutableOnlyDuringCall,
// The runtime may hold onto `data` after the call to `BufferFromHostBuffer` // The runtime may hold onto `data` after the call to `BufferFromHostBuffer`
// returns while the runtime completes a transfer to the device. The caller // returns while the runtime completes a transfer to the device. The caller
// promises not to mutate or free `data` until the transfer completes, at // promises not to mutate or free `data` until the transfer completes, at
// which point the runtime will release `buffer_reference`. It is also // which point the runtime will call `on_done_with_host_buffer`. It is also
// correct to wait on the host (directly or indirectly) for the buffer's // correct to wait on the host (directly or indirectly) for the buffer's
// definition event to complete. // definition event to complete.
kImmutableUntilTransferCompletes, kImmutableUntilTransferCompletes,
@ -215,15 +215,17 @@ class PjRtClient {
// `data` contents as long as the buffer is alive. The caller promises to // `data` contents as long as the buffer is alive. The caller promises to
// keep `data` alive and not to mutate its contents as long as the buffer is // keep `data` alive and not to mutate its contents as long as the buffer is
// alive; to notify the caller that the buffer may be freed, the runtime // alive; to notify the caller that the buffer may be freed, the runtime
// will release its `buffer_reference` when the PjRtBuffer is freed. On // will call `on_done_with_host_buffer` when the PjRtBuffer is freed. On
// non-CPU platforms this acts identically to // non-CPU platforms this acts identically to
// kImmutableUntilTransferCompletes. // kImmutableUntilTransferCompletes.
kZeroCopy, kZeroCopy,
}; };
// on_done_with_host_buffer is optional and may be null.
// on_done_with_host_buffer will be called iff an OK status is returned.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer( virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape, const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics, HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) = 0; std::function<void()> on_done_with_host_buffer, PjRtDevice* device) = 0;
// Note that literal must remain in scope until the transfer has completed, so // Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on // the caller should, for example, wait for BlockHostUntilReady() completes on

View File

@ -604,7 +604,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::BufferFromHostBuffer( PjRtStreamExecutorClient::BufferFromHostBuffer(
const void* data, const Shape& shape, const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics, HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) { std::function<void()> on_done_with_host_buffer, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme( tensorflow::profiler::TraceMe traceme(
"PjRtStreamExecutorClient::BufferFromHostBuffer"); "PjRtStreamExecutorClient::BufferFromHostBuffer");
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: " VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
@ -647,19 +647,20 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
// further copies. At the time of writing we require a 16-byte alignment // further copies. At the time of writing we require a 16-byte alignment
// because XLA may generate code which requires it. // because XLA may generate code which requires it.
if (can_use_zero_copy) { if (can_use_zero_copy) {
on_delete_callback = [buffer_reference{std::move(buffer_reference)}]() { on_delete_callback = std::move(on_done_with_host_buffer);
// Frees buffer_reference.
};
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size); buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
} else { } else {
void* staging_buffer = host_memory_allocator()->AllocateRaw( void* staging_buffer = host_memory_allocator()->AllocateRaw(
cpu_function_runtime::kMinAlign, size); cpu_function_runtime::kMinAlign, size);
buffer = se::DeviceMemoryBase(staging_buffer, size);
std::memcpy(staging_buffer, data, size);
if (on_done_with_host_buffer) {
on_done_with_host_buffer();
}
on_delete_callback = [staging_buffer, host_memory_allocator = on_delete_callback = [staging_buffer, host_memory_allocator =
host_memory_allocator()]() { host_memory_allocator()]() {
host_memory_allocator->DeallocateRaw(staging_buffer); host_memory_allocator->DeallocateRaw(staging_buffer);
}; };
buffer = se::DeviceMemoryBase(staging_buffer, size);
std::memcpy(staging_buffer, data, size);
} }
absl::Span<const std::shared_ptr<BufferSequencingEvent>> absl::Span<const std::shared_ptr<BufferSequencingEvent>>
definition_events; definition_events;
@ -702,7 +703,10 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
// thread. // thread.
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) { if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
std::memcpy(staging_buffer.get(), data, size); std::memcpy(staging_buffer.get(), data, size);
buffer_reference.reset(); if (on_done_with_host_buffer) {
on_done_with_host_buffer();
on_done_with_host_buffer = nullptr;
}
data = nullptr; data = nullptr;
} }
@ -718,7 +722,8 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
py_buffer{py_buffer.get()}, compact_shape, py_buffer{py_buffer.get()}, compact_shape,
on_device_shape{py_buffer->on_device_shape()}, on_device_shape{py_buffer->on_device_shape()},
staging_buffer{std::move(staging_buffer)}, staging_buffer{std::move(staging_buffer)},
buffer_reference{std::move(buffer_reference)}, on_done_with_host_buffer{
std::move(on_done_with_host_buffer)},
host_buffer_semantics]() { host_buffer_semantics]() {
PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer); PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
@ -756,9 +761,16 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
local_device, std::move(device_buffer), event, local_device, std::move(device_buffer), event,
local_device->host_to_device_stream())); local_device->host_to_device_stream()));
local_device->ThenRelease( local_device->callback_stream()->ThenWaitFor(
local_device->host_to_device_stream(), local_device->host_to_device_stream());
std::make_pair(std::move(buffer_reference), std::move(staging_buffer))); local_device->ThenExecuteOnCallbackThread(
local_device->callback_stream(),
[staging_buffer{std::move(staging_buffer)},
on_done_with_host_buffer{std::move(on_done_with_host_buffer)}]() {
if (on_done_with_host_buffer) {
on_done_with_host_buffer();
}
});
}; };
if (is_cpu_platform) { if (is_cpu_platform) {
// Using the thread_pool would be a double thread hop; the code // Using the thread_pool would be a double thread hop; the code
@ -1306,7 +1318,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtStreamExecutorBuffer::CopyToDevice(
return dst_device->client()->BufferFromHostBuffer( return dst_device->client()->BufferFromHostBuffer(
literal_pointer->untyped_data(), literal_pointer->shape(), literal_pointer->untyped_data(), literal_pointer->shape(),
PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy, PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy,
std::move(literal), dst_device); [literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
} }
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(

View File

@ -187,7 +187,8 @@ class PjRtStreamExecutorClient : public PjRtClient {
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer( StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape, const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics, HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override; std::function<void()> on_done_with_host_buffer,
PjRtDevice* device) override;
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral( StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device) override; const LiteralSlice& literal, PjRtDevice* device) override;

View File

@ -123,15 +123,23 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
py::cast<std::string>(py::repr(argument))); py::cast<std::string>(py::repr(argument)));
} }
std::function<void()> on_done_with_host_buffer;
if (host_buffer_semantics !=
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) {
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref = std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
GlobalPyRefManager()->ManageReference(std::move(c->array)); GlobalPyRefManager()->ManageReference(std::move(c->array));
on_done_with_host_buffer =
[py_buffer_ref{
std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ };
}
std::unique_ptr<PjRtBuffer> buffer; std::unique_ptr<PjRtBuffer> buffer;
{ {
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(buffer, pjrt_client_->BufferFromHostBuffer( TF_ASSIGN_OR_RETURN(buffer,
pjrt_client_->BufferFromHostBuffer(
c->buf_ptr, c->shape, host_buffer_semantics, c->buf_ptr, c->shape, host_buffer_semantics,
std::move(py_buffer_ref), device)); std::move(on_done_with_host_buffer), device));
} }
return buffer; return buffer;
} }