diff --git a/tensorflow/core/kernels/image/resize_bilinear_op.cc b/tensorflow/core/kernels/image/resize_bilinear_op.cc index e01b1776419..320d2ff5fb9 100644 --- a/tensorflow/core/kernels/image/resize_bilinear_op.cc +++ b/tensorflow/core/kernels/image/resize_bilinear_op.cc @@ -132,41 +132,25 @@ inline __m128 compute_lerp_v(const __m128 top_left, const __m128 top_right, #endif template -void ResizeLine3Channels(const T* const ys_input_lower_ptr, - const T* const ys_input_upper_ptr, - const CachedInterpolation* const xs, - const float ys_lerp, const int64 out_width, - float* out_y) { +void ResizeLineChannels(const T* const ys_input_lower_ptr, + const T* const ys_input_upper_ptr, + const CachedInterpolation* const xs, + const float ys_lerp, const int64 out_width, + float* out_y, const int channels) { for (int64 x = 0; x < out_width; ++x) { const int64 xs_lower = xs[x].lower; const int64 xs_upper = xs[x].upper; const float xs_lerp = xs[x].lerp; - // Read channel 0. - const float top_left0(ys_input_lower_ptr[xs_lower + 0]); - const float top_right0(ys_input_lower_ptr[xs_upper + 0]); - const float bottom_left0(ys_input_upper_ptr[xs_lower + 0]); - const float bottom_right0(ys_input_upper_ptr[xs_upper + 0]); + for (int c = 0; c < channels; ++c) { + const float top_left(ys_input_lower_ptr[xs_lower + c]); + const float top_right(ys_input_lower_ptr[xs_upper + c]); + const float bottom_left(ys_input_upper_ptr[xs_lower + c]); + const float bottom_right(ys_input_upper_ptr[xs_upper + c]); - // Read channel 1. - const float top_left1(ys_input_lower_ptr[xs_lower + 1]); - const float top_right1(ys_input_lower_ptr[xs_upper + 1]); - const float bottom_left1(ys_input_upper_ptr[xs_lower + 1]); - const float bottom_right1(ys_input_upper_ptr[xs_upper + 1]); - - // Read channel 2. - const float top_left2(ys_input_lower_ptr[xs_lower + 2]); - const float top_right2(ys_input_lower_ptr[xs_upper + 2]); - const float bottom_left2(ys_input_upper_ptr[xs_lower + 2]); - const float bottom_right2(ys_input_upper_ptr[xs_upper + 2]); - - // Compute output. - out_y[x * 3 + 0] = compute_lerp(top_left0, top_right0, bottom_left0, - bottom_right0, xs_lerp, ys_lerp); - out_y[x * 3 + 1] = compute_lerp(top_left1, top_right1, bottom_left1, - bottom_right1, xs_lerp, ys_lerp); - out_y[x * 3 + 2] = compute_lerp(top_left2, top_right2, bottom_left2, - bottom_right2, xs_lerp, ys_lerp); + out_y[x * channels + c] = compute_lerp(top_left, top_right, bottom_left, + bottom_right, xs_lerp, ys_lerp); + } } } @@ -212,9 +196,8 @@ void ResizeLine3ChannelsVector(const T* const ys_input_lower_ptr, } // The last pixel of each row must be done in a non-vectorized way // because we cannot overflow. - ResizeLine3Channels(ys_input_lower_ptr, ys_input_upper_ptr, - xs + out_width - 1, ys_lerp, 1, - out_y + (out_width - 1) * 3); + ResizeLineChannels(ys_input_lower_ptr, ys_input_upper_ptr, xs + out_width - 1, + ys_lerp, 1, out_y + (out_width - 1) * 3, 3); } #endif @@ -251,8 +234,8 @@ void resize_image(typename TTypes::ConstTensor images, ResizeLine3ChannelsVector(ys_input_lower_ptr, ys_input_upper_ptr, xs, ys[y].lerp, out_width, output_y_ptr); #else - ResizeLine3Channels(ys_input_lower_ptr, ys_input_upper_ptr, xs, - ys[y].lerp, out_width, output_y_ptr); + ResizeLineChannels(ys_input_lower_ptr, ys_input_upper_ptr, xs, + ys[y].lerp, out_width, output_y_ptr, 3); #endif output_y_ptr += out_row_size; } @@ -264,21 +247,10 @@ void resize_image(typename TTypes::ConstTensor images, for (int64 y = 0; y < out_height; ++y) { const T* ys_input_lower_ptr = input_b_ptr + ys[y].lower * in_row_size; const T* ys_input_upper_ptr = input_b_ptr + ys[y].upper * in_row_size; - const float ys_lerp = ys[y].lerp; - for (int64 x = 0; x < out_width; ++x) { - auto xs_lower = xs[x].lower; - auto xs_upper = xs[x].upper; - auto xs_lerp = xs[x].lerp; - for (int c = 0; c < channels; ++c) { - const float top_left(ys_input_lower_ptr[xs_lower + c]); - const float top_right(ys_input_lower_ptr[xs_upper + c]); - const float bottom_left(ys_input_upper_ptr[xs_lower + c]); - const float bottom_right(ys_input_upper_ptr[xs_upper + c]); - output_y_ptr[x * channels + c] = - compute_lerp(top_left, top_right, bottom_left, bottom_right, - xs_lerp, ys_lerp); - } - } + + ResizeLineChannels(ys_input_lower_ptr, ys_input_upper_ptr, xs, + ys[y].lerp, out_width, output_y_ptr, channels); + output_y_ptr += out_row_size; } input_b_ptr += in_batch_num_values;