Add TpuStream::EnqueueTransferHostToDevice and TpuStream::EnqueueTransferDeviceToHost
Also puts TpuStream into the tensorflow::tpu namespace. PiperOrigin-RevId: 330606385 Change-Id: Ib59b879b71192732ceafc0d1b47a269c2676889f
This commit is contained in:
parent
1832b255e0
commit
297286d734
tensorflow
core/tpu
stream_executor/tpu
@ -54,6 +54,8 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
||||
TFTPU_SET_FN(executor_fn, TpuStream_Stream);
|
||||
TFTPU_SET_FN(executor_fn, TpuStream_Status);
|
||||
TFTPU_SET_FN(executor_fn, TpuStream_IsSameSharedMemoryLocation);
|
||||
TFTPU_SET_FN(executor_fn, TpuStream_EnqueueTransferHostToDevice);
|
||||
TFTPU_SET_FN(executor_fn, TpuStream_EnqueueTransferDeviceToHost);
|
||||
TFTPU_SET_FN(executor_fn, TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuEvent_New);
|
||||
|
@ -42,7 +42,7 @@ static SE_ExecutableRunOptions ToC(
|
||||
se_options.device_ordinal = options.run_options().device_ordinal();
|
||||
if (options.run_options().host_to_device_stream() != nullptr) {
|
||||
se_options.host_to_device_stream =
|
||||
static_cast<TpuStream*>(
|
||||
static_cast<tensorflow::tpu::TpuStream*>(
|
||||
options.run_options().host_to_device_stream()->implementation())
|
||||
->se_stream();
|
||||
} else {
|
||||
@ -71,7 +71,8 @@ static SE_ExecutableRunOptions ToC(
|
||||
|
||||
auto impl =
|
||||
const_cast<stream_executor::Stream*>(options.stream())->implementation();
|
||||
se_options.stream = static_cast<TpuStream*>(impl)->se_stream();
|
||||
se_options.stream =
|
||||
static_cast<tensorflow::tpu::TpuStream*>(impl)->se_stream();
|
||||
return se_options;
|
||||
}
|
||||
} // namespace ApiConverter
|
||||
|
@ -173,7 +173,7 @@ TpuExecutor::GetTimerImplementation() {
|
||||
std::unique_ptr<::stream_executor::internal::StreamInterface>
|
||||
TpuExecutor::GetStreamImplementation() {
|
||||
SE_Stream* tpu_stream = tpu::ExecutorApiFn()->TpuStream_NewFn(executor_);
|
||||
auto ptr = absl::make_unique<TpuStream>(tpu_stream);
|
||||
auto ptr = absl::make_unique<tpu::TpuStream>(tpu_stream);
|
||||
tpu_platform().mutex().lock();
|
||||
stream_map()[ptr.get()] = tpu_stream;
|
||||
tpu_platform().mutex().unlock();
|
||||
|
@ -127,6 +127,14 @@ void TpuStream_Free(SE_Stream*);
|
||||
void* TpuStream_Stream(SE_Stream*);
|
||||
bool TpuStream_Status(SE_Stream*);
|
||||
bool TpuStream_IsSameSharedMemoryLocation(SE_Stream*, SE_Stream*);
|
||||
void TpuStream_EnqueueTransferHostToDevice(SE_Stream* stream,
|
||||
SE_DeviceMemoryBase device_dst,
|
||||
void* host_src, uint64_t size,
|
||||
SE_Status* status);
|
||||
void TpuStream_EnqueueTransferDeviceToHost(SE_Stream* stream,
|
||||
SE_DeviceMemoryBase device_src,
|
||||
void* host_dst, uint64_t size,
|
||||
SE_Status* status);
|
||||
void TpuStream_TpuEnqueueOnDeviceSendRecvLocal(SE_Stream* stream,
|
||||
SE_DeviceMemoryBase send_buffer,
|
||||
SE_DeviceMemoryBase recv_buffer,
|
||||
@ -355,6 +363,8 @@ struct TfTpu_ExecutorApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Stream);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Status);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_IsSameSharedMemoryLocation);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_EnqueueTransferHostToDevice);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_EnqueueTransferDeviceToHost);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuEvent_New);
|
||||
|
@ -23,6 +23,9 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_stream_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
class TpuStream : public tensorflow::tpu::TpuStreamInterface {
|
||||
public:
|
||||
using Status = stream_executor::port::Status;
|
||||
@ -39,6 +42,26 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
|
||||
stream_, static_cast<TpuStream*>(other)->stream_);
|
||||
}
|
||||
|
||||
Status EnqueueTransferHostToDevice(
|
||||
stream_executor::DeviceMemoryBase device_dst, const void* host_src,
|
||||
uint64 size) {
|
||||
StatusHelper status;
|
||||
tensorflow::tpu::ExecutorApiFn()->TpuStream_EnqueueTransferHostToDeviceFn(
|
||||
stream_, ApiConverter::ToC(device_dst), const_cast<void*>(host_src),
|
||||
size, status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status EnqueueTransferDeviceToHost(
|
||||
stream_executor::DeviceMemoryBase device_src, void* host_dst,
|
||||
uint64 size) {
|
||||
StatusHelper status;
|
||||
tensorflow::tpu::ExecutorApiFn()->TpuStream_EnqueueTransferDeviceToHostFn(
|
||||
stream_, ApiConverter::ToC(device_src), host_dst, size,
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status EnqueueOnTpuDeviceSendRecvLocal(
|
||||
stream_executor::DeviceMemoryBase send_buffer,
|
||||
stream_executor::DeviceMemoryBase recv_buffer) override {
|
||||
@ -56,4 +79,7 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
|
||||
mutable SE_Stream* stream_;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
|
||||
|
Loading…
Reference in New Issue
Block a user