[PJRT] Change PjRtClient::BufferFromHostBuffer to take an on_done_with_host_buffer callback rather than a std::shared_ptr<void> that must be kept alive.

This API is clearer (several users were passing dummy shared_ptr<> values) and probably a little bit faster (for example, when the value to be kept alive is not already a std::shared_ptr<>).

PiperOrigin-RevId: 354089254
Change-Id: Ib812efa3e4952562d1eddbbb339cdcdd0d2f2844
This commit is contained in:
Peter Hawkins 2021-01-27 07:42:47 -08:00 committed by TensorFlower Gardener
parent 7d5328c5df
commit a75ad4bc12
6 changed files with 50 additions and 26 deletions

View File

@ -74,19 +74,19 @@ TEST(GpuMultiStream, Basics) {
client->BufferFromHostBuffer(
dummy_inputs.data(), dummy_shape,
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, device));
/*on_done_with_host_buffer=*/nullptr, device));
TF_ASSERT_OK_AND_ASSIGN(
auto in_buffer0,
client->BufferFromHostBuffer(
inputs.data(), shape,
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, device));
/*on_done_with_host_buffer=*/nullptr, device));
TF_ASSERT_OK_AND_ASSIGN(
auto in_buffer1,
client->BufferFromHostBuffer(
inputs.data(), shape,
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, device));
/*on_done_with_host_buffer=*/nullptr, device));
// The execution may be enqueued before the transfers complete, requiring
// adequate device-side synchronization.
ExecuteOptions options;

View File

@ -107,6 +107,7 @@ class LocalDeviceState {
se::Stream* host_to_device_stream() const {
return host_to_device_stream_.get();
}
se::Stream* callback_stream() const { return callback_stream_.get(); }
// Returns a device to host stream. Allocates streams in a round-robin fashion
// amongst the available streams.

View File

@ -199,14 +199,14 @@ class PjRtClient {
// The runtime may not hold references to `data` after the call to
// `BufferFromHostBuffer` completes. The caller promises that `data` is
// immutable and will not be freed only for the duration of the
// BufferFromHostBuffer call. `buffer_reference` will be freed by the time
// `BufferFromHostBuffer` returns.
// BufferFromHostBuffer call. `on_done_with_host_buffer` will be called
// before `BufferFromHostBuffer` returns.
kImmutableOnlyDuringCall,
// The runtime may hold onto `data` after the call to `BufferFromHostBuffer`
// returns while the runtime completes a transfer to the device. The caller
// promises not to mutate or free `data` until the transfer completes, at
// which point the runtime will release `buffer_reference`. It is also
// which point the runtime will call `on_done_with_host_buffer`. It is also
// correct to wait on the host (directly or indirectly) for the buffer's
// definition event to complete.
kImmutableUntilTransferCompletes,
@ -215,15 +215,17 @@ class PjRtClient {
// `data` contents as long as the buffer is alive. The caller promises to
// keep `data` alive and not to mutate its contents as long as the buffer is
// alive; to notify the caller that the buffer may be freed, the runtime
// will release its `buffer_reference` when the PjRtBuffer is freed. On
// will call `on_done_with_host_buffer` when the PjRtBuffer is freed. On
// non-CPU platforms this acts identically to
// kImmutableUntilTransferCompletes.
kZeroCopy,
};
// on_done_with_host_buffer is optional and may be null.
// on_done_with_host_buffer will be called iff an OK status is returned.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) = 0;
std::function<void()> on_done_with_host_buffer, PjRtDevice* device) = 0;
// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on

View File

@ -604,7 +604,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
std::function<void()> on_done_with_host_buffer, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme(
"PjRtStreamExecutorClient::BufferFromHostBuffer");
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
@ -647,19 +647,20 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
// further copies. At the time of writing we require a 16-byte alignment
// because XLA may generate code which requires it.
if (can_use_zero_copy) {
on_delete_callback = [buffer_reference{std::move(buffer_reference)}]() {
// Frees buffer_reference.
};
on_delete_callback = std::move(on_done_with_host_buffer);
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
} else {
void* staging_buffer = host_memory_allocator()->AllocateRaw(
cpu_function_runtime::kMinAlign, size);
buffer = se::DeviceMemoryBase(staging_buffer, size);
std::memcpy(staging_buffer, data, size);
if (on_done_with_host_buffer) {
on_done_with_host_buffer();
}
on_delete_callback = [staging_buffer, host_memory_allocator =
host_memory_allocator()]() {
host_memory_allocator->DeallocateRaw(staging_buffer);
};
buffer = se::DeviceMemoryBase(staging_buffer, size);
std::memcpy(staging_buffer, data, size);
}
absl::Span<const std::shared_ptr<BufferSequencingEvent>>
definition_events;
@ -702,7 +703,10 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
// thread.
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
std::memcpy(staging_buffer.get(), data, size);
buffer_reference.reset();
if (on_done_with_host_buffer) {
on_done_with_host_buffer();
on_done_with_host_buffer = nullptr;
}
data = nullptr;
}
@ -718,7 +722,8 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
py_buffer{py_buffer.get()}, compact_shape,
on_device_shape{py_buffer->on_device_shape()},
staging_buffer{std::move(staging_buffer)},
buffer_reference{std::move(buffer_reference)},
on_done_with_host_buffer{
std::move(on_done_with_host_buffer)},
host_buffer_semantics]() {
PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
@ -756,9 +761,16 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
local_device, std::move(device_buffer), event,
local_device->host_to_device_stream()));
local_device->ThenRelease(
local_device->host_to_device_stream(),
std::make_pair(std::move(buffer_reference), std::move(staging_buffer)));
local_device->callback_stream()->ThenWaitFor(
local_device->host_to_device_stream());
local_device->ThenExecuteOnCallbackThread(
local_device->callback_stream(),
[staging_buffer{std::move(staging_buffer)},
on_done_with_host_buffer{std::move(on_done_with_host_buffer)}]() {
if (on_done_with_host_buffer) {
on_done_with_host_buffer();
}
});
};
if (is_cpu_platform) {
// Using the thread_pool would be a double thread hop; the code
@ -1306,7 +1318,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtStreamExecutorBuffer::CopyToDevice(
return dst_device->client()->BufferFromHostBuffer(
literal_pointer->untyped_data(), literal_pointer->shape(),
PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy,
std::move(literal), dst_device);
[literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
}
TF_ASSIGN_OR_RETURN(

View File

@ -187,7 +187,8 @@ class PjRtStreamExecutorClient : public PjRtClient {
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
std::function<void()> on_done_with_host_buffer,
PjRtDevice* device) override;
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device) override;

View File

@ -123,15 +123,23 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
py::cast<std::string>(py::repr(argument)));
}
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
GlobalPyRefManager()->ManageReference(std::move(c->array));
std::function<void()> on_done_with_host_buffer;
if (host_buffer_semantics !=
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) {
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
GlobalPyRefManager()->ManageReference(std::move(c->array));
on_done_with_host_buffer =
[py_buffer_ref{
std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ };
}
std::unique_ptr<PjRtBuffer> buffer;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(buffer, pjrt_client_->BufferFromHostBuffer(
c->buf_ptr, c->shape, host_buffer_semantics,
std::move(py_buffer_ref), device));
TF_ASSIGN_OR_RETURN(buffer,
pjrt_client_->BufferFromHostBuffer(
c->buf_ptr, c->shape, host_buffer_semantics,
std::move(on_done_with_host_buffer), device));
}
return buffer;
}