diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index e57fc48c11e..911a7b16096 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -165,8 +165,8 @@ StatusOr> PyTpuBuffer::FromLiterals( tuple_shape, [driver, &leaves, &tuple_shape, leaves_references](tpu_driver::BufferHandle* handle) { - auto event = driver->TransferToDevice( - leaves[0].untyped_data(), handle, tuple_shape.ToProto(), {}); + auto event = + driver->TransferToDevice(leaves[0].untyped_data(), handle, {}); event->AddCallback([leaves_references](Status) {}); return event; }, @@ -188,9 +188,7 @@ StatusOr> PyTpuBuffer::FromLiterals( CreateBuffer( indexed_shape.shape, [driver, &leaf, &indexed_shape](tpu_driver::BufferHandle* handle) { - return driver->TransferToDevice(leaf.untyped_data(), handle, - indexed_shape.shape.ToProto(), - {}); + return driver->TransferToDevice(leaf.untyped_data(), handle, {}); }, client, device_ordinal)); child_buffer_ptrs.push_back(child_buffer.get()); @@ -290,14 +288,13 @@ Status PyTpuBuffer::CopyToHostAsync() { CHECK(child_buffers_.empty()); transfer_events.push_back(client_->driver()->TransferFromDevice( device_buffer_->handle.get(), host_value->value->untyped_data(), - host_value->value->shape().ToProto(), events)); + events)); } else { for (int i = 0; i < child_buffers_.size(); ++i) { auto& c = child_buffers_[i]; transfer_events.push_back(client_->driver()->TransferFromDevice( c->handle.get(), - host_value->value->untyped_data(xla::ShapeIndex({i})), - host_value->value->shape().tuple_shapes(i).ToProto(), events)); + host_value->value->untyped_data(xla::ShapeIndex({i})), events)); } } } diff --git a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc index 0477dbbdbd4..591792974aa 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc @@ -66,22 +66,29 @@ class GrpcEvent : public Event { class GrpcBufferHandle : public BufferHandle { public: - explicit GrpcBufferHandle(EventId id, std::shared_ptr event) - : id_(id), stream_(event->stream()), event_(std::move(event)) {} + explicit GrpcBufferHandle(EventId id, std::shared_ptr event, + int64_t bytes, + std::optional shape = std::nullopt) + : id_(id), + stream_(event->stream()), + event_(std::move(event)), + bytes_(bytes), + shape_(shape) {} std::shared_ptr OnReady() override { return event_; } - int64_t size_in_bytes() override { - LOG(FATAL) << "Unimplemented."; - return 0; - } + int64_t size_in_bytes() override { return bytes_; } EventId id() const { return id_; } GrpcTpuStream* stream() const { return stream_; } + std::optional shape() override { return shape_; } + private: const EventId id_; GrpcTpuStream* stream_; std::shared_ptr event_; + int64_t bytes_; + std::optional shape_; }; class GrpcCompiledProgramHandle : public CompiledProgramHandle { @@ -160,18 +167,9 @@ class GrpcTpuStream { std::unique_ptr Deallocate(std::unique_ptr handle, absl::Span wait_for); - std::unique_ptr TransferToDevice(const void* src, int64_t num_bytes, - BufferHandle* dst, - absl::Span wait_for); - std::unique_ptr TransferFromDevice(const BufferHandle* src, void* dst, - int64_t num_bytes, - absl::Span wait_for); - std::unique_ptr TransferToDevice(const void* src, BufferHandle* dst, - const xla::ShapeProto& shape, absl::Span wait_for); std::unique_ptr TransferFromDevice(const BufferHandle* src, void* dst, - const xla::ShapeProto& shape, absl::Span wait_for); std::unique_ptr TransferFromDeviceToDevice( @@ -353,29 +351,16 @@ class GrpcTpuDriver : public TpuDriver { } std::unique_ptr TransferToDevice( - const void* src, int64_t num_bytes, BufferHandle* dst, + const void* src, BufferHandle* dst, absl::Span wait_for) override { auto* stream = static_cast(dst)->stream(); - return stream->TransferToDevice(src, num_bytes, dst, wait_for); + return stream->TransferToDevice(src, dst, wait_for); } std::unique_ptr TransferFromDevice( - const BufferHandle* src, void* dst, int64_t num_bytes, + const BufferHandle* src, void* dst, absl::Span wait_for) override { auto* stream = static_cast(src)->stream(); - return stream->TransferFromDevice(src, dst, num_bytes, wait_for); - } - - std::unique_ptr TransferToDevice( - const void* src, BufferHandle* dst, const xla::ShapeProto& shape, - absl::Span wait_for) override { - auto* stream = static_cast(dst)->stream(); - return stream->TransferToDevice(src, dst, shape, wait_for); - } - std::unique_ptr TransferFromDevice( - const BufferHandle* src, void* dst, const xla::ShapeProto& shape, - absl::Span wait_for) override { - auto* stream = static_cast(src)->stream(); - return stream->TransferFromDevice(src, dst, shape, wait_for); + return stream->TransferFromDevice(src, dst, wait_for); } std::unique_ptr TransferFromDeviceToDevice( @@ -685,7 +670,8 @@ std::unique_ptr GrpcTpuStream::Allocate( auto event = absl::make_unique(EventId::FromInt(req->operation_id()), this); AddWriteRequest(std::move(req)); - return absl::make_unique(event->id(), std::move(event)); + return absl::make_unique(event->id(), std::move(event), + num_bytes); } std::unique_ptr GrpcTpuStream::Allocate( @@ -700,7 +686,8 @@ std::unique_ptr GrpcTpuStream::Allocate( auto event = absl::make_unique(EventId::FromInt(req->operation_id()), this); AddWriteRequest(std::move(req)); - return absl::make_unique(event->id(), std::move(event)); + return absl::make_unique( + event->id(), std::move(event), ComputeBytesFromShape(shape), shape); } std::unique_ptr GrpcTpuStream::AllocateTuple( @@ -719,7 +706,7 @@ std::unique_ptr GrpcTpuStream::AllocateTuple( auto event = absl::make_unique(EventId::FromInt(req->operation_id()), this); AddWriteRequest(std::move(req)); - return absl::make_unique(event->id(), std::move(event)); + return absl::make_unique(event->id(), std::move(event), 0); } std::unique_ptr GrpcTpuStream::Deallocate( @@ -736,13 +723,12 @@ std::unique_ptr GrpcTpuStream::Deallocate( } std::unique_ptr GrpcTpuStream::TransferToDevice( - const void* src, int64_t num_bytes, BufferHandle* dst, - absl::Span wait_for) { + const void* src, BufferHandle* dst, absl::Span wait_for) { auto req = absl::make_unique(); InitializeRequest(req.get(), wait_for); TraceMe activity(absl::StrCat("GrpcTpuStream::TransferToDevice")); req->mutable_transfer_to()->mutable_data()->assign( - static_cast(src), num_bytes); + static_cast(src), dst->size_in_bytes()); req->mutable_transfer_to()->set_target_handle( static_cast(dst)->id().AsInt()); auto event = @@ -752,61 +738,16 @@ std::unique_ptr GrpcTpuStream::TransferToDevice( } std::unique_ptr GrpcTpuStream::TransferFromDevice( - const BufferHandle* src, void* dst, int64_t num_bytes, - absl::Span wait_for) { + const BufferHandle* src, void* dst, absl::Span wait_for) { auto req = absl::make_unique(); InitializeRequest(req.get(), wait_for); TraceMe activity(absl::StrCat("GrpcTpuStream::TransferFromDevice")); req->mutable_transfer_from()->set_source_handle( static_cast(src)->id().AsInt()); - req->mutable_transfer_from()->set_length(num_bytes); EventId event_id = EventId::FromInt(req->operation_id()); { absl::MutexLock lock(&transfers_mutex_); - TransferInfo info(dst, num_bytes); - transfers_.insert(std::make_pair(event_id, info)); - } - auto event = absl::make_unique(event_id, this); - AddWriteRequest(std::move(req)); - return event; -} - -std::unique_ptr GrpcTpuStream::TransferToDevice( - const void* src, BufferHandle* dst, const xla::ShapeProto& shape, - absl::Span wait_for) { - auto req = absl::make_unique(); - InitializeRequest(req.get(), wait_for); - - TraceMe activity(absl::StrCat("GrpcTpuStream::TransferToDevice(shape)", - req->operation_id())); - req->mutable_transfer_to()->mutable_data()->assign( - static_cast(src), ComputeBytesFromShape(shape)); - req->mutable_transfer_to()->set_target_handle( - static_cast(dst)->id().AsInt()); - *req->mutable_transfer_to()->mutable_linearize_shape() = shape; - auto event = - absl::make_unique(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return event; -} - -std::unique_ptr GrpcTpuStream::TransferFromDevice( - const BufferHandle* src, void* dst, const xla::ShapeProto& shape, - absl::Span wait_for) { - auto req = absl::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity(absl::StrCat("GrpcTpuStream::TransferFromDevice(shape)", - req->operation_id())); - - int bytes_expected = ComputeBytesFromShape(shape); - req->mutable_transfer_from()->set_source_handle( - static_cast(src)->id().AsInt()); - req->mutable_transfer_from()->set_length(bytes_expected); - *req->mutable_transfer_from()->mutable_delinearize_shape() = shape; - EventId event_id = EventId::FromInt(req->operation_id()); - { - absl::MutexLock lock(&transfers_mutex_); - TransferInfo info(dst, bytes_expected); + TransferInfo info(dst, const_cast(src)->size_in_bytes()); transfers_.insert(std::make_pair(event_id, info)); } auto event = absl::make_unique(event_id, this); diff --git a/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc index 0edaf428de8..d0ed5cb2bc1 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc @@ -77,6 +77,7 @@ class RecordingBufferHandle : public BufferHandle { event_(std::make_shared(handle_->OnReady(), id_)) {} std::shared_ptr OnReady() override { return event_; } int64_t size_in_bytes() override { return handle_->size_in_bytes(); } + std::optional shape() override { return handle_->shape(); } private: std::unique_ptr handle_; @@ -257,24 +258,29 @@ class RecordingTpuDriver : public TpuDriver { } std::unique_ptr TransferToDevice( - const void* src, int64_t num_bytes, BufferHandle* dst, + const void* src, BufferHandle* dst, absl::Span wait_for) override { + int64_t num_bytes = dst->size_in_bytes(); auto unwrapped_wait_for = UnwrapWaitFor(wait_for); auto thread_id = GetCurrentThreadId(); - int64_t recording_handle_id = static_cast(dst)->id_; + auto recording_handle = static_cast(dst); + int64_t recording_handle_id = recording_handle->id_; auto recording_event = std::make_unique(driver_->TransferToDevice( - src, num_bytes, - static_cast(dst)->handle_.get(), + src, static_cast(dst)->handle_.get(), unwrapped_wait_for)); int64_t event_id = recording_event->id_; { StreamRequest::Entry r; r.mutable_transfer_to()->set_target_handle(recording_handle_id); - r.mutable_transfer_to()->mutable_data()->assign( - static_cast(src), num_bytes); + if (num_bytes > 0) { + r.mutable_transfer_to()->mutable_data()->assign( + static_cast(src), num_bytes); + } else { + *r.mutable_transfer_to()->mutable_data() = ""; + } PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); } @@ -282,7 +288,7 @@ class RecordingTpuDriver : public TpuDriver { } std::unique_ptr TransferFromDevice( - const BufferHandle* src, void* dst, int64_t num_bytes, + const BufferHandle* src, void* dst, absl::Span wait_for) override { auto unwrapped_wait_for = UnwrapWaitFor(wait_for); @@ -291,64 +297,12 @@ class RecordingTpuDriver : public TpuDriver { auto recording_event = std::make_unique(driver_->TransferFromDevice( static_cast(src)->handle_.get(), dst, - num_bytes, unwrapped_wait_for)); - auto event_id = recording_event->id_; - - { - StreamRequest::Entry r; - r.mutable_transfer_from()->set_source_handle(src_handle_id); - r.mutable_transfer_from()->set_length(num_bytes); - PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); - } - - return recording_event; - } - - std::unique_ptr TransferToDevice( - const void* src, BufferHandle* dst, const xla::ShapeProto& shape, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - int64_t shape_num_bytes = ComputeBytesFromShape(shape); - - auto thread_id = GetCurrentThreadId(); - int64_t recording_handle_id = static_cast(dst)->id_; - auto recording_event = - std::make_unique(driver_->TransferToDevice( - src, static_cast(dst)->handle_.get(), shape, unwrapped_wait_for)); - int64_t handle_id = recording_event->id_; - - { - StreamRequest::Entry r; - r.mutable_transfer_to()->set_target_handle(recording_handle_id); - *r.mutable_transfer_to()->mutable_linearize_shape() = shape; - r.mutable_transfer_to()->mutable_data()->assign( - static_cast(src), shape_num_bytes); - PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id); - } - - return recording_event; - } - - std::unique_ptr TransferFromDevice( - const BufferHandle* src, void* dst, const xla::ShapeProto& shape, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto src_handle_id = static_cast(src)->id_; - auto recording_event = - std::make_unique(driver_->TransferFromDevice( - static_cast(src)->handle_.get(), dst, - shape, unwrapped_wait_for)); auto event_id = recording_event->id_; { StreamRequest::Entry r; r.mutable_transfer_from()->set_source_handle(src_handle_id); - r.mutable_transfer_from()->set_length(-1); - *r.mutable_transfer_from()->mutable_delinearize_shape() = shape; PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); } diff --git a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h index 5edd93fa4a9..2a93de8b6e5 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h +++ b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -73,6 +74,7 @@ class BufferHandle { // automatically add this event as a dependency. virtual std::shared_ptr OnReady() = 0; virtual int64_t size_in_bytes() = 0; + virtual std::optional shape() = 0; }; // Represents a compiled program on the host. @@ -166,6 +168,7 @@ class TpuDriver { virtual std::unique_ptr Allocate( int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, absl::Span wait_for) = 0; + // Allocate a buffer representing a tuple of `children` buffers. // // The returned tuple buffer handle does not manage the memory of `children`: @@ -183,13 +186,8 @@ class TpuDriver { std::unique_ptr handle, absl::Span wait_for) = 0; - virtual std::unique_ptr TransferToDevice( - const void* src, int64_t num_bytes, BufferHandle* dst, - absl::Span wait_for) = 0; - virtual std::unique_ptr TransferFromDevice( - const BufferHandle* src, void* dst, int64_t num_bytes, - absl::Span wait_for) = 0; - /* `src` must be laid out in consecutive row-major format for ingestion, and + /* For buffers declared with an xla::ShapeProto rather than a raw size, + * `src` must be laid out in consecutive row-major format for ingestion, and * each element must take up the number of bytes specified by the type. * * For example, if you have a [3,3,3] tensor with a Float32 type, then the @@ -207,10 +205,10 @@ class TpuDriver { * `TransferFromDevice` will write out the shape back in this order as well. */ virtual std::unique_ptr TransferToDevice( - const void* src, BufferHandle* dst, const xla::ShapeProto& shape, + const void* src, BufferHandle* dst, absl::Span wait_for) = 0; virtual std::unique_ptr TransferFromDevice( - const BufferHandle* src, void* dst, const xla::ShapeProto& shape, + const BufferHandle* src, void* dst, absl::Span wait_for) = 0; virtual std::unique_ptr TransferFromDeviceToDevice( diff --git a/tensorflow/compiler/xla/python/tpu_driver/tpu_service.proto b/tensorflow/compiler/xla/python/tpu_driver/tpu_service.proto index f46708b5069..3b9b69e7cb4 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/tpu_service.proto +++ b/tensorflow/compiler/xla/python/tpu_driver/tpu_service.proto @@ -47,16 +47,11 @@ message DeallocateRequest { message TransferToDeviceRequest { required int64 target_handle = 1; - required bytes data = 9; - - optional xla.ShapeProto linearize_shape = 10; + required bytes data = 2; } message TransferFromDeviceRequest { required int64 source_handle = 1; - required int64 length = 3; - - optional xla.ShapeProto delinearize_shape = 4; } message TransferFromDeviceResponse {