Merge pull request #39166 from kaixih:pr_vectorize_transpose
PiperOrigin-RevId: 313308844 Change-Id: I715c70e7255e1f1d5930905702fe4c265b38541a
This commit is contained in:
commit
7cfcd3aaf4
@ -210,6 +210,57 @@ __global__ void ShuffleInTensor3Simple(int nthreads,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static constexpr int kUnroll = 4;
|
||||||
|
|
||||||
|
template <typename T, int sp0, int sp1, int sp2, bool conjugate = false>
|
||||||
|
__global__ void ShuffleInTensor3SimpleVector(int nthreads,
|
||||||
|
const T* __restrict__ input,
|
||||||
|
Dimension<3> input_dims,
|
||||||
|
T* __restrict__ output) {
|
||||||
|
Dimension<3> output_dims;
|
||||||
|
output_dims[sp0] = input_dims[0];
|
||||||
|
output_dims[sp1] = input_dims[1];
|
||||||
|
output_dims[sp2] = input_dims[2];
|
||||||
|
|
||||||
|
const int stride = blockDim.x * gridDim.x * kUnroll;
|
||||||
|
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
T buf[kUnroll];
|
||||||
|
|
||||||
|
int output_index;
|
||||||
|
for (output_index = tid * kUnroll; output_index + kUnroll - 1 < nthreads;
|
||||||
|
output_index += stride) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kUnroll; i++) {
|
||||||
|
int output_index_i = output_index + i;
|
||||||
|
Index<3> output_tensor_index =
|
||||||
|
FlatToTensorIndex(output_index_i, output_dims);
|
||||||
|
Index<3> input_tensor_index;
|
||||||
|
input_tensor_index[0] = output_tensor_index[sp0];
|
||||||
|
input_tensor_index[1] = output_tensor_index[sp1];
|
||||||
|
input_tensor_index[2] = output_tensor_index[sp2];
|
||||||
|
|
||||||
|
int input_index_i = TensorIndexToFlat(input_tensor_index, input_dims);
|
||||||
|
buf[i] = maybe_conj<T, conjugate>::run(ldg(input + input_index_i));
|
||||||
|
}
|
||||||
|
float2* out = reinterpret_cast<float2*>(output + output_index);
|
||||||
|
*out = *reinterpret_cast<float2*>(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; output_index < nthreads; ++output_index) {
|
||||||
|
Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
|
||||||
|
|
||||||
|
Index<3> input_tensor_index;
|
||||||
|
input_tensor_index[0] = output_tensor_index[sp0];
|
||||||
|
input_tensor_index[1] = output_tensor_index[sp1];
|
||||||
|
input_tensor_index[2] = output_tensor_index[sp2];
|
||||||
|
|
||||||
|
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
|
||||||
|
|
||||||
|
output[output_index] =
|
||||||
|
maybe_conj<T, conjugate>::run(ldg(input + input_index));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor,
|
// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor,
|
||||||
// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
|
// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
|
||||||
//
|
//
|
||||||
@ -1008,10 +1059,40 @@ struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> {
|
|||||||
static_cast<int>(combined_dims[2])};
|
static_cast<int>(combined_dims[2])};
|
||||||
size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
|
size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
|
||||||
GpuLaunchConfig config = GetGpuLaunchConfig(total_size, d);
|
GpuLaunchConfig config = GetGpuLaunchConfig(total_size, d);
|
||||||
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>,
|
|
||||||
config.block_count, config.thread_per_block, 0,
|
auto out_ptr = reinterpret_cast<uintptr_t>(out);
|
||||||
d.stream(), config.virtual_thread_count, in,
|
bool aligned = out_ptr % 16 == 0;
|
||||||
input_dims, out));
|
|
||||||
|
bool use_vector = false;
|
||||||
|
bool use_custom_config = false;
|
||||||
|
if ((input_dims[0] <= 128 && input_dims[2] <= 128) ||
|
||||||
|
input_dims[0] * input_dims[1] <= 128 ||
|
||||||
|
input_dims[1] * input_dims[2] <= 8) {
|
||||||
|
use_vector = true;
|
||||||
|
use_custom_config = true;
|
||||||
|
} else if (input_dims[1] * input_dims[2] <= 16384) {
|
||||||
|
use_vector = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sizeof(T) == 2 && aligned && use_vector) {
|
||||||
|
int block_count;
|
||||||
|
if (use_custom_config) {
|
||||||
|
block_count = (total_size + config.thread_per_block - 1) /
|
||||||
|
config.thread_per_block;
|
||||||
|
} else {
|
||||||
|
block_count = config.block_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CHECK_OK(
|
||||||
|
GpuLaunchKernel(ShuffleInTensor3SimpleVector<T, 2, 1, 0, conjugate>,
|
||||||
|
block_count, config.thread_per_block / kUnroll, 0,
|
||||||
|
d.stream(), total_size, in, input_dims, out));
|
||||||
|
} else {
|
||||||
|
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>,
|
||||||
|
config.block_count, config.thread_per_block,
|
||||||
|
0, d.stream(), config.virtual_thread_count,
|
||||||
|
in, input_dims, out));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user