diff --git a/tensorflow/core/kernels/adjust_hsv_gpu.cu.h b/tensorflow/core/kernels/adjust_hsv_gpu.cu.h
index 4a43dc55d6d..13b6848e49f 100644
--- a/tensorflow/core/kernels/adjust_hsv_gpu.cu.h
+++ b/tensorflow/core/kernels/adjust_hsv_gpu.cu.h
@@ -94,9 +94,9 @@ inline __device__ RgbTuple hsv2rgb_cuda(const float h, const float s,
 template <bool AdjustHue, bool AdjustSaturation, bool AdjustV, typename T>
 __global__ void adjust_hsv_nhwc(const int64 number_elements,
                                 const T* const __restrict__ input,
-                                T* const output, const float* const hue_delta,
-                                const float* const saturation_scale,
-                                const float* const value_scale) {
+                                T* const __restrict__ output, const float* const __restrict__ hue_delta,
+                                const float* const __restrict__ saturation_scale,
+                                const float* const __restrict__ value_scale) {
   // multiply by 3 since we're dealing with contiguous RGB bytes for each pixel
   // (NHWC)
   for (int64 idx = (blockDim.x * blockIdx.x + threadIdx.x) * 3;
diff --git a/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc
index 17a615fd9bb..ac142806029 100644
--- a/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc
@@ -41,13 +41,13 @@ DEFINE_GPU_KERNELS(double)
 
 template <typename dtype>
 __global__ void AvePoolBackwardNHWC(const int nthreads,
-                                    const dtype* const top_diff, const int num,
+                                    const dtype* const __restrict__ top_diff, const int num,
                                     const int height, const int width,
                                     const int channels, const int pooled_height,
                                     const int pooled_width, const int kernel_h,
                                     const int kernel_w, const int stride_h,
                                     const int stride_w, const int pad_t,
-                                    const int pad_l, dtype* const bottom_diff) {
+                                    const int pad_l, dtype* const __restrict__ bottom_diff) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     // find out the local index
     // find out the local offset
diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc
index 843aedc3e0f..06c5ce7a0da 100644
--- a/tensorflow/core/kernels/bias_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc
@@ -51,8 +51,8 @@ struct AccumulatorType<Eigen::half> {
 // Definition of the GPU implementations declared in bias_op.cc.
 
 template <typename T>
-__global__ void BiasNHWCKernel(int32 nthreads, const T* input, const T* bias,
-                               T* output, int32 bias_size) {
+__global__ void BiasNHWCKernel(int32 nthreads, const T* __restrict__ input, const T* __restrict__ bias,
+                               T* __restrict__ output, int32 bias_size) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     int32 bias_offset = index % bias_size;
     output[index] = ldg(input + index) + ldg(bias + bias_offset);
@@ -60,8 +60,8 @@ __global__ void BiasNHWCKernel(int32 nthreads, const T* input, const T* bias,
 }
 
 template <typename T>
-__global__ void BiasNCHWKernel(int32 nthreads, const T* input, const T* bias,
-                               T* output, int32 bias_size, int32 image_size) {
+__global__ void BiasNCHWKernel(int32 nthreads, const T* __restrict__ input, const T* __restrict__ bias,
+                               T* __restrict__ output, int32 bias_size, int32 image_size) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     int32 index2 = index / image_size;
     int32 bias_offset = index2 % bias_size;
@@ -97,8 +97,8 @@ void BiasGPU<T>::compute(const GPUDevice& d, const T* input, const T* bias,
 
 // A naive implementation that is functional on all cases.
 template <typename T>
-__global__ void BiasGradNHWC_Naive(int32 nthreads, const T* output_backprop,
-                                   T* bias_backprop, int32 bias_size) {
+__global__ void BiasGradNHWC_Naive(int32 nthreads, const T* __restrict__ output_backprop,
+                                   T* __restrict__ bias_backprop, int32 bias_size) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     int32 bias_offset = index % bias_size;
     GpuAtomicAdd(bias_backprop + bias_offset, ldg(output_backprop + index));
@@ -107,8 +107,8 @@ __global__ void BiasGradNHWC_Naive(int32 nthreads, const T* output_backprop,
 
 // A naive implementation that is functional on all cases.
 template <typename T>
-__global__ void BiasGradNCHW_Naive(int32 nthreads, const T* output_backprop,
-                                   T* bias_backprop, int32 bias_size,
+__global__ void BiasGradNCHW_Naive(int32 nthreads, const T* __restrict__ output_backprop,
+                                   T* __restrict__ bias_backprop, int32 bias_size,
                                    int32 image_size) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     int32 index2 = index / image_size;
@@ -120,8 +120,8 @@ __global__ void BiasGradNCHW_Naive(int32 nthreads, const T* output_backprop,
 
 template <typename T>
 __global__ void BiasGradNHWC_SharedAtomics(int32 nthreads,
-                                           const T* output_backprop,
-                                           T* bias_backprop, int32 bias_size) {
+                                           const T* __restrict__ output_backprop,
+                                           T* __restrict__ bias_backprop, int32 bias_size) {
   typedef typename AccumulatorType<T>::type AccT;
   GPU_DYNAMIC_SHARED_MEM_DECL(8, char, s_buf);
   AccT* s_data = reinterpret_cast<AccT*>(s_buf);
@@ -143,8 +143,8 @@ __global__ void BiasGradNHWC_SharedAtomics(int32 nthreads,
 }
 
 template <typename T>
-__global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
-                                           T* bias_backprop, int32 batch,
+__global__ void BiasGradNCHW_SharedAtomics(const T* __restrict__ output_backprop,
+                                           T* __restrict__ bias_backprop, int32 batch,
                                            int32 bias_size, int32 image_size,
                                            int group_size) {
   // Initialize the shared memory.
diff --git a/tensorflow/core/kernels/bucketize_op_gpu.cu.cc b/tensorflow/core/kernels/bucketize_op_gpu.cu.cc
index 91ccf357436..998a2721a93 100644
--- a/tensorflow/core/kernels/bucketize_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bucketize_op_gpu.cu.cc
@@ -34,8 +34,8 @@ typedef Eigen::GpuDevice GPUDevice;
 
 template <typename T, bool useSharedMem>
 __global__ void BucketizeCustomKernel(
-    const int32 size_in, const T* in, const int32 size_boundaries,
-    GpuDeviceArrayStruct<float> boundaries_array, int32* out) {
+    const int32 size_in, const T* __restrict__ in, const int32 size_boundaries,
+    GpuDeviceArrayStruct<float> boundaries_array, int32* __restrict__ out) {
   const float* boundaries = GetGpuDeviceArrayOnDevice(&boundaries_array);
 
   GPU_DYNAMIC_SHARED_MEM_DECL(sizeof(float), unsigned char, shared_mem);
diff --git a/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc b/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc
index a698960114c..2060b647359 100644
--- a/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc
@@ -36,7 +36,7 @@ typedef Eigen::GpuDevice GPUDevice;
 // A Cuda kernel to check if each element is Inf or Nan. If any exists, the
 // relevant elements in abnormal_detected will be set
 template <typename T>
-__global__ void CheckNumericsKernel(const T* data, int size,
+__global__ void CheckNumericsKernel(const T* __restrict__ data, int size,
                                     int abnormal_detected[2]) {
   const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
   const int32 total_thread_count = gridDim.x * blockDim.x;
diff --git a/tensorflow/core/kernels/compare_and_bitpack_op_gpu.cu.cc b/tensorflow/core/kernels/compare_and_bitpack_op_gpu.cu.cc
index 0877a853625..710df5a8119 100644
--- a/tensorflow/core/kernels/compare_and_bitpack_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/compare_and_bitpack_op_gpu.cu.cc
@@ -33,7 +33,7 @@ namespace functor {
 
 template <typename T>
 __global__ void CompareAndBitpackKernel(const int size, const T* threshold,
-                                        const T* input, uint8* output) {
+                                        const T* __restrict__ input, uint8* __restrict__ output) {
   // TODO(ebrevdo): Erich said: to get a better memory access pattern
   // you could have 8 threads load this data and do a comparison, then
   // use the ballot instruction to combine the values from each thread
@@ -55,9 +55,9 @@ __global__ void CompareAndBitpackKernel(const int size, const T* threshold,
 
 template <>
 __global__ void CompareAndBitpackKernel<bool>(const int size,
-                                              const bool* threshold,
-                                              const bool* input,
-                                              uint8* output) {
+                                              const bool* __restrict__ threshold,
+                                              const bool* __restrict__ input,
+                                              uint8* __restrict__ output) {
   // TODO(ebrevdo): Erich said: I think you could again have multiple
   // threads work on one block and use the ballot instruction to the
   // bit packing in one instruction.
@@ -77,9 +77,9 @@ __global__ void CompareAndBitpackKernel<bool>(const int size,
 
 template <>
 __global__ void CompareAndBitpackKernel<float>(const int size,
-                                               const float* threshold,
-                                               const float* input,
-                                               uint8* output) {
+                                               const float* __restrict__ threshold,
+                                               const float* __restrict__ input,
+                                               uint8* __restrict__ output) {
   const float thresh = ldg(threshold);
   GPU_1D_KERNEL_LOOP(i, size) {
     const float4 block0 = ldg(reinterpret_cast<const float4*>(input + 8 * i));
@@ -94,9 +94,9 @@ __global__ void CompareAndBitpackKernel<float>(const int size,
 
 template <>
 __global__ void CompareAndBitpackKernel<double>(const int size,
-                                                const double* threshold,
-                                                const double* input,
-                                                uint8* output) {
+                                                const double* __restrict__ threshold,
+                                                const double* __restrict__ input,
+                                                uint8* __restrict__ output) {
   const double thresh = ldg(threshold);
   GPU_1D_KERNEL_LOOP(i, size) {
     const double2 block0 = ldg(reinterpret_cast<const double2*>(input + 8 * i));
diff --git a/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc b/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc
index 3ddd071fe4a..9b2d8dc5e0e 100644
--- a/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc
+++ b/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc
@@ -35,7 +35,7 @@ namespace {
 
 template <typename T, typename IntType>
 __global__ void concat_fixed_kernel(
-    GpuDeviceArrayStruct<const T*> input_ptr_data, int split_size,
+    GpuDeviceArrayStruct<const T*> __restrict__ input_ptr_data, int split_size,
     int total_rows, int total_cols, T* output) {
   const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data);
   IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -59,7 +59,7 @@ __global__ void concat_fixed_kernel(
 // cannot be in anonymous namespace due to extern shared memory
 template <typename T, typename IntType, bool useSmem>
 __global__ void concat_variable_kernel(
-    GpuDeviceArrayStruct<const T*> input_ptr_data,
+    GpuDeviceArrayStruct<const T*> __restrict__ input_ptr_data,
     GpuDeviceArrayStruct<IntType> output_scan, IntType total_rows,
     IntType total_cols, T* output) {
   const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data);
diff --git a/tensorflow/core/kernels/conv_2d_gpu.h b/tensorflow/core/kernels/conv_2d_gpu.h
index f47f2a1527e..6c238fe74c1 100644
--- a/tensorflow/core/kernels/conv_2d_gpu.h
+++ b/tensorflow/core/kernels/conv_2d_gpu.h
@@ -182,8 +182,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex(
 // Requires that nthreads is equal to the total number of elements in the input
 // tensor.
 template <typename T, int sp0, int sp1, int sp2, bool conjugate = false>
-__global__ void ShuffleInTensor3Simple(int nthreads, const T* input,
-                                       Dimension<3> input_dims, T* output) {
+__global__ void ShuffleInTensor3Simple(int nthreads, const T* __restrict__ input,
+                                       Dimension<3> input_dims, T* __restrict__ output) {
   Dimension<3> output_dims;
   output_dims[sp0] = input_dims[0];
   output_dims[sp1] = input_dims[1];
@@ -366,8 +366,8 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(
 // A Gpu custom kernel that convert input to output, given proper padding on
 // the left and the top. The padded value is zero.
 template <typename T, int NDIMS>
-__global__ void PadInputCustomKernelNHWC(int nthreads, const T* input,
-                                         Dimension<NDIMS> input_dims, T* output,
+__global__ void PadInputCustomKernelNHWC(int nthreads, const T* __restrict__ input,
+                                         Dimension<NDIMS> input_dims, T* __restrict__ output,
                                          Dimension<NDIMS> output_dims,
                                          Dimension<NDIMS - 2> padding_left) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
@@ -395,8 +395,8 @@ __global__ void PadInputCustomKernelNHWC(int nthreads, const T* input,
 }
 
 template <typename T, int NDIMS>
-__global__ void PadInputCustomKernelNCHW(int nthreads, const T* input,
-                                         Dimension<NDIMS> input_dims, T* output,
+__global__ void PadInputCustomKernelNCHW(int nthreads, const T* __restrict__ input,
+                                         Dimension<NDIMS> input_dims, T* __restrict__ output,
                                          Dimension<NDIMS> output_dims,
                                          Dimension<NDIMS - 2> padding_left) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
diff --git a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
index 3202516be7e..7a213c356a4 100644
--- a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
@@ -38,10 +38,10 @@ enum InterpolationMethod {
 
 template <typename T>
 __global__ void CropAndResizeKernel(
-    const int32 nthreads, const T* image_ptr, const float* boxes_ptr,
-    const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
+    const int32 nthreads, const T* __restrict__ image_ptr, const float* __restrict__ boxes_ptr,
+    const int32* __restrict__ box_ind_ptr, int num_boxes, int batch, int image_height,
     int image_width, int crop_height, int crop_width, int depth, int method_id,
-    float extrapolation_value, float* crops_ptr) {
+    float extrapolation_value, float* __restrict__ crops_ptr) {
   GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
     // out_idx = d + depth * (w + crop_width * (h + crop_height * b))
     int idx = out_idx;
@@ -130,10 +130,10 @@ __global__ void CropAndResizeKernel(
 
 template <typename T>
 __global__ void CropAndResizeBackpropImageKernel(
-    const int32 nthreads, const float* grads_ptr, const float* boxes_ptr,
-    const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
+    const int32 nthreads, const float* __restrict__ grads_ptr, const float* __restrict__ boxes_ptr,
+    const int32* __restrict__ box_ind_ptr, int num_boxes, int batch, int image_height,
     int image_width, int crop_height, int crop_width, int depth,
-    T* grads_image_ptr, int method_id) {
+    T* __restrict__ grads_image_ptr, int method_id) {
   GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
     // out_idx = d + depth * (w + crop_width * (h + crop_height * b))
     int idx = out_idx;
@@ -225,10 +225,10 @@ __global__ void CropAndResizeBackpropImageKernel(
 
 template <typename T>
 __global__ void CropAndResizeBackpropBoxesKernel(
-    const int32 nthreads, const float* grads_ptr, const T* image_ptr,
-    const float* boxes_ptr, const int32* box_ind_ptr, int num_boxes, int batch,
+    const int32 nthreads, const float* __restrict__ grads_ptr, const T* __restrict__ image_ptr,
+    const float* __restrict__ boxes_ptr, const int32* __restrict__ box_ind_ptr, int num_boxes, int batch,
     int image_height, int image_width, int crop_height, int crop_width,
-    int depth, float* grads_boxes_ptr) {
+    int depth, float* __restrict__ grads_boxes_ptr) {
   GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
     // out_idx = d + depth * (w + crop_width * (h + crop_height * b))
     int idx = out_idx;
diff --git a/tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc b/tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc
index 55d3033483f..dba3a87dcc0 100644
--- a/tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc
@@ -24,8 +24,8 @@ limitations under the License.
 namespace tensorflow {
 
 template <typename T>
-__global__ void UnaryClipCustomKernel(const int32 size_in, const T *in0,
-                                      const T *in1, const T *in2, T *out) {
+__global__ void UnaryClipCustomKernel(const int32 size_in, const T* __restrict__ in0,
+                                      const T* __restrict__ in1, const T* __restrict__ in2, T* __restrict__ out) {
   GPU_1D_KERNEL_LOOP(i, size_in) {
     T value = in2[0] < in0[i] ? in2[0] : in0[i];
     out[i] = value < in1[0] ? in1[0] : value;
@@ -33,9 +33,9 @@ __global__ void UnaryClipCustomKernel(const int32 size_in, const T *in0,
 }
 
 template <typename T>
-__global__ void BinaryRightClipCustomKernel(const int32 size_in, const T *in0,
-                                            const T *in1, const T *in2,
-                                            T *out) {
+__global__ void BinaryRightClipCustomKernel(const int32 size_in, const T* __restrict__ in0,
+                                            const T* __restrict__ in1, const T* __restrict__ in2,
+                                            T* __restrict__ out) {
   GPU_1D_KERNEL_LOOP(i, size_in) {
     T value = in2[i] < in0[i] ? in2[i] : in0[i];
     out[i] = value < in1[0] ? in1[0] : value;
@@ -43,8 +43,8 @@ __global__ void BinaryRightClipCustomKernel(const int32 size_in, const T *in0,
 }
 
 template <typename T>
-__global__ void BinaryLeftClipCustomKernel(const int32 size_in, const T *in0,
-                                           const T *in1, const T *in2, T *out) {
+__global__ void BinaryLeftClipCustomKernel(const int32 size_in, const T* __restrict__ in0,
+                                           const T* __restrict__ in1, const T* __restrict__ in2, T* __restrict__ out) {
   GPU_1D_KERNEL_LOOP(i, size_in) {
     T value = in2[0] < in0[i] ? in2[0] : in0[i];
     out[i] = value < in1[i] ? in1[i] : value;
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.h b/tensorflow/core/kernels/depthwise_conv_op_gpu.h
index 73606a80273..6ec39eece5e 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.h
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.h
@@ -83,8 +83,8 @@ enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD };
 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
           int kKnownDepthMultiplier>
 __global__ void __launch_bounds__(1024, 2)
-    DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args, const T* input,
-                                 const T* filter, T* output, int num_outputs) {
+    DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args, const T* __restrict__ input,
+                                 const T* __restrict__ filter, T* __restrict__ output, int num_outputs) {
   typedef typename detail::PseudoHalfType<T>::Type S;
   const int in_height = args.in_rows;
   const int in_width = args.in_cols;
@@ -187,7 +187,7 @@ template <typename T, DepthwiseConv2dDirection kDirection,
           int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth,
           bool kKnownEvenHeight>
 __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
-    const DepthwiseArgs args, const T* input, const T* filter, T* output) {
+    const DepthwiseArgs args, const T* __restrict__ input, const T* __restrict__ filter, T* __restrict__ output) {
   typedef typename detail::PseudoHalfType<T>::Type S;
   assert(CanLaunchDepthwiseConv2dGPUSmall(args));
   // Holds block plus halo and filter data for blockDim.x depths.
@@ -327,8 +327,8 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
           int kKnownDepthMultiplier>
 __global__ void __launch_bounds__(1024, 2)
-    DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args, const T* input,
-                                 const T* filter, T* output, int num_outputs) {
+    DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args, const T* __restrict__ input,
+                                 const T* __restrict__ filter, T* __restrict__ output, int num_outputs) {
   typedef typename detail::PseudoHalfType<T>::Type S;
   const int in_height = args.in_rows;
   const int in_width = args.in_cols;
@@ -475,7 +475,7 @@ template <typename T, DepthwiseConv2dDirection kDirection,
           int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth,
           bool kKnownEvenHeight>
 __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
-    const DepthwiseArgs args, const T* input, const T* filter, T* output) {
+    const DepthwiseArgs args, const T* __restrict__ input, const T* __restrict__ filter, T* __restrict__ output) {
   typedef typename detail::PseudoHalfType<T>::Type S;
   assert(CanLaunchDepthwiseConv2dGPUSmall(args));
   // Holds block plus halo and filter data for blockDim.z depths.
@@ -799,8 +799,8 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
           int kKnownDepthMultiplier>
 __global__ void __launch_bounds__(640, 2)
     DepthwiseConv2dBackpropInputGPUKernelNHWC(const DepthwiseArgs args,
-                                              const T* out_backprop,
-                                              const T* filter, T* in_backprop,
+                                              const T* __restrict__ out_backprop,
+                                              const T* __restrict__ filter, T* __restrict__ in_backprop,
                                               int num_in_backprop) {
   const int in_height = args.in_rows;
   const int in_width = args.in_cols;
@@ -869,8 +869,8 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
           int kKnownDepthMultiplier>
 __global__ void __launch_bounds__(640, 2)
     DepthwiseConv2dBackpropInputGPUKernelNCHW(const DepthwiseArgs args,
-                                              const T* out_backprop,
-                                              const T* filter, T* in_backprop,
+                                              const T* __restrict__ out_backprop,
+                                              const T* __restrict__ filter, T* __restrict__ in_backprop,
                                               int num_in_backprop) {
   const int in_height = args.in_rows;
   const int in_width = args.in_cols;
@@ -1020,9 +1020,9 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
           int kKnownDepthMultiplier>
 __global__ void __launch_bounds__(640, 2)
     DepthwiseConv2dBackpropFilterGPUKernelNHWC(const DepthwiseArgs args,
-                                               const T* out_backprop,
-                                               const T* input,
-                                               T* filter_backprop,
+                                               const T* __restrict__ out_backprop,
+                                               const T* __restrict__ input,
+                                               T* __restrict__ filter_backprop,
                                                int num_out_backprop) {
   const int in_height = args.in_rows;
   const int in_width = args.in_cols;
@@ -1153,7 +1153,7 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
           int kBlockDepth, int kAccumPixels>
 __global__
 __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
-    const DepthwiseArgs args, const T* output, const T* input, T* filter) {
+    const DepthwiseArgs args, const T* __restrict__ output, const T* __restrict__ input, T* __restrict__ filter) {
   typedef typename detail::PseudoHalfType<T>::Type S;
   assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.z));
   // Holds block plus halo and filter data for blockDim.x depths.
@@ -1305,9 +1305,9 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
           int kKnownDepthMultiplier>
 __global__ void __launch_bounds__(640, 2)
     DepthwiseConv2dBackpropFilterGPUKernelNCHW(const DepthwiseArgs args,
-                                               const T* out_backprop,
-                                               const T* input,
-                                               T* filter_backprop,
+                                               const T* __restrict__ out_backprop,
+                                               const T* __restrict__ input,
+                                               T* __restrict__ filter_backprop,
                                                int num_out_backprop) {
   const int in_height = args.in_rows;
   const int in_width = args.in_cols;
@@ -1426,7 +1426,7 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
           int kBlockDepth, int kAccumPixels>
 __global__
 __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
-    const DepthwiseArgs args, const T* output, const T* input, T* filter) {
+    const DepthwiseArgs args, const T* __restrict__ output, const T* __restrict__ input, T* __restrict__ filter) {
   typedef typename detail::PseudoHalfType<T>::Type S;
   assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.x));
   // Holds block plus halo and filter data for blockDim.z depths.
diff --git a/tensorflow/core/kernels/determinant_op_gpu.cu.cc b/tensorflow/core/kernels/determinant_op_gpu.cu.cc
index 23081c9650f..2e07c020085 100644
--- a/tensorflow/core/kernels/determinant_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/determinant_op_gpu.cu.cc
@@ -84,10 +84,10 @@ __device__ inline complex128 operator/(const complex128& a, const double& b) {
 // the sign argument is ignored.
 template <typename Scalar, bool compute_log_abs_det = true>
 __global__ void DeterminantFromPivotedLUKernel(int nthreads, int n,
-                                               const Scalar* lu_factor,
-                                               const int* all_pivots,
-                                               Scalar* sign,
-                                               Scalar* log_abs_det) {
+                                               const Scalar* __restrict__ lu_factor,
+                                               const int* __restrict__ all_pivots,
+                                               Scalar* __restrict__ sign,
+                                               Scalar* __restrict__ log_abs_det) {
   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
   const int matrix_size = n * n;
   const int stride = n + 1;
diff --git a/tensorflow/core/kernels/diag_op_gpu.cu.cc b/tensorflow/core/kernels/diag_op_gpu.cu.cc
index d716d9cd45c..2e505a84c99 100644
--- a/tensorflow/core/kernels/diag_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/diag_op_gpu.cu.cc
@@ -30,7 +30,7 @@ typedef Eigen::GpuDevice GPUDevice;
 
 template <typename T>
 __global__ void DiagGpuKernel(const int num_threads, const int64 size,
-                              const T* in, T* out) {
+                              const T* __restrict__ in, T* __restrict__ out) {
   GPU_1D_KERNEL_LOOP(index, num_threads) {
     // Fill the diagonal elements or set to zero in other place.
     if (index % (1 + size) == 0) {
diff --git a/tensorflow/core/kernels/dilation_ops_gpu.cu.cc b/tensorflow/core/kernels/dilation_ops_gpu.cu.cc
index 5cb88243abe..6cd5724f0a1 100644
--- a/tensorflow/core/kernels/dilation_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/dilation_ops_gpu.cu.cc
@@ -35,13 +35,13 @@ typedef Eigen::GpuDevice GPUDevice;
 namespace {
 
 template <typename T>
-__global__ void DilationKernel(const int32 nthreads, const T* input_ptr,
-                               const T* filter_ptr, int batch, int input_rows,
+__global__ void DilationKernel(const int32 nthreads, const T* __restrict__ input_ptr,
+                               const T* __restrict__ filter_ptr, int batch, int input_rows,
                                int input_cols, int depth, int filter_rows,
                                int filter_cols, int output_rows,
                                int output_cols, int stride_rows,
                                int stride_cols, int rate_rows, int rate_cols,
-                               int pad_top, int pad_left, T* output_ptr) {
+                               int pad_top, int pad_left, T* __restrict__ output_ptr) {
   GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
     // out_idx = d + depth * (w_out + output_cols * (h_out + output_rows * b))
     const int d = out_idx % depth;
@@ -76,11 +76,11 @@ __global__ void DilationKernel(const int32 nthreads, const T* input_ptr,
 
 template <typename T>
 __global__ void DilationBackpropInputKernel(
-    const int32 nthreads, const T* input_ptr, const T* filter_ptr,
-    const T* out_backprop_ptr, int batch, int input_rows, int input_cols,
+    const int32 nthreads, const T* __restrict__ input_ptr, const T* __restrict__ filter_ptr,
+    const T* __restrict__ out_backprop_ptr, int batch, int input_rows, int input_cols,
     int depth, int filter_rows, int filter_cols, int output_rows,
     int output_cols, int stride_rows, int stride_cols, int rate_rows,
-    int rate_cols, int pad_top, int pad_left, T* in_backprop_ptr) {
+    int rate_cols, int pad_top, int pad_left, T* __restrict__ in_backprop_ptr) {
   GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
     // out_idx = d + depth * (w_out + output_cols * (h_out + output_rows * b))
     const int d = out_idx % depth;
@@ -125,11 +125,11 @@ __global__ void DilationBackpropInputKernel(
 
 template <typename T>
 __global__ void DilationBackpropFilterKernel(
-    const int32 nthreads, const T* input_ptr, const T* filter_ptr,
-    const T* out_backprop_ptr, int batch, int input_rows, int input_cols,
+    const int32 nthreads, const T* __restrict__ input_ptr, const T* __restrict__ filter_ptr,
+    const T* __restrict__ out_backprop_ptr, int batch, int input_rows, int input_cols,
     int depth, int filter_rows, int filter_cols, int output_rows,
     int output_cols, int stride_rows, int stride_cols, int rate_rows,
-    int rate_cols, int pad_top, int pad_left, T* filter_backprop_ptr) {
+    int rate_cols, int pad_top, int pad_left, T* __restrict__ filter_backprop_ptr) {
   GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
     // out_idx = d + depth * (w_out + output_cols * (h_out + output_rows * b))
     const int d = out_idx % depth;
diff --git a/tensorflow/core/kernels/eye_functor_gpu.cu.cc b/tensorflow/core/kernels/eye_functor_gpu.cu.cc
index 5756299d02d..90df538dd2c 100644
--- a/tensorflow/core/kernels/eye_functor_gpu.cu.cc
+++ b/tensorflow/core/kernels/eye_functor_gpu.cu.cc
@@ -30,7 +30,7 @@ typedef Eigen::GpuDevice GPUDevice;
 
 template <typename Scalar>
 __global__ void EyeKernel(int num_threads, int batch_size, int m, int n,
-                          Scalar* output_ptr) {
+                          Scalar* __restrict__ output_ptr) {
   const Scalar one = Scalar(1);
   const Scalar zero = Scalar(0);
   GPU_1D_KERNEL_LOOP(index, num_threads) {
diff --git a/tensorflow/core/kernels/gather_functor_gpu.cu.h b/tensorflow/core/kernels/gather_functor_gpu.cu.h
index 6507846e3f0..7549a923d1e 100644
--- a/tensorflow/core/kernels/gather_functor_gpu.cu.h
+++ b/tensorflow/core/kernels/gather_functor_gpu.cu.h
@@ -30,7 +30,7 @@ namespace tensorflow {
 typedef Eigen::GpuDevice GPUDevice;
 
 template <typename T, typename Index, bool is_axis_zero>
-__global__ void GatherOpKernel(const T* params, const Index* indices, T* out,
+__global__ void GatherOpKernel(const T* __restrict__ params, const Index* __restrict__ indices, T* __restrict__ out,
                                int64 gather_dim_size, int64 indices_size,
                                int64 slice_size, int64 out_size) {
   GPU_1D_KERNEL_LOOP(i, out_size) {
diff --git a/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc b/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc
index f746105e00f..9cc8548128d 100644
--- a/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc
@@ -28,7 +28,7 @@ typedef Eigen::GpuDevice GPUDevice;
 
 template <typename T, typename Index, int IXDIM>
 __global__ void GatherSliceOpKernel(
-    const T* params, const Index* indices, T* out,
+    const T* __restrict__ params, const Index* __restrict__ indices, T* __restrict__ out,
     const Eigen::array<int64, IXDIM> batch_strides,
     const Eigen::array<int64, IXDIM> batch_indices, const int64 indices_size,
     const int64 slice_size, const int64 out_size) {
diff --git a/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc b/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc
index dbcb0ad8a70..2ea41981d30 100644
--- a/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc
+++ b/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc
@@ -29,7 +29,7 @@ typedef Eigen::GpuDevice Device;
 template <typename T>
 __global__ void DoParallelConcatOpKernel(int nthreads, const int64 rows,
                                          const int64 cols, int32 loc,
-                                         const T* src, T* dst) {
+                                         const T* __restrict__ src, T* __restrict__ dst) {
   GPU_1D_KERNEL_LOOP(idx, nthreads) {
     int64 c = idx % cols;
     int64 r = (loc % rows + rows) % rows;  // Guard index range.
@@ -80,8 +80,8 @@ Status DoParallelConcat(const Device& d, const Tensor& value, int32 loc,
 
 template <typename T, InplaceOpType op>
 __global__ void DoInplaceOpKernel(int nthreads, const int64 rows,
-                                  const int64 cols, const int64 n, const T* src,
-                                  const int32* rowids, T* dst) {
+                                  const int64 cols, const int64 n, const T* __restrict__ src,
+                                  const int32* __restrict__ rowids, T* __restrict__ dst) {
   GPU_1D_KERNEL_LOOP(idx, nthreads) {
     int64 r = idx / cols;
     int64 c = idx % cols;
diff --git a/tensorflow/core/kernels/lu_op_gpu.cu.cc b/tensorflow/core/kernels/lu_op_gpu.cu.cc
index 8ae38013341..a2b8b9ab894 100644
--- a/tensorflow/core/kernels/lu_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/lu_op_gpu.cu.cc
@@ -61,8 +61,8 @@ __device__ void ComputePermutationFromTranspositions(
 // transpositions.
 template <typename Scalar>
 __global__ void ComputePermutationFromTranspositionsKernel(
-    GpuLaunchConfig config, const int64 num_rows, const int* all_pivots,
-    Scalar* all_permutation_indices) {
+    GpuLaunchConfig config, const int64 num_rows, const int* __restrict__ all_pivots,
+    Scalar* __restrict__ all_permutation_indices) {
   // We only parallelize over batches here. Performance is not critical,
   // since this cheap O(num_rows) kernel always follows an O(num_rows^3)
   // LU factorization.
diff --git a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
index 5c90963ea72..4a94c51e878 100644
--- a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
@@ -33,8 +33,8 @@ __global__ void MatrixBandPartKernel(const int num_threads,
                                      const int batch_size, const int m,
                                      const int n, const int num_lower_diags,
                                      const int num_upper_diags,
-                                     const Scalar* input_ptr,
-                                     Scalar* output_ptr) {
+                                     const Scalar* __restrict__ input_ptr,
+                                     Scalar* __restrict__ output_ptr) {
   GPU_1D_KERNEL_LOOP(index, num_threads) {
     const int col = index % n;
     const int row = (index / n) % m;
diff --git a/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc
index d26978fe49c..f6aee377e06 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc
@@ -31,8 +31,8 @@ __global__ void MatrixSetDiagKernel(const int num_threads, const int m,
                                     const int n, const int num_diags,
                                     const int max_diag_len,
                                     const int upper_diag_index,
-                                    const Scalar* diag_ptr,
-                                    Scalar* output_ptr) {
+                                    const Scalar* __restrict__ diag_ptr,
+                                    Scalar* __restrict__ output_ptr) {
   GPU_1D_KERNEL_LOOP(index, num_threads) {
     const int batch_and_diag_index = index / max_diag_len;
     const int index_in_the_diagonal =
@@ -56,8 +56,8 @@ template <typename Scalar>
 __global__ void MatrixCopyInputAndSetDiagKernel(
     const int num_threads, const int m, const int n, const int num_diags,
     const int max_diag_len, const int lower_diag_index,
-    const int upper_diag_index, const Scalar* input_ptr, const Scalar* diag_ptr,
-    Scalar* output_ptr) {
+    const int upper_diag_index, const Scalar* __restrict__ input_ptr, const Scalar* __restrict__ diag_ptr,
+    Scalar* __restrict__ output_ptr) {
   GPU_1D_KERNEL_LOOP(index, num_threads) {
     const int batch_and_row_index = index / n;
     const int col = index - batch_and_row_index * n;
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
index f3b2e516be2..ecedc62704d 100644
--- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
@@ -65,11 +65,11 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool IsGreaterThan(dtype a, dtype b) {
 //                      kThreadsPerBlock, 0, cuda_stream>>>(...);
 template <bool propagate_nans, typename dtype>
 __global__ void MaxPoolForwardNCHW(
-    const int nthreads, const dtype* bottom_data, const int channels,
+    const int nthreads, const dtype* __restrict__ bottom_data, const int channels,
     const int height, const int width, const int pooled_height,
     const int pooled_width, const int kernel_h, const int kernel_w,
     const int stride_h, const int stride_w, const int pad_t, const int pad_l,
-    dtype* top_data, int64* mask, const bool include_batch_in_index) {
+    dtype* __restrict__ top_data, int64* __restrict__ mask, const bool include_batch_in_index) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     int pw = index % pooled_width;
     int ph = (index / pooled_width) % pooled_height;
@@ -108,11 +108,11 @@ __global__ void MaxPoolForwardNCHW(
 // the same X, y coordinate.
 // (so channels = outer_channels, output_size = real output size / 4).
 __global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C(
-    const int nthreads, const int32* bottom_data, const int height,
+    const int nthreads, const int32* __restrict__ bottom_data, const int height,
     const int width, const int channels, const int pooled_height,
     const int pooled_width, const int kernel_h, const int kernel_w,
     const int stride_h, const int stride_w, const int pad_t, const int pad_l,
-    int32* top_data) {
+    int32* __restrict__ top_data) {
   // TODO(pauldonnelly): Implement a better optimized version of this kernel.
   const int32 kMinINT8X4 = 0x80808080;
   GPU_1D_KERNEL_LOOP(index, nthreads) {
@@ -141,11 +141,11 @@ __global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C(
 
 template <bool propagate_nans, typename dtype>
 __global__ void MaxPoolForwardNHWC(
-    const int nthreads, const dtype* bottom_data, const int height,
+    const int nthreads, const dtype* __restrict__ bottom_data, const int height,
     const int width, const int channels, const int pooled_height,
     const int pooled_width, const int kernel_h, const int kernel_w,
     const int stride_h, const int stride_w, const int pad_t, const int pad_l,
-    dtype* top_data, int64* mask, const bool include_batch_in_index) {
+    dtype* __restrict__ top_data, int64* __restrict__ mask, const bool include_batch_in_index) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     int n = index;
     int c = n % channels;
@@ -180,11 +180,11 @@ __global__ void MaxPoolForwardNHWC(
 
 template <typename dtype>
 __global__ void MaxPoolBackwardNoMaskNHWC(
-    const int nthreads, const dtype* bottom_data, const int height,
+    const int nthreads, const dtype* __restrict__ bottom_data, const int height,
     const int width, const int channels, const int pooled_height,
     const int pooled_width, const int kernel_h, const int kernel_w,
     const int stride_h, const int stride_w, const int pad_t, const int pad_l,
-    const dtype* top_diff, dtype* bottom_diff) {
+    const dtype* __restrict__ top_diff, dtype* __restrict__ bottom_diff) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     // First find out the index to the maximum, since we have no mask.
     int n = index;
@@ -240,9 +240,9 @@ __global__ void MaxPoolBackwardNoMaskNHWC(
 // the kernel is run, you will need to make sure that bottom_diff is filled with
 // zero first.
 template <typename dtype>
-__global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff,
-                                const int64* mask, const int top_offset,
-                                const int bottom_offset, dtype* bottom_diff,
+__global__ void MaxPoolBackward(const int nthreads, const dtype* __restrict__ top_diff,
+                                const int64* __restrict__ mask, const int top_offset,
+                                const int bottom_offset, dtype* __restrict__ bottom_diff,
                                 const bool include_batch_in_index) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     const int offset =
@@ -267,11 +267,11 @@ __global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff,
 //     bottom_diff: the gradient of the gradient w.r.t. output.
 template <typename dtype>
 __global__ void MaxPoolGradBackwardNoMaskNCHW(
-    const int nthreads, const dtype* bottom_data, const dtype* output_data,
+    const int nthreads, const dtype* __restrict__ bottom_data, const dtype* __restrict__ output_data,
     const int pooled_height, const int pooled_width, const int channels,
     const int height, const int width, const int kernel_h, const int kernel_w,
     const int stride_h, const int stride_w, const int pad_t, const int pad_l,
-    const dtype* top_diff, dtype* bottom_diff) {
+    const dtype* __restrict__ top_diff, dtype* __restrict__ bottom_diff) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     // First find out the index to the maximum, since we have no mask.
     int pw = index % pooled_width;
@@ -307,11 +307,11 @@ __global__ void MaxPoolGradBackwardNoMaskNCHW(
 
 template <typename dtype>
 __global__ void MaxPoolGradBackwardNoMaskNHWC(
-    const int nthreads, const dtype* bottom_data, const dtype* output_data,
+    const int nthreads, const dtype* __restrict__ bottom_data, const dtype* __restrict__ output_data,
     const int pooled_height, const int pooled_width, const int channels,
     const int height, const int width, const int kernel_h, const int kernel_w,
     const int stride_h, const int stride_w, const int pad_t, const int pad_l,
-    const dtype* top_diff, dtype* bottom_diff) {
+    const dtype* __restrict__ top_diff, dtype* __restrict__ bottom_diff) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     // First find out the index to the maximum, since we have no mask.
     int n = index;
@@ -367,9 +367,9 @@ __global__ void MaxPoolGradBackwardNoMaskNHWC(
 //     include_batch_in_index: whether to include batch dimension in flattened
 //         index of `argmax`.
 template <typename dtype>
-__global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff,
-                                    const int64* mask, const int top_offset,
-                                    const int bottom_offset, dtype* bottom_diff,
+__global__ void MaxPoolGradBackward(const int nthreads, const dtype* __restrict__ top_diff,
+                                    const int64* __restrict__ mask, const int top_offset,
+                                    const int bottom_offset, dtype* __restrict__ bottom_diff,
                                     const bool include_batch_in_index) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     const int offset =
diff --git a/tensorflow/core/kernels/multinomial_op_gpu.cu.cc b/tensorflow/core/kernels/multinomial_op_gpu.cu.cc
index 97603833182..7e11e078b75 100644
--- a/tensorflow/core/kernels/multinomial_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/multinomial_op_gpu.cu.cc
@@ -45,8 +45,8 @@ using GPUDevice = Eigen::GpuDevice;
 //   scores: [B, S, C];  maxima: [B, S];  output: [B, S].
 template <typename OutputType>
 __global__ void MultinomialKernel(int32 nthreads, const int32 num_classes,
-                                  const int32 num_samples, const float* scores,
-                                  const float* maxima, OutputType* output) {
+                                  const int32 num_samples, const float* __restrict__ scores,
+                                  const float* __restrict__ maxima, OutputType* __restrict__ output) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     const int maxima_idx = index / num_classes;
     if (ldg(maxima + maxima_idx) == ldg(scores + index)) {
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cu.cc b/tensorflow/core/kernels/non_max_suppression_op.cu.cc
index 9b8526a75c3..a7a169c5f5c 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cu.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cu.cc
@@ -130,9 +130,9 @@ __device__ EIGEN_STRONG_INLINE void Flipped<true>(Box& box) {
 // x1<x2 and y1<y2.
 template <bool flip_box>
 __launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__
-    void NMSKernel(const Box* d_desc_sorted_boxes, const int num_boxes,
+    void NMSKernel(const Box* __restrict__ d_desc_sorted_boxes, const int num_boxes,
                    const float iou_threshold, const int bit_mask_len,
-                   int* d_delete_mask) {
+                   int* __restrict__ d_delete_mask) {
   // Storing boxes used by this CUDA block in the shared memory.
   __shared__ Box shared_i_boxes[kNmsBlockDim];
   // Same thing with areas
@@ -208,15 +208,15 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
 // IndexMultiSelect(num_elements, indices, original1 ,selected1, original2,
 // selected2).
 template <typename Index, typename T, typename... Args>
-__global__ void IndexMultiSelect(const int num_elements, const Index* indices,
-                                 const T* original, T* selected, Args... args) {
+__global__ void IndexMultiSelect(const int num_elements, const Index* __restrict__ indices,
+                                 const T* __restrict__ original, T* __restrict__ selected, Args... args) {
   for (const int idx : CudaGridRangeX(num_elements)) {
     SelectHelper(idx, indices[idx], original, selected, args...);
   }
 }
 
 template <typename T>
-__global__ void Iota(const int num_elements, const T offset, T* to_fill) {
+__global__ void Iota(const int num_elements, const T offset, T* __restrict__ to_fill) {
   for (int idx : CudaGridRangeX(num_elements)) {
     to_fill[idx] = static_cast<T>(idx) + offset;
   }
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
index 5c26126371f..5a05c7de147 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
@@ -53,9 +53,9 @@ template <typename T>
 __global__ void __launch_bounds__(1024)
     TruncatedNormalKernel(random::PhiloxRandom gen, T* data, int64 num_batches,
                           int64 samples_per_batch, int64 num_elements,
-                          const T* means, bool single_mean, const T* stddevs,
-                          bool single_stddev, const T* minvals,
-                          bool single_minval, const T* maxvals,
+                          const T* __restrict__ means, bool single_mean, const T* stddevs,
+                          bool single_stddev, const T* __restrict__ minvals,
+                          bool single_minval, const T* __restrict__ maxvals,
                           bool single_maxval, int64 kMaxIterations) {
   const int32 max_samples_per_item = 2 * kMaxIterations;
   // Initial offset as given by GPU_1D_KERNEL_LOOP.
diff --git a/tensorflow/core/kernels/pooling_ops_3d_gpu.cu.cc b/tensorflow/core/kernels/pooling_ops_3d_gpu.cu.cc
index a38e7de1d2d..36653cdb7e1 100644
--- a/tensorflow/core/kernels/pooling_ops_3d_gpu.cu.cc
+++ b/tensorflow/core/kernels/pooling_ops_3d_gpu.cu.cc
@@ -28,13 +28,13 @@ namespace {
 
 template <typename dtype>
 __global__ void MaxPoolGradBackwardNoMaskNCDHW(
-    const int nthreads, const dtype* bottom_data, const dtype* output_data,
+    const int nthreads, const dtype* __restrict__ bottom_data, const dtype* __restrict__ output_data,
     const int pooled_plane, const int pooled_height, const int pooled_width,
     const int channels, const int plane, const int height, const int width,
     const int kernel_p, const int kernel_h, const int kernel_w,
     const int stride_p, const int stride_h, const int stride_w, const int pad_p,
-    const int pad_t, const int pad_l, const dtype* top_diff,
-    dtype* bottom_diff) {
+    const int pad_t, const int pad_l, const dtype* __restrict__ top_diff,
+    dtype* __restrict__ bottom_diff) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     // First find out the index to the maximum, since we have no mask.
     int pw = index % pooled_width;
@@ -78,13 +78,13 @@ __global__ void MaxPoolGradBackwardNoMaskNCDHW(
 
 template <typename dtype>
 __global__ void MaxPoolGradBackwardNoMaskNDHWC(
-    const int nthreads, const dtype* bottom_data, const dtype* output_data,
+    const int nthreads, const dtype* __restrict__ bottom_data, const dtype* __restrict__ output_data,
     const int pooled_plane, const int pooled_height, const int pooled_width,
     const int channels, const int plane, const int height, const int width,
     const int kernel_p, const int kernel_h, const int kernel_w,
     const int stride_p, const int stride_h, const int stride_w, const int pad_p,
-    const int pad_t, const int pad_l, const dtype* top_diff,
-    dtype* bottom_diff) {
+    const int pad_t, const int pad_l, const dtype* __restrict__ top_diff,
+    dtype* __restrict__ bottom_diff) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     // First find out the index to the maximum, since we have no mask.
     int n = index;
diff --git a/tensorflow/core/kernels/population_count_op_gpu.cu.cc b/tensorflow/core/kernels/population_count_op_gpu.cu.cc
index c04975fa0e7..2eed370d991 100644
--- a/tensorflow/core/kernels/population_count_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/population_count_op_gpu.cu.cc
@@ -33,14 +33,14 @@ typedef Eigen::GpuDevice GPUDevice;
 namespace functor {
 
 template <typename T>
-__global__ void PopulationCountKernel(const int size, const T* input,
-                                      uint8* output) {
+__global__ void PopulationCountKernel(const int size, const T* __restrict__ input,
+                                      uint8* __restrict__ output) {
   GPU_1D_KERNEL_LOOP(i, size) { output[i] = __popc(ldg(input + i)); }
 }
 
 template <>
-__global__ void PopulationCountKernel(const int size, const int8* input,
-                                      uint8* output) {
+__global__ void PopulationCountKernel(const int size, const int8* __restrict__ input,
+                                      uint8* __restrict__ output) {
   // For some reason, __popc on a negative int8 gets confused.
   GPU_1D_KERNEL_LOOP(i, size) {
     output[i] = __popc(ldg(reinterpret_cast<const uint8*>(input + i)));
@@ -48,8 +48,8 @@ __global__ void PopulationCountKernel(const int size, const int8* input,
 }
 
 template <>
-__global__ void PopulationCountKernel(const int size, const int16* input,
-                                      uint8* output) {
+__global__ void PopulationCountKernel(const int size, const int16* __restrict__ input,
+                                      uint8* __restrict__ output) {
   // For some reason, __popc on a negative int16 gets confused.
   GPU_1D_KERNEL_LOOP(i, size) {
     output[i] = __popc(ldg(reinterpret_cast<const uint16*>(input + i)));
@@ -57,8 +57,8 @@ __global__ void PopulationCountKernel(const int size, const int16* input,
 }
 
 template <>
-__global__ void PopulationCountKernel<int64>(const int size, const int64* input,
-                                             uint8* output) {
+__global__ void PopulationCountKernel<int64>(const int size, const int64* __restrict__ input,
+                                             uint8* __restrict__ output) {
   GPU_1D_KERNEL_LOOP(i, size) { output[i] = __popcll(ldg(input + i)); }
 }
 
diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc
index 385565b28d4..7cacafacb31 100644
--- a/tensorflow/core/kernels/relu_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc
@@ -34,9 +34,9 @@ namespace functor {
 // This kernel computes ReluGrad by processing one half2, two fp16, at a time.
 // It effectively does: backdrops = (feature > 0) ? gradient : 0
 // It also tries to use native half2 primitives as much as possible.
-__global__ void ReluGradHalfKernel(const Eigen::half* gradient,
-                                   const Eigen::half* feature,
-                                   Eigen::half* backprop, int32 count) {
+__global__ void ReluGradHalfKernel(const Eigen::half* __restrict__ gradient,
+                                   const Eigen::half* __restrict__ feature,
+                                   Eigen::half* __restrict__ backprop, int32 count) {
   int32 half2_count = count >> 1;
   int32 index = blockIdx.x * blockDim.x + threadIdx.x;
   const int32 total_device_threads = gridDim.x * blockDim.x;
@@ -112,8 +112,8 @@ struct ReluGrad<Device, Eigen::half> {
   }
 };
 
-__global__ void Relu_int8x4_kernel(int vect_count, const int32* input,
-                                   int32* output) {
+__global__ void Relu_int8x4_kernel(int vect_count, const int32* __restrict__ input,
+                                   int32* __restrict__ output) {
   CUDA_1D_KERNEL_LOOP(index, vect_count) {
     output[index] = __vmaxs4(input[index], 0);
   }
diff --git a/tensorflow/core/kernels/resize_bilinear_op_gpu.cu.cc b/tensorflow/core/kernels/resize_bilinear_op_gpu.cu.cc
index d2afe4a89ab..0baa7a05b47 100644
--- a/tensorflow/core/kernels/resize_bilinear_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/resize_bilinear_op_gpu.cu.cc
@@ -113,11 +113,11 @@ __global__ void ResizeBilinearKernel_faster(
 }
 
 template <typename T>
-__global__ void ResizeBilinearKernel(const int32 nthreads, const T* images,
+__global__ void ResizeBilinearKernel(const int32 nthreads, const T* __restrict__ images,
                                      float height_scale, float width_scale,
                                      int batch, int in_height, int in_width,
                                      int channels, int out_height,
-                                     int out_width, float* output) {
+                                     int out_width, float* __restrict__ output) {
   GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
     // out_idx = c + channels * (x + out_width * (y + out_height * b))
     int idx = out_idx;
@@ -166,9 +166,9 @@ __global__ void ResizeBilinearKernel(const int32 nthreads, const T* images,
 
 template <typename T>
 __global__ void ResizeBilinearGradKernel(
-    const int32 nthreads, const float* input_grad, float height_scale,
+    const int32 nthreads, const float* __restrict__ input_grad, float height_scale,
     float width_scale, int batch, int original_height, int original_width,
-    int channels, int resized_height, int resized_width, T* output_grad) {
+    int channels, int resized_height, int resized_width, T* __restrict__ output_grad) {
   GPU_1D_KERNEL_LOOP(in_idx, nthreads) {
     // in_idx = c + channels * (x + resized_width * (y + resized_height * b))
     int idx = in_idx;
@@ -228,11 +228,11 @@ __global__ void ResizeBilinearGradKernel(
 
 template <typename T>
 __global__ void LegacyResizeBilinearKernel(const int32 nthreads,
-                                           const T* images, float height_scale,
+                                           const T* __restrict__ images, float height_scale,
                                            float width_scale, int batch,
                                            int in_height, int in_width,
                                            int channels, int out_height,
-                                           int out_width, float* output) {
+                                           int out_width, float* __restrict__ output) {
   GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
     // out_idx = c + channels * (x + out_width * (y + out_height * b))
     int idx = out_idx;
@@ -280,9 +280,9 @@ __global__ void LegacyResizeBilinearKernel(const int32 nthreads,
 
 template <typename T>
 __global__ void LegacyResizeBilinearGradKernel(
-    const int32 nthreads, const float* input_grad, float height_scale,
+    const int32 nthreads, const float* __restrict__ input_grad, float height_scale,
     float width_scale, int batch, int original_height, int original_width,
-    int channels, int resized_height, int resized_width, T* output_grad) {
+    int channels, int resized_height, int resized_width, T* __restrict__ output_grad) {
   GPU_1D_KERNEL_LOOP(in_idx, nthreads) {
     // in_idx = c + channels * (x + resized_width * (y + resized_height * b))
     int idx = in_idx;
diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op_gpu.cu.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op_gpu.cu.cc
index 73a27c5817a..b6a9c77ba13 100644
--- a/tensorflow/core/kernels/resize_nearest_neighbor_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/resize_nearest_neighbor_op_gpu.cu.cc
@@ -33,7 +33,7 @@ namespace {
 
 template <typename T>
 __global__ void ResizeNearestNeighborNHWC(
-    const int nthreads, const T* bottom_data, const int in_height,
+    const int nthreads, const T* __restrict__ bottom_data, const int in_height,
     const int in_width, const int channels, const int out_height,
     const int out_width, const float height_scale, const float width_scale,
     T* top_data) {
@@ -64,10 +64,10 @@ __global__ void ResizeNearestNeighborNHWC(
 
 template <typename T, bool align_corners>
 __global__ void LegacyResizeNearestNeighborNHWC(
-    const int nthreads, const T* bottom_data, const int in_height,
+    const int nthreads, const T* __restrict__ bottom_data, const int in_height,
     const int in_width, const int channels, const int out_height,
     const int out_width, const float height_scale, const float width_scale,
-    T* top_data) {
+    T* __restrict__ top_data) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     int n = index;
     int c = n % channels;
@@ -93,10 +93,10 @@ __global__ void LegacyResizeNearestNeighborNHWC(
 
 template <typename T>
 __global__ void ResizeNearestNeighborBackwardNHWC(
-    const int nthreads, const T* top_diff, const int in_height,
+    const int nthreads, const T* __restrict__ top_diff, const int in_height,
     const int in_width, const int channels, const int out_height,
     const int out_width, const float height_scale, const float width_scale,
-    T* bottom_diff) {
+    T* __restrict__ bottom_diff) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     int n = index;
     int c = n % channels;
@@ -124,10 +124,10 @@ __global__ void ResizeNearestNeighborBackwardNHWC(
 
 template <typename T, bool align_corners>
 __global__ void LegacyResizeNearestNeighborBackwardNHWC(
-    const int nthreads, const T* top_diff, const int in_height,
+    const int nthreads, const T* __restrict__ top_diff, const int in_height,
     const int in_width, const int channels, const int out_height,
     const int out_width, const float height_scale, const float width_scale,
-    T* bottom_diff) {
+    T* __restrict__ bottom_diff) {
   GPU_1D_KERNEL_LOOP(index, nthreads) {
     int n = index;
     int c = n % channels;
diff --git a/tensorflow/core/kernels/roll_op_gpu.cu.cc b/tensorflow/core/kernels/roll_op_gpu.cu.cc
index c5ef02d84a6..d70d4b719b5 100644
--- a/tensorflow/core/kernels/roll_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/roll_op_gpu.cu.cc
@@ -31,8 +31,8 @@ namespace {
 
 template <typename T>
 __global__ void RollKernel(const int32 nthreads, const int32 num_dims,
-                           const T* input, T* output, const int32* dim_size,
-                           const int32* threshold, const int64* dim_range) {
+                           const T* __restrict__ input, T* __restrict__ output, const int32* __restrict__ dim_size,
+                           const int32* __restrict__ threshold, const int64* __restrict__ dim_range) {
   CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
     int64 offset = 0;
     for (int i = 0; i < num_dims; i++) {
diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.h b/tensorflow/core/kernels/scatter_functor_gpu.cu.h
index 11baad0d585..fa742fac1fb 100644
--- a/tensorflow/core/kernels/scatter_functor_gpu.cu.h
+++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.h
@@ -70,8 +70,8 @@ struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MAX> {
 };
 
 template <typename T, typename Index, scatter_op::UpdateOp op>
-__global__ void ScatterOpCustomKernel(T* params, const T* updates,
-                                      const Index* indices,
+__global__ void ScatterOpCustomKernel(T* __restrict__ params, const T* __restrict__ updates,
+                                      const Index* __restrict__ indices,
                                       Index first_dim_size, Index updates_size,
                                       Index indices_size) {
   Index update_block = updates_size / indices_size;
@@ -90,8 +90,8 @@ __global__ void ScatterOpCustomKernel(T* params, const T* updates,
 }
 
 template <typename T, typename Index, scatter_op::UpdateOp op>
-__global__ void ScatterScalarOpCustomKernel(T* params, const T* update,
-                                            const Index* indices,
+__global__ void ScatterScalarOpCustomKernel(T* __restrict__ params, const T* __restrict__ update,
+                                            const Index* __restrict__ indices,
                                             Index first_dim_size,
                                             Index indices_size,
                                             Index synthesized_updates_size) {
diff --git a/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
index 58a78ddceef..5fc6d63c53c 100644
--- a/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
@@ -31,9 +31,9 @@ typedef Eigen::GpuDevice GPUDevice;
 
 namespace {
 template <typename T, typename OutType>
-__global__ void UpperBoundKernel(const T* sorted_inputs, int batch_size,
+__global__ void UpperBoundKernel(const T* __restrict__ sorted_inputs, int batch_size,
                                  int sorted_inputs_size, int values_size,
-                                 const T* values, OutType* outputs) {
+                                 const T* __restrict__ values, OutType* __restrict__ outputs) {
   GPU_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
     int bid = work_unit_id / values_size;
     T value = values[work_unit_id];
@@ -43,9 +43,9 @@ __global__ void UpperBoundKernel(const T* sorted_inputs, int batch_size,
 }
 
 template <typename T, typename OutType>
-__global__ void LowerBoundKernel(const T* sorted_inputs, int batch_size,
+__global__ void LowerBoundKernel(const T* __restrict__ sorted_inputs, int batch_size,
                                  int sorted_inputs_size, int values_size,
-                                 const T* values, OutType* outputs) {
+                                 const T* __restrict__ values, OutType* __restrict__ outputs) {
   GPU_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
     int bid = work_unit_id / values_size;
     T value = values[work_unit_id];
diff --git a/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc b/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc
index b80599e93de..595fb9fe49c 100644
--- a/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc
@@ -29,12 +29,12 @@ typedef Eigen::GpuDevice GPUDevice;
 // Space2Depth kernel for FORMAT_NHWC.
 // See 'spacetodepth_op.h' for a more detailed description.
 template <typename dtype>
-__global__ void S2D_NHWC(const int32 nthreads, const dtype* input_ptr,
+__global__ void S2D_NHWC(const int32 nthreads, const dtype* __restrict__ input_ptr,
                          const int block_size, const int batch_size,
                          const int input_height, const int input_width,
                          const int input_depth, const int output_height,
                          const int output_width, const int output_depth,
-                         dtype* output_ptr) {
+                         dtype* __restrict__ output_ptr) {
   GPU_1D_KERNEL_LOOP(inp_idx, nthreads) {
     // inp_idx = d + input_depth * (w + input_width * (h + input_height * b))
     const int d = inp_idx % input_depth;
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
index f1651455190..26e406f7608 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
@@ -29,9 +29,9 @@ typedef Eigen::GpuDevice GPUDevice;
 template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
 __global__ void SparseTensorDenseMatMulKernel(int nnz, int m, int b_rows,
                                               int b_cols, int p,
-                                              const Tindices* a_indices,
-                                              const T* a_values, const T* b,
-                                              T* out) {
+                                              const Tindices* __restrict__ a_indices,
+                                              const T* __restrict__ a_values, const T* __restrict__ b,
+                                              T* __restrict__ out) {
   // out_{ij} = sum_k {a_ik b_kj}
   // out = A * B', out_{ij} = sum_k {a_ik (b')_kj}; b'_{kj} = b_{jk}
   const int n = (ADJ_B) ? b_cols : b_rows;
diff --git a/tensorflow/core/kernels/split_lib_gpu.cu.cc b/tensorflow/core/kernels/split_lib_gpu.cu.cc
index c2cc7d61006..50fc933eb28 100644
--- a/tensorflow/core/kernels/split_lib_gpu.cu.cc
+++ b/tensorflow/core/kernels/split_lib_gpu.cu.cc
@@ -74,7 +74,7 @@ TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
 namespace {
 
 template <typename T>
-__global__ void SplitOpKernel(const T* input, int32 prefix_dim_size,
+__global__ void SplitOpKernel(const T* __restrict__ input, int32 prefix_dim_size,
                               int32 split_dim_size, int32 suffix_dim_size,
                               GpuDeviceArrayStruct<T*> output_ptr_data) {
   const int32 num_split = output_ptr_data.size;
@@ -112,7 +112,7 @@ __global__ void SplitOpKernel(const T* input, int32 prefix_dim_size,
 // very similar to the concat kernel except the input/output logic
 // is reversed
 template <typename T, typename IntType, bool useSmem>
-__global__ void split_v_kernel(const T* input_ptr,
+__global__ void split_v_kernel(const T* __restrict__ input_ptr,
                                GpuDeviceArrayStruct<IntType> output_scan,
                                IntType total_rows, IntType total_cols,
                                GpuDeviceArrayStruct<T*> output_ptr_data) {
@@ -169,7 +169,7 @@ __global__ void split_v_kernel(const T* input_ptr,
 // different from the original split implementation due to 2D vs 3D
 // dimensions.  This version is likely faster due to less integer math.
 template <typename T>
-__global__ void SplitVOpKernel_fixed(const T* input, int32 prefix_dim_size,
+__global__ void SplitVOpKernel_fixed(const T* __restrict__ input, int32 prefix_dim_size,
                                      int32 suffix_dim_size,
                                      GpuDeviceArrayStruct<T*> output_ptr_data) {
   const int32 num_split = output_ptr_data.size;
diff --git a/tensorflow/core/kernels/stateful_random_ops_gpu.cu.cc b/tensorflow/core/kernels/stateful_random_ops_gpu.cu.cc
index deb58275897..e5e49640b51 100644
--- a/tensorflow/core/kernels/stateful_random_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/stateful_random_ops_gpu.cu.cc
@@ -35,8 +35,8 @@ __device__ int thread_counter;
 template <typename Distribution>
 __global__ void FillKernel(
     Distribution dist, int64 state_size, int64 output_size,
-    StateElementType* state_data,
-    typename Distribution::ResultElementType* output_data) {
+    StateElementType* __restrict__ state_data,
+    typename Distribution::ResultElementType* __restrict__ output_data) {
   // Threads in this block share `philox`. Thread 0 is responsible for
   // initializing it.
   __shared__ char philox_raw[sizeof(PhiloxRandom)];
@@ -90,7 +90,7 @@ void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
 }
 
 // Precondition: there is only 1 block and 1 thread.
-__global__ void SkipKernel(int64 delta, StateElementType* state_data) {
+__global__ void SkipKernel(int64 delta, StateElementType* __restrict__ state_data) {
   auto philox = GetPhiloxRandomFromMem(state_data);
   UpdateMemWithPhiloxRandom(philox, delta, state_data);
 }
diff --git a/tensorflow/core/kernels/svd_op_gpu.cu.cc b/tensorflow/core/kernels/svd_op_gpu.cu.cc
index aa95df6dfe6..30116121c87 100644
--- a/tensorflow/core/kernels/svd_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/svd_op_gpu.cu.cc
@@ -60,9 +60,9 @@ namespace {
 // real value of V (which should be computed)
 template <class Scalar>
 __global__ void ComputeValueOfVKernel(Gpu2DLaunchConfig config, int64 m,
-                                      int64 ldu, const Scalar* M,
-                                      const Scalar* U, const Scalar* S,
-                                      Scalar* V) {
+                                      int64 ldu, const Scalar* __restrict__ M,
+                                      const Scalar* __restrict__ U, const Scalar* __restrict__ S,
+                                      Scalar* __restrict__ V) {
   GPU_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) {
     GPU_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) {
       Scalar v = M[i + m * batch] * U[ldu * (i + m * batch)] * S[batch];
@@ -74,7 +74,7 @@ __global__ void ComputeValueOfVKernel(Gpu2DLaunchConfig config, int64 m,
 // Extracts the sign of V
 // V[i] = V[i]>=0 ? 1 : 0
 template <class Scalar>
-__global__ void ExtractSignOfVKernel(GpuLaunchConfig config, Scalar* V) {
+__global__ void ExtractSignOfVKernel(GpuLaunchConfig config, Scalar* __restrict__ V) {
   GPU_1D_KERNEL_LOOP(i, config.virtual_thread_count) {
     V[i] = V[i] >= 0 ? Scalar(1) : Scalar(-1);
   }
diff --git a/tensorflow/core/kernels/tile_functor_gpu.h b/tensorflow/core/kernels/tile_functor_gpu.h
index b013d68c7fe..8b92039af1f 100644
--- a/tensorflow/core/kernels/tile_functor_gpu.h
+++ b/tensorflow/core/kernels/tile_functor_gpu.h
@@ -30,8 +30,8 @@ namespace tensorflow {
 namespace internal {
 
 template <typename T>
-__global__ void TileKernel(int nthreads, const T* src, const int32* buf,
-                           const int32 ndims, T* dst) {
+__global__ void TileKernel(int nthreads, const T* __restrict__ src, const int32* __restrict__ buf,
+                           const int32 ndims, T* __restrict__ dst) {
   const int32* in_strides = buf;
   const int32* out_strides = buf + ndims;
   const int32* in_dim_sizes = buf + ndims * 2;
diff --git a/tensorflow/core/kernels/topk_op_gpu.h b/tensorflow/core/kernels/topk_op_gpu.h
index 12717ca11fe..a8d534462c0 100644
--- a/tensorflow/core/kernels/topk_op_gpu.h
+++ b/tensorflow/core/kernels/topk_op_gpu.h
@@ -341,8 +341,8 @@ __device__ void mergeShards(int num_shards, int k,
 extern __shared__ char shared_memory[];
 
 template <typename T>
-__global__ void TopKKernel(const T* input, int length, int k, bool sorted,
-                           T* output, int* indices) {
+__global__ void TopKKernel(const T* __restrict__ input, int length, int k, bool sorted,
+                           T* __restrict__ output, int* __restrict__ indices) {
   const int batch_index = blockIdx.x;
   const T* batch_input = input + batch_index * length;
 
diff --git a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
index 4f54d7340aa..7511025e426 100644
--- a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
+++ b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
@@ -32,8 +32,8 @@ namespace tensorflow {
 namespace internal {
 
 template <typename T, bool conjugate>
-__global__ void TransposeKernel(int nthreads, const T* src, const int32* buf,
-                                const int32 ndims, T* dst) {
+__global__ void TransposeKernel(int nthreads, const T* __restrict__ src, const int32* __restrict__ buf,
+                                const int32 ndims, T* __restrict__ dst) {
   const int32* in_strides = buf;
   const int32* out_strides = buf + ndims;
   const int32* perm = buf + ndims * 2;
diff --git a/tensorflow/core/kernels/tridiagonal_matmul_op_gpu.cu.cc b/tensorflow/core/kernels/tridiagonal_matmul_op_gpu.cu.cc
index 8c829be4785..14a47c711dc 100644
--- a/tensorflow/core/kernels/tridiagonal_matmul_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/tridiagonal_matmul_op_gpu.cu.cc
@@ -35,10 +35,10 @@ namespace tensorflow {
 
 template <typename Scalar>
 __global__ void TridiagonalMatMulKernel(int batch_size, int m, int n,
-                                        const Scalar* superdiag,
-                                        const Scalar* maindiag,
-                                        const Scalar* subdiag,
-                                        const Scalar* rhs, Scalar* product) {
+                                        const Scalar* __restrict__ superdiag,
+                                        const Scalar* __restrict__ maindiag,
+                                        const Scalar* __restrict__ subdiag,
+                                        const Scalar* __restrict__ rhs, Scalar* __restrict__ product) {
   for (int i : CudaGridRangeX(batch_size * m * n)) {
     int row_id = i / n;
     Scalar result = maindiag[row_id] * rhs[i];
diff --git a/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc b/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc
index 88a3f2d1ca9..9f11aaf4c43 100644
--- a/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc
@@ -40,9 +40,9 @@ static const char kNotInvertibleScalarMsg[] =
     "The matrix is not invertible: it is a scalar with value zero.";
 
 template <typename Scalar>
-__global__ void SolveForSizeOneOrTwoKernel(const int m, const Scalar* diags,
-                                           const Scalar* rhs, const int num_rhs,
-                                           Scalar* x, bool* not_invertible) {
+__global__ void SolveForSizeOneOrTwoKernel(const int m, const Scalar* __restrict__ diags,
+                                           const Scalar* __restrict__ rhs, const int num_rhs,
+                                           Scalar* __restrict__ x, bool* __restrict__ not_invertible) {
   if (m == 1) {
     if (diags[1] == Scalar(0)) {
       *not_invertible = true;
diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h
index 8c228d60ebb..bbed89629ef 100644
--- a/tensorflow/core/kernels/where_op_gpu.cu.h
+++ b/tensorflow/core/kernels/where_op_gpu.cu.h
@@ -52,7 +52,7 @@ namespace functor {
 template <int NDIM, typename TIndex>
 __global__ void PropagateWhereIndicesKernel(
     const TIndex output_rows, const typename Eigen::array<TIndex, NDIM> strides,
-    int64* output) {
+    int64* __restrict__ output) {
   // TODO(ebrevdo): Use a multi-dimensional loop, increasing the
   // dimensions of individual indices manually, instead of relying on
   // a scalar loop variable and using integer division.