[PJRT] Remove unneeded special case in TPU client.
PiperOrigin-RevId: 351452108 Change-Id: I648790059321b8226758a086c30e26ca00940598
This commit is contained in:
parent
4c7e5af62a
commit
eaf998c965
@ -68,29 +68,8 @@ Status TpuDeviceState::ThenMemcpyDeviceToDevice(
|
||||
se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
|
||||
auto* transfer_tpu_stream = tensorflow::down_cast<tf_tpu::TpuStream*>(
|
||||
transfer_stream->implementation());
|
||||
tf_tpu::TpuTopologyExternal topology =
|
||||
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
|
||||
// TODO(b/157179600): use device-to-device transfers when implemented instead
|
||||
// of copying via host.
|
||||
if (topology.version() == kTpuV4) {
|
||||
LOG(WARNING)
|
||||
<< "device-to-device transfers not yet implemented, copying via host";
|
||||
auto* dst_tpu_stream =
|
||||
tensorflow::down_cast<tf_tpu::TpuStream*>(dst_stream->implementation());
|
||||
TF_RET_CHECK(src_buffer.size() == dst_buffer.size());
|
||||
auto host_tmp = std::make_unique<char[]>(src_buffer.size());
|
||||
TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueTransferDeviceToHost(
|
||||
src_buffer, host_tmp.get(), src_buffer.size()));
|
||||
dst_stream->ThenWaitFor(transfer_stream);
|
||||
TF_RETURN_IF_ERROR(dst_tpu_stream->EnqueueTransferHostToDevice(
|
||||
dst_buffer, host_tmp.get(), dst_buffer.size()));
|
||||
transfer_stream->ThenWaitFor(dst_stream);
|
||||
char* tmp = host_tmp.release();
|
||||
dst_stream->ThenDoHostCallback([tmp] { delete[] tmp; });
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal(
|
||||
src_buffer, dst_buffer));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal(
|
||||
src_buffer, dst_buffer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user