do some refactor
This commit is contained in:
parent
c217663ded
commit
a57c998135
@ -173,20 +173,18 @@ struct ResizeNearestNeighbor<GPUDevice, T, half_pixel_centers, align_corners> {
|
||||
if (output_size == 0) return true;
|
||||
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(output_size, d);
|
||||
if (half_pixel_centers) {
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
ResizeNearestNeighborNHWC<T>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(), output_size, input.data(),
|
||||
in_height, in_width, channels, out_height, out_width, height_scale,
|
||||
width_scale, output.data()));
|
||||
return d.ok();
|
||||
} else {
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
LegacyResizeNearestNeighborNHWC<T, align_corners>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(), output_size, input.data(),
|
||||
in_height, in_width, channels, out_height, out_width, height_scale,
|
||||
width_scale, output.data()));
|
||||
}
|
||||
void (*kernel)(const int nthreads, const T* __restrict__ bottom_data,
|
||||
const int in_height, const int in_width, const int channels,
|
||||
const int out_height, const int out_width,
|
||||
const float height_scale, const float width_scale,
|
||||
T* top_data) =
|
||||
half_pixel_centers ? ResizeNearestNeighborNHWC<T>
|
||||
: LegacyResizeNearestNeighborNHWC<T, align_corners>;
|
||||
TF_CHECK_OK(
|
||||
GpuLaunchKernel(kernel, config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), config.virtual_thread_count, input.data(),
|
||||
in_height, in_width, channels, out_height, out_width,
|
||||
height_scale, width_scale, output.data()));
|
||||
return d.ok();
|
||||
}
|
||||
};
|
||||
@ -228,23 +226,20 @@ struct ResizeNearestNeighborGrad<GPUDevice, T, half_pixel_centers,
|
||||
if (input_size == 0) return true;
|
||||
|
||||
GpuLaunchConfig input_config = GetGpuLaunchConfig(input_size, d);
|
||||
if (half_pixel_centers) {
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
ResizeNearestNeighborBackwardNHWC<T>, input_config.block_count,
|
||||
input_config.thread_per_block, 0, d.stream(),
|
||||
input_config.virtual_thread_count, input.data(), in_height, in_width,
|
||||
channels, out_height, out_width, height_scale, width_scale,
|
||||
output.data()));
|
||||
return d.ok();
|
||||
} else {
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
LegacyResizeNearestNeighborBackwardNHWC<T, align_corners>,
|
||||
input_config.block_count, input_config.thread_per_block, 0,
|
||||
d.stream(), input_config.virtual_thread_count, input.data(),
|
||||
in_height, in_width, channels, out_height, out_width, height_scale,
|
||||
width_scale, output.data()));
|
||||
return d.ok();
|
||||
}
|
||||
void (*kernel)(const int nthreads, const T* __restrict__ top_diff,
|
||||
const int in_height, const int in_width, const int channels,
|
||||
const int out_height, const int out_width,
|
||||
const float height_scale, const float width_scale,
|
||||
T* __restrict__ bottom_diff) =
|
||||
half_pixel_centers
|
||||
? ResizeNearestNeighborBackwardNHWC<T>
|
||||
: LegacyResizeNearestNeighborBackwardNHWC<T, align_corners>;
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
kernel, input_config.block_count, input_config.thread_per_block, 0,
|
||||
d.stream(), input_config.virtual_thread_count, input.data(), in_height,
|
||||
in_width, channels, out_height, out_width, height_scale, width_scale,
|
||||
output.data()));
|
||||
return d.ok();
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user