Optimize bilinear resizing using vectorization.
``` name old time/op new time/op delta BM_Resize 4.04ms ± 2% 2.84ms ± 2% -29.71% (p=0.000 n=10+10) ``` PiperOrigin-RevId: 315219756 Change-Id: Idcdc00c31199060c67665aeb52b24f495664dbdf
This commit is contained in:
		
							parent
							
								
									cef0c8cf71
								
							
						
					
					
						commit
						f65efd739c
					
				@ -18,6 +18,10 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/kernels/resize_bilinear_op.h"
 | 
			
		||||
 | 
			
		||||
#ifdef __SSE4_1__
 | 
			
		||||
#include <xmmintrin.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 | 
			
		||||
#include "tensorflow/core/framework/op_kernel.h"
 | 
			
		||||
@ -107,6 +111,107 @@ inline float compute_lerp(const float top_left, const float top_right,
 | 
			
		||||
  return top + (bottom - top) * y_lerp;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#ifdef __SSE4_1__
 | 
			
		||||
/* Vector version of the above */
 | 
			
		||||
inline __m128 compute_lerp_v(const __m128 top_left, const __m128 top_right,
 | 
			
		||||
                             const __m128 bottom_left,
 | 
			
		||||
                             const __m128 bottom_right, const __m128 x_lerp,
 | 
			
		||||
                             const __m128 y_lerp) {
 | 
			
		||||
  const __m128 top =
 | 
			
		||||
      _mm_add_ps(top_left, _mm_mul_ps(_mm_sub_ps(top_right, top_left), x_lerp));
 | 
			
		||||
  const __m128 bottom = _mm_add_ps(
 | 
			
		||||
      bottom_left, _mm_mul_ps(_mm_sub_ps(bottom_right, bottom_left), x_lerp));
 | 
			
		||||
  return _mm_add_ps(top, _mm_mul_ps(_mm_sub_ps(bottom, top), y_lerp));
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
void ResizeLine3Channels(const T* const ys_input_lower_ptr,
 | 
			
		||||
                         const T* const ys_input_upper_ptr,
 | 
			
		||||
                         const CachedInterpolation* const xs,
 | 
			
		||||
                         const float ys_lerp, const int64 out_width,
 | 
			
		||||
                         float* out_y) {
 | 
			
		||||
  for (int64 x = 0; x < out_width; ++x) {
 | 
			
		||||
    const int64 xs_lower = xs[x].lower;
 | 
			
		||||
    const int64 xs_upper = xs[x].upper;
 | 
			
		||||
    const float xs_lerp = xs[x].lerp;
 | 
			
		||||
 | 
			
		||||
    // Read channel 0.
 | 
			
		||||
    const float top_left0(ys_input_lower_ptr[xs_lower + 0]);
 | 
			
		||||
    const float top_right0(ys_input_lower_ptr[xs_upper + 0]);
 | 
			
		||||
    const float bottom_left0(ys_input_upper_ptr[xs_lower + 0]);
 | 
			
		||||
    const float bottom_right0(ys_input_upper_ptr[xs_upper + 0]);
 | 
			
		||||
 | 
			
		||||
    // Read channel 1.
 | 
			
		||||
    const float top_left1(ys_input_lower_ptr[xs_lower + 1]);
 | 
			
		||||
    const float top_right1(ys_input_lower_ptr[xs_upper + 1]);
 | 
			
		||||
    const float bottom_left1(ys_input_upper_ptr[xs_lower + 1]);
 | 
			
		||||
    const float bottom_right1(ys_input_upper_ptr[xs_upper + 1]);
 | 
			
		||||
 | 
			
		||||
    // Read channel 2.
 | 
			
		||||
    const float top_left2(ys_input_lower_ptr[xs_lower + 2]);
 | 
			
		||||
    const float top_right2(ys_input_lower_ptr[xs_upper + 2]);
 | 
			
		||||
    const float bottom_left2(ys_input_upper_ptr[xs_lower + 2]);
 | 
			
		||||
    const float bottom_right2(ys_input_upper_ptr[xs_upper + 2]);
 | 
			
		||||
 | 
			
		||||
    // Compute output.
 | 
			
		||||
    out_y[x * 3 + 0] = compute_lerp(top_left0, top_right0, bottom_left0,
 | 
			
		||||
                                    bottom_right0, xs_lerp, ys_lerp);
 | 
			
		||||
    out_y[x * 3 + 1] = compute_lerp(top_left1, top_right1, bottom_left1,
 | 
			
		||||
                                    bottom_right1, xs_lerp, ys_lerp);
 | 
			
		||||
    out_y[x * 3 + 2] = compute_lerp(top_left2, top_right2, bottom_left2,
 | 
			
		||||
                                    bottom_right2, xs_lerp, ys_lerp);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#ifdef __SSE4_1__
 | 
			
		||||
 | 
			
		||||
// Load 3 floats from the given buffer, which must be of size at least 4.
 | 
			
		||||
template <typename T>
 | 
			
		||||
inline __m128 load_3xfloat_v(T* values) {
 | 
			
		||||
  return _mm_set_ps(0.0f, static_cast<float>(values[2]),
 | 
			
		||||
                    static_cast<float>(values[1]),
 | 
			
		||||
                    static_cast<float>(values[0]));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Specialize cases that can be done more efficiently.
 | 
			
		||||
template <>
 | 
			
		||||
inline __m128 load_3xfloat_v(float* values) {
 | 
			
		||||
  return _mm_loadu_ps(values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
void ResizeLine3ChannelsVector(const T* const ys_input_lower_ptr,
 | 
			
		||||
                               const T* const ys_input_upper_ptr,
 | 
			
		||||
                               const CachedInterpolation* const xs,
 | 
			
		||||
                               const float ys_lerp, const int64 out_width,
 | 
			
		||||
                               float* out_y) {
 | 
			
		||||
  const __m128 ys_lerp_v = _mm_set1_ps(ys_lerp);
 | 
			
		||||
  // All pixels but the last one can overflow, vectorize the inside of the
 | 
			
		||||
  // row.
 | 
			
		||||
  int64 x = 0;
 | 
			
		||||
  for (x = 0; x < out_width - 1; ++x) {
 | 
			
		||||
    const int64 xs_lower = xs[x].lower;
 | 
			
		||||
    const int64 xs_upper = xs[x].upper;
 | 
			
		||||
    const __m128 xs_lerp_v = _mm_set1_ps(xs[x].lerp);
 | 
			
		||||
 | 
			
		||||
    const __m128 top_left_v = load_3xfloat_v(ys_input_lower_ptr + xs_lower);
 | 
			
		||||
    const __m128 top_right_v = load_3xfloat_v(ys_input_lower_ptr + xs_upper);
 | 
			
		||||
    const __m128 bottom_left_v = load_3xfloat_v(ys_input_upper_ptr + xs_lower);
 | 
			
		||||
    const __m128 bottom_right_v = load_3xfloat_v(ys_input_upper_ptr + xs_upper);
 | 
			
		||||
 | 
			
		||||
    _mm_storeu_ps(out_y + x * 3,
 | 
			
		||||
                  compute_lerp_v(top_left_v, top_right_v, bottom_left_v,
 | 
			
		||||
                                 bottom_right_v, xs_lerp_v, ys_lerp_v));
 | 
			
		||||
  }
 | 
			
		||||
  // The last pixel of each row must be done in a non-vectorized way
 | 
			
		||||
  // because we cannot overflow.
 | 
			
		||||
  ResizeLine3Channels(ys_input_lower_ptr, ys_input_upper_ptr,
 | 
			
		||||
                      xs + out_width - 1, ys_lerp, 1,
 | 
			
		||||
                      out_y + (out_width - 1) * 3);
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
void resize_image(
 | 
			
		||||
    typename TTypes<T, 4>::ConstTensor images, const int batch_size,
 | 
			
		||||
@ -136,41 +241,13 @@ void resize_image(typename TTypes<T, 4>::ConstTensor images,
 | 
			
		||||
      for (int64 y = 0; y < out_height; ++y) {
 | 
			
		||||
        const T* ys_input_lower_ptr = input_b_ptr + ys[y].lower * in_row_size;
 | 
			
		||||
        const T* ys_input_upper_ptr = input_b_ptr + ys[y].upper * in_row_size;
 | 
			
		||||
        const float ys_lerp = ys[y].lerp;
 | 
			
		||||
        for (int64 x = 0; x < out_width; ++x) {
 | 
			
		||||
          const int64 xs_lower = xs[x].lower;
 | 
			
		||||
          const int64 xs_upper = xs[x].upper;
 | 
			
		||||
          const float xs_lerp = xs[x].lerp;
 | 
			
		||||
 | 
			
		||||
          // Read channel 0.
 | 
			
		||||
          const float top_left0(ys_input_lower_ptr[xs_lower + 0]);
 | 
			
		||||
          const float top_right0(ys_input_lower_ptr[xs_upper + 0]);
 | 
			
		||||
          const float bottom_left0(ys_input_upper_ptr[xs_lower + 0]);
 | 
			
		||||
          const float bottom_right0(ys_input_upper_ptr[xs_upper + 0]);
 | 
			
		||||
 | 
			
		||||
          // Read channel 1.
 | 
			
		||||
          const float top_left1(ys_input_lower_ptr[xs_lower + 1]);
 | 
			
		||||
          const float top_right1(ys_input_lower_ptr[xs_upper + 1]);
 | 
			
		||||
          const float bottom_left1(ys_input_upper_ptr[xs_lower + 1]);
 | 
			
		||||
          const float bottom_right1(ys_input_upper_ptr[xs_upper + 1]);
 | 
			
		||||
 | 
			
		||||
          // Read channel 2.
 | 
			
		||||
          const float top_left2(ys_input_lower_ptr[xs_lower + 2]);
 | 
			
		||||
          const float top_right2(ys_input_lower_ptr[xs_upper + 2]);
 | 
			
		||||
          const float bottom_left2(ys_input_upper_ptr[xs_lower + 2]);
 | 
			
		||||
          const float bottom_right2(ys_input_upper_ptr[xs_upper + 2]);
 | 
			
		||||
 | 
			
		||||
          // Compute output.
 | 
			
		||||
          output_y_ptr[x * channels + 0] =
 | 
			
		||||
              compute_lerp(top_left0, top_right0, bottom_left0, bottom_right0,
 | 
			
		||||
                           xs_lerp, ys_lerp);
 | 
			
		||||
          output_y_ptr[x * channels + 1] =
 | 
			
		||||
              compute_lerp(top_left1, top_right1, bottom_left1, bottom_right1,
 | 
			
		||||
                           xs_lerp, ys_lerp);
 | 
			
		||||
          output_y_ptr[x * channels + 2] =
 | 
			
		||||
              compute_lerp(top_left2, top_right2, bottom_left2, bottom_right2,
 | 
			
		||||
                           xs_lerp, ys_lerp);
 | 
			
		||||
        }
 | 
			
		||||
#ifdef __SSE4_1__
 | 
			
		||||
        ResizeLine3ChannelsVector(ys_input_lower_ptr, ys_input_upper_ptr, xs,
 | 
			
		||||
                                  ys[y].lerp, out_width, output_y_ptr);
 | 
			
		||||
#else
 | 
			
		||||
        ResizeLine3Channels(ys_input_lower_ptr, ys_input_upper_ptr, xs,
 | 
			
		||||
                            ys[y].lerp, out_width, output_y_ptr);
 | 
			
		||||
#endif
 | 
			
		||||
        output_y_ptr += out_row_size;
 | 
			
		||||
      }
 | 
			
		||||
      input_b_ptr += in_batch_num_values;
 | 
			
		||||
@ -338,6 +415,7 @@ struct ResizeBilinearGrad<CPUDevice, T> {
 | 
			
		||||
              static_cast<Eigen::Index>(ceilf(in_x)), original_width - 1);
 | 
			
		||||
          const float x_lerp = in_x - floorf(in_x);
 | 
			
		||||
          const float inverse_x_lerp = (1.0f - x_lerp);
 | 
			
		||||
          // TODO(b/158287314): Look into vectorizing this.
 | 
			
		||||
          for (Eigen::Index c = 0; c < channels; ++c) {
 | 
			
		||||
            output_grad(b, top_y_index, left_x_index, c) +=
 | 
			
		||||
                T(input_grad(b, y, x, c) * inverse_y_lerp * inverse_x_lerp);
 | 
			
		||||
 | 
			
		||||
@ -28,6 +28,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/lib/random/random.h"
 | 
			
		||||
#include "tensorflow/core/lib/strings/str_util.h"
 | 
			
		||||
#include "tensorflow/core/platform/test.h"
 | 
			
		||||
#include "tensorflow/core/platform/test_benchmark.h"
 | 
			
		||||
#include "tensorflow/core/public/session_options.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
@ -543,4 +544,39 @@ INSTANTIATE_TEST_SUITE_P(ResizeBilinearOpAlignCornersTestGpu,
 | 
			
		||||
                         ResizeBilinearOpAlignCornersTest,
 | 
			
		||||
                         ::testing::Values(TestDevice::GPU));
 | 
			
		||||
#endif  // GOOGLE_CUDA
 | 
			
		||||
 | 
			
		||||
class ResizeBM : public ResizeBilinearOpTest {
 | 
			
		||||
 public:
 | 
			
		||||
  void TestBody() override {}
 | 
			
		||||
  void SetUpBenchmark(int input_width, int input_height, int num_channels,
 | 
			
		||||
                      int output_width, int output_height) {
 | 
			
		||||
    TF_EXPECT_OK(NodeDefBuilder("resize_bilinear_op", "ResizeBilinear")
 | 
			
		||||
                     .Input(FakeInput(DT_FLOAT))
 | 
			
		||||
                     .Input(FakeInput(DT_INT32))
 | 
			
		||||
                     .Attr("align_corners", align_corners_)
 | 
			
		||||
                     .Attr("half_pixel_centers", half_pixel_centers_)
 | 
			
		||||
                     .Finalize(node_def()));
 | 
			
		||||
    TF_EXPECT_OK(InitOp());
 | 
			
		||||
    const TensorShape shape(
 | 
			
		||||
        {/*batch_size*/ 1, input_width, input_height, num_channels});
 | 
			
		||||
    SetRandomImageInput(shape);
 | 
			
		||||
    AddInputFromArray<int32>(TensorShape({2}), {output_width, output_height});
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  using ResizeBilinearOpTest::RunOpKernel;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#ifdef PLATFORM_GOOGLE
 | 
			
		||||
 | 
			
		||||
void BM_Resize(benchmark::State& state) {
 | 
			
		||||
  ResizeBM bench;
 | 
			
		||||
  bench.SetUpBenchmark(640, 480, 3, 1024, 768);
 | 
			
		||||
  for (const auto _ : state) {
 | 
			
		||||
    CHECK(bench.RunOpKernel().ok());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Resize);
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user