Merge pull request #39166 from kaixih:pr_vectorize_transpose
PiperOrigin-RevId: 313308844 Change-Id: I715c70e7255e1f1d5930905702fe4c265b38541a
This commit is contained in:
commit
7cfcd3aaf4
@ -210,6 +210,57 @@ __global__ void ShuffleInTensor3Simple(int nthreads,
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int kUnroll = 4;
|
||||
|
||||
template <typename T, int sp0, int sp1, int sp2, bool conjugate = false>
|
||||
__global__ void ShuffleInTensor3SimpleVector(int nthreads,
|
||||
const T* __restrict__ input,
|
||||
Dimension<3> input_dims,
|
||||
T* __restrict__ output) {
|
||||
Dimension<3> output_dims;
|
||||
output_dims[sp0] = input_dims[0];
|
||||
output_dims[sp1] = input_dims[1];
|
||||
output_dims[sp2] = input_dims[2];
|
||||
|
||||
const int stride = blockDim.x * gridDim.x * kUnroll;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
T buf[kUnroll];
|
||||
|
||||
int output_index;
|
||||
for (output_index = tid * kUnroll; output_index + kUnroll - 1 < nthreads;
|
||||
output_index += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kUnroll; i++) {
|
||||
int output_index_i = output_index + i;
|
||||
Index<3> output_tensor_index =
|
||||
FlatToTensorIndex(output_index_i, output_dims);
|
||||
Index<3> input_tensor_index;
|
||||
input_tensor_index[0] = output_tensor_index[sp0];
|
||||
input_tensor_index[1] = output_tensor_index[sp1];
|
||||
input_tensor_index[2] = output_tensor_index[sp2];
|
||||
|
||||
int input_index_i = TensorIndexToFlat(input_tensor_index, input_dims);
|
||||
buf[i] = maybe_conj<T, conjugate>::run(ldg(input + input_index_i));
|
||||
}
|
||||
float2* out = reinterpret_cast<float2*>(output + output_index);
|
||||
*out = *reinterpret_cast<float2*>(buf);
|
||||
}
|
||||
|
||||
for (; output_index < nthreads; ++output_index) {
|
||||
Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
|
||||
|
||||
Index<3> input_tensor_index;
|
||||
input_tensor_index[0] = output_tensor_index[sp0];
|
||||
input_tensor_index[1] = output_tensor_index[sp1];
|
||||
input_tensor_index[2] = output_tensor_index[sp2];
|
||||
|
||||
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
|
||||
|
||||
output[output_index] =
|
||||
maybe_conj<T, conjugate>::run(ldg(input + input_index));
|
||||
}
|
||||
}
|
||||
|
||||
// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor,
|
||||
// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
|
||||
//
|
||||
@ -1008,10 +1059,40 @@ struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> {
|
||||
static_cast<int>(combined_dims[2])};
|
||||
size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(total_size, d);
|
||||
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), config.virtual_thread_count, in,
|
||||
input_dims, out));
|
||||
|
||||
auto out_ptr = reinterpret_cast<uintptr_t>(out);
|
||||
bool aligned = out_ptr % 16 == 0;
|
||||
|
||||
bool use_vector = false;
|
||||
bool use_custom_config = false;
|
||||
if ((input_dims[0] <= 128 && input_dims[2] <= 128) ||
|
||||
input_dims[0] * input_dims[1] <= 128 ||
|
||||
input_dims[1] * input_dims[2] <= 8) {
|
||||
use_vector = true;
|
||||
use_custom_config = true;
|
||||
} else if (input_dims[1] * input_dims[2] <= 16384) {
|
||||
use_vector = true;
|
||||
}
|
||||
|
||||
if (sizeof(T) == 2 && aligned && use_vector) {
|
||||
int block_count;
|
||||
if (use_custom_config) {
|
||||
block_count = (total_size + config.thread_per_block - 1) /
|
||||
config.thread_per_block;
|
||||
} else {
|
||||
block_count = config.block_count;
|
||||
}
|
||||
|
||||
TF_CHECK_OK(
|
||||
GpuLaunchKernel(ShuffleInTensor3SimpleVector<T, 2, 1, 0, conjugate>,
|
||||
block_count, config.thread_per_block / kUnroll, 0,
|
||||
d.stream(), total_size, in, input_dims, out));
|
||||
} else {
|
||||
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>,
|
||||
config.block_count, config.thread_per_block,
|
||||
0, d.stream(), config.virtual_thread_count,
|
||||
in, input_dims, out));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user