|
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
|
|
|
|
#ifndef TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_
|
|
|
|
|
#define TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_
|
|
|
|
|
|
|
|
|
|
#if GOOGLE_CUDA
|
|
|
|
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
|
|
|
|
|
|
|
|
|
#define EIGEN_USE_GPU
|
|
|
|
|
|
|
|
|
@ -25,7 +25,9 @@ limitations under the License.
|
|
|
|
|
#include <limits>
|
|
|
|
|
#include <utility>
|
|
|
|
|
|
|
|
|
|
#if GOOGLE_CUDA
|
|
|
|
|
#include "third_party/gpus/cuda/include/cuda.h"
|
|
|
|
|
#endif
|
|
|
|
|
#include "tensorflow/core/framework/register_types.h"
|
|
|
|
|
#include "tensorflow/core/kernels/conv_2d.h"
|
|
|
|
|
#include "tensorflow/core/lib/math/math_util.h"
|
|
|
|
@ -49,7 +51,7 @@ struct maybe_conj {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Partial specializations for Cuda types used to store complex numbers.
|
|
|
|
|
// Partial specializations for Gpu types used to store complex numbers.
|
|
|
|
|
template <bool conjugate>
|
|
|
|
|
struct maybe_conj<float2, conjugate> {
|
|
|
|
|
__device__ static __inline__ float2 run(float2 c) {
|
|
|
|
@ -191,7 +193,7 @@ __global__ void ShuffleInTensor3Simple(int nthreads, const T* input,
|
|
|
|
|
// performance. Iterating over output will generate sequential writes and
|
|
|
|
|
// random reads that performs better compared to sequential reads and random
|
|
|
|
|
// writes.
|
|
|
|
|
CUDA_1D_KERNEL_LOOP(output_index, nthreads) {
|
|
|
|
|
GPU_1D_KERNEL_LOOP(output_index, nthreads) {
|
|
|
|
|
Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
|
|
|
|
|
|
|
|
|
|
Index<3> input_tensor_index;
|
|
|
|
@ -232,11 +234,15 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(
|
|
|
|
|
// One extra line in the inner dimension to avoid share memory bank conflict.
|
|
|
|
|
// This is to mimic the following, but no constructor of T can be invoked.
|
|
|
|
|
// __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
|
|
|
|
|
#if GOOGLE_CUDA
|
|
|
|
|
__shared__ __align__(
|
|
|
|
|
alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)];
|
|
|
|
|
typedef T(*SharedMemoryTile)[TileSizeJ + 1];
|
|
|
|
|
SharedMemoryTile shared_memory_tile =
|
|
|
|
|
reinterpret_cast<SharedMemoryTile>(shared_mem_raw);
|
|
|
|
|
#elif TENSORFLOW_USE_ROCM
|
|
|
|
|
__shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
int x = threadIdx.x;
|
|
|
|
|
|
|
|
|
@ -357,14 +363,14 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// A Cuda custom kernel that convert input to output, given proper padding on
|
|
|
|
|
// A Gpu custom kernel that convert input to output, given proper padding on
|
|
|
|
|
// the left and the top. The padded value is zero.
|
|
|
|
|
template <typename T, int NDIMS>
|
|
|
|
|
__global__ void PadInputCustomKernelNHWC(int nthreads, const T* input,
|
|
|
|
|
Dimension<NDIMS> input_dims, T* output,
|
|
|
|
|
Dimension<NDIMS> output_dims,
|
|
|
|
|
Dimension<NDIMS - 2> padding_left) {
|
|
|
|
|
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
|
|
|
|
GPU_1D_KERNEL_LOOP(index, nthreads) {
|
|
|
|
|
int output_index = index;
|
|
|
|
|
Index<NDIMS> output_tensor_index =
|
|
|
|
|
FlatToTensorIndex(output_index, output_dims);
|
|
|
|
@ -393,7 +399,7 @@ __global__ void PadInputCustomKernelNCHW(int nthreads, const T* input,
|
|
|
|
|
Dimension<NDIMS> input_dims, T* output,
|
|
|
|
|
Dimension<NDIMS> output_dims,
|
|
|
|
|
Dimension<NDIMS - 2> padding_left) {
|
|
|
|
|
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
|
|
|
|
GPU_1D_KERNEL_LOOP(index, nthreads) {
|
|
|
|
|
int output_index = index;
|
|
|
|
|
Index<NDIMS> output_tensor_index =
|
|
|
|
|
FlatToTensorIndex(output_index, output_dims);
|
|
|
|
@ -432,19 +438,19 @@ struct TransformFilter<GPUDevice, T, int, NDIMS> {
|
|
|
|
|
}
|
|
|
|
|
combined_dims[1] = in.dimension(NDIMS - 2); // input filters
|
|
|
|
|
combined_dims[2] = in.dimension(NDIMS - 1); // output filters
|
|
|
|
|
GpuLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
|
|
|
|
|
GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
|
|
|
|
|
|
|
|
|
|
if (dst_filter_format == FORMAT_OIHW) {
|
|
|
|
|
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
|
|
|
|
|
config.block_count, config.thread_per_block,
|
|
|
|
|
0, d.stream(), config.virtual_thread_count,
|
|
|
|
|
in.data(), combined_dims, out.data()));
|
|
|
|
|
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
|
|
|
|
|
config.block_count, config.thread_per_block,
|
|
|
|
|
0, d.stream(), config.virtual_thread_count,
|
|
|
|
|
in.data(), combined_dims, out.data()));
|
|
|
|
|
|
|
|
|
|
} else if (dst_filter_format == FORMAT_OHWI) {
|
|
|
|
|
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 1, 2, 0>,
|
|
|
|
|
config.block_count, config.thread_per_block,
|
|
|
|
|
0, d.stream(), config.virtual_thread_count,
|
|
|
|
|
in.data(), combined_dims, out.data()));
|
|
|
|
|
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 1, 2, 0>,
|
|
|
|
|
config.block_count, config.thread_per_block,
|
|
|
|
|
0, d.stream(), config.virtual_thread_count,
|
|
|
|
|
in.data(), combined_dims, out.data()));
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "Unsupported filter format: "
|
|
|
|
@ -471,11 +477,11 @@ struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
|
|
|
|
|
combined_dims[2] *= in.dimension(i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GpuLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
|
|
|
|
|
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
|
|
|
|
|
config.block_count, config.thread_per_block,
|
|
|
|
|
0, d.stream(), config.virtual_thread_count,
|
|
|
|
|
in.data(), combined_dims, out.data()));
|
|
|
|
|
GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
|
|
|
|
|
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
|
|
|
|
|
config.block_count, config.thread_per_block,
|
|
|
|
|
0, d.stream(), config.virtual_thread_count,
|
|
|
|
|
in.data(), combined_dims, out.data()));
|
|
|
|
|
|
|
|
|
|
} else if (src_filter_format == FORMAT_OHWI) {
|
|
|
|
|
combined_dims[0] = in.dimension(0); // output filters
|
|
|
|
@ -485,11 +491,11 @@ struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
|
|
|
|
|
}
|
|
|
|
|
combined_dims[2] = in.dimension(NDIMS - 1); // input filters
|
|
|
|
|
|
|
|
|
|
GpuLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
|
|
|
|
|
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 2, 0, 1>,
|
|
|
|
|
config.block_count, config.thread_per_block,
|
|
|
|
|
0, d.stream(), config.virtual_thread_count,
|
|
|
|
|
in.data(), combined_dims, out.data()));
|
|
|
|
|
GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
|
|
|
|
|
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 0, 1>,
|
|
|
|
|
config.block_count, config.thread_per_block,
|
|
|
|
|
0, d.stream(), config.virtual_thread_count,
|
|
|
|
|
in.data(), combined_dims, out.data()));
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
// TODO(ezhulenev): Set error status in OpKernelContext instead.
|
|
|
|
@ -510,7 +516,7 @@ struct PadInput<GPUDevice, T, int, NDIMS> {
|
|
|
|
|
const std::array<int, NDIMS - 2>& padding_right,
|
|
|
|
|
typename TTypes<T, NDIMS, int>::Tensor out,
|
|
|
|
|
TensorFormat format) {
|
|
|
|
|
GpuLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
|
|
|
|
|
GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
|
|
|
|
|
Dimension<NDIMS> input_dims;
|
|
|
|
|
for (int i = 0; i < NDIMS; ++i) {
|
|
|
|
|
input_dims[i] = in.dimension(i);
|
|
|
|
@ -523,12 +529,12 @@ struct PadInput<GPUDevice, T, int, NDIMS> {
|
|
|
|
|
const Dimension<NDIMS - 2> padding_left_dim(padding_left);
|
|
|
|
|
|
|
|
|
|
if (format == FORMAT_NHWC) {
|
|
|
|
|
TF_CHECK_OK(CudaLaunchKernel(
|
|
|
|
|
TF_CHECK_OK(GpuLaunchKernel(
|
|
|
|
|
PadInputCustomKernelNHWC<T, NDIMS>, config.block_count,
|
|
|
|
|
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
|
|
|
|
|
in.data(), input_dims, out.data(), output_dims, padding_left_dim));
|
|
|
|
|
} else if (format == FORMAT_NCHW) {
|
|
|
|
|
TF_CHECK_OK(CudaLaunchKernel(
|
|
|
|
|
TF_CHECK_OK(GpuLaunchKernel(
|
|
|
|
|
PadInputCustomKernelNCHW<T, NDIMS>, config.block_count,
|
|
|
|
|
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
|
|
|
|
|
in.data(), input_dims, out.data(), output_dims, padding_left_dim));
|
|
|
|
@ -640,13 +646,13 @@ void LaunchBatchNarrowMatrixTransposeKernel(
|
|
|
|
|
const T* input, const Dimension<3>& input_dims, T* output) {
|
|
|
|
|
constexpr int NumThreads = TileLongSide;
|
|
|
|
|
if (tile_size_i <= TileLongSide && tile_size_j <= TileShortSide) {
|
|
|
|
|
TF_CHECK_OK(CudaLaunchKernel(
|
|
|
|
|
TF_CHECK_OK(GpuLaunchKernel(
|
|
|
|
|
SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileLongSide,
|
|
|
|
|
TileShortSide>,
|
|
|
|
|
total_tiles_count, NumThreads, 0, d.stream(), input, input_dims,
|
|
|
|
|
output));
|
|
|
|
|
} else {
|
|
|
|
|
TF_CHECK_OK(CudaLaunchKernel(
|
|
|
|
|
TF_CHECK_OK(GpuLaunchKernel(
|
|
|
|
|
SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileShortSide,
|
|
|
|
|
TileLongSide>,
|
|
|
|
|
total_tiles_count, NumThreads, 0, d.stream(), input, input_dims,
|
|
|
|
@ -951,8 +957,7 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
|
|
|
|
|
|
|
|
|
|
int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
|
|
|
|
|
input_dims_in_tiles[2];
|
|
|
|
|
|
|
|
|
|
TF_CHECK_OK(CudaLaunchKernel(
|
|
|
|
|
TF_CHECK_OK(GpuLaunchKernel(
|
|
|
|
|
SwapDimension1And2InTensor3UsingTiles<T, kNumThreads, kTileSize,
|
|
|
|
|
kTileSize, conjugate>,
|
|
|
|
|
total_tiles_count, kNumThreads, 0, d.stream(), input, input_dims,
|
|
|
|
@ -963,11 +968,11 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
|
|
|
|
|
d, input, input_dims, output, kMinDimensionToUseTiles);
|
|
|
|
|
} else {
|
|
|
|
|
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
|
|
|
|
|
GpuLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
|
|
|
|
|
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 0, 2, 1, conjugate>,
|
|
|
|
|
config.block_count, config.thread_per_block, 0,
|
|
|
|
|
d.stream(), config.virtual_thread_count, input,
|
|
|
|
|
input_dims, output));
|
|
|
|
|
GpuLaunchConfig config = GetGpuLaunchConfig(total_element_count, d);
|
|
|
|
|
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 0, 2, 1, conjugate>,
|
|
|
|
|
config.block_count, config.thread_per_block, 0,
|
|
|
|
|
d.stream(), config.virtual_thread_count, input,
|
|
|
|
|
input_dims, output));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -996,11 +1001,11 @@ struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> {
|
|
|
|
|
static_cast<int>(combined_dims[1]),
|
|
|
|
|
static_cast<int>(combined_dims[2])};
|
|
|
|
|
size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
|
|
|
|
|
GpuLaunchConfig config = GetCudaLaunchConfig(total_size, d);
|
|
|
|
|
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>,
|
|
|
|
|
config.block_count, config.thread_per_block, 0,
|
|
|
|
|
d.stream(), config.virtual_thread_count, in,
|
|
|
|
|
input_dims, out));
|
|
|
|
|
GpuLaunchConfig config = GetGpuLaunchConfig(total_size, d);
|
|
|
|
|
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>,
|
|
|
|
|
config.block_count, config.thread_per_block, 0,
|
|
|
|
|
d.stream(), config.virtual_thread_count, in,
|
|
|
|
|
input_dims, out));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -1043,6 +1048,6 @@ struct NCHWToNHWC<GPUDevice, T, NDIMS> {
|
|
|
|
|
} // namespace functor
|
|
|
|
|
} // namespace tensorflow
|
|
|
|
|
|
|
|
|
|
#endif // GOOGLE_CUDA
|
|
|
|
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
|
|
|
|
|
|
|
|
|
#endif // TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_
|
|
|
|
|