[PJRT] Remove unneeded special case in TPU client.

PiperOrigin-RevId: 351452108
Change-Id: I648790059321b8226758a086c30e26ca00940598
This commit is contained in:
Skye Wanderman-Milne 2021-01-12 14:37:31 -08:00 committed by TensorFlower Gardener
parent 4c7e5af62a
commit eaf998c965

View File

@ -68,29 +68,8 @@ Status TpuDeviceState::ThenMemcpyDeviceToDevice(
se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) { se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
auto* transfer_tpu_stream = tensorflow::down_cast<tf_tpu::TpuStream*>( auto* transfer_tpu_stream = tensorflow::down_cast<tf_tpu::TpuStream*>(
transfer_stream->implementation()); transfer_stream->implementation());
tf_tpu::TpuTopologyExternal topology = TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal(
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology(); src_buffer, dst_buffer));
// TODO(b/157179600): use device-to-device transfers when implemented instead
// of copying via host.
if (topology.version() == kTpuV4) {
LOG(WARNING)
<< "device-to-device transfers not yet implemented, copying via host";
auto* dst_tpu_stream =
tensorflow::down_cast<tf_tpu::TpuStream*>(dst_stream->implementation());
TF_RET_CHECK(src_buffer.size() == dst_buffer.size());
auto host_tmp = std::make_unique<char[]>(src_buffer.size());
TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueTransferDeviceToHost(
src_buffer, host_tmp.get(), src_buffer.size()));
dst_stream->ThenWaitFor(transfer_stream);
TF_RETURN_IF_ERROR(dst_tpu_stream->EnqueueTransferHostToDevice(
dst_buffer, host_tmp.get(), dst_buffer.size()));
transfer_stream->ThenWaitFor(dst_stream);
char* tmp = host_tmp.release();
dst_stream->ThenDoHostCallback([tmp] { delete[] tmp; });
} else {
TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal(
src_buffer, dst_buffer));
}
return Status::OK(); return Status::OK();
} }