Add TpuStream::EnqueueTransferHostToDevice and TpuStream::EnqueueTransferDeviceToHost
Also puts TpuStream into the tensorflow::tpu namespace. PiperOrigin-RevId: 330606385 Change-Id: Ib59b879b71192732ceafc0d1b47a269c2676889f
This commit is contained in:
parent
1832b255e0
commit
297286d734
@ -54,6 +54,8 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
|||||||
TFTPU_SET_FN(executor_fn, TpuStream_Stream);
|
TFTPU_SET_FN(executor_fn, TpuStream_Stream);
|
||||||
TFTPU_SET_FN(executor_fn, TpuStream_Status);
|
TFTPU_SET_FN(executor_fn, TpuStream_Status);
|
||||||
TFTPU_SET_FN(executor_fn, TpuStream_IsSameSharedMemoryLocation);
|
TFTPU_SET_FN(executor_fn, TpuStream_IsSameSharedMemoryLocation);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStream_EnqueueTransferHostToDevice);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStream_EnqueueTransferDeviceToHost);
|
||||||
TFTPU_SET_FN(executor_fn, TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
|
TFTPU_SET_FN(executor_fn, TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
|
||||||
|
|
||||||
TFTPU_SET_FN(executor_fn, TpuEvent_New);
|
TFTPU_SET_FN(executor_fn, TpuEvent_New);
|
||||||
|
|||||||
@ -42,7 +42,7 @@ static SE_ExecutableRunOptions ToC(
|
|||||||
se_options.device_ordinal = options.run_options().device_ordinal();
|
se_options.device_ordinal = options.run_options().device_ordinal();
|
||||||
if (options.run_options().host_to_device_stream() != nullptr) {
|
if (options.run_options().host_to_device_stream() != nullptr) {
|
||||||
se_options.host_to_device_stream =
|
se_options.host_to_device_stream =
|
||||||
static_cast<TpuStream*>(
|
static_cast<tensorflow::tpu::TpuStream*>(
|
||||||
options.run_options().host_to_device_stream()->implementation())
|
options.run_options().host_to_device_stream()->implementation())
|
||||||
->se_stream();
|
->se_stream();
|
||||||
} else {
|
} else {
|
||||||
@ -71,7 +71,8 @@ static SE_ExecutableRunOptions ToC(
|
|||||||
|
|
||||||
auto impl =
|
auto impl =
|
||||||
const_cast<stream_executor::Stream*>(options.stream())->implementation();
|
const_cast<stream_executor::Stream*>(options.stream())->implementation();
|
||||||
se_options.stream = static_cast<TpuStream*>(impl)->se_stream();
|
se_options.stream =
|
||||||
|
static_cast<tensorflow::tpu::TpuStream*>(impl)->se_stream();
|
||||||
return se_options;
|
return se_options;
|
||||||
}
|
}
|
||||||
} // namespace ApiConverter
|
} // namespace ApiConverter
|
||||||
|
|||||||
@ -173,7 +173,7 @@ TpuExecutor::GetTimerImplementation() {
|
|||||||
std::unique_ptr<::stream_executor::internal::StreamInterface>
|
std::unique_ptr<::stream_executor::internal::StreamInterface>
|
||||||
TpuExecutor::GetStreamImplementation() {
|
TpuExecutor::GetStreamImplementation() {
|
||||||
SE_Stream* tpu_stream = tpu::ExecutorApiFn()->TpuStream_NewFn(executor_);
|
SE_Stream* tpu_stream = tpu::ExecutorApiFn()->TpuStream_NewFn(executor_);
|
||||||
auto ptr = absl::make_unique<TpuStream>(tpu_stream);
|
auto ptr = absl::make_unique<tpu::TpuStream>(tpu_stream);
|
||||||
tpu_platform().mutex().lock();
|
tpu_platform().mutex().lock();
|
||||||
stream_map()[ptr.get()] = tpu_stream;
|
stream_map()[ptr.get()] = tpu_stream;
|
||||||
tpu_platform().mutex().unlock();
|
tpu_platform().mutex().unlock();
|
||||||
|
|||||||
@ -127,6 +127,14 @@ void TpuStream_Free(SE_Stream*);
|
|||||||
void* TpuStream_Stream(SE_Stream*);
|
void* TpuStream_Stream(SE_Stream*);
|
||||||
bool TpuStream_Status(SE_Stream*);
|
bool TpuStream_Status(SE_Stream*);
|
||||||
bool TpuStream_IsSameSharedMemoryLocation(SE_Stream*, SE_Stream*);
|
bool TpuStream_IsSameSharedMemoryLocation(SE_Stream*, SE_Stream*);
|
||||||
|
void TpuStream_EnqueueTransferHostToDevice(SE_Stream* stream,
|
||||||
|
SE_DeviceMemoryBase device_dst,
|
||||||
|
void* host_src, uint64_t size,
|
||||||
|
SE_Status* status);
|
||||||
|
void TpuStream_EnqueueTransferDeviceToHost(SE_Stream* stream,
|
||||||
|
SE_DeviceMemoryBase device_src,
|
||||||
|
void* host_dst, uint64_t size,
|
||||||
|
SE_Status* status);
|
||||||
void TpuStream_TpuEnqueueOnDeviceSendRecvLocal(SE_Stream* stream,
|
void TpuStream_TpuEnqueueOnDeviceSendRecvLocal(SE_Stream* stream,
|
||||||
SE_DeviceMemoryBase send_buffer,
|
SE_DeviceMemoryBase send_buffer,
|
||||||
SE_DeviceMemoryBase recv_buffer,
|
SE_DeviceMemoryBase recv_buffer,
|
||||||
@ -355,6 +363,8 @@ struct TfTpu_ExecutorApiFn {
|
|||||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Stream);
|
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Stream);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Status);
|
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Status);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_IsSameSharedMemoryLocation);
|
TFTPU_ADD_FN_IN_STRUCT(TpuStream_IsSameSharedMemoryLocation);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStream_EnqueueTransferHostToDevice);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStream_EnqueueTransferDeviceToHost);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
|
TFTPU_ADD_FN_IN_STRUCT(TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
|
||||||
|
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuEvent_New);
|
TFTPU_ADD_FN_IN_STRUCT(TpuEvent_New);
|
||||||
|
|||||||
@ -23,6 +23,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_stream_interface.h"
|
#include "tensorflow/stream_executor/tpu/tpu_stream_interface.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace tpu {
|
||||||
|
|
||||||
class TpuStream : public tensorflow::tpu::TpuStreamInterface {
|
class TpuStream : public tensorflow::tpu::TpuStreamInterface {
|
||||||
public:
|
public:
|
||||||
using Status = stream_executor::port::Status;
|
using Status = stream_executor::port::Status;
|
||||||
@ -39,6 +42,26 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
|
|||||||
stream_, static_cast<TpuStream*>(other)->stream_);
|
stream_, static_cast<TpuStream*>(other)->stream_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status EnqueueTransferHostToDevice(
|
||||||
|
stream_executor::DeviceMemoryBase device_dst, const void* host_src,
|
||||||
|
uint64 size) {
|
||||||
|
StatusHelper status;
|
||||||
|
tensorflow::tpu::ExecutorApiFn()->TpuStream_EnqueueTransferHostToDeviceFn(
|
||||||
|
stream_, ApiConverter::ToC(device_dst), const_cast<void*>(host_src),
|
||||||
|
size, status.c_status);
|
||||||
|
return status.status();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status EnqueueTransferDeviceToHost(
|
||||||
|
stream_executor::DeviceMemoryBase device_src, void* host_dst,
|
||||||
|
uint64 size) {
|
||||||
|
StatusHelper status;
|
||||||
|
tensorflow::tpu::ExecutorApiFn()->TpuStream_EnqueueTransferDeviceToHostFn(
|
||||||
|
stream_, ApiConverter::ToC(device_src), host_dst, size,
|
||||||
|
status.c_status);
|
||||||
|
return status.status();
|
||||||
|
}
|
||||||
|
|
||||||
Status EnqueueOnTpuDeviceSendRecvLocal(
|
Status EnqueueOnTpuDeviceSendRecvLocal(
|
||||||
stream_executor::DeviceMemoryBase send_buffer,
|
stream_executor::DeviceMemoryBase send_buffer,
|
||||||
stream_executor::DeviceMemoryBase recv_buffer) override {
|
stream_executor::DeviceMemoryBase recv_buffer) override {
|
||||||
@ -56,4 +79,7 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
|
|||||||
mutable SE_Stream* stream_;
|
mutable SE_Stream* stream_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace tpu
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
|
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user