Add TpuStream::EnqueueTransferHostToDevice and TpuStream::EnqueueTransferDeviceToHost

Also puts TpuStream into the tensorflow::tpu namespace.

PiperOrigin-RevId: 330606385
Change-Id: Ib59b879b71192732ceafc0d1b47a269c2676889f
This commit is contained in:
Skye Wanderman-Milne 2020-09-08 16:03:23 -07:00 committed by TensorFlower Gardener
parent 1832b255e0
commit 297286d734
5 changed files with 42 additions and 3 deletions

View File

@ -54,6 +54,8 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
TFTPU_SET_FN(executor_fn, TpuStream_Stream);
TFTPU_SET_FN(executor_fn, TpuStream_Status);
TFTPU_SET_FN(executor_fn, TpuStream_IsSameSharedMemoryLocation);
TFTPU_SET_FN(executor_fn, TpuStream_EnqueueTransferHostToDevice);
TFTPU_SET_FN(executor_fn, TpuStream_EnqueueTransferDeviceToHost);
TFTPU_SET_FN(executor_fn, TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
TFTPU_SET_FN(executor_fn, TpuEvent_New);

View File

@ -42,7 +42,7 @@ static SE_ExecutableRunOptions ToC(
se_options.device_ordinal = options.run_options().device_ordinal();
if (options.run_options().host_to_device_stream() != nullptr) {
se_options.host_to_device_stream =
static_cast<TpuStream*>(
static_cast<tensorflow::tpu::TpuStream*>(
options.run_options().host_to_device_stream()->implementation())
->se_stream();
} else {
@ -71,7 +71,8 @@ static SE_ExecutableRunOptions ToC(
auto impl =
const_cast<stream_executor::Stream*>(options.stream())->implementation();
se_options.stream = static_cast<TpuStream*>(impl)->se_stream();
se_options.stream =
static_cast<tensorflow::tpu::TpuStream*>(impl)->se_stream();
return se_options;
}
} // namespace ApiConverter

View File

@ -173,7 +173,7 @@ TpuExecutor::GetTimerImplementation() {
std::unique_ptr<::stream_executor::internal::StreamInterface>
TpuExecutor::GetStreamImplementation() {
SE_Stream* tpu_stream = tpu::ExecutorApiFn()->TpuStream_NewFn(executor_);
auto ptr = absl::make_unique<TpuStream>(tpu_stream);
auto ptr = absl::make_unique<tpu::TpuStream>(tpu_stream);
tpu_platform().mutex().lock();
stream_map()[ptr.get()] = tpu_stream;
tpu_platform().mutex().unlock();

View File

@ -127,6 +127,14 @@ void TpuStream_Free(SE_Stream*);
void* TpuStream_Stream(SE_Stream*);
bool TpuStream_Status(SE_Stream*);
bool TpuStream_IsSameSharedMemoryLocation(SE_Stream*, SE_Stream*);
void TpuStream_EnqueueTransferHostToDevice(SE_Stream* stream,
SE_DeviceMemoryBase device_dst,
void* host_src, uint64_t size,
SE_Status* status);
void TpuStream_EnqueueTransferDeviceToHost(SE_Stream* stream,
SE_DeviceMemoryBase device_src,
void* host_dst, uint64_t size,
SE_Status* status);
void TpuStream_TpuEnqueueOnDeviceSendRecvLocal(SE_Stream* stream,
SE_DeviceMemoryBase send_buffer,
SE_DeviceMemoryBase recv_buffer,
@ -355,6 +363,8 @@ struct TfTpu_ExecutorApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Stream);
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Status);
TFTPU_ADD_FN_IN_STRUCT(TpuStream_IsSameSharedMemoryLocation);
TFTPU_ADD_FN_IN_STRUCT(TpuStream_EnqueueTransferHostToDevice);
TFTPU_ADD_FN_IN_STRUCT(TpuStream_EnqueueTransferDeviceToHost);
TFTPU_ADD_FN_IN_STRUCT(TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
TFTPU_ADD_FN_IN_STRUCT(TpuEvent_New);

View File

@ -23,6 +23,9 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_stream_interface.h"
namespace tensorflow {
namespace tpu {
class TpuStream : public tensorflow::tpu::TpuStreamInterface {
public:
using Status = stream_executor::port::Status;
@ -39,6 +42,26 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
stream_, static_cast<TpuStream*>(other)->stream_);
}
Status EnqueueTransferHostToDevice(
stream_executor::DeviceMemoryBase device_dst, const void* host_src,
uint64 size) {
StatusHelper status;
tensorflow::tpu::ExecutorApiFn()->TpuStream_EnqueueTransferHostToDeviceFn(
stream_, ApiConverter::ToC(device_dst), const_cast<void*>(host_src),
size, status.c_status);
return status.status();
}
Status EnqueueTransferDeviceToHost(
stream_executor::DeviceMemoryBase device_src, void* host_dst,
uint64 size) {
StatusHelper status;
tensorflow::tpu::ExecutorApiFn()->TpuStream_EnqueueTransferDeviceToHostFn(
stream_, ApiConverter::ToC(device_src), host_dst, size,
status.c_status);
return status.status();
}
Status EnqueueOnTpuDeviceSendRecvLocal(
stream_executor::DeviceMemoryBase send_buffer,
stream_executor::DeviceMemoryBase recv_buffer) override {
@ -56,4 +79,7 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
mutable SE_Stream* stream_;
};
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_