diff --git a/tensorflow/core/kernels/image/resize_nearest_neighbor_op_gpu.cu.cc b/tensorflow/core/kernels/image/resize_nearest_neighbor_op_gpu.cu.cc index 50066d5b653..93fde9131f2 100644 --- a/tensorflow/core/kernels/image/resize_nearest_neighbor_op_gpu.cu.cc +++ b/tensorflow/core/kernels/image/resize_nearest_neighbor_op_gpu.cu.cc @@ -173,20 +173,18 @@ struct ResizeNearestNeighbor { if (output_size == 0) return true; GpuLaunchConfig config = GetGpuLaunchConfig(output_size, d); - if (half_pixel_centers) { - TF_CHECK_OK(GpuLaunchKernel( - ResizeNearestNeighborNHWC, config.block_count, - config.thread_per_block, 0, d.stream(), output_size, input.data(), - in_height, in_width, channels, out_height, out_width, height_scale, - width_scale, output.data())); - return d.ok(); - } else { - TF_CHECK_OK(GpuLaunchKernel( - LegacyResizeNearestNeighborNHWC, config.block_count, - config.thread_per_block, 0, d.stream(), output_size, input.data(), - in_height, in_width, channels, out_height, out_width, height_scale, - width_scale, output.data())); - } + void (*kernel)(const int nthreads, const T* __restrict__ bottom_data, + const int in_height, const int in_width, const int channels, + const int out_height, const int out_width, + const float height_scale, const float width_scale, + T* top_data) = + half_pixel_centers ? ResizeNearestNeighborNHWC + : LegacyResizeNearestNeighborNHWC; + TF_CHECK_OK( + GpuLaunchKernel(kernel, config.block_count, config.thread_per_block, 0, + d.stream(), config.virtual_thread_count, input.data(), + in_height, in_width, channels, out_height, out_width, + height_scale, width_scale, output.data())); return d.ok(); } }; @@ -228,23 +226,20 @@ struct ResizeNearestNeighborGrad, input_config.block_count, - input_config.thread_per_block, 0, d.stream(), - input_config.virtual_thread_count, input.data(), in_height, in_width, - channels, out_height, out_width, height_scale, width_scale, - output.data())); - return d.ok(); - } else { - TF_CHECK_OK(GpuLaunchKernel( - LegacyResizeNearestNeighborBackwardNHWC, - input_config.block_count, input_config.thread_per_block, 0, - d.stream(), input_config.virtual_thread_count, input.data(), - in_height, in_width, channels, out_height, out_width, height_scale, - width_scale, output.data())); - return d.ok(); - } + void (*kernel)(const int nthreads, const T* __restrict__ top_diff, + const int in_height, const int in_width, const int channels, + const int out_height, const int out_width, + const float height_scale, const float width_scale, + T* __restrict__ bottom_diff) = + half_pixel_centers + ? ResizeNearestNeighborBackwardNHWC + : LegacyResizeNearestNeighborBackwardNHWC; + TF_CHECK_OK(GpuLaunchKernel( + kernel, input_config.block_count, input_config.thread_per_block, 0, + d.stream(), input_config.virtual_thread_count, input.data(), in_height, + in_width, channels, out_height, out_width, height_scale, width_scale, + output.data())); + return d.ok(); } };