added __restrict__ to all possible __global__ kernels

This commit is contained in:
ThisIsIsaac 2019-08-01 15:25:07 +09:00
parent 415771767b
commit 6387cf7d1d
44 changed files with 201 additions and 201 deletions

View File

@ -94,9 +94,9 @@ inline __device__ RgbTuple hsv2rgb_cuda(const float h, const float s,
template <bool AdjustHue, bool AdjustSaturation, bool AdjustV, typename T>
__global__ void adjust_hsv_nhwc(const int64 number_elements,
const T* const __restrict__ input,
T* const output, const float* const hue_delta,
const float* const saturation_scale,
const float* const value_scale) {
T* const __restrict__ output, const float* const __restrict__ hue_delta,
const float* const __restrict__ saturation_scale,
const float* const __restrict__ value_scale) {
// multiply by 3 since we're dealing with contiguous RGB bytes for each pixel
// (NHWC)
for (int64 idx = (blockDim.x * blockIdx.x + threadIdx.x) * 3;

View File

@ -41,13 +41,13 @@ DEFINE_GPU_KERNELS(double)
template <typename dtype>
__global__ void AvePoolBackwardNHWC(const int nthreads,
const dtype* const top_diff, const int num,
const dtype* const __restrict__ top_diff, const int num,
const int height, const int width,
const int channels, const int pooled_height,
const int pooled_width, const int kernel_h,
const int kernel_w, const int stride_h,
const int stride_w, const int pad_t,
const int pad_l, dtype* const bottom_diff) {
const int pad_l, dtype* const __restrict__ bottom_diff) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset

View File

@ -51,8 +51,8 @@ struct AccumulatorType<Eigen::half> {
// Definition of the GPU implementations declared in bias_op.cc.
template <typename T>
__global__ void BiasNHWCKernel(int32 nthreads, const T* input, const T* bias,
T* output, int32 bias_size) {
__global__ void BiasNHWCKernel(int32 nthreads, const T* __restrict__ input, const T* __restrict__ bias,
T* __restrict__ output, int32 bias_size) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
int32 bias_offset = index % bias_size;
output[index] = ldg(input + index) + ldg(bias + bias_offset);
@ -60,8 +60,8 @@ __global__ void BiasNHWCKernel(int32 nthreads, const T* input, const T* bias,
}
template <typename T>
__global__ void BiasNCHWKernel(int32 nthreads, const T* input, const T* bias,
T* output, int32 bias_size, int32 image_size) {
__global__ void BiasNCHWKernel(int32 nthreads, const T* __restrict__ input, const T* __restrict__ bias,
T* __restrict__ output, int32 bias_size, int32 image_size) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
int32 index2 = index / image_size;
int32 bias_offset = index2 % bias_size;
@ -97,8 +97,8 @@ void BiasGPU<T>::compute(const GPUDevice& d, const T* input, const T* bias,
// A naive implementation that is functional on all cases.
template <typename T>
__global__ void BiasGradNHWC_Naive(int32 nthreads, const T* output_backprop,
T* bias_backprop, int32 bias_size) {
__global__ void BiasGradNHWC_Naive(int32 nthreads, const T* __restrict__ output_backprop,
T* __restrict__ bias_backprop, int32 bias_size) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
int32 bias_offset = index % bias_size;
GpuAtomicAdd(bias_backprop + bias_offset, ldg(output_backprop + index));
@ -107,8 +107,8 @@ __global__ void BiasGradNHWC_Naive(int32 nthreads, const T* output_backprop,
// A naive implementation that is functional on all cases.
template <typename T>
__global__ void BiasGradNCHW_Naive(int32 nthreads, const T* output_backprop,
T* bias_backprop, int32 bias_size,
__global__ void BiasGradNCHW_Naive(int32 nthreads, const T* __restrict__ output_backprop,
T* __restrict__ bias_backprop, int32 bias_size,
int32 image_size) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
int32 index2 = index / image_size;
@ -120,8 +120,8 @@ __global__ void BiasGradNCHW_Naive(int32 nthreads, const T* output_backprop,
template <typename T>
__global__ void BiasGradNHWC_SharedAtomics(int32 nthreads,
const T* output_backprop,
T* bias_backprop, int32 bias_size) {
const T* __restrict__ output_backprop,
T* __restrict__ bias_backprop, int32 bias_size) {
typedef typename AccumulatorType<T>::type AccT;
GPU_DYNAMIC_SHARED_MEM_DECL(8, char, s_buf);
AccT* s_data = reinterpret_cast<AccT*>(s_buf);
@ -143,8 +143,8 @@ __global__ void BiasGradNHWC_SharedAtomics(int32 nthreads,
}
template <typename T>
__global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
T* bias_backprop, int32 batch,
__global__ void BiasGradNCHW_SharedAtomics(const T* __restrict__ output_backprop,
T* __restrict__ bias_backprop, int32 batch,
int32 bias_size, int32 image_size,
int group_size) {
// Initialize the shared memory.

View File

@ -34,8 +34,8 @@ typedef Eigen::GpuDevice GPUDevice;
template <typename T, bool useSharedMem>
__global__ void BucketizeCustomKernel(
const int32 size_in, const T* in, const int32 size_boundaries,
GpuDeviceArrayStruct<float> boundaries_array, int32* out) {
const int32 size_in, const T* __restrict__ in, const int32 size_boundaries,
GpuDeviceArrayStruct<float> boundaries_array, int32* __restrict__ out) {
const float* boundaries = GetGpuDeviceArrayOnDevice(&boundaries_array);
GPU_DYNAMIC_SHARED_MEM_DECL(sizeof(float), unsigned char, shared_mem);

View File

@ -36,7 +36,7 @@ typedef Eigen::GpuDevice GPUDevice;
// A Cuda kernel to check if each element is Inf or Nan. If any exists, the
// relevant elements in abnormal_detected will be set
template <typename T>
__global__ void CheckNumericsKernel(const T* data, int size,
__global__ void CheckNumericsKernel(const T* __restrict__ data, int size,
int abnormal_detected[2]) {
const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const int32 total_thread_count = gridDim.x * blockDim.x;

View File

@ -33,7 +33,7 @@ namespace functor {
template <typename T>
__global__ void CompareAndBitpackKernel(const int size, const T* threshold,
const T* input, uint8* output) {
const T* __restrict__ input, uint8* __restrict__ output) {
// TODO(ebrevdo): Erich said: to get a better memory access pattern
// you could have 8 threads load this data and do a comparison, then
// use the ballot instruction to combine the values from each thread
@ -55,9 +55,9 @@ __global__ void CompareAndBitpackKernel(const int size, const T* threshold,
template <>
__global__ void CompareAndBitpackKernel<bool>(const int size,
const bool* threshold,
const bool* input,
uint8* output) {
const bool* __restrict__ threshold,
const bool* __restrict__ input,
uint8* __restrict__ output) {
// TODO(ebrevdo): Erich said: I think you could again have multiple
// threads work on one block and use the ballot instruction to the
// bit packing in one instruction.
@ -77,9 +77,9 @@ __global__ void CompareAndBitpackKernel<bool>(const int size,
template <>
__global__ void CompareAndBitpackKernel<float>(const int size,
const float* threshold,
const float* input,
uint8* output) {
const float* __restrict__ threshold,
const float* __restrict__ input,
uint8* __restrict__ output) {
const float thresh = ldg(threshold);
GPU_1D_KERNEL_LOOP(i, size) {
const float4 block0 = ldg(reinterpret_cast<const float4*>(input + 8 * i));
@ -94,9 +94,9 @@ __global__ void CompareAndBitpackKernel<float>(const int size,
template <>
__global__ void CompareAndBitpackKernel<double>(const int size,
const double* threshold,
const double* input,
uint8* output) {
const double* __restrict__ threshold,
const double* __restrict__ input,
uint8* __restrict__ output) {
const double thresh = ldg(threshold);
GPU_1D_KERNEL_LOOP(i, size) {
const double2 block0 = ldg(reinterpret_cast<const double2*>(input + 8 * i));

View File

@ -35,7 +35,7 @@ namespace {
template <typename T, typename IntType>
__global__ void concat_fixed_kernel(
GpuDeviceArrayStruct<const T*> input_ptr_data, int split_size,
GpuDeviceArrayStruct<const T*> __restrict__ input_ptr_data, int split_size,
int total_rows, int total_cols, T* output) {
const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data);
IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
@ -59,7 +59,7 @@ __global__ void concat_fixed_kernel(
// cannot be in anonymous namespace due to extern shared memory
template <typename T, typename IntType, bool useSmem>
__global__ void concat_variable_kernel(
GpuDeviceArrayStruct<const T*> input_ptr_data,
GpuDeviceArrayStruct<const T*> __restrict__ input_ptr_data,
GpuDeviceArrayStruct<IntType> output_scan, IntType total_rows,
IntType total_cols, T* output) {
const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data);

View File

@ -182,8 +182,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex(
// Requires that nthreads is equal to the total number of elements in the input
// tensor.
template <typename T, int sp0, int sp1, int sp2, bool conjugate = false>
__global__ void ShuffleInTensor3Simple(int nthreads, const T* input,
Dimension<3> input_dims, T* output) {
__global__ void ShuffleInTensor3Simple(int nthreads, const T* __restrict__ input,
Dimension<3> input_dims, T* __restrict__ output) {
Dimension<3> output_dims;
output_dims[sp0] = input_dims[0];
output_dims[sp1] = input_dims[1];
@ -366,8 +366,8 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(
// A Gpu custom kernel that convert input to output, given proper padding on
// the left and the top. The padded value is zero.
template <typename T, int NDIMS>
__global__ void PadInputCustomKernelNHWC(int nthreads, const T* input,
Dimension<NDIMS> input_dims, T* output,
__global__ void PadInputCustomKernelNHWC(int nthreads, const T* __restrict__ input,
Dimension<NDIMS> input_dims, T* __restrict__ output,
Dimension<NDIMS> output_dims,
Dimension<NDIMS - 2> padding_left) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
@ -395,8 +395,8 @@ __global__ void PadInputCustomKernelNHWC(int nthreads, const T* input,
}
template <typename T, int NDIMS>
__global__ void PadInputCustomKernelNCHW(int nthreads, const T* input,
Dimension<NDIMS> input_dims, T* output,
__global__ void PadInputCustomKernelNCHW(int nthreads, const T* __restrict__ input,
Dimension<NDIMS> input_dims, T* __restrict__ output,
Dimension<NDIMS> output_dims,
Dimension<NDIMS - 2> padding_left) {
GPU_1D_KERNEL_LOOP(index, nthreads) {

View File

@ -38,10 +38,10 @@ enum InterpolationMethod {
template <typename T>
__global__ void CropAndResizeKernel(
const int32 nthreads, const T* image_ptr, const float* boxes_ptr,
const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
const int32 nthreads, const T* __restrict__ image_ptr, const float* __restrict__ boxes_ptr,
const int32* __restrict__ box_ind_ptr, int num_boxes, int batch, int image_height,
int image_width, int crop_height, int crop_width, int depth, int method_id,
float extrapolation_value, float* crops_ptr) {
float extrapolation_value, float* __restrict__ crops_ptr) {
GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
int idx = out_idx;
@ -130,10 +130,10 @@ __global__ void CropAndResizeKernel(
template <typename T>
__global__ void CropAndResizeBackpropImageKernel(
const int32 nthreads, const float* grads_ptr, const float* boxes_ptr,
const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
const int32 nthreads, const float* __restrict__ grads_ptr, const float* __restrict__ boxes_ptr,
const int32* __restrict__ box_ind_ptr, int num_boxes, int batch, int image_height,
int image_width, int crop_height, int crop_width, int depth,
T* grads_image_ptr, int method_id) {
T* __restrict__ grads_image_ptr, int method_id) {
GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
int idx = out_idx;
@ -225,10 +225,10 @@ __global__ void CropAndResizeBackpropImageKernel(
template <typename T>
__global__ void CropAndResizeBackpropBoxesKernel(
const int32 nthreads, const float* grads_ptr, const T* image_ptr,
const float* boxes_ptr, const int32* box_ind_ptr, int num_boxes, int batch,
const int32 nthreads, const float* __restrict__ grads_ptr, const T* __restrict__ image_ptr,
const float* __restrict__ boxes_ptr, const int32* __restrict__ box_ind_ptr, int num_boxes, int batch,
int image_height, int image_width, int crop_height, int crop_width,
int depth, float* grads_boxes_ptr) {
int depth, float* __restrict__ grads_boxes_ptr) {
GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
int idx = out_idx;

View File

@ -24,8 +24,8 @@ limitations under the License.
namespace tensorflow {
template <typename T>
__global__ void UnaryClipCustomKernel(const int32 size_in, const T *in0,
const T *in1, const T *in2, T *out) {
__global__ void UnaryClipCustomKernel(const int32 size_in, const T* __restrict__ in0,
const T* __restrict__ in1, const T* __restrict__ in2, T* __restrict__ out) {
GPU_1D_KERNEL_LOOP(i, size_in) {
T value = in2[0] < in0[i] ? in2[0] : in0[i];
out[i] = value < in1[0] ? in1[0] : value;
@ -33,9 +33,9 @@ __global__ void UnaryClipCustomKernel(const int32 size_in, const T *in0,
}
template <typename T>
__global__ void BinaryRightClipCustomKernel(const int32 size_in, const T *in0,
const T *in1, const T *in2,
T *out) {
__global__ void BinaryRightClipCustomKernel(const int32 size_in, const T* __restrict__ in0,
const T* __restrict__ in1, const T* __restrict__ in2,
T* __restrict__ out) {
GPU_1D_KERNEL_LOOP(i, size_in) {
T value = in2[i] < in0[i] ? in2[i] : in0[i];
out[i] = value < in1[0] ? in1[0] : value;
@ -43,8 +43,8 @@ __global__ void BinaryRightClipCustomKernel(const int32 size_in, const T *in0,
}
template <typename T>
__global__ void BinaryLeftClipCustomKernel(const int32 size_in, const T *in0,
const T *in1, const T *in2, T *out) {
__global__ void BinaryLeftClipCustomKernel(const int32 size_in, const T* __restrict__ in0,
const T* __restrict__ in1, const T* __restrict__ in2, T* __restrict__ out) {
GPU_1D_KERNEL_LOOP(i, size_in) {
T value = in2[0] < in0[i] ? in2[0] : in0[i];
out[i] = value < in1[i] ? in1[i] : value;

View File

@ -83,8 +83,8 @@ enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD };
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void __launch_bounds__(1024, 2)
DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args, const T* input,
const T* filter, T* output, int num_outputs) {
DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args, const T* __restrict__ input,
const T* __restrict__ filter, T* __restrict__ output, int num_outputs) {
typedef typename detail::PseudoHalfType<T>::Type S;
const int in_height = args.in_rows;
const int in_width = args.in_cols;
@ -187,7 +187,7 @@ template <typename T, DepthwiseConv2dDirection kDirection,
int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth,
bool kKnownEvenHeight>
__global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
const DepthwiseArgs args, const T* input, const T* filter, T* output) {
const DepthwiseArgs args, const T* __restrict__ input, const T* __restrict__ filter, T* __restrict__ output) {
typedef typename detail::PseudoHalfType<T>::Type S;
assert(CanLaunchDepthwiseConv2dGPUSmall(args));
// Holds block plus halo and filter data for blockDim.x depths.
@ -327,8 +327,8 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void __launch_bounds__(1024, 2)
DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args, const T* input,
const T* filter, T* output, int num_outputs) {
DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args, const T* __restrict__ input,
const T* __restrict__ filter, T* __restrict__ output, int num_outputs) {
typedef typename detail::PseudoHalfType<T>::Type S;
const int in_height = args.in_rows;
const int in_width = args.in_cols;
@ -475,7 +475,7 @@ template <typename T, DepthwiseConv2dDirection kDirection,
int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth,
bool kKnownEvenHeight>
__global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
const DepthwiseArgs args, const T* input, const T* filter, T* output) {
const DepthwiseArgs args, const T* __restrict__ input, const T* __restrict__ filter, T* __restrict__ output) {
typedef typename detail::PseudoHalfType<T>::Type S;
assert(CanLaunchDepthwiseConv2dGPUSmall(args));
// Holds block plus halo and filter data for blockDim.z depths.
@ -799,8 +799,8 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void __launch_bounds__(640, 2)
DepthwiseConv2dBackpropInputGPUKernelNHWC(const DepthwiseArgs args,
const T* out_backprop,
const T* filter, T* in_backprop,
const T* __restrict__ out_backprop,
const T* __restrict__ filter, T* __restrict__ in_backprop,
int num_in_backprop) {
const int in_height = args.in_rows;
const int in_width = args.in_cols;
@ -869,8 +869,8 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void __launch_bounds__(640, 2)
DepthwiseConv2dBackpropInputGPUKernelNCHW(const DepthwiseArgs args,
const T* out_backprop,
const T* filter, T* in_backprop,
const T* __restrict__ out_backprop,
const T* __restrict__ filter, T* __restrict__ in_backprop,
int num_in_backprop) {
const int in_height = args.in_rows;
const int in_width = args.in_cols;
@ -1020,9 +1020,9 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void __launch_bounds__(640, 2)
DepthwiseConv2dBackpropFilterGPUKernelNHWC(const DepthwiseArgs args,
const T* out_backprop,
const T* input,
T* filter_backprop,
const T* __restrict__ out_backprop,
const T* __restrict__ input,
T* __restrict__ filter_backprop,
int num_out_backprop) {
const int in_height = args.in_rows;
const int in_width = args.in_cols;
@ -1153,7 +1153,7 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kBlockDepth, int kAccumPixels>
__global__
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const DepthwiseArgs args, const T* output, const T* input, T* filter) {
const DepthwiseArgs args, const T* __restrict__ output, const T* __restrict__ input, T* __restrict__ filter) {
typedef typename detail::PseudoHalfType<T>::Type S;
assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.z));
// Holds block plus halo and filter data for blockDim.x depths.
@ -1305,9 +1305,9 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void __launch_bounds__(640, 2)
DepthwiseConv2dBackpropFilterGPUKernelNCHW(const DepthwiseArgs args,
const T* out_backprop,
const T* input,
T* filter_backprop,
const T* __restrict__ out_backprop,
const T* __restrict__ input,
T* __restrict__ filter_backprop,
int num_out_backprop) {
const int in_height = args.in_rows;
const int in_width = args.in_cols;
@ -1426,7 +1426,7 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kBlockDepth, int kAccumPixels>
__global__
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
const DepthwiseArgs args, const T* output, const T* input, T* filter) {
const DepthwiseArgs args, const T* __restrict__ output, const T* __restrict__ input, T* __restrict__ filter) {
typedef typename detail::PseudoHalfType<T>::Type S;
assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.x));
// Holds block plus halo and filter data for blockDim.z depths.

View File

@ -84,10 +84,10 @@ __device__ inline complex128 operator/(const complex128& a, const double& b) {
// the sign argument is ignored.
template <typename Scalar, bool compute_log_abs_det = true>
__global__ void DeterminantFromPivotedLUKernel(int nthreads, int n,
const Scalar* lu_factor,
const int* all_pivots,
Scalar* sign,
Scalar* log_abs_det) {
const Scalar* __restrict__ lu_factor,
const int* __restrict__ all_pivots,
Scalar* __restrict__ sign,
Scalar* __restrict__ log_abs_det) {
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
const int matrix_size = n * n;
const int stride = n + 1;

View File

@ -30,7 +30,7 @@ typedef Eigen::GpuDevice GPUDevice;
template <typename T>
__global__ void DiagGpuKernel(const int num_threads, const int64 size,
const T* in, T* out) {
const T* __restrict__ in, T* __restrict__ out) {
GPU_1D_KERNEL_LOOP(index, num_threads) {
// Fill the diagonal elements or set to zero in other place.
if (index % (1 + size) == 0) {

View File

@ -35,13 +35,13 @@ typedef Eigen::GpuDevice GPUDevice;
namespace {
template <typename T>
__global__ void DilationKernel(const int32 nthreads, const T* input_ptr,
const T* filter_ptr, int batch, int input_rows,
__global__ void DilationKernel(const int32 nthreads, const T* __restrict__ input_ptr,
const T* __restrict__ filter_ptr, int batch, int input_rows,
int input_cols, int depth, int filter_rows,
int filter_cols, int output_rows,
int output_cols, int stride_rows,
int stride_cols, int rate_rows, int rate_cols,
int pad_top, int pad_left, T* output_ptr) {
int pad_top, int pad_left, T* __restrict__ output_ptr) {
GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w_out + output_cols * (h_out + output_rows * b))
const int d = out_idx % depth;
@ -76,11 +76,11 @@ __global__ void DilationKernel(const int32 nthreads, const T* input_ptr,
template <typename T>
__global__ void DilationBackpropInputKernel(
const int32 nthreads, const T* input_ptr, const T* filter_ptr,
const T* out_backprop_ptr, int batch, int input_rows, int input_cols,
const int32 nthreads, const T* __restrict__ input_ptr, const T* __restrict__ filter_ptr,
const T* __restrict__ out_backprop_ptr, int batch, int input_rows, int input_cols,
int depth, int filter_rows, int filter_cols, int output_rows,
int output_cols, int stride_rows, int stride_cols, int rate_rows,
int rate_cols, int pad_top, int pad_left, T* in_backprop_ptr) {
int rate_cols, int pad_top, int pad_left, T* __restrict__ in_backprop_ptr) {
GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w_out + output_cols * (h_out + output_rows * b))
const int d = out_idx % depth;
@ -125,11 +125,11 @@ __global__ void DilationBackpropInputKernel(
template <typename T>
__global__ void DilationBackpropFilterKernel(
const int32 nthreads, const T* input_ptr, const T* filter_ptr,
const T* out_backprop_ptr, int batch, int input_rows, int input_cols,
const int32 nthreads, const T* __restrict__ input_ptr, const T* __restrict__ filter_ptr,
const T* __restrict__ out_backprop_ptr, int batch, int input_rows, int input_cols,
int depth, int filter_rows, int filter_cols, int output_rows,
int output_cols, int stride_rows, int stride_cols, int rate_rows,
int rate_cols, int pad_top, int pad_left, T* filter_backprop_ptr) {
int rate_cols, int pad_top, int pad_left, T* __restrict__ filter_backprop_ptr) {
GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w_out + output_cols * (h_out + output_rows * b))
const int d = out_idx % depth;

View File

@ -30,7 +30,7 @@ typedef Eigen::GpuDevice GPUDevice;
template <typename Scalar>
__global__ void EyeKernel(int num_threads, int batch_size, int m, int n,
Scalar* output_ptr) {
Scalar* __restrict__ output_ptr) {
const Scalar one = Scalar(1);
const Scalar zero = Scalar(0);
GPU_1D_KERNEL_LOOP(index, num_threads) {

View File

@ -30,7 +30,7 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
template <typename T, typename Index, bool is_axis_zero>
__global__ void GatherOpKernel(const T* params, const Index* indices, T* out,
__global__ void GatherOpKernel(const T* __restrict__ params, const Index* __restrict__ indices, T* __restrict__ out,
int64 gather_dim_size, int64 indices_size,
int64 slice_size, int64 out_size) {
GPU_1D_KERNEL_LOOP(i, out_size) {

View File

@ -28,7 +28,7 @@ typedef Eigen::GpuDevice GPUDevice;
template <typename T, typename Index, int IXDIM>
__global__ void GatherSliceOpKernel(
const T* params, const Index* indices, T* out,
const T* __restrict__ params, const Index* __restrict__ indices, T* __restrict__ out,
const Eigen::array<int64, IXDIM> batch_strides,
const Eigen::array<int64, IXDIM> batch_indices, const int64 indices_size,
const int64 slice_size, const int64 out_size) {

View File

@ -29,7 +29,7 @@ typedef Eigen::GpuDevice Device;
template <typename T>
__global__ void DoParallelConcatOpKernel(int nthreads, const int64 rows,
const int64 cols, int32 loc,
const T* src, T* dst) {
const T* __restrict__ src, T* __restrict__ dst) {
GPU_1D_KERNEL_LOOP(idx, nthreads) {
int64 c = idx % cols;
int64 r = (loc % rows + rows) % rows; // Guard index range.
@ -80,8 +80,8 @@ Status DoParallelConcat(const Device& d, const Tensor& value, int32 loc,
template <typename T, InplaceOpType op>
__global__ void DoInplaceOpKernel(int nthreads, const int64 rows,
const int64 cols, const int64 n, const T* src,
const int32* rowids, T* dst) {
const int64 cols, const int64 n, const T* __restrict__ src,
const int32* __restrict__ rowids, T* __restrict__ dst) {
GPU_1D_KERNEL_LOOP(idx, nthreads) {
int64 r = idx / cols;
int64 c = idx % cols;

View File

@ -61,8 +61,8 @@ __device__ void ComputePermutationFromTranspositions(
// transpositions.
template <typename Scalar>
__global__ void ComputePermutationFromTranspositionsKernel(
GpuLaunchConfig config, const int64 num_rows, const int* all_pivots,
Scalar* all_permutation_indices) {
GpuLaunchConfig config, const int64 num_rows, const int* __restrict__ all_pivots,
Scalar* __restrict__ all_permutation_indices) {
// We only parallelize over batches here. Performance is not critical,
// since this cheap O(num_rows) kernel always follows an O(num_rows^3)
// LU factorization.

View File

@ -33,8 +33,8 @@ __global__ void MatrixBandPartKernel(const int num_threads,
const int batch_size, const int m,
const int n, const int num_lower_diags,
const int num_upper_diags,
const Scalar* input_ptr,
Scalar* output_ptr) {
const Scalar* __restrict__ input_ptr,
Scalar* __restrict__ output_ptr) {
GPU_1D_KERNEL_LOOP(index, num_threads) {
const int col = index % n;
const int row = (index / n) % m;

View File

@ -31,8 +31,8 @@ __global__ void MatrixSetDiagKernel(const int num_threads, const int m,
const int n, const int num_diags,
const int max_diag_len,
const int upper_diag_index,
const Scalar* diag_ptr,
Scalar* output_ptr) {
const Scalar* __restrict__ diag_ptr,
Scalar* __restrict__ output_ptr) {
GPU_1D_KERNEL_LOOP(index, num_threads) {
const int batch_and_diag_index = index / max_diag_len;
const int index_in_the_diagonal =
@ -56,8 +56,8 @@ template <typename Scalar>
__global__ void MatrixCopyInputAndSetDiagKernel(
const int num_threads, const int m, const int n, const int num_diags,
const int max_diag_len, const int lower_diag_index,
const int upper_diag_index, const Scalar* input_ptr, const Scalar* diag_ptr,
Scalar* output_ptr) {
const int upper_diag_index, const Scalar* __restrict__ input_ptr, const Scalar* __restrict__ diag_ptr,
Scalar* __restrict__ output_ptr) {
GPU_1D_KERNEL_LOOP(index, num_threads) {
const int batch_and_row_index = index / n;
const int col = index - batch_and_row_index * n;

View File

@ -65,11 +65,11 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool IsGreaterThan(dtype a, dtype b) {
// kThreadsPerBlock, 0, cuda_stream>>>(...);
template <bool propagate_nans, typename dtype>
__global__ void MaxPoolForwardNCHW(
const int nthreads, const dtype* bottom_data, const int channels,
const int nthreads, const dtype* __restrict__ bottom_data, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
dtype* top_data, int64* mask, const bool include_batch_in_index) {
dtype* __restrict__ top_data, int64* __restrict__ mask, const bool include_batch_in_index) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
@ -108,11 +108,11 @@ __global__ void MaxPoolForwardNCHW(
// the same X, y coordinate.
// (so channels = outer_channels, output_size = real output size / 4).
__global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C(
const int nthreads, const int32* bottom_data, const int height,
const int nthreads, const int32* __restrict__ bottom_data, const int height,
const int width, const int channels, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
int32* top_data) {
int32* __restrict__ top_data) {
// TODO(pauldonnelly): Implement a better optimized version of this kernel.
const int32 kMinINT8X4 = 0x80808080;
GPU_1D_KERNEL_LOOP(index, nthreads) {
@ -141,11 +141,11 @@ __global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C(
template <bool propagate_nans, typename dtype>
__global__ void MaxPoolForwardNHWC(
const int nthreads, const dtype* bottom_data, const int height,
const int nthreads, const dtype* __restrict__ bottom_data, const int height,
const int width, const int channels, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
dtype* top_data, int64* mask, const bool include_batch_in_index) {
dtype* __restrict__ top_data, int64* __restrict__ mask, const bool include_batch_in_index) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
int n = index;
int c = n % channels;
@ -180,11 +180,11 @@ __global__ void MaxPoolForwardNHWC(
template <typename dtype>
__global__ void MaxPoolBackwardNoMaskNHWC(
const int nthreads, const dtype* bottom_data, const int height,
const int nthreads, const dtype* __restrict__ bottom_data, const int height,
const int width, const int channels, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
const dtype* top_diff, dtype* bottom_diff) {
const dtype* __restrict__ top_diff, dtype* __restrict__ bottom_diff) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
// First find out the index to the maximum, since we have no mask.
int n = index;
@ -240,9 +240,9 @@ __global__ void MaxPoolBackwardNoMaskNHWC(
// the kernel is run, you will need to make sure that bottom_diff is filled with
// zero first.
template <typename dtype>
__global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff,
const int64* mask, const int top_offset,
const int bottom_offset, dtype* bottom_diff,
__global__ void MaxPoolBackward(const int nthreads, const dtype* __restrict__ top_diff,
const int64* __restrict__ mask, const int top_offset,
const int bottom_offset, dtype* __restrict__ bottom_diff,
const bool include_batch_in_index) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
const int offset =
@ -267,11 +267,11 @@ __global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff,
// bottom_diff: the gradient of the gradient w.r.t. output.
template <typename dtype>
__global__ void MaxPoolGradBackwardNoMaskNCHW(
const int nthreads, const dtype* bottom_data, const dtype* output_data,
const int nthreads, const dtype* __restrict__ bottom_data, const dtype* __restrict__ output_data,
const int pooled_height, const int pooled_width, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
const dtype* top_diff, dtype* bottom_diff) {
const dtype* __restrict__ top_diff, dtype* __restrict__ bottom_diff) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
// First find out the index to the maximum, since we have no mask.
int pw = index % pooled_width;
@ -307,11 +307,11 @@ __global__ void MaxPoolGradBackwardNoMaskNCHW(
template <typename dtype>
__global__ void MaxPoolGradBackwardNoMaskNHWC(
const int nthreads, const dtype* bottom_data, const dtype* output_data,
const int nthreads, const dtype* __restrict__ bottom_data, const dtype* __restrict__ output_data,
const int pooled_height, const int pooled_width, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
const dtype* top_diff, dtype* bottom_diff) {
const dtype* __restrict__ top_diff, dtype* __restrict__ bottom_diff) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
// First find out the index to the maximum, since we have no mask.
int n = index;
@ -367,9 +367,9 @@ __global__ void MaxPoolGradBackwardNoMaskNHWC(
// include_batch_in_index: whether to include batch dimension in flattened
// index of `argmax`.
template <typename dtype>
__global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff,
const int64* mask, const int top_offset,
const int bottom_offset, dtype* bottom_diff,
__global__ void MaxPoolGradBackward(const int nthreads, const dtype* __restrict__ top_diff,
const int64* __restrict__ mask, const int top_offset,
const int bottom_offset, dtype* __restrict__ bottom_diff,
const bool include_batch_in_index) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
const int offset =

View File

@ -45,8 +45,8 @@ using GPUDevice = Eigen::GpuDevice;
// scores: [B, S, C]; maxima: [B, S]; output: [B, S].
template <typename OutputType>
__global__ void MultinomialKernel(int32 nthreads, const int32 num_classes,
const int32 num_samples, const float* scores,
const float* maxima, OutputType* output) {
const int32 num_samples, const float* __restrict__ scores,
const float* __restrict__ maxima, OutputType* __restrict__ output) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
const int maxima_idx = index / num_classes;
if (ldg(maxima + maxima_idx) == ldg(scores + index)) {

View File

@ -130,9 +130,9 @@ __device__ EIGEN_STRONG_INLINE void Flipped<true>(Box& box) {
// x1<x2 and y1<y2.
template <bool flip_box>
__launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__
void NMSKernel(const Box* d_desc_sorted_boxes, const int num_boxes,
void NMSKernel(const Box* __restrict__ d_desc_sorted_boxes, const int num_boxes,
const float iou_threshold, const int bit_mask_len,
int* d_delete_mask) {
int* __restrict__ d_delete_mask) {
// Storing boxes used by this CUDA block in the shared memory.
__shared__ Box shared_i_boxes[kNmsBlockDim];
// Same thing with areas
@ -208,15 +208,15 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
// IndexMultiSelect(num_elements, indices, original1 ,selected1, original2,
// selected2).
template <typename Index, typename T, typename... Args>
__global__ void IndexMultiSelect(const int num_elements, const Index* indices,
const T* original, T* selected, Args... args) {
__global__ void IndexMultiSelect(const int num_elements, const Index* __restrict__ indices,
const T* __restrict__ original, T* __restrict__ selected, Args... args) {
for (const int idx : CudaGridRangeX(num_elements)) {
SelectHelper(idx, indices[idx], original, selected, args...);
}
}
template <typename T>
__global__ void Iota(const int num_elements, const T offset, T* to_fill) {
__global__ void Iota(const int num_elements, const T offset, T* __restrict__ to_fill) {
for (int idx : CudaGridRangeX(num_elements)) {
to_fill[idx] = static_cast<T>(idx) + offset;
}

View File

@ -53,9 +53,9 @@ template <typename T>
__global__ void __launch_bounds__(1024)
TruncatedNormalKernel(random::PhiloxRandom gen, T* data, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
const T* means, bool single_mean, const T* stddevs,
bool single_stddev, const T* minvals,
bool single_minval, const T* maxvals,
const T* __restrict__ means, bool single_mean, const T* stddevs,
bool single_stddev, const T* __restrict__ minvals,
bool single_minval, const T* __restrict__ maxvals,
bool single_maxval, int64 kMaxIterations) {
const int32 max_samples_per_item = 2 * kMaxIterations;
// Initial offset as given by GPU_1D_KERNEL_LOOP.

View File

@ -28,13 +28,13 @@ namespace {
template <typename dtype>
__global__ void MaxPoolGradBackwardNoMaskNCDHW(
const int nthreads, const dtype* bottom_data, const dtype* output_data,
const int nthreads, const dtype* __restrict__ bottom_data, const dtype* __restrict__ output_data,
const int pooled_plane, const int pooled_height, const int pooled_width,
const int channels, const int plane, const int height, const int width,
const int kernel_p, const int kernel_h, const int kernel_w,
const int stride_p, const int stride_h, const int stride_w, const int pad_p,
const int pad_t, const int pad_l, const dtype* top_diff,
dtype* bottom_diff) {
const int pad_t, const int pad_l, const dtype* __restrict__ top_diff,
dtype* __restrict__ bottom_diff) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
// First find out the index to the maximum, since we have no mask.
int pw = index % pooled_width;
@ -78,13 +78,13 @@ __global__ void MaxPoolGradBackwardNoMaskNCDHW(
template <typename dtype>
__global__ void MaxPoolGradBackwardNoMaskNDHWC(
const int nthreads, const dtype* bottom_data, const dtype* output_data,
const int nthreads, const dtype* __restrict__ bottom_data, const dtype* __restrict__ output_data,
const int pooled_plane, const int pooled_height, const int pooled_width,
const int channels, const int plane, const int height, const int width,
const int kernel_p, const int kernel_h, const int kernel_w,
const int stride_p, const int stride_h, const int stride_w, const int pad_p,
const int pad_t, const int pad_l, const dtype* top_diff,
dtype* bottom_diff) {
const int pad_t, const int pad_l, const dtype* __restrict__ top_diff,
dtype* __restrict__ bottom_diff) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
// First find out the index to the maximum, since we have no mask.
int n = index;

View File

@ -33,14 +33,14 @@ typedef Eigen::GpuDevice GPUDevice;
namespace functor {
template <typename T>
__global__ void PopulationCountKernel(const int size, const T* input,
uint8* output) {
__global__ void PopulationCountKernel(const int size, const T* __restrict__ input,
uint8* __restrict__ output) {
GPU_1D_KERNEL_LOOP(i, size) { output[i] = __popc(ldg(input + i)); }
}
template <>
__global__ void PopulationCountKernel(const int size, const int8* input,
uint8* output) {
__global__ void PopulationCountKernel(const int size, const int8* __restrict__ input,
uint8* __restrict__ output) {
// For some reason, __popc on a negative int8 gets confused.
GPU_1D_KERNEL_LOOP(i, size) {
output[i] = __popc(ldg(reinterpret_cast<const uint8*>(input + i)));
@ -48,8 +48,8 @@ __global__ void PopulationCountKernel(const int size, const int8* input,
}
template <>
__global__ void PopulationCountKernel(const int size, const int16* input,
uint8* output) {
__global__ void PopulationCountKernel(const int size, const int16* __restrict__ input,
uint8* __restrict__ output) {
// For some reason, __popc on a negative int16 gets confused.
GPU_1D_KERNEL_LOOP(i, size) {
output[i] = __popc(ldg(reinterpret_cast<const uint16*>(input + i)));
@ -57,8 +57,8 @@ __global__ void PopulationCountKernel(const int size, const int16* input,
}
template <>
__global__ void PopulationCountKernel<int64>(const int size, const int64* input,
uint8* output) {
__global__ void PopulationCountKernel<int64>(const int size, const int64* __restrict__ input,
uint8* __restrict__ output) {
GPU_1D_KERNEL_LOOP(i, size) { output[i] = __popcll(ldg(input + i)); }
}

View File

@ -34,9 +34,9 @@ namespace functor {
// This kernel computes ReluGrad by processing one half2, two fp16, at a time.
// It effectively does: backdrops = (feature > 0) ? gradient : 0
// It also tries to use native half2 primitives as much as possible.
__global__ void ReluGradHalfKernel(const Eigen::half* gradient,
const Eigen::half* feature,
Eigen::half* backprop, int32 count) {
__global__ void ReluGradHalfKernel(const Eigen::half* __restrict__ gradient,
const Eigen::half* __restrict__ feature,
Eigen::half* __restrict__ backprop, int32 count) {
int32 half2_count = count >> 1;
int32 index = blockIdx.x * blockDim.x + threadIdx.x;
const int32 total_device_threads = gridDim.x * blockDim.x;
@ -112,8 +112,8 @@ struct ReluGrad<Device, Eigen::half> {
}
};
__global__ void Relu_int8x4_kernel(int vect_count, const int32* input,
int32* output) {
__global__ void Relu_int8x4_kernel(int vect_count, const int32* __restrict__ input,
int32* __restrict__ output) {
CUDA_1D_KERNEL_LOOP(index, vect_count) {
output[index] = __vmaxs4(input[index], 0);
}

View File

@ -113,11 +113,11 @@ __global__ void ResizeBilinearKernel_faster(
}
template <typename T>
__global__ void ResizeBilinearKernel(const int32 nthreads, const T* images,
__global__ void ResizeBilinearKernel(const int32 nthreads, const T* __restrict__ images,
float height_scale, float width_scale,
int batch, int in_height, int in_width,
int channels, int out_height,
int out_width, float* output) {
int out_width, float* __restrict__ output) {
GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = c + channels * (x + out_width * (y + out_height * b))
int idx = out_idx;
@ -166,9 +166,9 @@ __global__ void ResizeBilinearKernel(const int32 nthreads, const T* images,
template <typename T>
__global__ void ResizeBilinearGradKernel(
const int32 nthreads, const float* input_grad, float height_scale,
const int32 nthreads, const float* __restrict__ input_grad, float height_scale,
float width_scale, int batch, int original_height, int original_width,
int channels, int resized_height, int resized_width, T* output_grad) {
int channels, int resized_height, int resized_width, T* __restrict__ output_grad) {
GPU_1D_KERNEL_LOOP(in_idx, nthreads) {
// in_idx = c + channels * (x + resized_width * (y + resized_height * b))
int idx = in_idx;
@ -228,11 +228,11 @@ __global__ void ResizeBilinearGradKernel(
template <typename T>
__global__ void LegacyResizeBilinearKernel(const int32 nthreads,
const T* images, float height_scale,
const T* __restrict__ images, float height_scale,
float width_scale, int batch,
int in_height, int in_width,
int channels, int out_height,
int out_width, float* output) {
int out_width, float* __restrict__ output) {
GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = c + channels * (x + out_width * (y + out_height * b))
int idx = out_idx;
@ -280,9 +280,9 @@ __global__ void LegacyResizeBilinearKernel(const int32 nthreads,
template <typename T>
__global__ void LegacyResizeBilinearGradKernel(
const int32 nthreads, const float* input_grad, float height_scale,
const int32 nthreads, const float* __restrict__ input_grad, float height_scale,
float width_scale, int batch, int original_height, int original_width,
int channels, int resized_height, int resized_width, T* output_grad) {
int channels, int resized_height, int resized_width, T* __restrict__ output_grad) {
GPU_1D_KERNEL_LOOP(in_idx, nthreads) {
// in_idx = c + channels * (x + resized_width * (y + resized_height * b))
int idx = in_idx;

View File

@ -33,7 +33,7 @@ namespace {
template <typename T>
__global__ void ResizeNearestNeighborNHWC(
const int nthreads, const T* bottom_data, const int in_height,
const int nthreads, const T* __restrict__ bottom_data, const int in_height,
const int in_width, const int channels, const int out_height,
const int out_width, const float height_scale, const float width_scale,
T* top_data) {
@ -64,10 +64,10 @@ __global__ void ResizeNearestNeighborNHWC(
template <typename T, bool align_corners>
__global__ void LegacyResizeNearestNeighborNHWC(
const int nthreads, const T* bottom_data, const int in_height,
const int nthreads, const T* __restrict__ bottom_data, const int in_height,
const int in_width, const int channels, const int out_height,
const int out_width, const float height_scale, const float width_scale,
T* top_data) {
T* __restrict__ top_data) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
int n = index;
int c = n % channels;
@ -93,10 +93,10 @@ __global__ void LegacyResizeNearestNeighborNHWC(
template <typename T>
__global__ void ResizeNearestNeighborBackwardNHWC(
const int nthreads, const T* top_diff, const int in_height,
const int nthreads, const T* __restrict__ top_diff, const int in_height,
const int in_width, const int channels, const int out_height,
const int out_width, const float height_scale, const float width_scale,
T* bottom_diff) {
T* __restrict__ bottom_diff) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
int n = index;
int c = n % channels;
@ -124,10 +124,10 @@ __global__ void ResizeNearestNeighborBackwardNHWC(
template <typename T, bool align_corners>
__global__ void LegacyResizeNearestNeighborBackwardNHWC(
const int nthreads, const T* top_diff, const int in_height,
const int nthreads, const T* __restrict__ top_diff, const int in_height,
const int in_width, const int channels, const int out_height,
const int out_width, const float height_scale, const float width_scale,
T* bottom_diff) {
T* __restrict__ bottom_diff) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
int n = index;
int c = n % channels;

View File

@ -31,8 +31,8 @@ namespace {
template <typename T>
__global__ void RollKernel(const int32 nthreads, const int32 num_dims,
const T* input, T* output, const int32* dim_size,
const int32* threshold, const int64* dim_range) {
const T* __restrict__ input, T* __restrict__ output, const int32* __restrict__ dim_size,
const int32* __restrict__ threshold, const int64* __restrict__ dim_range) {
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
int64 offset = 0;
for (int i = 0; i < num_dims; i++) {

View File

@ -70,8 +70,8 @@ struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MAX> {
};
template <typename T, typename Index, scatter_op::UpdateOp op>
__global__ void ScatterOpCustomKernel(T* params, const T* updates,
const Index* indices,
__global__ void ScatterOpCustomKernel(T* __restrict__ params, const T* __restrict__ updates,
const Index* __restrict__ indices,
Index first_dim_size, Index updates_size,
Index indices_size) {
Index update_block = updates_size / indices_size;
@ -90,8 +90,8 @@ __global__ void ScatterOpCustomKernel(T* params, const T* updates,
}
template <typename T, typename Index, scatter_op::UpdateOp op>
__global__ void ScatterScalarOpCustomKernel(T* params, const T* update,
const Index* indices,
__global__ void ScatterScalarOpCustomKernel(T* __restrict__ params, const T* __restrict__ update,
const Index* __restrict__ indices,
Index first_dim_size,
Index indices_size,
Index synthesized_updates_size) {

View File

@ -31,9 +31,9 @@ typedef Eigen::GpuDevice GPUDevice;
namespace {
template <typename T, typename OutType>
__global__ void UpperBoundKernel(const T* sorted_inputs, int batch_size,
__global__ void UpperBoundKernel(const T* __restrict__ sorted_inputs, int batch_size,
int sorted_inputs_size, int values_size,
const T* values, OutType* outputs) {
const T* __restrict__ values, OutType* __restrict__ outputs) {
GPU_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
int bid = work_unit_id / values_size;
T value = values[work_unit_id];
@ -43,9 +43,9 @@ __global__ void UpperBoundKernel(const T* sorted_inputs, int batch_size,
}
template <typename T, typename OutType>
__global__ void LowerBoundKernel(const T* sorted_inputs, int batch_size,
__global__ void LowerBoundKernel(const T* __restrict__ sorted_inputs, int batch_size,
int sorted_inputs_size, int values_size,
const T* values, OutType* outputs) {
const T* __restrict__ values, OutType* __restrict__ outputs) {
GPU_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
int bid = work_unit_id / values_size;
T value = values[work_unit_id];

View File

@ -29,12 +29,12 @@ typedef Eigen::GpuDevice GPUDevice;
// Space2Depth kernel for FORMAT_NHWC.
// See 'spacetodepth_op.h' for a more detailed description.
template <typename dtype>
__global__ void S2D_NHWC(const int32 nthreads, const dtype* input_ptr,
__global__ void S2D_NHWC(const int32 nthreads, const dtype* __restrict__ input_ptr,
const int block_size, const int batch_size,
const int input_height, const int input_width,
const int input_depth, const int output_height,
const int output_width, const int output_depth,
dtype* output_ptr) {
dtype* __restrict__ output_ptr) {
GPU_1D_KERNEL_LOOP(inp_idx, nthreads) {
// inp_idx = d + input_depth * (w + input_width * (h + input_height * b))
const int d = inp_idx % input_depth;

View File

@ -29,9 +29,9 @@ typedef Eigen::GpuDevice GPUDevice;
template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
__global__ void SparseTensorDenseMatMulKernel(int nnz, int m, int b_rows,
int b_cols, int p,
const Tindices* a_indices,
const T* a_values, const T* b,
T* out) {
const Tindices* __restrict__ a_indices,
const T* __restrict__ a_values, const T* __restrict__ b,
T* __restrict__ out) {
// out_{ij} = sum_k {a_ik b_kj}
// out = A * B', out_{ij} = sum_k {a_ik (b')_kj}; b'_{kj} = b_{jk}
const int n = (ADJ_B) ? b_cols : b_rows;

View File

@ -74,7 +74,7 @@ TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
namespace {
template <typename T>
__global__ void SplitOpKernel(const T* input, int32 prefix_dim_size,
__global__ void SplitOpKernel(const T* __restrict__ input, int32 prefix_dim_size,
int32 split_dim_size, int32 suffix_dim_size,
GpuDeviceArrayStruct<T*> output_ptr_data) {
const int32 num_split = output_ptr_data.size;
@ -112,7 +112,7 @@ __global__ void SplitOpKernel(const T* input, int32 prefix_dim_size,
// very similar to the concat kernel except the input/output logic
// is reversed
template <typename T, typename IntType, bool useSmem>
__global__ void split_v_kernel(const T* input_ptr,
__global__ void split_v_kernel(const T* __restrict__ input_ptr,
GpuDeviceArrayStruct<IntType> output_scan,
IntType total_rows, IntType total_cols,
GpuDeviceArrayStruct<T*> output_ptr_data) {
@ -169,7 +169,7 @@ __global__ void split_v_kernel(const T* input_ptr,
// different from the original split implementation due to 2D vs 3D
// dimensions. This version is likely faster due to less integer math.
template <typename T>
__global__ void SplitVOpKernel_fixed(const T* input, int32 prefix_dim_size,
__global__ void SplitVOpKernel_fixed(const T* __restrict__ input, int32 prefix_dim_size,
int32 suffix_dim_size,
GpuDeviceArrayStruct<T*> output_ptr_data) {
const int32 num_split = output_ptr_data.size;

View File

@ -35,8 +35,8 @@ __device__ int thread_counter;
template <typename Distribution>
__global__ void FillKernel(
Distribution dist, int64 state_size, int64 output_size,
StateElementType* state_data,
typename Distribution::ResultElementType* output_data) {
StateElementType* __restrict__ state_data,
typename Distribution::ResultElementType* __restrict__ output_data) {
// Threads in this block share `philox`. Thread 0 is responsible for
// initializing it.
__shared__ char philox_raw[sizeof(PhiloxRandom)];
@ -90,7 +90,7 @@ void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
}
// Precondition: there is only 1 block and 1 thread.
__global__ void SkipKernel(int64 delta, StateElementType* state_data) {
__global__ void SkipKernel(int64 delta, StateElementType* __restrict__ state_data) {
auto philox = GetPhiloxRandomFromMem(state_data);
UpdateMemWithPhiloxRandom(philox, delta, state_data);
}

View File

@ -60,9 +60,9 @@ namespace {
// real value of V (which should be computed)
template <class Scalar>
__global__ void ComputeValueOfVKernel(Gpu2DLaunchConfig config, int64 m,
int64 ldu, const Scalar* M,
const Scalar* U, const Scalar* S,
Scalar* V) {
int64 ldu, const Scalar* __restrict__ M,
const Scalar* __restrict__ U, const Scalar* __restrict__ S,
Scalar* __restrict__ V) {
GPU_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) {
GPU_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) {
Scalar v = M[i + m * batch] * U[ldu * (i + m * batch)] * S[batch];
@ -74,7 +74,7 @@ __global__ void ComputeValueOfVKernel(Gpu2DLaunchConfig config, int64 m,
// Extracts the sign of V
// V[i] = V[i]>=0 ? 1 : 0
template <class Scalar>
__global__ void ExtractSignOfVKernel(GpuLaunchConfig config, Scalar* V) {
__global__ void ExtractSignOfVKernel(GpuLaunchConfig config, Scalar* __restrict__ V) {
GPU_1D_KERNEL_LOOP(i, config.virtual_thread_count) {
V[i] = V[i] >= 0 ? Scalar(1) : Scalar(-1);
}

View File

@ -30,8 +30,8 @@ namespace tensorflow {
namespace internal {
template <typename T>
__global__ void TileKernel(int nthreads, const T* src, const int32* buf,
const int32 ndims, T* dst) {
__global__ void TileKernel(int nthreads, const T* __restrict__ src, const int32* __restrict__ buf,
const int32 ndims, T* __restrict__ dst) {
const int32* in_strides = buf;
const int32* out_strides = buf + ndims;
const int32* in_dim_sizes = buf + ndims * 2;

View File

@ -341,8 +341,8 @@ __device__ void mergeShards(int num_shards, int k,
extern __shared__ char shared_memory[];
template <typename T>
__global__ void TopKKernel(const T* input, int length, int k, bool sorted,
T* output, int* indices) {
__global__ void TopKKernel(const T* __restrict__ input, int length, int k, bool sorted,
T* __restrict__ output, int* __restrict__ indices) {
const int batch_index = blockIdx.x;
const T* batch_input = input + batch_index * length;

View File

@ -32,8 +32,8 @@ namespace tensorflow {
namespace internal {
template <typename T, bool conjugate>
__global__ void TransposeKernel(int nthreads, const T* src, const int32* buf,
const int32 ndims, T* dst) {
__global__ void TransposeKernel(int nthreads, const T* __restrict__ src, const int32* __restrict__ buf,
const int32 ndims, T* __restrict__ dst) {
const int32* in_strides = buf;
const int32* out_strides = buf + ndims;
const int32* perm = buf + ndims * 2;

View File

@ -35,10 +35,10 @@ namespace tensorflow {
template <typename Scalar>
__global__ void TridiagonalMatMulKernel(int batch_size, int m, int n,
const Scalar* superdiag,
const Scalar* maindiag,
const Scalar* subdiag,
const Scalar* rhs, Scalar* product) {
const Scalar* __restrict__ superdiag,
const Scalar* __restrict__ maindiag,
const Scalar* __restrict__ subdiag,
const Scalar* __restrict__ rhs, Scalar* __restrict__ product) {
for (int i : CudaGridRangeX(batch_size * m * n)) {
int row_id = i / n;
Scalar result = maindiag[row_id] * rhs[i];

View File

@ -40,9 +40,9 @@ static const char kNotInvertibleScalarMsg[] =
"The matrix is not invertible: it is a scalar with value zero.";
template <typename Scalar>
__global__ void SolveForSizeOneOrTwoKernel(const int m, const Scalar* diags,
const Scalar* rhs, const int num_rhs,
Scalar* x, bool* not_invertible) {
__global__ void SolveForSizeOneOrTwoKernel(const int m, const Scalar* __restrict__ diags,
const Scalar* __restrict__ rhs, const int num_rhs,
Scalar* __restrict__ x, bool* __restrict__ not_invertible) {
if (m == 1) {
if (diags[1] == Scalar(0)) {
*not_invertible = true;

View File

@ -52,7 +52,7 @@ namespace functor {
template <int NDIM, typename TIndex>
__global__ void PropagateWhereIndicesKernel(
const TIndex output_rows, const typename Eigen::array<TIndex, NDIM> strides,
int64* output) {
int64* __restrict__ output) {
// TODO(ebrevdo): Use a multi-dimensional loop, increasing the
// dimensions of individual indices manually, instead of relying on
// a scalar loop variable and using integer division.