Merge pull request #39166 from kaixih:pr_vectorize_transpose

PiperOrigin-RevId: 313308844
Change-Id: I715c70e7255e1f1d5930905702fe4c265b38541a
This commit is contained in:
TensorFlower Gardener 2020-05-26 18:50:26 -07:00
commit 7cfcd3aaf4

View File

@ -210,6 +210,57 @@ __global__ void ShuffleInTensor3Simple(int nthreads,
}
}
static constexpr int kUnroll = 4;
template <typename T, int sp0, int sp1, int sp2, bool conjugate = false>
__global__ void ShuffleInTensor3SimpleVector(int nthreads,
const T* __restrict__ input,
Dimension<3> input_dims,
T* __restrict__ output) {
Dimension<3> output_dims;
output_dims[sp0] = input_dims[0];
output_dims[sp1] = input_dims[1];
output_dims[sp2] = input_dims[2];
const int stride = blockDim.x * gridDim.x * kUnroll;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
T buf[kUnroll];
int output_index;
for (output_index = tid * kUnroll; output_index + kUnroll - 1 < nthreads;
output_index += stride) {
#pragma unroll
for (int i = 0; i < kUnroll; i++) {
int output_index_i = output_index + i;
Index<3> output_tensor_index =
FlatToTensorIndex(output_index_i, output_dims);
Index<3> input_tensor_index;
input_tensor_index[0] = output_tensor_index[sp0];
input_tensor_index[1] = output_tensor_index[sp1];
input_tensor_index[2] = output_tensor_index[sp2];
int input_index_i = TensorIndexToFlat(input_tensor_index, input_dims);
buf[i] = maybe_conj<T, conjugate>::run(ldg(input + input_index_i));
}
float2* out = reinterpret_cast<float2*>(output + output_index);
*out = *reinterpret_cast<float2*>(buf);
}
for (; output_index < nthreads; ++output_index) {
Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
Index<3> input_tensor_index;
input_tensor_index[0] = output_tensor_index[sp0];
input_tensor_index[1] = output_tensor_index[sp1];
input_tensor_index[2] = output_tensor_index[sp2];
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
output[output_index] =
maybe_conj<T, conjugate>::run(ldg(input + input_index));
}
}
// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor,
// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
//
@ -1008,10 +1059,40 @@ struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> {
static_cast<int>(combined_dims[2])};
size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
GpuLaunchConfig config = GetGpuLaunchConfig(total_size, d);
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>,
config.block_count, config.thread_per_block, 0,
d.stream(), config.virtual_thread_count, in,
input_dims, out));
auto out_ptr = reinterpret_cast<uintptr_t>(out);
bool aligned = out_ptr % 16 == 0;
bool use_vector = false;
bool use_custom_config = false;
if ((input_dims[0] <= 128 && input_dims[2] <= 128) ||
input_dims[0] * input_dims[1] <= 128 ||
input_dims[1] * input_dims[2] <= 8) {
use_vector = true;
use_custom_config = true;
} else if (input_dims[1] * input_dims[2] <= 16384) {
use_vector = true;
}
if (sizeof(T) == 2 && aligned && use_vector) {
int block_count;
if (use_custom_config) {
block_count = (total_size + config.thread_per_block - 1) /
config.thread_per_block;
} else {
block_count = config.block_count;
}
TF_CHECK_OK(
GpuLaunchKernel(ShuffleInTensor3SimpleVector<T, 2, 1, 0, conjugate>,
block_count, config.thread_per_block / kUnroll, 0,
d.stream(), total_size, in, input_dims, out));
} else {
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>,
config.block_count, config.thread_per_block,
0, d.stream(), config.virtual_thread_count,
in, input_dims, out));
}
}
};