diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.cc b/tensorflow/compiler/xla/pjrt/tpu_client.cc index 830f7c66f2a..95f581479e5 100644 --- a/tensorflow/compiler/xla/pjrt/tpu_client.cc +++ b/tensorflow/compiler/xla/pjrt/tpu_client.cc @@ -68,29 +68,8 @@ Status TpuDeviceState::ThenMemcpyDeviceToDevice( se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) { auto* transfer_tpu_stream = tensorflow::down_cast( transfer_stream->implementation()); - tf_tpu::TpuTopologyExternal topology = - tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology(); - // TODO(b/157179600): use device-to-device transfers when implemented instead - // of copying via host. - if (topology.version() == kTpuV4) { - LOG(WARNING) - << "device-to-device transfers not yet implemented, copying via host"; - auto* dst_tpu_stream = - tensorflow::down_cast(dst_stream->implementation()); - TF_RET_CHECK(src_buffer.size() == dst_buffer.size()); - auto host_tmp = std::make_unique(src_buffer.size()); - TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueTransferDeviceToHost( - src_buffer, host_tmp.get(), src_buffer.size())); - dst_stream->ThenWaitFor(transfer_stream); - TF_RETURN_IF_ERROR(dst_tpu_stream->EnqueueTransferHostToDevice( - dst_buffer, host_tmp.get(), dst_buffer.size())); - transfer_stream->ThenWaitFor(dst_stream); - char* tmp = host_tmp.release(); - dst_stream->ThenDoHostCallback([tmp] { delete[] tmp; }); - } else { - TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal( - src_buffer, dst_buffer)); - } + TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal( + src_buffer, dst_buffer)); return Status::OK(); }