From 297286d7347100975a328fb92c036e40576bf67d Mon Sep 17 00:00:00 2001
From: Skye Wanderman-Milne <skyewm@google.com>
Date: Tue, 8 Sep 2020 16:03:23 -0700
Subject: [PATCH] Add TpuStream::EnqueueTransferHostToDevice and
 TpuStream::EnqueueTransferDeviceToHost

Also puts TpuStream into the tensorflow::tpu namespace.

PiperOrigin-RevId: 330606385
Change-Id: Ib59b879b71192732ceafc0d1b47a269c2676889f
---
 tensorflow/core/tpu/tpu_executor_init_fns.inc |  2 ++
 tensorflow/core/tpu/tpu_on_demand_compiler.cc |  5 ++--
 .../stream_executor/tpu/tpu_executor.cc       |  2 +-
 .../stream_executor/tpu/tpu_executor_c_api.h  | 10 +++++++
 tensorflow/stream_executor/tpu/tpu_stream.h   | 26 +++++++++++++++++++
 5 files changed, 42 insertions(+), 3 deletions(-)

diff --git a/tensorflow/core/tpu/tpu_executor_init_fns.inc b/tensorflow/core/tpu/tpu_executor_init_fns.inc
index 4970292c499..8cc1a3c4d18 100644
--- a/tensorflow/core/tpu/tpu_executor_init_fns.inc
+++ b/tensorflow/core/tpu/tpu_executor_init_fns.inc
@@ -54,6 +54,8 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
   TFTPU_SET_FN(executor_fn, TpuStream_Stream);
   TFTPU_SET_FN(executor_fn, TpuStream_Status);
   TFTPU_SET_FN(executor_fn, TpuStream_IsSameSharedMemoryLocation);
+  TFTPU_SET_FN(executor_fn, TpuStream_EnqueueTransferHostToDevice);
+  TFTPU_SET_FN(executor_fn, TpuStream_EnqueueTransferDeviceToHost);
   TFTPU_SET_FN(executor_fn, TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
 
   TFTPU_SET_FN(executor_fn, TpuEvent_New);
diff --git a/tensorflow/core/tpu/tpu_on_demand_compiler.cc b/tensorflow/core/tpu/tpu_on_demand_compiler.cc
index f918edfcffb..89dcdcaa0a8 100644
--- a/tensorflow/core/tpu/tpu_on_demand_compiler.cc
+++ b/tensorflow/core/tpu/tpu_on_demand_compiler.cc
@@ -42,7 +42,7 @@ static SE_ExecutableRunOptions ToC(
   se_options.device_ordinal = options.run_options().device_ordinal();
   if (options.run_options().host_to_device_stream() != nullptr) {
     se_options.host_to_device_stream =
-        static_cast<TpuStream*>(
+        static_cast<tensorflow::tpu::TpuStream*>(
             options.run_options().host_to_device_stream()->implementation())
             ->se_stream();
   } else {
@@ -71,7 +71,8 @@ static SE_ExecutableRunOptions ToC(
 
   auto impl =
       const_cast<stream_executor::Stream*>(options.stream())->implementation();
-  se_options.stream = static_cast<TpuStream*>(impl)->se_stream();
+  se_options.stream =
+      static_cast<tensorflow::tpu::TpuStream*>(impl)->se_stream();
   return se_options;
 }
 }  // namespace ApiConverter
diff --git a/tensorflow/stream_executor/tpu/tpu_executor.cc b/tensorflow/stream_executor/tpu/tpu_executor.cc
index 166deb716ca..841f16ebe0e 100644
--- a/tensorflow/stream_executor/tpu/tpu_executor.cc
+++ b/tensorflow/stream_executor/tpu/tpu_executor.cc
@@ -173,7 +173,7 @@ TpuExecutor::GetTimerImplementation() {
 std::unique_ptr<::stream_executor::internal::StreamInterface>
 TpuExecutor::GetStreamImplementation() {
   SE_Stream* tpu_stream = tpu::ExecutorApiFn()->TpuStream_NewFn(executor_);
-  auto ptr = absl::make_unique<TpuStream>(tpu_stream);
+  auto ptr = absl::make_unique<tpu::TpuStream>(tpu_stream);
   tpu_platform().mutex().lock();
   stream_map()[ptr.get()] = tpu_stream;
   tpu_platform().mutex().unlock();
diff --git a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h
index b59a8f2ad08..217d4fb5738 100644
--- a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h
+++ b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h
@@ -127,6 +127,14 @@ void TpuStream_Free(SE_Stream*);
 void* TpuStream_Stream(SE_Stream*);
 bool TpuStream_Status(SE_Stream*);
 bool TpuStream_IsSameSharedMemoryLocation(SE_Stream*, SE_Stream*);
+void TpuStream_EnqueueTransferHostToDevice(SE_Stream* stream,
+                                           SE_DeviceMemoryBase device_dst,
+                                           void* host_src, uint64_t size,
+                                           SE_Status* status);
+void TpuStream_EnqueueTransferDeviceToHost(SE_Stream* stream,
+                                           SE_DeviceMemoryBase device_src,
+                                           void* host_dst, uint64_t size,
+                                           SE_Status* status);
 void TpuStream_TpuEnqueueOnDeviceSendRecvLocal(SE_Stream* stream,
                                                SE_DeviceMemoryBase send_buffer,
                                                SE_DeviceMemoryBase recv_buffer,
@@ -355,6 +363,8 @@ struct TfTpu_ExecutorApiFn {
   TFTPU_ADD_FN_IN_STRUCT(TpuStream_Stream);
   TFTPU_ADD_FN_IN_STRUCT(TpuStream_Status);
   TFTPU_ADD_FN_IN_STRUCT(TpuStream_IsSameSharedMemoryLocation);
+  TFTPU_ADD_FN_IN_STRUCT(TpuStream_EnqueueTransferHostToDevice);
+  TFTPU_ADD_FN_IN_STRUCT(TpuStream_EnqueueTransferDeviceToHost);
   TFTPU_ADD_FN_IN_STRUCT(TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
 
   TFTPU_ADD_FN_IN_STRUCT(TpuEvent_New);
diff --git a/tensorflow/stream_executor/tpu/tpu_stream.h b/tensorflow/stream_executor/tpu/tpu_stream.h
index ab84005c718..cd7637d12c6 100644
--- a/tensorflow/stream_executor/tpu/tpu_stream.h
+++ b/tensorflow/stream_executor/tpu/tpu_stream.h
@@ -23,6 +23,9 @@ limitations under the License.
 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
 #include "tensorflow/stream_executor/tpu/tpu_stream_interface.h"
 
+namespace tensorflow {
+namespace tpu {
+
 class TpuStream : public tensorflow::tpu::TpuStreamInterface {
  public:
   using Status = stream_executor::port::Status;
@@ -39,6 +42,26 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
             stream_, static_cast<TpuStream*>(other)->stream_);
   }
 
+  Status EnqueueTransferHostToDevice(
+      stream_executor::DeviceMemoryBase device_dst, const void* host_src,
+      uint64 size) {
+    StatusHelper status;
+    tensorflow::tpu::ExecutorApiFn()->TpuStream_EnqueueTransferHostToDeviceFn(
+        stream_, ApiConverter::ToC(device_dst), const_cast<void*>(host_src),
+        size, status.c_status);
+    return status.status();
+  }
+
+  Status EnqueueTransferDeviceToHost(
+      stream_executor::DeviceMemoryBase device_src, void* host_dst,
+      uint64 size) {
+    StatusHelper status;
+    tensorflow::tpu::ExecutorApiFn()->TpuStream_EnqueueTransferDeviceToHostFn(
+        stream_, ApiConverter::ToC(device_src), host_dst, size,
+        status.c_status);
+    return status.status();
+  }
+
   Status EnqueueOnTpuDeviceSendRecvLocal(
       stream_executor::DeviceMemoryBase send_buffer,
       stream_executor::DeviceMemoryBase recv_buffer) override {
@@ -56,4 +79,7 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
   mutable SE_Stream* stream_;
 };
 
+}  // namespace tpu
+}  // namespace tensorflow
+
 #endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_