Merge pull request #47567 from around-star:patch-1

PiperOrigin-RevId: 361160717
Change-Id: I5fe0a8ba102193d834c385741288390a638b26c7
This commit is contained in:
TensorFlower Gardener 2021-03-05 09:26:23 -08:00
commit 017ec03f9b

View File

@ -132,41 +132,25 @@ inline __m128 compute_lerp_v(const __m128 top_left, const __m128 top_right,
#endif #endif
template <typename T> template <typename T>
void ResizeLine3Channels(const T* const ys_input_lower_ptr, void ResizeLineChannels(const T* const ys_input_lower_ptr,
const T* const ys_input_upper_ptr, const T* const ys_input_upper_ptr,
const CachedInterpolation* const xs, const CachedInterpolation* const xs,
const float ys_lerp, const int64 out_width, const float ys_lerp, const int64 out_width,
float* out_y) { float* out_y, const int channels) {
for (int64 x = 0; x < out_width; ++x) { for (int64 x = 0; x < out_width; ++x) {
const int64 xs_lower = xs[x].lower; const int64 xs_lower = xs[x].lower;
const int64 xs_upper = xs[x].upper; const int64 xs_upper = xs[x].upper;
const float xs_lerp = xs[x].lerp; const float xs_lerp = xs[x].lerp;
// Read channel 0. for (int c = 0; c < channels; ++c) {
const float top_left0(ys_input_lower_ptr[xs_lower + 0]); const float top_left(ys_input_lower_ptr[xs_lower + c]);
const float top_right0(ys_input_lower_ptr[xs_upper + 0]); const float top_right(ys_input_lower_ptr[xs_upper + c]);
const float bottom_left0(ys_input_upper_ptr[xs_lower + 0]); const float bottom_left(ys_input_upper_ptr[xs_lower + c]);
const float bottom_right0(ys_input_upper_ptr[xs_upper + 0]); const float bottom_right(ys_input_upper_ptr[xs_upper + c]);
// Read channel 1. out_y[x * channels + c] = compute_lerp(top_left, top_right, bottom_left,
const float top_left1(ys_input_lower_ptr[xs_lower + 1]); bottom_right, xs_lerp, ys_lerp);
const float top_right1(ys_input_lower_ptr[xs_upper + 1]); }
const float bottom_left1(ys_input_upper_ptr[xs_lower + 1]);
const float bottom_right1(ys_input_upper_ptr[xs_upper + 1]);
// Read channel 2.
const float top_left2(ys_input_lower_ptr[xs_lower + 2]);
const float top_right2(ys_input_lower_ptr[xs_upper + 2]);
const float bottom_left2(ys_input_upper_ptr[xs_lower + 2]);
const float bottom_right2(ys_input_upper_ptr[xs_upper + 2]);
// Compute output.
out_y[x * 3 + 0] = compute_lerp(top_left0, top_right0, bottom_left0,
bottom_right0, xs_lerp, ys_lerp);
out_y[x * 3 + 1] = compute_lerp(top_left1, top_right1, bottom_left1,
bottom_right1, xs_lerp, ys_lerp);
out_y[x * 3 + 2] = compute_lerp(top_left2, top_right2, bottom_left2,
bottom_right2, xs_lerp, ys_lerp);
} }
} }
@ -212,9 +196,8 @@ void ResizeLine3ChannelsVector(const T* const ys_input_lower_ptr,
} }
// The last pixel of each row must be done in a non-vectorized way // The last pixel of each row must be done in a non-vectorized way
// because we cannot overflow. // because we cannot overflow.
ResizeLine3Channels(ys_input_lower_ptr, ys_input_upper_ptr, ResizeLineChannels(ys_input_lower_ptr, ys_input_upper_ptr, xs + out_width - 1,
xs + out_width - 1, ys_lerp, 1, ys_lerp, 1, out_y + (out_width - 1) * 3, 3);
out_y + (out_width - 1) * 3);
} }
#endif #endif
@ -251,8 +234,8 @@ void resize_image(typename TTypes<T, 4>::ConstTensor images,
ResizeLine3ChannelsVector(ys_input_lower_ptr, ys_input_upper_ptr, xs, ResizeLine3ChannelsVector(ys_input_lower_ptr, ys_input_upper_ptr, xs,
ys[y].lerp, out_width, output_y_ptr); ys[y].lerp, out_width, output_y_ptr);
#else #else
ResizeLine3Channels(ys_input_lower_ptr, ys_input_upper_ptr, xs, ResizeLineChannels(ys_input_lower_ptr, ys_input_upper_ptr, xs,
ys[y].lerp, out_width, output_y_ptr); ys[y].lerp, out_width, output_y_ptr, 3);
#endif #endif
output_y_ptr += out_row_size; output_y_ptr += out_row_size;
} }
@ -264,21 +247,10 @@ void resize_image(typename TTypes<T, 4>::ConstTensor images,
for (int64 y = 0; y < out_height; ++y) { for (int64 y = 0; y < out_height; ++y) {
const T* ys_input_lower_ptr = input_b_ptr + ys[y].lower * in_row_size; const T* ys_input_lower_ptr = input_b_ptr + ys[y].lower * in_row_size;
const T* ys_input_upper_ptr = input_b_ptr + ys[y].upper * in_row_size; const T* ys_input_upper_ptr = input_b_ptr + ys[y].upper * in_row_size;
const float ys_lerp = ys[y].lerp;
for (int64 x = 0; x < out_width; ++x) { ResizeLineChannels(ys_input_lower_ptr, ys_input_upper_ptr, xs,
auto xs_lower = xs[x].lower; ys[y].lerp, out_width, output_y_ptr, channels);
auto xs_upper = xs[x].upper;
auto xs_lerp = xs[x].lerp;
for (int c = 0; c < channels; ++c) {
const float top_left(ys_input_lower_ptr[xs_lower + c]);
const float top_right(ys_input_lower_ptr[xs_upper + c]);
const float bottom_left(ys_input_upper_ptr[xs_lower + c]);
const float bottom_right(ys_input_upper_ptr[xs_upper + c]);
output_y_ptr[x * channels + c] =
compute_lerp(top_left, top_right, bottom_left, bottom_right,
xs_lerp, ys_lerp);
}
}
output_y_ptr += out_row_size; output_y_ptr += out_row_size;
} }
input_b_ptr += in_batch_num_values; input_b_ptr += in_batch_num_values;