Optimize bilinear resizing using vectorization.
``` name old time/op new time/op delta BM_Resize 4.04ms ± 2% 2.84ms ± 2% -29.71% (p=0.000 n=10+10) ``` PiperOrigin-RevId: 315219756 Change-Id: Idcdc00c31199060c67665aeb52b24f495664dbdf
This commit is contained in:
parent
cef0c8cf71
commit
f65efd739c
@ -18,6 +18,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/kernels/resize_bilinear_op.h"
|
||||
|
||||
#ifdef __SSE4_1__
|
||||
#include <xmmintrin.h>
|
||||
#endif
|
||||
|
||||
#include <memory>
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -107,6 +111,107 @@ inline float compute_lerp(const float top_left, const float top_right,
|
||||
return top + (bottom - top) * y_lerp;
|
||||
}
|
||||
|
||||
#ifdef __SSE4_1__
|
||||
/* Vector version of the above */
|
||||
inline __m128 compute_lerp_v(const __m128 top_left, const __m128 top_right,
|
||||
const __m128 bottom_left,
|
||||
const __m128 bottom_right, const __m128 x_lerp,
|
||||
const __m128 y_lerp) {
|
||||
const __m128 top =
|
||||
_mm_add_ps(top_left, _mm_mul_ps(_mm_sub_ps(top_right, top_left), x_lerp));
|
||||
const __m128 bottom = _mm_add_ps(
|
||||
bottom_left, _mm_mul_ps(_mm_sub_ps(bottom_right, bottom_left), x_lerp));
|
||||
return _mm_add_ps(top, _mm_mul_ps(_mm_sub_ps(bottom, top), y_lerp));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
void ResizeLine3Channels(const T* const ys_input_lower_ptr,
|
||||
const T* const ys_input_upper_ptr,
|
||||
const CachedInterpolation* const xs,
|
||||
const float ys_lerp, const int64 out_width,
|
||||
float* out_y) {
|
||||
for (int64 x = 0; x < out_width; ++x) {
|
||||
const int64 xs_lower = xs[x].lower;
|
||||
const int64 xs_upper = xs[x].upper;
|
||||
const float xs_lerp = xs[x].lerp;
|
||||
|
||||
// Read channel 0.
|
||||
const float top_left0(ys_input_lower_ptr[xs_lower + 0]);
|
||||
const float top_right0(ys_input_lower_ptr[xs_upper + 0]);
|
||||
const float bottom_left0(ys_input_upper_ptr[xs_lower + 0]);
|
||||
const float bottom_right0(ys_input_upper_ptr[xs_upper + 0]);
|
||||
|
||||
// Read channel 1.
|
||||
const float top_left1(ys_input_lower_ptr[xs_lower + 1]);
|
||||
const float top_right1(ys_input_lower_ptr[xs_upper + 1]);
|
||||
const float bottom_left1(ys_input_upper_ptr[xs_lower + 1]);
|
||||
const float bottom_right1(ys_input_upper_ptr[xs_upper + 1]);
|
||||
|
||||
// Read channel 2.
|
||||
const float top_left2(ys_input_lower_ptr[xs_lower + 2]);
|
||||
const float top_right2(ys_input_lower_ptr[xs_upper + 2]);
|
||||
const float bottom_left2(ys_input_upper_ptr[xs_lower + 2]);
|
||||
const float bottom_right2(ys_input_upper_ptr[xs_upper + 2]);
|
||||
|
||||
// Compute output.
|
||||
out_y[x * 3 + 0] = compute_lerp(top_left0, top_right0, bottom_left0,
|
||||
bottom_right0, xs_lerp, ys_lerp);
|
||||
out_y[x * 3 + 1] = compute_lerp(top_left1, top_right1, bottom_left1,
|
||||
bottom_right1, xs_lerp, ys_lerp);
|
||||
out_y[x * 3 + 2] = compute_lerp(top_left2, top_right2, bottom_left2,
|
||||
bottom_right2, xs_lerp, ys_lerp);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __SSE4_1__
|
||||
|
||||
// Load 3 floats from the given buffer, which must be of size at least 4.
|
||||
template <typename T>
|
||||
inline __m128 load_3xfloat_v(T* values) {
|
||||
return _mm_set_ps(0.0f, static_cast<float>(values[2]),
|
||||
static_cast<float>(values[1]),
|
||||
static_cast<float>(values[0]));
|
||||
}
|
||||
|
||||
// Specialize cases that can be done more efficiently.
|
||||
template <>
|
||||
inline __m128 load_3xfloat_v(float* values) {
|
||||
return _mm_loadu_ps(values);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ResizeLine3ChannelsVector(const T* const ys_input_lower_ptr,
|
||||
const T* const ys_input_upper_ptr,
|
||||
const CachedInterpolation* const xs,
|
||||
const float ys_lerp, const int64 out_width,
|
||||
float* out_y) {
|
||||
const __m128 ys_lerp_v = _mm_set1_ps(ys_lerp);
|
||||
// All pixels but the last one can overflow, vectorize the inside of the
|
||||
// row.
|
||||
int64 x = 0;
|
||||
for (x = 0; x < out_width - 1; ++x) {
|
||||
const int64 xs_lower = xs[x].lower;
|
||||
const int64 xs_upper = xs[x].upper;
|
||||
const __m128 xs_lerp_v = _mm_set1_ps(xs[x].lerp);
|
||||
|
||||
const __m128 top_left_v = load_3xfloat_v(ys_input_lower_ptr + xs_lower);
|
||||
const __m128 top_right_v = load_3xfloat_v(ys_input_lower_ptr + xs_upper);
|
||||
const __m128 bottom_left_v = load_3xfloat_v(ys_input_upper_ptr + xs_lower);
|
||||
const __m128 bottom_right_v = load_3xfloat_v(ys_input_upper_ptr + xs_upper);
|
||||
|
||||
_mm_storeu_ps(out_y + x * 3,
|
||||
compute_lerp_v(top_left_v, top_right_v, bottom_left_v,
|
||||
bottom_right_v, xs_lerp_v, ys_lerp_v));
|
||||
}
|
||||
// The last pixel of each row must be done in a non-vectorized way
|
||||
// because we cannot overflow.
|
||||
ResizeLine3Channels(ys_input_lower_ptr, ys_input_upper_ptr,
|
||||
xs + out_width - 1, ys_lerp, 1,
|
||||
out_y + (out_width - 1) * 3);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
void resize_image(
|
||||
typename TTypes<T, 4>::ConstTensor images, const int batch_size,
|
||||
@ -136,41 +241,13 @@ void resize_image(typename TTypes<T, 4>::ConstTensor images,
|
||||
for (int64 y = 0; y < out_height; ++y) {
|
||||
const T* ys_input_lower_ptr = input_b_ptr + ys[y].lower * in_row_size;
|
||||
const T* ys_input_upper_ptr = input_b_ptr + ys[y].upper * in_row_size;
|
||||
const float ys_lerp = ys[y].lerp;
|
||||
for (int64 x = 0; x < out_width; ++x) {
|
||||
const int64 xs_lower = xs[x].lower;
|
||||
const int64 xs_upper = xs[x].upper;
|
||||
const float xs_lerp = xs[x].lerp;
|
||||
|
||||
// Read channel 0.
|
||||
const float top_left0(ys_input_lower_ptr[xs_lower + 0]);
|
||||
const float top_right0(ys_input_lower_ptr[xs_upper + 0]);
|
||||
const float bottom_left0(ys_input_upper_ptr[xs_lower + 0]);
|
||||
const float bottom_right0(ys_input_upper_ptr[xs_upper + 0]);
|
||||
|
||||
// Read channel 1.
|
||||
const float top_left1(ys_input_lower_ptr[xs_lower + 1]);
|
||||
const float top_right1(ys_input_lower_ptr[xs_upper + 1]);
|
||||
const float bottom_left1(ys_input_upper_ptr[xs_lower + 1]);
|
||||
const float bottom_right1(ys_input_upper_ptr[xs_upper + 1]);
|
||||
|
||||
// Read channel 2.
|
||||
const float top_left2(ys_input_lower_ptr[xs_lower + 2]);
|
||||
const float top_right2(ys_input_lower_ptr[xs_upper + 2]);
|
||||
const float bottom_left2(ys_input_upper_ptr[xs_lower + 2]);
|
||||
const float bottom_right2(ys_input_upper_ptr[xs_upper + 2]);
|
||||
|
||||
// Compute output.
|
||||
output_y_ptr[x * channels + 0] =
|
||||
compute_lerp(top_left0, top_right0, bottom_left0, bottom_right0,
|
||||
xs_lerp, ys_lerp);
|
||||
output_y_ptr[x * channels + 1] =
|
||||
compute_lerp(top_left1, top_right1, bottom_left1, bottom_right1,
|
||||
xs_lerp, ys_lerp);
|
||||
output_y_ptr[x * channels + 2] =
|
||||
compute_lerp(top_left2, top_right2, bottom_left2, bottom_right2,
|
||||
xs_lerp, ys_lerp);
|
||||
}
|
||||
#ifdef __SSE4_1__
|
||||
ResizeLine3ChannelsVector(ys_input_lower_ptr, ys_input_upper_ptr, xs,
|
||||
ys[y].lerp, out_width, output_y_ptr);
|
||||
#else
|
||||
ResizeLine3Channels(ys_input_lower_ptr, ys_input_upper_ptr, xs,
|
||||
ys[y].lerp, out_width, output_y_ptr);
|
||||
#endif
|
||||
output_y_ptr += out_row_size;
|
||||
}
|
||||
input_b_ptr += in_batch_num_values;
|
||||
@ -338,6 +415,7 @@ struct ResizeBilinearGrad<CPUDevice, T> {
|
||||
static_cast<Eigen::Index>(ceilf(in_x)), original_width - 1);
|
||||
const float x_lerp = in_x - floorf(in_x);
|
||||
const float inverse_x_lerp = (1.0f - x_lerp);
|
||||
// TODO(b/158287314): Look into vectorizing this.
|
||||
for (Eigen::Index c = 0; c < channels; ++c) {
|
||||
output_grad(b, top_y_index, left_x_index, c) +=
|
||||
T(input_grad(b, y, x, c) * inverse_y_lerp * inverse_x_lerp);
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -543,4 +544,39 @@ INSTANTIATE_TEST_SUITE_P(ResizeBilinearOpAlignCornersTestGpu,
|
||||
ResizeBilinearOpAlignCornersTest,
|
||||
::testing::Values(TestDevice::GPU));
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
class ResizeBM : public ResizeBilinearOpTest {
|
||||
public:
|
||||
void TestBody() override {}
|
||||
void SetUpBenchmark(int input_width, int input_height, int num_channels,
|
||||
int output_width, int output_height) {
|
||||
TF_EXPECT_OK(NodeDefBuilder("resize_bilinear_op", "ResizeBilinear")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_INT32))
|
||||
.Attr("align_corners", align_corners_)
|
||||
.Attr("half_pixel_centers", half_pixel_centers_)
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
const TensorShape shape(
|
||||
{/*batch_size*/ 1, input_width, input_height, num_channels});
|
||||
SetRandomImageInput(shape);
|
||||
AddInputFromArray<int32>(TensorShape({2}), {output_width, output_height});
|
||||
}
|
||||
|
||||
using ResizeBilinearOpTest::RunOpKernel;
|
||||
};
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
|
||||
void BM_Resize(benchmark::State& state) {
|
||||
ResizeBM bench;
|
||||
bench.SetUpBenchmark(640, 480, 3, 1024, 768);
|
||||
for (const auto _ : state) {
|
||||
CHECK(bench.RunOpKernel().ok());
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_Resize);
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user