Add ROCm support for launching 2D GPU convolutions

This commit is contained in:
Wen-Heng (Jack) Chung 2019-05-30 16:01:31 +00:00
parent ef0b1eff8d
commit 82ccb9a50d
8 changed files with 62 additions and 57 deletions

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_
#define TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -25,7 +25,9 @@ limitations under the License.
#include <limits>
#include <utility>
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#endif
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/lib/math/math_util.h"
@ -49,7 +51,7 @@ struct maybe_conj {
}
};
// Partial specializations for Cuda types used to store complex numbers.
// Partial specializations for Gpu types used to store complex numbers.
template <bool conjugate>
struct maybe_conj<float2, conjugate> {
__device__ static __inline__ float2 run(float2 c) {
@ -191,7 +193,7 @@ __global__ void ShuffleInTensor3Simple(int nthreads, const T* input,
// performance. Iterating over output will generate sequential writes and
// random reads that performs better compared to sequential reads and random
// writes.
CUDA_1D_KERNEL_LOOP(output_index, nthreads) {
GPU_1D_KERNEL_LOOP(output_index, nthreads) {
Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
Index<3> input_tensor_index;
@ -232,11 +234,15 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(
// One extra line in the inner dimension to avoid share memory bank conflict.
// This is to mimic the following, but no constructor of T can be invoked.
// __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
#if GOOGLE_CUDA
__shared__ __align__(
alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)];
typedef T(*SharedMemoryTile)[TileSizeJ + 1];
SharedMemoryTile shared_memory_tile =
reinterpret_cast<SharedMemoryTile>(shared_mem_raw);
#elif TENSORFLOW_USE_ROCM
__shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
#endif
int x = threadIdx.x;
@ -357,14 +363,14 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(
}
}
// A Cuda custom kernel that convert input to output, given proper padding on
// A Gpu custom kernel that convert input to output, given proper padding on
// the left and the top. The padded value is zero.
template <typename T, int NDIMS>
__global__ void PadInputCustomKernelNHWC(int nthreads, const T* input,
Dimension<NDIMS> input_dims, T* output,
Dimension<NDIMS> output_dims,
Dimension<NDIMS - 2> padding_left) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
int output_index = index;
Index<NDIMS> output_tensor_index =
FlatToTensorIndex(output_index, output_dims);
@ -393,7 +399,7 @@ __global__ void PadInputCustomKernelNCHW(int nthreads, const T* input,
Dimension<NDIMS> input_dims, T* output,
Dimension<NDIMS> output_dims,
Dimension<NDIMS - 2> padding_left) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
int output_index = index;
Index<NDIMS> output_tensor_index =
FlatToTensorIndex(output_index, output_dims);
@ -432,19 +438,19 @@ struct TransformFilter<GPUDevice, T, int, NDIMS> {
}
combined_dims[1] = in.dimension(NDIMS - 2); // input filters
combined_dims[2] = in.dimension(NDIMS - 1); // output filters
GpuLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
if (dst_filter_format == FORMAT_OIHW) {
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
config.block_count, config.thread_per_block,
0, d.stream(), config.virtual_thread_count,
in.data(), combined_dims, out.data()));
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
config.block_count, config.thread_per_block,
0, d.stream(), config.virtual_thread_count,
in.data(), combined_dims, out.data()));
} else if (dst_filter_format == FORMAT_OHWI) {
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 1, 2, 0>,
config.block_count, config.thread_per_block,
0, d.stream(), config.virtual_thread_count,
in.data(), combined_dims, out.data()));
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 1, 2, 0>,
config.block_count, config.thread_per_block,
0, d.stream(), config.virtual_thread_count,
in.data(), combined_dims, out.data()));
} else {
LOG(ERROR) << "Unsupported filter format: "
@ -471,11 +477,11 @@ struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
combined_dims[2] *= in.dimension(i);
}
GpuLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
config.block_count, config.thread_per_block,
0, d.stream(), config.virtual_thread_count,
in.data(), combined_dims, out.data()));
GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
config.block_count, config.thread_per_block,
0, d.stream(), config.virtual_thread_count,
in.data(), combined_dims, out.data()));
} else if (src_filter_format == FORMAT_OHWI) {
combined_dims[0] = in.dimension(0); // output filters
@ -485,11 +491,11 @@ struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
}
combined_dims[2] = in.dimension(NDIMS - 1); // input filters
GpuLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 2, 0, 1>,
config.block_count, config.thread_per_block,
0, d.stream(), config.virtual_thread_count,
in.data(), combined_dims, out.data()));
GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 0, 1>,
config.block_count, config.thread_per_block,
0, d.stream(), config.virtual_thread_count,
in.data(), combined_dims, out.data()));
} else {
// TODO(ezhulenev): Set error status in OpKernelContext instead.
@ -510,7 +516,7 @@ struct PadInput<GPUDevice, T, int, NDIMS> {
const std::array<int, NDIMS - 2>& padding_right,
typename TTypes<T, NDIMS, int>::Tensor out,
TensorFormat format) {
GpuLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
Dimension<NDIMS> input_dims;
for (int i = 0; i < NDIMS; ++i) {
input_dims[i] = in.dimension(i);
@ -523,12 +529,12 @@ struct PadInput<GPUDevice, T, int, NDIMS> {
const Dimension<NDIMS - 2> padding_left_dim(padding_left);
if (format == FORMAT_NHWC) {
TF_CHECK_OK(CudaLaunchKernel(
TF_CHECK_OK(GpuLaunchKernel(
PadInputCustomKernelNHWC<T, NDIMS>, config.block_count,
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
in.data(), input_dims, out.data(), output_dims, padding_left_dim));
} else if (format == FORMAT_NCHW) {
TF_CHECK_OK(CudaLaunchKernel(
TF_CHECK_OK(GpuLaunchKernel(
PadInputCustomKernelNCHW<T, NDIMS>, config.block_count,
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
in.data(), input_dims, out.data(), output_dims, padding_left_dim));
@ -640,13 +646,13 @@ void LaunchBatchNarrowMatrixTransposeKernel(
const T* input, const Dimension<3>& input_dims, T* output) {
constexpr int NumThreads = TileLongSide;
if (tile_size_i <= TileLongSide && tile_size_j <= TileShortSide) {
TF_CHECK_OK(CudaLaunchKernel(
TF_CHECK_OK(GpuLaunchKernel(
SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileLongSide,
TileShortSide>,
total_tiles_count, NumThreads, 0, d.stream(), input, input_dims,
output));
} else {
TF_CHECK_OK(CudaLaunchKernel(
TF_CHECK_OK(GpuLaunchKernel(
SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileShortSide,
TileLongSide>,
total_tiles_count, NumThreads, 0, d.stream(), input, input_dims,
@ -951,8 +957,7 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
input_dims_in_tiles[2];
TF_CHECK_OK(CudaLaunchKernel(
TF_CHECK_OK(GpuLaunchKernel(
SwapDimension1And2InTensor3UsingTiles<T, kNumThreads, kTileSize,
kTileSize, conjugate>,
total_tiles_count, kNumThreads, 0, d.stream(), input, input_dims,
@ -963,11 +968,11 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
d, input, input_dims, output, kMinDimensionToUseTiles);
} else {
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
GpuLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 0, 2, 1, conjugate>,
config.block_count, config.thread_per_block, 0,
d.stream(), config.virtual_thread_count, input,
input_dims, output));
GpuLaunchConfig config = GetGpuLaunchConfig(total_element_count, d);
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 0, 2, 1, conjugate>,
config.block_count, config.thread_per_block, 0,
d.stream(), config.virtual_thread_count, input,
input_dims, output));
}
}
@ -996,11 +1001,11 @@ struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> {
static_cast<int>(combined_dims[1]),
static_cast<int>(combined_dims[2])};
size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
GpuLaunchConfig config = GetCudaLaunchConfig(total_size, d);
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>,
config.block_count, config.thread_per_block, 0,
d.stream(), config.virtual_thread_count, in,
input_dims, out));
GpuLaunchConfig config = GetGpuLaunchConfig(total_size, d);
TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>,
config.block_count, config.thread_per_block, 0,
d.stream(), config.virtual_thread_count, in,
input_dims, out));
}
};
@ -1043,6 +1048,6 @@ struct NCHWToNHWC<GPUDevice, T, NDIMS> {
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -47,4 +47,4 @@ template struct PadInput<Eigen::GpuDevice, double, int, 5>;
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -60,4 +60,4 @@ template struct PadInput<Eigen::GpuDevice, float, int, 5>;
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -54,4 +54,4 @@ template struct PadInput<Eigen::GpuDevice, Eigen::half, int, 5>;
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -35,4 +35,4 @@ template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, uint16>;
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -35,4 +35,4 @@ template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, uint32>;
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -35,4 +35,4 @@ template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, uint64>;
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -35,4 +35,4 @@ template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, uint8>;
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM