Simplify TransferToDevice and TransferFromDevice API for TPU Driver
PiperOrigin-RevId: 280319592 Change-Id: Ia63099fdea24489bc0e0a7cef55bffeb32543b46
This commit is contained in:
parent
f95bd6ec17
commit
9b94c27ef6
@ -165,8 +165,8 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
|
||||
tuple_shape,
|
||||
[driver, &leaves, &tuple_shape,
|
||||
leaves_references](tpu_driver::BufferHandle* handle) {
|
||||
auto event = driver->TransferToDevice(
|
||||
leaves[0].untyped_data(), handle, tuple_shape.ToProto(), {});
|
||||
auto event =
|
||||
driver->TransferToDevice(leaves[0].untyped_data(), handle, {});
|
||||
event->AddCallback([leaves_references](Status) {});
|
||||
return event;
|
||||
},
|
||||
@ -188,9 +188,7 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
|
||||
CreateBuffer(
|
||||
indexed_shape.shape,
|
||||
[driver, &leaf, &indexed_shape](tpu_driver::BufferHandle* handle) {
|
||||
return driver->TransferToDevice(leaf.untyped_data(), handle,
|
||||
indexed_shape.shape.ToProto(),
|
||||
{});
|
||||
return driver->TransferToDevice(leaf.untyped_data(), handle, {});
|
||||
},
|
||||
client, device_ordinal));
|
||||
child_buffer_ptrs.push_back(child_buffer.get());
|
||||
@ -290,14 +288,13 @@ Status PyTpuBuffer::CopyToHostAsync() {
|
||||
CHECK(child_buffers_.empty());
|
||||
transfer_events.push_back(client_->driver()->TransferFromDevice(
|
||||
device_buffer_->handle.get(), host_value->value->untyped_data(),
|
||||
host_value->value->shape().ToProto(), events));
|
||||
events));
|
||||
} else {
|
||||
for (int i = 0; i < child_buffers_.size(); ++i) {
|
||||
auto& c = child_buffers_[i];
|
||||
transfer_events.push_back(client_->driver()->TransferFromDevice(
|
||||
c->handle.get(),
|
||||
host_value->value->untyped_data(xla::ShapeIndex({i})),
|
||||
host_value->value->shape().tuple_shapes(i).ToProto(), events));
|
||||
host_value->value->untyped_data(xla::ShapeIndex({i})), events));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -66,22 +66,29 @@ class GrpcEvent : public Event {
|
||||
|
||||
class GrpcBufferHandle : public BufferHandle {
|
||||
public:
|
||||
explicit GrpcBufferHandle(EventId id, std::shared_ptr<GrpcEvent> event)
|
||||
: id_(id), stream_(event->stream()), event_(std::move(event)) {}
|
||||
explicit GrpcBufferHandle(EventId id, std::shared_ptr<GrpcEvent> event,
|
||||
int64_t bytes,
|
||||
std::optional<xla::ShapeProto> shape = std::nullopt)
|
||||
: id_(id),
|
||||
stream_(event->stream()),
|
||||
event_(std::move(event)),
|
||||
bytes_(bytes),
|
||||
shape_(shape) {}
|
||||
|
||||
std::shared_ptr<Event> OnReady() override { return event_; }
|
||||
int64_t size_in_bytes() override {
|
||||
LOG(FATAL) << "Unimplemented.";
|
||||
return 0;
|
||||
}
|
||||
int64_t size_in_bytes() override { return bytes_; }
|
||||
|
||||
EventId id() const { return id_; }
|
||||
GrpcTpuStream* stream() const { return stream_; }
|
||||
|
||||
std::optional<xla::ShapeProto> shape() override { return shape_; }
|
||||
|
||||
private:
|
||||
const EventId id_;
|
||||
GrpcTpuStream* stream_;
|
||||
std::shared_ptr<GrpcEvent> event_;
|
||||
int64_t bytes_;
|
||||
std::optional<xla::ShapeProto> shape_;
|
||||
};
|
||||
|
||||
class GrpcCompiledProgramHandle : public CompiledProgramHandle {
|
||||
@ -160,18 +167,9 @@ class GrpcTpuStream {
|
||||
std::unique_ptr<Event> Deallocate(std::unique_ptr<BufferHandle> handle,
|
||||
absl::Span<Event* const> wait_for);
|
||||
|
||||
std::unique_ptr<Event> TransferToDevice(const void* src, int64_t num_bytes,
|
||||
BufferHandle* dst,
|
||||
absl::Span<Event* const> wait_for);
|
||||
std::unique_ptr<Event> TransferFromDevice(const BufferHandle* src, void* dst,
|
||||
int64_t num_bytes,
|
||||
absl::Span<Event* const> wait_for);
|
||||
|
||||
std::unique_ptr<Event> TransferToDevice(const void* src, BufferHandle* dst,
|
||||
const xla::ShapeProto& shape,
|
||||
absl::Span<Event* const> wait_for);
|
||||
std::unique_ptr<Event> TransferFromDevice(const BufferHandle* src, void* dst,
|
||||
const xla::ShapeProto& shape,
|
||||
absl::Span<Event* const> wait_for);
|
||||
|
||||
std::unique_ptr<Event> TransferFromDeviceToDevice(
|
||||
@ -353,29 +351,16 @@ class GrpcTpuDriver : public TpuDriver {
|
||||
}
|
||||
|
||||
std::unique_ptr<Event> TransferToDevice(
|
||||
const void* src, int64_t num_bytes, BufferHandle* dst,
|
||||
const void* src, BufferHandle* dst,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
auto* stream = static_cast<GrpcBufferHandle*>(dst)->stream();
|
||||
return stream->TransferToDevice(src, num_bytes, dst, wait_for);
|
||||
return stream->TransferToDevice(src, dst, wait_for);
|
||||
}
|
||||
std::unique_ptr<Event> TransferFromDevice(
|
||||
const BufferHandle* src, void* dst, int64_t num_bytes,
|
||||
const BufferHandle* src, void* dst,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
auto* stream = static_cast<const GrpcBufferHandle*>(src)->stream();
|
||||
return stream->TransferFromDevice(src, dst, num_bytes, wait_for);
|
||||
}
|
||||
|
||||
std::unique_ptr<Event> TransferToDevice(
|
||||
const void* src, BufferHandle* dst, const xla::ShapeProto& shape,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
auto* stream = static_cast<GrpcBufferHandle*>(dst)->stream();
|
||||
return stream->TransferToDevice(src, dst, shape, wait_for);
|
||||
}
|
||||
std::unique_ptr<Event> TransferFromDevice(
|
||||
const BufferHandle* src, void* dst, const xla::ShapeProto& shape,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
auto* stream = static_cast<const GrpcBufferHandle*>(src)->stream();
|
||||
return stream->TransferFromDevice(src, dst, shape, wait_for);
|
||||
return stream->TransferFromDevice(src, dst, wait_for);
|
||||
}
|
||||
|
||||
std::unique_ptr<Event> TransferFromDeviceToDevice(
|
||||
@ -685,7 +670,8 @@ std::unique_ptr<BufferHandle> GrpcTpuStream::Allocate(
|
||||
auto event =
|
||||
absl::make_unique<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
|
||||
AddWriteRequest(std::move(req));
|
||||
return absl::make_unique<GrpcBufferHandle>(event->id(), std::move(event));
|
||||
return absl::make_unique<GrpcBufferHandle>(event->id(), std::move(event),
|
||||
num_bytes);
|
||||
}
|
||||
|
||||
std::unique_ptr<BufferHandle> GrpcTpuStream::Allocate(
|
||||
@ -700,7 +686,8 @@ std::unique_ptr<BufferHandle> GrpcTpuStream::Allocate(
|
||||
auto event =
|
||||
absl::make_unique<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
|
||||
AddWriteRequest(std::move(req));
|
||||
return absl::make_unique<GrpcBufferHandle>(event->id(), std::move(event));
|
||||
return absl::make_unique<GrpcBufferHandle>(
|
||||
event->id(), std::move(event), ComputeBytesFromShape(shape), shape);
|
||||
}
|
||||
|
||||
std::unique_ptr<BufferHandle> GrpcTpuStream::AllocateTuple(
|
||||
@ -719,7 +706,7 @@ std::unique_ptr<BufferHandle> GrpcTpuStream::AllocateTuple(
|
||||
auto event =
|
||||
absl::make_unique<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
|
||||
AddWriteRequest(std::move(req));
|
||||
return absl::make_unique<GrpcBufferHandle>(event->id(), std::move(event));
|
||||
return absl::make_unique<GrpcBufferHandle>(event->id(), std::move(event), 0);
|
||||
}
|
||||
|
||||
std::unique_ptr<Event> GrpcTpuStream::Deallocate(
|
||||
@ -736,13 +723,12 @@ std::unique_ptr<Event> GrpcTpuStream::Deallocate(
|
||||
}
|
||||
|
||||
std::unique_ptr<Event> GrpcTpuStream::TransferToDevice(
|
||||
const void* src, int64_t num_bytes, BufferHandle* dst,
|
||||
absl::Span<Event* const> wait_for) {
|
||||
const void* src, BufferHandle* dst, absl::Span<Event* const> wait_for) {
|
||||
auto req = absl::make_unique<StreamRequest::Entry>();
|
||||
InitializeRequest(req.get(), wait_for);
|
||||
TraceMe activity(absl::StrCat("GrpcTpuStream::TransferToDevice"));
|
||||
req->mutable_transfer_to()->mutable_data()->assign(
|
||||
static_cast<const char*>(src), num_bytes);
|
||||
static_cast<const char*>(src), dst->size_in_bytes());
|
||||
req->mutable_transfer_to()->set_target_handle(
|
||||
static_cast<GrpcBufferHandle*>(dst)->id().AsInt());
|
||||
auto event =
|
||||
@ -752,61 +738,16 @@ std::unique_ptr<Event> GrpcTpuStream::TransferToDevice(
|
||||
}
|
||||
|
||||
std::unique_ptr<Event> GrpcTpuStream::TransferFromDevice(
|
||||
const BufferHandle* src, void* dst, int64_t num_bytes,
|
||||
absl::Span<Event* const> wait_for) {
|
||||
const BufferHandle* src, void* dst, absl::Span<Event* const> wait_for) {
|
||||
auto req = absl::make_unique<StreamRequest::Entry>();
|
||||
InitializeRequest(req.get(), wait_for);
|
||||
TraceMe activity(absl::StrCat("GrpcTpuStream::TransferFromDevice"));
|
||||
req->mutable_transfer_from()->set_source_handle(
|
||||
static_cast<const GrpcBufferHandle*>(src)->id().AsInt());
|
||||
req->mutable_transfer_from()->set_length(num_bytes);
|
||||
EventId event_id = EventId::FromInt(req->operation_id());
|
||||
{
|
||||
absl::MutexLock lock(&transfers_mutex_);
|
||||
TransferInfo info(dst, num_bytes);
|
||||
transfers_.insert(std::make_pair(event_id, info));
|
||||
}
|
||||
auto event = absl::make_unique<GrpcEvent>(event_id, this);
|
||||
AddWriteRequest(std::move(req));
|
||||
return event;
|
||||
}
|
||||
|
||||
std::unique_ptr<Event> GrpcTpuStream::TransferToDevice(
|
||||
const void* src, BufferHandle* dst, const xla::ShapeProto& shape,
|
||||
absl::Span<Event* const> wait_for) {
|
||||
auto req = absl::make_unique<StreamRequest::Entry>();
|
||||
InitializeRequest(req.get(), wait_for);
|
||||
|
||||
TraceMe activity(absl::StrCat("GrpcTpuStream::TransferToDevice(shape)",
|
||||
req->operation_id()));
|
||||
req->mutable_transfer_to()->mutable_data()->assign(
|
||||
static_cast<const char*>(src), ComputeBytesFromShape(shape));
|
||||
req->mutable_transfer_to()->set_target_handle(
|
||||
static_cast<GrpcBufferHandle*>(dst)->id().AsInt());
|
||||
*req->mutable_transfer_to()->mutable_linearize_shape() = shape;
|
||||
auto event =
|
||||
absl::make_unique<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
|
||||
AddWriteRequest(std::move(req));
|
||||
return event;
|
||||
}
|
||||
|
||||
std::unique_ptr<Event> GrpcTpuStream::TransferFromDevice(
|
||||
const BufferHandle* src, void* dst, const xla::ShapeProto& shape,
|
||||
absl::Span<Event* const> wait_for) {
|
||||
auto req = absl::make_unique<StreamRequest::Entry>();
|
||||
InitializeRequest(req.get(), wait_for);
|
||||
TraceMe activity(absl::StrCat("GrpcTpuStream::TransferFromDevice(shape)",
|
||||
req->operation_id()));
|
||||
|
||||
int bytes_expected = ComputeBytesFromShape(shape);
|
||||
req->mutable_transfer_from()->set_source_handle(
|
||||
static_cast<const GrpcBufferHandle*>(src)->id().AsInt());
|
||||
req->mutable_transfer_from()->set_length(bytes_expected);
|
||||
*req->mutable_transfer_from()->mutable_delinearize_shape() = shape;
|
||||
EventId event_id = EventId::FromInt(req->operation_id());
|
||||
{
|
||||
absl::MutexLock lock(&transfers_mutex_);
|
||||
TransferInfo info(dst, bytes_expected);
|
||||
TransferInfo info(dst, const_cast<BufferHandle*>(src)->size_in_bytes());
|
||||
transfers_.insert(std::make_pair(event_id, info));
|
||||
}
|
||||
auto event = absl::make_unique<GrpcEvent>(event_id, this);
|
||||
|
@ -77,6 +77,7 @@ class RecordingBufferHandle : public BufferHandle {
|
||||
event_(std::make_shared<RecordingEvent>(handle_->OnReady(), id_)) {}
|
||||
std::shared_ptr<Event> OnReady() override { return event_; }
|
||||
int64_t size_in_bytes() override { return handle_->size_in_bytes(); }
|
||||
std::optional<xla::ShapeProto> shape() override { return handle_->shape(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<BufferHandle> handle_;
|
||||
@ -257,24 +258,29 @@ class RecordingTpuDriver : public TpuDriver {
|
||||
}
|
||||
|
||||
std::unique_ptr<Event> TransferToDevice(
|
||||
const void* src, int64_t num_bytes, BufferHandle* dst,
|
||||
const void* src, BufferHandle* dst,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
int64_t num_bytes = dst->size_in_bytes();
|
||||
auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
|
||||
|
||||
auto thread_id = GetCurrentThreadId();
|
||||
int64_t recording_handle_id = static_cast<RecordingBufferHandle*>(dst)->id_;
|
||||
auto recording_handle = static_cast<RecordingBufferHandle*>(dst);
|
||||
int64_t recording_handle_id = recording_handle->id_;
|
||||
auto recording_event =
|
||||
std::make_unique<RecordingEvent>(driver_->TransferToDevice(
|
||||
src, num_bytes,
|
||||
static_cast<RecordingBufferHandle*>(dst)->handle_.get(),
|
||||
src, static_cast<RecordingBufferHandle*>(dst)->handle_.get(),
|
||||
unwrapped_wait_for));
|
||||
int64_t event_id = recording_event->id_;
|
||||
|
||||
{
|
||||
StreamRequest::Entry r;
|
||||
r.mutable_transfer_to()->set_target_handle(recording_handle_id);
|
||||
r.mutable_transfer_to()->mutable_data()->assign(
|
||||
static_cast<const char*>(src), num_bytes);
|
||||
if (num_bytes > 0) {
|
||||
r.mutable_transfer_to()->mutable_data()->assign(
|
||||
static_cast<const char*>(src), num_bytes);
|
||||
} else {
|
||||
*r.mutable_transfer_to()->mutable_data() = "";
|
||||
}
|
||||
PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
|
||||
}
|
||||
|
||||
@ -282,7 +288,7 @@ class RecordingTpuDriver : public TpuDriver {
|
||||
}
|
||||
|
||||
std::unique_ptr<Event> TransferFromDevice(
|
||||
const BufferHandle* src, void* dst, int64_t num_bytes,
|
||||
const BufferHandle* src, void* dst,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
|
||||
|
||||
@ -291,64 +297,12 @@ class RecordingTpuDriver : public TpuDriver {
|
||||
auto recording_event =
|
||||
std::make_unique<RecordingEvent>(driver_->TransferFromDevice(
|
||||
static_cast<const RecordingBufferHandle*>(src)->handle_.get(), dst,
|
||||
num_bytes, unwrapped_wait_for));
|
||||
auto event_id = recording_event->id_;
|
||||
|
||||
{
|
||||
StreamRequest::Entry r;
|
||||
r.mutable_transfer_from()->set_source_handle(src_handle_id);
|
||||
r.mutable_transfer_from()->set_length(num_bytes);
|
||||
PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
|
||||
}
|
||||
|
||||
return recording_event;
|
||||
}
|
||||
|
||||
std::unique_ptr<Event> TransferToDevice(
|
||||
const void* src, BufferHandle* dst, const xla::ShapeProto& shape,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
|
||||
|
||||
int64_t shape_num_bytes = ComputeBytesFromShape(shape);
|
||||
|
||||
auto thread_id = GetCurrentThreadId();
|
||||
int64_t recording_handle_id = static_cast<RecordingBufferHandle*>(dst)->id_;
|
||||
auto recording_event =
|
||||
std::make_unique<RecordingEvent>(driver_->TransferToDevice(
|
||||
src, static_cast<RecordingBufferHandle*>(dst)->handle_.get(), shape,
|
||||
unwrapped_wait_for));
|
||||
int64_t handle_id = recording_event->id_;
|
||||
|
||||
{
|
||||
StreamRequest::Entry r;
|
||||
r.mutable_transfer_to()->set_target_handle(recording_handle_id);
|
||||
*r.mutable_transfer_to()->mutable_linearize_shape() = shape;
|
||||
r.mutable_transfer_to()->mutable_data()->assign(
|
||||
static_cast<const char*>(src), shape_num_bytes);
|
||||
PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
|
||||
}
|
||||
|
||||
return recording_event;
|
||||
}
|
||||
|
||||
std::unique_ptr<Event> TransferFromDevice(
|
||||
const BufferHandle* src, void* dst, const xla::ShapeProto& shape,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
|
||||
|
||||
auto thread_id = GetCurrentThreadId();
|
||||
auto src_handle_id = static_cast<const RecordingBufferHandle*>(src)->id_;
|
||||
auto recording_event =
|
||||
std::make_unique<RecordingEvent>(driver_->TransferFromDevice(
|
||||
static_cast<const RecordingBufferHandle*>(src)->handle_.get(), dst,
|
||||
shape, unwrapped_wait_for));
|
||||
auto event_id = recording_event->id_;
|
||||
|
||||
{
|
||||
StreamRequest::Entry r;
|
||||
r.mutable_transfer_from()->set_source_handle(src_handle_id);
|
||||
r.mutable_transfer_from()->set_length(-1);
|
||||
*r.mutable_transfer_from()->mutable_delinearize_shape() = shape;
|
||||
PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -73,6 +74,7 @@ class BufferHandle {
|
||||
// automatically add this event as a dependency.
|
||||
virtual std::shared_ptr<Event> OnReady() = 0;
|
||||
virtual int64_t size_in_bytes() = 0;
|
||||
virtual std::optional<xla::ShapeProto> shape() = 0;
|
||||
};
|
||||
|
||||
// Represents a compiled program on the host.
|
||||
@ -166,6 +168,7 @@ class TpuDriver {
|
||||
virtual std::unique_ptr<BufferHandle> Allocate(
|
||||
int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
|
||||
absl::Span<Event* const> wait_for) = 0;
|
||||
|
||||
// Allocate a buffer representing a tuple of `children` buffers.
|
||||
//
|
||||
// The returned tuple buffer handle does not manage the memory of `children`:
|
||||
@ -183,13 +186,8 @@ class TpuDriver {
|
||||
std::unique_ptr<BufferHandle> handle,
|
||||
absl::Span<Event* const> wait_for) = 0;
|
||||
|
||||
virtual std::unique_ptr<Event> TransferToDevice(
|
||||
const void* src, int64_t num_bytes, BufferHandle* dst,
|
||||
absl::Span<Event* const> wait_for) = 0;
|
||||
virtual std::unique_ptr<Event> TransferFromDevice(
|
||||
const BufferHandle* src, void* dst, int64_t num_bytes,
|
||||
absl::Span<Event* const> wait_for) = 0;
|
||||
/* `src` must be laid out in consecutive row-major format for ingestion, and
|
||||
/* For buffers declared with an xla::ShapeProto rather than a raw size,
|
||||
* `src` must be laid out in consecutive row-major format for ingestion, and
|
||||
* each element must take up the number of bytes specified by the type.
|
||||
*
|
||||
* For example, if you have a [3,3,3] tensor with a Float32 type, then the
|
||||
@ -207,10 +205,10 @@ class TpuDriver {
|
||||
* `TransferFromDevice` will write out the shape back in this order as well.
|
||||
*/
|
||||
virtual std::unique_ptr<Event> TransferToDevice(
|
||||
const void* src, BufferHandle* dst, const xla::ShapeProto& shape,
|
||||
const void* src, BufferHandle* dst,
|
||||
absl::Span<Event* const> wait_for) = 0;
|
||||
virtual std::unique_ptr<Event> TransferFromDevice(
|
||||
const BufferHandle* src, void* dst, const xla::ShapeProto& shape,
|
||||
const BufferHandle* src, void* dst,
|
||||
absl::Span<Event* const> wait_for) = 0;
|
||||
|
||||
virtual std::unique_ptr<Event> TransferFromDeviceToDevice(
|
||||
|
@ -47,16 +47,11 @@ message DeallocateRequest {
|
||||
|
||||
message TransferToDeviceRequest {
|
||||
required int64 target_handle = 1;
|
||||
required bytes data = 9;
|
||||
|
||||
optional xla.ShapeProto linearize_shape = 10;
|
||||
required bytes data = 2;
|
||||
}
|
||||
|
||||
message TransferFromDeviceRequest {
|
||||
required int64 source_handle = 1;
|
||||
required int64 length = 3;
|
||||
|
||||
optional xla.ShapeProto delinearize_shape = 4;
|
||||
}
|
||||
|
||||
message TransferFromDeviceResponse {
|
||||
|
Loading…
Reference in New Issue
Block a user