Add TpuStream::EnqueueTransferHostToDevice and TpuStream::EnqueueTransferDeviceToHost
Also puts TpuStream into the tensorflow::tpu namespace. PiperOrigin-RevId: 330606385 Change-Id: Ib59b879b71192732ceafc0d1b47a269c2676889f
This commit is contained in:
		
							parent
							
								
									1832b255e0
								
							
						
					
					
						commit
						297286d734
					
				@ -54,6 +54,8 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
 | 
			
		||||
  TFTPU_SET_FN(executor_fn, TpuStream_Stream);
 | 
			
		||||
  TFTPU_SET_FN(executor_fn, TpuStream_Status);
 | 
			
		||||
  TFTPU_SET_FN(executor_fn, TpuStream_IsSameSharedMemoryLocation);
 | 
			
		||||
  TFTPU_SET_FN(executor_fn, TpuStream_EnqueueTransferHostToDevice);
 | 
			
		||||
  TFTPU_SET_FN(executor_fn, TpuStream_EnqueueTransferDeviceToHost);
 | 
			
		||||
  TFTPU_SET_FN(executor_fn, TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
 | 
			
		||||
 | 
			
		||||
  TFTPU_SET_FN(executor_fn, TpuEvent_New);
 | 
			
		||||
 | 
			
		||||
@ -42,7 +42,7 @@ static SE_ExecutableRunOptions ToC(
 | 
			
		||||
  se_options.device_ordinal = options.run_options().device_ordinal();
 | 
			
		||||
  if (options.run_options().host_to_device_stream() != nullptr) {
 | 
			
		||||
    se_options.host_to_device_stream =
 | 
			
		||||
        static_cast<TpuStream*>(
 | 
			
		||||
        static_cast<tensorflow::tpu::TpuStream*>(
 | 
			
		||||
            options.run_options().host_to_device_stream()->implementation())
 | 
			
		||||
            ->se_stream();
 | 
			
		||||
  } else {
 | 
			
		||||
@ -71,7 +71,8 @@ static SE_ExecutableRunOptions ToC(
 | 
			
		||||
 | 
			
		||||
  auto impl =
 | 
			
		||||
      const_cast<stream_executor::Stream*>(options.stream())->implementation();
 | 
			
		||||
  se_options.stream = static_cast<TpuStream*>(impl)->se_stream();
 | 
			
		||||
  se_options.stream =
 | 
			
		||||
      static_cast<tensorflow::tpu::TpuStream*>(impl)->se_stream();
 | 
			
		||||
  return se_options;
 | 
			
		||||
}
 | 
			
		||||
}  // namespace ApiConverter
 | 
			
		||||
 | 
			
		||||
@ -173,7 +173,7 @@ TpuExecutor::GetTimerImplementation() {
 | 
			
		||||
std::unique_ptr<::stream_executor::internal::StreamInterface>
 | 
			
		||||
TpuExecutor::GetStreamImplementation() {
 | 
			
		||||
  SE_Stream* tpu_stream = tpu::ExecutorApiFn()->TpuStream_NewFn(executor_);
 | 
			
		||||
  auto ptr = absl::make_unique<TpuStream>(tpu_stream);
 | 
			
		||||
  auto ptr = absl::make_unique<tpu::TpuStream>(tpu_stream);
 | 
			
		||||
  tpu_platform().mutex().lock();
 | 
			
		||||
  stream_map()[ptr.get()] = tpu_stream;
 | 
			
		||||
  tpu_platform().mutex().unlock();
 | 
			
		||||
 | 
			
		||||
@ -127,6 +127,14 @@ void TpuStream_Free(SE_Stream*);
 | 
			
		||||
void* TpuStream_Stream(SE_Stream*);
 | 
			
		||||
bool TpuStream_Status(SE_Stream*);
 | 
			
		||||
bool TpuStream_IsSameSharedMemoryLocation(SE_Stream*, SE_Stream*);
 | 
			
		||||
void TpuStream_EnqueueTransferHostToDevice(SE_Stream* stream,
 | 
			
		||||
                                           SE_DeviceMemoryBase device_dst,
 | 
			
		||||
                                           void* host_src, uint64_t size,
 | 
			
		||||
                                           SE_Status* status);
 | 
			
		||||
void TpuStream_EnqueueTransferDeviceToHost(SE_Stream* stream,
 | 
			
		||||
                                           SE_DeviceMemoryBase device_src,
 | 
			
		||||
                                           void* host_dst, uint64_t size,
 | 
			
		||||
                                           SE_Status* status);
 | 
			
		||||
void TpuStream_TpuEnqueueOnDeviceSendRecvLocal(SE_Stream* stream,
 | 
			
		||||
                                               SE_DeviceMemoryBase send_buffer,
 | 
			
		||||
                                               SE_DeviceMemoryBase recv_buffer,
 | 
			
		||||
@ -355,6 +363,8 @@ struct TfTpu_ExecutorApiFn {
 | 
			
		||||
  TFTPU_ADD_FN_IN_STRUCT(TpuStream_Stream);
 | 
			
		||||
  TFTPU_ADD_FN_IN_STRUCT(TpuStream_Status);
 | 
			
		||||
  TFTPU_ADD_FN_IN_STRUCT(TpuStream_IsSameSharedMemoryLocation);
 | 
			
		||||
  TFTPU_ADD_FN_IN_STRUCT(TpuStream_EnqueueTransferHostToDevice);
 | 
			
		||||
  TFTPU_ADD_FN_IN_STRUCT(TpuStream_EnqueueTransferDeviceToHost);
 | 
			
		||||
  TFTPU_ADD_FN_IN_STRUCT(TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
 | 
			
		||||
 | 
			
		||||
  TFTPU_ADD_FN_IN_STRUCT(TpuEvent_New);
 | 
			
		||||
 | 
			
		||||
@ -23,6 +23,9 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
 | 
			
		||||
#include "tensorflow/stream_executor/tpu/tpu_stream_interface.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
namespace tpu {
 | 
			
		||||
 | 
			
		||||
class TpuStream : public tensorflow::tpu::TpuStreamInterface {
 | 
			
		||||
 public:
 | 
			
		||||
  using Status = stream_executor::port::Status;
 | 
			
		||||
@ -39,6 +42,26 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
 | 
			
		||||
            stream_, static_cast<TpuStream*>(other)->stream_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Status EnqueueTransferHostToDevice(
 | 
			
		||||
      stream_executor::DeviceMemoryBase device_dst, const void* host_src,
 | 
			
		||||
      uint64 size) {
 | 
			
		||||
    StatusHelper status;
 | 
			
		||||
    tensorflow::tpu::ExecutorApiFn()->TpuStream_EnqueueTransferHostToDeviceFn(
 | 
			
		||||
        stream_, ApiConverter::ToC(device_dst), const_cast<void*>(host_src),
 | 
			
		||||
        size, status.c_status);
 | 
			
		||||
    return status.status();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Status EnqueueTransferDeviceToHost(
 | 
			
		||||
      stream_executor::DeviceMemoryBase device_src, void* host_dst,
 | 
			
		||||
      uint64 size) {
 | 
			
		||||
    StatusHelper status;
 | 
			
		||||
    tensorflow::tpu::ExecutorApiFn()->TpuStream_EnqueueTransferDeviceToHostFn(
 | 
			
		||||
        stream_, ApiConverter::ToC(device_src), host_dst, size,
 | 
			
		||||
        status.c_status);
 | 
			
		||||
    return status.status();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Status EnqueueOnTpuDeviceSendRecvLocal(
 | 
			
		||||
      stream_executor::DeviceMemoryBase send_buffer,
 | 
			
		||||
      stream_executor::DeviceMemoryBase recv_buffer) override {
 | 
			
		||||
@ -56,4 +79,7 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
 | 
			
		||||
  mutable SE_Stream* stream_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace tpu
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user