Simplify TransferToDevice and TransferFromDevice API for TPU Driver

PiperOrigin-RevId: 280319592
Change-Id: Ia63099fdea24489bc0e0a7cef55bffeb32543b46
This commit is contained in:
Frank Chen 2019-11-13 17:40:07 -08:00 committed by TensorFlower Gardener
parent f95bd6ec17
commit 9b94c27ef6
5 changed files with 52 additions and 167 deletions

View File

@ -165,8 +165,8 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
tuple_shape,
[driver, &leaves, &tuple_shape,
leaves_references](tpu_driver::BufferHandle* handle) {
auto event = driver->TransferToDevice(
leaves[0].untyped_data(), handle, tuple_shape.ToProto(), {});
auto event =
driver->TransferToDevice(leaves[0].untyped_data(), handle, {});
event->AddCallback([leaves_references](Status) {});
return event;
},
@ -188,9 +188,7 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
CreateBuffer(
indexed_shape.shape,
[driver, &leaf, &indexed_shape](tpu_driver::BufferHandle* handle) {
return driver->TransferToDevice(leaf.untyped_data(), handle,
indexed_shape.shape.ToProto(),
{});
return driver->TransferToDevice(leaf.untyped_data(), handle, {});
},
client, device_ordinal));
child_buffer_ptrs.push_back(child_buffer.get());
@ -290,14 +288,13 @@ Status PyTpuBuffer::CopyToHostAsync() {
CHECK(child_buffers_.empty());
transfer_events.push_back(client_->driver()->TransferFromDevice(
device_buffer_->handle.get(), host_value->value->untyped_data(),
host_value->value->shape().ToProto(), events));
events));
} else {
for (int i = 0; i < child_buffers_.size(); ++i) {
auto& c = child_buffers_[i];
transfer_events.push_back(client_->driver()->TransferFromDevice(
c->handle.get(),
host_value->value->untyped_data(xla::ShapeIndex({i})),
host_value->value->shape().tuple_shapes(i).ToProto(), events));
host_value->value->untyped_data(xla::ShapeIndex({i})), events));
}
}
}

View File

@ -66,22 +66,29 @@ class GrpcEvent : public Event {
class GrpcBufferHandle : public BufferHandle {
public:
explicit GrpcBufferHandle(EventId id, std::shared_ptr<GrpcEvent> event)
: id_(id), stream_(event->stream()), event_(std::move(event)) {}
explicit GrpcBufferHandle(EventId id, std::shared_ptr<GrpcEvent> event,
int64_t bytes,
std::optional<xla::ShapeProto> shape = std::nullopt)
: id_(id),
stream_(event->stream()),
event_(std::move(event)),
bytes_(bytes),
shape_(shape) {}
std::shared_ptr<Event> OnReady() override { return event_; }
int64_t size_in_bytes() override {
LOG(FATAL) << "Unimplemented.";
return 0;
}
int64_t size_in_bytes() override { return bytes_; }
EventId id() const { return id_; }
GrpcTpuStream* stream() const { return stream_; }
std::optional<xla::ShapeProto> shape() override { return shape_; }
private:
const EventId id_;
GrpcTpuStream* stream_;
std::shared_ptr<GrpcEvent> event_;
int64_t bytes_;
std::optional<xla::ShapeProto> shape_;
};
class GrpcCompiledProgramHandle : public CompiledProgramHandle {
@ -160,18 +167,9 @@ class GrpcTpuStream {
std::unique_ptr<Event> Deallocate(std::unique_ptr<BufferHandle> handle,
absl::Span<Event* const> wait_for);
std::unique_ptr<Event> TransferToDevice(const void* src, int64_t num_bytes,
BufferHandle* dst,
absl::Span<Event* const> wait_for);
std::unique_ptr<Event> TransferFromDevice(const BufferHandle* src, void* dst,
int64_t num_bytes,
absl::Span<Event* const> wait_for);
std::unique_ptr<Event> TransferToDevice(const void* src, BufferHandle* dst,
const xla::ShapeProto& shape,
absl::Span<Event* const> wait_for);
std::unique_ptr<Event> TransferFromDevice(const BufferHandle* src, void* dst,
const xla::ShapeProto& shape,
absl::Span<Event* const> wait_for);
std::unique_ptr<Event> TransferFromDeviceToDevice(
@ -353,29 +351,16 @@ class GrpcTpuDriver : public TpuDriver {
}
std::unique_ptr<Event> TransferToDevice(
const void* src, int64_t num_bytes, BufferHandle* dst,
const void* src, BufferHandle* dst,
absl::Span<Event* const> wait_for) override {
auto* stream = static_cast<GrpcBufferHandle*>(dst)->stream();
return stream->TransferToDevice(src, num_bytes, dst, wait_for);
return stream->TransferToDevice(src, dst, wait_for);
}
std::unique_ptr<Event> TransferFromDevice(
const BufferHandle* src, void* dst, int64_t num_bytes,
const BufferHandle* src, void* dst,
absl::Span<Event* const> wait_for) override {
auto* stream = static_cast<const GrpcBufferHandle*>(src)->stream();
return stream->TransferFromDevice(src, dst, num_bytes, wait_for);
}
std::unique_ptr<Event> TransferToDevice(
const void* src, BufferHandle* dst, const xla::ShapeProto& shape,
absl::Span<Event* const> wait_for) override {
auto* stream = static_cast<GrpcBufferHandle*>(dst)->stream();
return stream->TransferToDevice(src, dst, shape, wait_for);
}
std::unique_ptr<Event> TransferFromDevice(
const BufferHandle* src, void* dst, const xla::ShapeProto& shape,
absl::Span<Event* const> wait_for) override {
auto* stream = static_cast<const GrpcBufferHandle*>(src)->stream();
return stream->TransferFromDevice(src, dst, shape, wait_for);
return stream->TransferFromDevice(src, dst, wait_for);
}
std::unique_ptr<Event> TransferFromDeviceToDevice(
@ -685,7 +670,8 @@ std::unique_ptr<BufferHandle> GrpcTpuStream::Allocate(
auto event =
absl::make_unique<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
AddWriteRequest(std::move(req));
return absl::make_unique<GrpcBufferHandle>(event->id(), std::move(event));
return absl::make_unique<GrpcBufferHandle>(event->id(), std::move(event),
num_bytes);
}
std::unique_ptr<BufferHandle> GrpcTpuStream::Allocate(
@ -700,7 +686,8 @@ std::unique_ptr<BufferHandle> GrpcTpuStream::Allocate(
auto event =
absl::make_unique<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
AddWriteRequest(std::move(req));
return absl::make_unique<GrpcBufferHandle>(event->id(), std::move(event));
return absl::make_unique<GrpcBufferHandle>(
event->id(), std::move(event), ComputeBytesFromShape(shape), shape);
}
std::unique_ptr<BufferHandle> GrpcTpuStream::AllocateTuple(
@ -719,7 +706,7 @@ std::unique_ptr<BufferHandle> GrpcTpuStream::AllocateTuple(
auto event =
absl::make_unique<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
AddWriteRequest(std::move(req));
return absl::make_unique<GrpcBufferHandle>(event->id(), std::move(event));
return absl::make_unique<GrpcBufferHandle>(event->id(), std::move(event), 0);
}
std::unique_ptr<Event> GrpcTpuStream::Deallocate(
@ -736,13 +723,12 @@ std::unique_ptr<Event> GrpcTpuStream::Deallocate(
}
std::unique_ptr<Event> GrpcTpuStream::TransferToDevice(
const void* src, int64_t num_bytes, BufferHandle* dst,
absl::Span<Event* const> wait_for) {
const void* src, BufferHandle* dst, absl::Span<Event* const> wait_for) {
auto req = absl::make_unique<StreamRequest::Entry>();
InitializeRequest(req.get(), wait_for);
TraceMe activity(absl::StrCat("GrpcTpuStream::TransferToDevice"));
req->mutable_transfer_to()->mutable_data()->assign(
static_cast<const char*>(src), num_bytes);
static_cast<const char*>(src), dst->size_in_bytes());
req->mutable_transfer_to()->set_target_handle(
static_cast<GrpcBufferHandle*>(dst)->id().AsInt());
auto event =
@ -752,61 +738,16 @@ std::unique_ptr<Event> GrpcTpuStream::TransferToDevice(
}
std::unique_ptr<Event> GrpcTpuStream::TransferFromDevice(
const BufferHandle* src, void* dst, int64_t num_bytes,
absl::Span<Event* const> wait_for) {
const BufferHandle* src, void* dst, absl::Span<Event* const> wait_for) {
auto req = absl::make_unique<StreamRequest::Entry>();
InitializeRequest(req.get(), wait_for);
TraceMe activity(absl::StrCat("GrpcTpuStream::TransferFromDevice"));
req->mutable_transfer_from()->set_source_handle(
static_cast<const GrpcBufferHandle*>(src)->id().AsInt());
req->mutable_transfer_from()->set_length(num_bytes);
EventId event_id = EventId::FromInt(req->operation_id());
{
absl::MutexLock lock(&transfers_mutex_);
TransferInfo info(dst, num_bytes);
transfers_.insert(std::make_pair(event_id, info));
}
auto event = absl::make_unique<GrpcEvent>(event_id, this);
AddWriteRequest(std::move(req));
return event;
}
std::unique_ptr<Event> GrpcTpuStream::TransferToDevice(
const void* src, BufferHandle* dst, const xla::ShapeProto& shape,
absl::Span<Event* const> wait_for) {
auto req = absl::make_unique<StreamRequest::Entry>();
InitializeRequest(req.get(), wait_for);
TraceMe activity(absl::StrCat("GrpcTpuStream::TransferToDevice(shape)",
req->operation_id()));
req->mutable_transfer_to()->mutable_data()->assign(
static_cast<const char*>(src), ComputeBytesFromShape(shape));
req->mutable_transfer_to()->set_target_handle(
static_cast<GrpcBufferHandle*>(dst)->id().AsInt());
*req->mutable_transfer_to()->mutable_linearize_shape() = shape;
auto event =
absl::make_unique<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
AddWriteRequest(std::move(req));
return event;
}
std::unique_ptr<Event> GrpcTpuStream::TransferFromDevice(
const BufferHandle* src, void* dst, const xla::ShapeProto& shape,
absl::Span<Event* const> wait_for) {
auto req = absl::make_unique<StreamRequest::Entry>();
InitializeRequest(req.get(), wait_for);
TraceMe activity(absl::StrCat("GrpcTpuStream::TransferFromDevice(shape)",
req->operation_id()));
int bytes_expected = ComputeBytesFromShape(shape);
req->mutable_transfer_from()->set_source_handle(
static_cast<const GrpcBufferHandle*>(src)->id().AsInt());
req->mutable_transfer_from()->set_length(bytes_expected);
*req->mutable_transfer_from()->mutable_delinearize_shape() = shape;
EventId event_id = EventId::FromInt(req->operation_id());
{
absl::MutexLock lock(&transfers_mutex_);
TransferInfo info(dst, bytes_expected);
TransferInfo info(dst, const_cast<BufferHandle*>(src)->size_in_bytes());
transfers_.insert(std::make_pair(event_id, info));
}
auto event = absl::make_unique<GrpcEvent>(event_id, this);

View File

@ -77,6 +77,7 @@ class RecordingBufferHandle : public BufferHandle {
event_(std::make_shared<RecordingEvent>(handle_->OnReady(), id_)) {}
std::shared_ptr<Event> OnReady() override { return event_; }
int64_t size_in_bytes() override { return handle_->size_in_bytes(); }
std::optional<xla::ShapeProto> shape() override { return handle_->shape(); }
private:
std::unique_ptr<BufferHandle> handle_;
@ -257,24 +258,29 @@ class RecordingTpuDriver : public TpuDriver {
}
std::unique_ptr<Event> TransferToDevice(
const void* src, int64_t num_bytes, BufferHandle* dst,
const void* src, BufferHandle* dst,
absl::Span<Event* const> wait_for) override {
int64_t num_bytes = dst->size_in_bytes();
auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
auto thread_id = GetCurrentThreadId();
int64_t recording_handle_id = static_cast<RecordingBufferHandle*>(dst)->id_;
auto recording_handle = static_cast<RecordingBufferHandle*>(dst);
int64_t recording_handle_id = recording_handle->id_;
auto recording_event =
std::make_unique<RecordingEvent>(driver_->TransferToDevice(
src, num_bytes,
static_cast<RecordingBufferHandle*>(dst)->handle_.get(),
src, static_cast<RecordingBufferHandle*>(dst)->handle_.get(),
unwrapped_wait_for));
int64_t event_id = recording_event->id_;
{
StreamRequest::Entry r;
r.mutable_transfer_to()->set_target_handle(recording_handle_id);
r.mutable_transfer_to()->mutable_data()->assign(
static_cast<const char*>(src), num_bytes);
if (num_bytes > 0) {
r.mutable_transfer_to()->mutable_data()->assign(
static_cast<const char*>(src), num_bytes);
} else {
*r.mutable_transfer_to()->mutable_data() = "";
}
PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
}
@ -282,7 +288,7 @@ class RecordingTpuDriver : public TpuDriver {
}
std::unique_ptr<Event> TransferFromDevice(
const BufferHandle* src, void* dst, int64_t num_bytes,
const BufferHandle* src, void* dst,
absl::Span<Event* const> wait_for) override {
auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
@ -291,64 +297,12 @@ class RecordingTpuDriver : public TpuDriver {
auto recording_event =
std::make_unique<RecordingEvent>(driver_->TransferFromDevice(
static_cast<const RecordingBufferHandle*>(src)->handle_.get(), dst,
num_bytes, unwrapped_wait_for));
auto event_id = recording_event->id_;
{
StreamRequest::Entry r;
r.mutable_transfer_from()->set_source_handle(src_handle_id);
r.mutable_transfer_from()->set_length(num_bytes);
PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
}
return recording_event;
}
std::unique_ptr<Event> TransferToDevice(
const void* src, BufferHandle* dst, const xla::ShapeProto& shape,
absl::Span<Event* const> wait_for) override {
auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
int64_t shape_num_bytes = ComputeBytesFromShape(shape);
auto thread_id = GetCurrentThreadId();
int64_t recording_handle_id = static_cast<RecordingBufferHandle*>(dst)->id_;
auto recording_event =
std::make_unique<RecordingEvent>(driver_->TransferToDevice(
src, static_cast<RecordingBufferHandle*>(dst)->handle_.get(), shape,
unwrapped_wait_for));
int64_t handle_id = recording_event->id_;
{
StreamRequest::Entry r;
r.mutable_transfer_to()->set_target_handle(recording_handle_id);
*r.mutable_transfer_to()->mutable_linearize_shape() = shape;
r.mutable_transfer_to()->mutable_data()->assign(
static_cast<const char*>(src), shape_num_bytes);
PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
}
return recording_event;
}
std::unique_ptr<Event> TransferFromDevice(
const BufferHandle* src, void* dst, const xla::ShapeProto& shape,
absl::Span<Event* const> wait_for) override {
auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
auto thread_id = GetCurrentThreadId();
auto src_handle_id = static_cast<const RecordingBufferHandle*>(src)->id_;
auto recording_event =
std::make_unique<RecordingEvent>(driver_->TransferFromDevice(
static_cast<const RecordingBufferHandle*>(src)->handle_.get(), dst,
shape, unwrapped_wait_for));
auto event_id = recording_event->id_;
{
StreamRequest::Entry r;
r.mutable_transfer_from()->set_source_handle(src_handle_id);
r.mutable_transfer_from()->set_length(-1);
*r.mutable_transfer_from()->mutable_delinearize_shape() = shape;
PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
}

View File

@ -20,6 +20,7 @@
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <vector>
@ -73,6 +74,7 @@ class BufferHandle {
// automatically add this event as a dependency.
virtual std::shared_ptr<Event> OnReady() = 0;
virtual int64_t size_in_bytes() = 0;
virtual std::optional<xla::ShapeProto> shape() = 0;
};
// Represents a compiled program on the host.
@ -166,6 +168,7 @@ class TpuDriver {
virtual std::unique_ptr<BufferHandle> Allocate(
int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
absl::Span<Event* const> wait_for) = 0;
// Allocate a buffer representing a tuple of `children` buffers.
//
// The returned tuple buffer handle does not manage the memory of `children`:
@ -183,13 +186,8 @@ class TpuDriver {
std::unique_ptr<BufferHandle> handle,
absl::Span<Event* const> wait_for) = 0;
virtual std::unique_ptr<Event> TransferToDevice(
const void* src, int64_t num_bytes, BufferHandle* dst,
absl::Span<Event* const> wait_for) = 0;
virtual std::unique_ptr<Event> TransferFromDevice(
const BufferHandle* src, void* dst, int64_t num_bytes,
absl::Span<Event* const> wait_for) = 0;
/* `src` must be laid out in consecutive row-major format for ingestion, and
/* For buffers declared with an xla::ShapeProto rather than a raw size,
* `src` must be laid out in consecutive row-major format for ingestion, and
* each element must take up the number of bytes specified by the type.
*
* For example, if you have a [3,3,3] tensor with a Float32 type, then the
@ -207,10 +205,10 @@ class TpuDriver {
* `TransferFromDevice` will write out the shape back in this order as well.
*/
virtual std::unique_ptr<Event> TransferToDevice(
const void* src, BufferHandle* dst, const xla::ShapeProto& shape,
const void* src, BufferHandle* dst,
absl::Span<Event* const> wait_for) = 0;
virtual std::unique_ptr<Event> TransferFromDevice(
const BufferHandle* src, void* dst, const xla::ShapeProto& shape,
const BufferHandle* src, void* dst,
absl::Span<Event* const> wait_for) = 0;
virtual std::unique_ptr<Event> TransferFromDeviceToDevice(

View File

@ -47,16 +47,11 @@ message DeallocateRequest {
message TransferToDeviceRequest {
required int64 target_handle = 1;
required bytes data = 9;
optional xla.ShapeProto linearize_shape = 10;
required bytes data = 2;
}
message TransferFromDeviceRequest {
required int64 source_handle = 1;
required int64 length = 3;
optional xla.ShapeProto delinearize_shape = 4;
}
message TransferFromDeviceResponse {