diff --git a/tensorflow/core/tpu/tpu_executor_init_fns.inc b/tensorflow/core/tpu/tpu_executor_init_fns.inc index 4970292c499..8cc1a3c4d18 100644 --- a/tensorflow/core/tpu/tpu_executor_init_fns.inc +++ b/tensorflow/core/tpu/tpu_executor_init_fns.inc @@ -54,6 +54,8 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) { TFTPU_SET_FN(executor_fn, TpuStream_Stream); TFTPU_SET_FN(executor_fn, TpuStream_Status); TFTPU_SET_FN(executor_fn, TpuStream_IsSameSharedMemoryLocation); + TFTPU_SET_FN(executor_fn, TpuStream_EnqueueTransferHostToDevice); + TFTPU_SET_FN(executor_fn, TpuStream_EnqueueTransferDeviceToHost); TFTPU_SET_FN(executor_fn, TpuStream_TpuEnqueueOnDeviceSendRecvLocal); TFTPU_SET_FN(executor_fn, TpuEvent_New); diff --git a/tensorflow/core/tpu/tpu_on_demand_compiler.cc b/tensorflow/core/tpu/tpu_on_demand_compiler.cc index f918edfcffb..89dcdcaa0a8 100644 --- a/tensorflow/core/tpu/tpu_on_demand_compiler.cc +++ b/tensorflow/core/tpu/tpu_on_demand_compiler.cc @@ -42,7 +42,7 @@ static SE_ExecutableRunOptions ToC( se_options.device_ordinal = options.run_options().device_ordinal(); if (options.run_options().host_to_device_stream() != nullptr) { se_options.host_to_device_stream = - static_cast( + static_cast( options.run_options().host_to_device_stream()->implementation()) ->se_stream(); } else { @@ -71,7 +71,8 @@ static SE_ExecutableRunOptions ToC( auto impl = const_cast(options.stream())->implementation(); - se_options.stream = static_cast(impl)->se_stream(); + se_options.stream = + static_cast(impl)->se_stream(); return se_options; } } // namespace ApiConverter diff --git a/tensorflow/stream_executor/tpu/tpu_executor.cc b/tensorflow/stream_executor/tpu/tpu_executor.cc index 166deb716ca..841f16ebe0e 100644 --- a/tensorflow/stream_executor/tpu/tpu_executor.cc +++ b/tensorflow/stream_executor/tpu/tpu_executor.cc @@ -173,7 +173,7 @@ TpuExecutor::GetTimerImplementation() { std::unique_ptr<::stream_executor::internal::StreamInterface> TpuExecutor::GetStreamImplementation() { SE_Stream* tpu_stream = tpu::ExecutorApiFn()->TpuStream_NewFn(executor_); - auto ptr = absl::make_unique(tpu_stream); + auto ptr = absl::make_unique(tpu_stream); tpu_platform().mutex().lock(); stream_map()[ptr.get()] = tpu_stream; tpu_platform().mutex().unlock(); diff --git a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h index b59a8f2ad08..217d4fb5738 100644 --- a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h +++ b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h @@ -127,6 +127,14 @@ void TpuStream_Free(SE_Stream*); void* TpuStream_Stream(SE_Stream*); bool TpuStream_Status(SE_Stream*); bool TpuStream_IsSameSharedMemoryLocation(SE_Stream*, SE_Stream*); +void TpuStream_EnqueueTransferHostToDevice(SE_Stream* stream, + SE_DeviceMemoryBase device_dst, + void* host_src, uint64_t size, + SE_Status* status); +void TpuStream_EnqueueTransferDeviceToHost(SE_Stream* stream, + SE_DeviceMemoryBase device_src, + void* host_dst, uint64_t size, + SE_Status* status); void TpuStream_TpuEnqueueOnDeviceSendRecvLocal(SE_Stream* stream, SE_DeviceMemoryBase send_buffer, SE_DeviceMemoryBase recv_buffer, @@ -355,6 +363,8 @@ struct TfTpu_ExecutorApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuStream_Stream); TFTPU_ADD_FN_IN_STRUCT(TpuStream_Status); TFTPU_ADD_FN_IN_STRUCT(TpuStream_IsSameSharedMemoryLocation); + TFTPU_ADD_FN_IN_STRUCT(TpuStream_EnqueueTransferHostToDevice); + TFTPU_ADD_FN_IN_STRUCT(TpuStream_EnqueueTransferDeviceToHost); TFTPU_ADD_FN_IN_STRUCT(TpuStream_TpuEnqueueOnDeviceSendRecvLocal); TFTPU_ADD_FN_IN_STRUCT(TpuEvent_New); diff --git a/tensorflow/stream_executor/tpu/tpu_stream.h b/tensorflow/stream_executor/tpu/tpu_stream.h index ab84005c718..cd7637d12c6 100644 --- a/tensorflow/stream_executor/tpu/tpu_stream.h +++ b/tensorflow/stream_executor/tpu/tpu_stream.h @@ -23,6 +23,9 @@ limitations under the License. #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" #include "tensorflow/stream_executor/tpu/tpu_stream_interface.h" +namespace tensorflow { +namespace tpu { + class TpuStream : public tensorflow::tpu::TpuStreamInterface { public: using Status = stream_executor::port::Status; @@ -39,6 +42,26 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface { stream_, static_cast(other)->stream_); } + Status EnqueueTransferHostToDevice( + stream_executor::DeviceMemoryBase device_dst, const void* host_src, + uint64 size) { + StatusHelper status; + tensorflow::tpu::ExecutorApiFn()->TpuStream_EnqueueTransferHostToDeviceFn( + stream_, ApiConverter::ToC(device_dst), const_cast(host_src), + size, status.c_status); + return status.status(); + } + + Status EnqueueTransferDeviceToHost( + stream_executor::DeviceMemoryBase device_src, void* host_dst, + uint64 size) { + StatusHelper status; + tensorflow::tpu::ExecutorApiFn()->TpuStream_EnqueueTransferDeviceToHostFn( + stream_, ApiConverter::ToC(device_src), host_dst, size, + status.c_status); + return status.status(); + } + Status EnqueueOnTpuDeviceSendRecvLocal( stream_executor::DeviceMemoryBase send_buffer, stream_executor::DeviceMemoryBase recv_buffer) override { @@ -56,4 +79,7 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface { mutable SE_Stream* stream_; }; +} // namespace tpu +} // namespace tensorflow + #endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_