do some refactor

This commit is contained in:
fsx950223 2020-08-14 13:35:28 +00:00
parent c217663ded
commit a57c998135

View File

@ -173,20 +173,18 @@ struct ResizeNearestNeighbor<GPUDevice, T, half_pixel_centers, align_corners> {
if (output_size == 0) return true;
GpuLaunchConfig config = GetGpuLaunchConfig(output_size, d);
if (half_pixel_centers) {
TF_CHECK_OK(GpuLaunchKernel(
ResizeNearestNeighborNHWC<T>, config.block_count,
config.thread_per_block, 0, d.stream(), output_size, input.data(),
in_height, in_width, channels, out_height, out_width, height_scale,
width_scale, output.data()));
return d.ok();
} else {
TF_CHECK_OK(GpuLaunchKernel(
LegacyResizeNearestNeighborNHWC<T, align_corners>, config.block_count,
config.thread_per_block, 0, d.stream(), output_size, input.data(),
in_height, in_width, channels, out_height, out_width, height_scale,
width_scale, output.data()));
}
void (*kernel)(const int nthreads, const T* __restrict__ bottom_data,
const int in_height, const int in_width, const int channels,
const int out_height, const int out_width,
const float height_scale, const float width_scale,
T* top_data) =
half_pixel_centers ? ResizeNearestNeighborNHWC<T>
: LegacyResizeNearestNeighborNHWC<T, align_corners>;
TF_CHECK_OK(
GpuLaunchKernel(kernel, config.block_count, config.thread_per_block, 0,
d.stream(), config.virtual_thread_count, input.data(),
in_height, in_width, channels, out_height, out_width,
height_scale, width_scale, output.data()));
return d.ok();
}
};
@ -228,23 +226,20 @@ struct ResizeNearestNeighborGrad<GPUDevice, T, half_pixel_centers,
if (input_size == 0) return true;
GpuLaunchConfig input_config = GetGpuLaunchConfig(input_size, d);
if (half_pixel_centers) {
TF_CHECK_OK(GpuLaunchKernel(
ResizeNearestNeighborBackwardNHWC<T>, input_config.block_count,
input_config.thread_per_block, 0, d.stream(),
input_config.virtual_thread_count, input.data(), in_height, in_width,
channels, out_height, out_width, height_scale, width_scale,
output.data()));
return d.ok();
} else {
TF_CHECK_OK(GpuLaunchKernel(
LegacyResizeNearestNeighborBackwardNHWC<T, align_corners>,
input_config.block_count, input_config.thread_per_block, 0,
d.stream(), input_config.virtual_thread_count, input.data(),
in_height, in_width, channels, out_height, out_width, height_scale,
width_scale, output.data()));
return d.ok();
}
void (*kernel)(const int nthreads, const T* __restrict__ top_diff,
const int in_height, const int in_width, const int channels,
const int out_height, const int out_width,
const float height_scale, const float width_scale,
T* __restrict__ bottom_diff) =
half_pixel_centers
? ResizeNearestNeighborBackwardNHWC<T>
: LegacyResizeNearestNeighborBackwardNHWC<T, align_corners>;
TF_CHECK_OK(GpuLaunchKernel(
kernel, input_config.block_count, input_config.thread_per_block, 0,
d.stream(), input_config.virtual_thread_count, input.data(), in_height,
in_width, channels, out_height, out_width, height_scale, width_scale,
output.data()));
return d.ok();
}
};