TFLite resize bilinear opt: General (non-Neon) x8 optimization.

PiperOrigin-RevId: 358936877
Change-Id: Id513b6e9ffbabd008d7858cfc55962d95aead550
This commit is contained in:
Alex Stark 2021-02-22 17:08:22 -08:00 committed by TensorFlower Gardener
parent 04a88b52d3
commit f6cb18c80a
2 changed files with 501 additions and 5 deletions

View File

@ -35,6 +35,472 @@ limitations under the License.
namespace tflite {
namespace optimized_ops {
namespace resize_bilinear {
// Optimized resize-bilinear for the special case where the scaling is x8 in
// width and height, and where we can operate on depth-8 blocks at a time. So
// the output blocks are 8x8x8 in width-height-depth.
//
// This optimization is for the half_pixel_centers == true version, for uint8,
// for non-NEON compilations.
inline void ResizeBilinear888Uint8(int32 batches, int32 input_height,
int32 input_width, int32 depth,
const uint8* input_data,
uint8* output_data) {
TFLITE_DCHECK_GE(input_height, 1);
TFLITE_DCHECK_GE(input_width, 1);
TFLITE_DCHECK_EQ(depth % 8, 0);
const int32 input_row_stride = input_width * depth;
const int32 output_row_stride = input_row_stride * 8;
for (int b = 0; b < batches; ++b) {
const uint8* input_base_ptr =
input_data + b * input_row_stride * input_height;
uint8* output_base_ptr =
output_data + b * output_row_stride * input_height * 8;
for (int c_block = 0; c_block < depth; c_block += 8) {
uint8 output_data[8];
uint16 accum[8];
// Top-left margin corner.
for (int c = 0; c < 8; ++c) {
output_data[c] = input_base_ptr[c_block + c];
output_base_ptr[c_block + c] = output_data[c];
output_base_ptr[c_block + c + depth] = output_data[c];
output_base_ptr[c_block + c + depth * 2] = output_data[c];
output_base_ptr[c_block + c + depth * 3] = output_data[c];
// Accumulate in 8.8 representation, pre-adding 0.5 for later rounding.
accum[c] =
(output_data[c] << 8) + 128; // 128 = 0.5 in 8.8 representation.
}
// Top-centre margin.
uint16 wdelta[8];
uint16 wdelta_twice[8];
for (int j = 0; j < (input_width - 1); ++j) {
for (int c = 0; c < 8; ++c) {
wdelta[c] = static_cast<uint16>(
input_base_ptr[c_block + c + depth * (j + 1)] -
input_base_ptr[c_block + c + depth * j])
<< 4;
wdelta_twice[c] = wdelta[c] << 1;
accum[c] += wdelta[c];
output_base_ptr[c_block + c + depth * j * 8 + depth * 4] =
accum[c] >> 8;
for (int p = 1; p < 8; ++p) {
accum[c] += wdelta_twice[c];
output_base_ptr[c_block + c + depth * j * 8 + depth * p +
depth * 4] = accum[c] >> 8;
}
accum[c] += wdelta[c];
}
}
// Top-right margin corner.
for (int c = 0; c < 8; ++c) {
// Accumulations have pre-added 0.5 for rounding, but that is just
// discarded and this just avoids re-loading.
output_data[c] = accum[c] >> 8;
TFLITE_DCHECK_EQ(
output_data[c],
input_base_ptr[c_block + c + depth * (input_width - 1)]);
output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
depth * 4] = output_data[c];
output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
depth * 4 + depth] = output_data[c];
output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
depth * 4 + depth * 2] = output_data[c];
output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
depth * 4 + depth * 3] = output_data[c];
}
}
// Fill out remainder of top margin.
std::memcpy(output_base_ptr + output_row_stride, output_base_ptr,
output_row_stride * sizeof(uint8));
std::memcpy(output_base_ptr + output_row_stride * 2, output_base_ptr,
output_row_stride * sizeof(uint8));
std::memcpy(output_base_ptr + output_row_stride * 3, output_base_ptr,
output_row_stride * sizeof(uint8));
output_base_ptr += output_row_stride * 4;
// Main rows.
for (int k = 0; k < (input_height - 1); ++k) {
for (int c_block = 0; c_block < depth; c_block += 8) {
uint8* output_base_ptr_0 = output_base_ptr;
uint8* output_base_ptr_1;
uint8* output_base_ptr_2;
uint8* output_base_ptr_3;
uint8* output_base_ptr_4;
uint8* output_base_ptr_5;
uint8* output_base_ptr_6;
uint8* output_base_ptr_7;
uint16 accum_0[8];
uint16 accum_1[8];
uint16 accum_2[8];
uint16 accum_3[8];
uint16 accum_4[8];
uint16 accum_5[8];
uint16 accum_6[8];
uint16 accum_7[8];
// We prefer accum_0[c], etc, in sense of packed-data array for
// register. However the compiler will not reliably optimize for an
// array, and so we do most of the work in pure scalar variables.
uint16 accum_0_c;
uint16 accum_1_c;
uint16 accum_2_c;
uint16 accum_3_c;
uint16 accum_4_c;
uint16 accum_5_c;
uint16 accum_6_c;
uint16 accum_7_c;
int16 hdelta_c;
int16 hdelta_twice_c;
// Left margin for 8 rows.
for (int c = 0; c < 8; ++c) {
hdelta_c = static_cast<uint16>(
input_base_ptr[c_block + c + input_row_stride] -
input_base_ptr[c_block + c])
<< 4;
// Accumulate in 8.8 representation, pre-adding 0.5 for later
// rounding.
accum_0_c = (input_base_ptr[c_block + c] << 8) + 128;
accum_0_c += hdelta_c;
output_base_ptr_0[c_block + c] = accum_0_c >> 8;
output_base_ptr_0[c_block + c + depth] = accum_0_c >> 8;
output_base_ptr_0[c_block + c + depth * 2] = accum_0_c >> 8;
output_base_ptr_0[c_block + c + depth * 3] = accum_0_c >> 8;
hdelta_twice_c = hdelta_c << 1;
output_base_ptr_1 = output_base_ptr_0 + output_row_stride;
accum_1_c = accum_0_c + hdelta_twice_c;
output_base_ptr_1[c_block + c] = accum_1_c >> 8;
output_base_ptr_1[c_block + c + depth] = accum_1_c >> 8;
output_base_ptr_1[c_block + c + depth * 2] = accum_1_c >> 8;
output_base_ptr_1[c_block + c + depth * 3] = accum_1_c >> 8;
output_base_ptr_2 = output_base_ptr_1 + output_row_stride;
accum_2_c = accum_1_c + hdelta_twice_c;
output_base_ptr_2[c_block + c] = accum_2_c >> 8;
output_base_ptr_2[c_block + c + depth] = accum_2_c >> 8;
output_base_ptr_2[c_block + c + depth * 2] = accum_2_c >> 8;
output_base_ptr_2[c_block + c + depth * 3] = accum_2_c >> 8;
output_base_ptr_3 = output_base_ptr_2 + output_row_stride;
accum_3_c = accum_2_c + hdelta_twice_c;
output_base_ptr_3[c_block + c] = accum_3_c >> 8;
output_base_ptr_3[c_block + c + depth] = accum_3_c >> 8;
output_base_ptr_3[c_block + c + depth * 2] = accum_3_c >> 8;
output_base_ptr_3[c_block + c + depth * 3] = accum_3_c >> 8;
output_base_ptr_4 = output_base_ptr_3 + output_row_stride;
accum_4_c = accum_3_c + hdelta_twice_c;
output_base_ptr_4[c_block + c] = accum_4_c >> 8;
output_base_ptr_4[c_block + c + depth] = accum_4_c >> 8;
output_base_ptr_4[c_block + c + depth * 2] = accum_4_c >> 8;
output_base_ptr_4[c_block + c + depth * 3] = accum_4_c >> 8;
output_base_ptr_5 = output_base_ptr_4 + output_row_stride;
accum_5_c = accum_4_c + hdelta_twice_c;
output_base_ptr_5[c_block + c] = accum_5_c >> 8;
output_base_ptr_5[c_block + c + depth] = accum_5_c >> 8;
output_base_ptr_5[c_block + c + depth * 2] = accum_5_c >> 8;
output_base_ptr_5[c_block + c + depth * 3] = accum_5_c >> 8;
output_base_ptr_6 = output_base_ptr_5 + output_row_stride;
accum_6_c = accum_5_c + hdelta_twice_c;
output_base_ptr_6[c_block + c] = accum_6_c >> 8;
output_base_ptr_6[c_block + c + depth] = accum_6_c >> 8;
output_base_ptr_6[c_block + c + depth * 2] = accum_6_c >> 8;
output_base_ptr_6[c_block + c + depth * 3] = accum_6_c >> 8;
output_base_ptr_7 = output_base_ptr_6 + output_row_stride;
accum_7_c = accum_6_c + hdelta_twice_c;
output_base_ptr_7[c_block + c] = accum_7_c >> 8;
output_base_ptr_7[c_block + c + depth] = accum_7_c >> 8;
output_base_ptr_7[c_block + c + depth * 2] = accum_7_c >> 8;
output_base_ptr_7[c_block + c + depth * 3] = accum_7_c >> 8;
accum_0[c] = accum_0_c;
accum_1[c] = accum_1_c;
accum_2[c] = accum_2_c;
accum_3[c] = accum_3_c;
accum_4[c] = accum_4_c;
accum_5[c] = accum_5_c;
accum_6[c] = accum_6_c;
accum_7[c] = accum_7_c;
}
// Main central body.
int16 wdelta_c;
int16 wdelta_twice_c;
int16 hwdelta_c;
int16 hwdelta_twice_c;
int16 incr_0_c;
int16 incr_1_c;
int16 incr_2_c;
int16 incr_3_c;
int16 incr_4_c;
int16 incr_5_c;
int16 incr_6_c;
int16 incr_7_c;
for (int j = 0; j < (input_width - 1); ++j) {
for (int c = 0; c < 8; ++c) {
accum_0_c = accum_0[c];
accum_1_c = accum_1[c];
accum_2_c = accum_2[c];
accum_3_c = accum_3[c];
accum_4_c = accum_4[c];
accum_5_c = accum_5[c];
accum_6_c = accum_6[c];
accum_7_c = accum_7[c];
wdelta_c = static_cast<uint16>(
input_base_ptr[c_block + c + depth * (j + 1)] -
input_base_ptr[c_block + c + depth * j])
<< 4;
wdelta_twice_c = wdelta_c << 1;
hwdelta_c = static_cast<uint16>(
input_base_ptr[c_block + c + depth * (j + 1) +
input_row_stride] -
input_base_ptr[c_block + c + depth * (j + 1)] -
input_base_ptr[c_block + c + depth * j + input_row_stride] +
input_base_ptr[c_block + c + depth * j]);
hwdelta_twice_c = hwdelta_c << 1;
uint16 incr_base = wdelta_c + hwdelta_c;
accum_0_c += incr_base;
output_base_ptr_0[c_block + c + depth * j * 8 + depth * 4] =
accum_0_c >> 8;
incr_0_c = incr_base << 1;
incr_base += hwdelta_twice_c;
accum_1_c += incr_base;
output_base_ptr_1[c_block + c + depth * j * 8 + depth * 4] =
accum_1_c >> 8;
incr_1_c = incr_base << 1;
incr_base += hwdelta_twice_c;
accum_2_c += incr_base;
output_base_ptr_2[c_block + c + depth * j * 8 + depth * 4] =
accum_2_c >> 8;
incr_2_c = incr_base << 1;
incr_base += hwdelta_twice_c;
accum_3_c += incr_base;
output_base_ptr_3[c_block + c + depth * j * 8 + depth * 4] =
accum_3_c >> 8;
incr_3_c = incr_base << 1;
incr_base += hwdelta_twice_c;
accum_4_c += incr_base;
output_base_ptr_4[c_block + c + depth * j * 8 + depth * 4] =
accum_4_c >> 8;
incr_4_c = incr_base << 1;
incr_base += hwdelta_twice_c;
accum_5_c += incr_base;
output_base_ptr_5[c_block + c + depth * j * 8 + depth * 4] =
accum_5_c >> 8;
incr_5_c = incr_base << 1;
incr_base += hwdelta_twice_c;
accum_6_c += incr_base;
output_base_ptr_6[c_block + c + depth * j * 8 + depth * 4] =
accum_6_c >> 8;
incr_6_c = incr_base << 1;
incr_base += hwdelta_twice_c;
accum_7_c += incr_base;
output_base_ptr_7[c_block + c + depth * j * 8 + depth * 4] =
accum_7_c >> 8;
incr_7_c = incr_base << 1;
for (int p = 1; p < 8; ++p) {
accum_0_c += incr_0_c;
output_base_ptr_0[c_block + c + depth * j * 8 + depth * p +
depth * 4] = accum_0_c >> 8;
accum_1_c += incr_1_c;
output_base_ptr_1[c_block + c + depth * j * 8 + depth * p +
depth * 4] = accum_1_c >> 8;
accum_2_c += incr_2_c;
output_base_ptr_2[c_block + c + depth * j * 8 + depth * p +
depth * 4] = accum_2_c >> 8;
accum_3_c += incr_3_c;
output_base_ptr_3[c_block + c + depth * j * 8 + depth * p +
depth * 4] = accum_3_c >> 8;
accum_4_c += incr_4_c;
output_base_ptr_4[c_block + c + depth * j * 8 + depth * p +
depth * 4] = accum_4_c >> 8;
accum_5_c += incr_5_c;
output_base_ptr_5[c_block + c + depth * j * 8 + depth * p +
depth * 4] = accum_5_c >> 8;
accum_6_c += incr_6_c;
output_base_ptr_6[c_block + c + depth * j * 8 + depth * p +
depth * 4] = accum_6_c >> 8;
accum_7_c += incr_7_c;
output_base_ptr_7[c_block + c + depth * j * 8 + depth * p +
depth * 4] = accum_7_c >> 8;
}
accum_0_c += incr_0_c / 2;
accum_1_c += incr_1_c / 2;
accum_2_c += incr_2_c / 2;
accum_3_c += incr_3_c / 2;
accum_4_c += incr_4_c / 2;
accum_5_c += incr_5_c / 2;
accum_6_c += incr_6_c / 2;
accum_7_c += incr_7_c / 2;
accum_0[c] = accum_0_c;
accum_1[c] = accum_1_c;
accum_2[c] = accum_2_c;
accum_3[c] = accum_3_c;
accum_4[c] = accum_4_c;
accum_5[c] = accum_5_c;
accum_6[c] = accum_6_c;
accum_7[c] = accum_7_c;
}
}
// Right margin.
uint8 output_data_0_c;
uint8 output_data_1_c;
uint8 output_data_2_c;
uint8 output_data_3_c;
uint8 output_data_4_c;
uint8 output_data_5_c;
uint8 output_data_6_c;
uint8 output_data_7_c;
for (int c = 0; c < 8; ++c) {
accum_0_c = accum_0[c];
accum_1_c = accum_1[c];
accum_2_c = accum_2[c];
accum_3_c = accum_3[c];
accum_4_c = accum_4[c];
accum_5_c = accum_5[c];
accum_6_c = accum_6[c];
accum_7_c = accum_7[c];
// Accumulations have pre-added 0.5 for rounding, but that is just
// discarded and this just avoids re-loading.
output_data_0_c = accum_0_c >> 8;
output_data_1_c = accum_1_c >> 8;
output_data_2_c = accum_2_c >> 8;
output_data_3_c = accum_3_c >> 8;
output_data_4_c = accum_4_c >> 8;
output_data_5_c = accum_5_c >> 8;
output_data_6_c = accum_6_c >> 8;
output_data_7_c = accum_7_c >> 8;
for (int p = 0; p < 4; ++p) {
output_base_ptr_0[c_block + c + depth * (input_width - 1) * 8 +
depth * 4 + depth * p] = output_data_0_c;
output_base_ptr_1[c_block + c + depth * (input_width - 1) * 8 +
depth * 4 + depth * p] = output_data_1_c;
output_base_ptr_2[c_block + c + depth * (input_width - 1) * 8 +
depth * 4 + depth * p] = output_data_2_c;
output_base_ptr_3[c_block + c + depth * (input_width - 1) * 8 +
depth * 4 + depth * p] = output_data_3_c;
output_base_ptr_4[c_block + c + depth * (input_width - 1) * 8 +
depth * 4 + depth * p] = output_data_4_c;
output_base_ptr_5[c_block + c + depth * (input_width - 1) * 8 +
depth * 4 + depth * p] = output_data_5_c;
output_base_ptr_6[c_block + c + depth * (input_width - 1) * 8 +
depth * 4 + depth * p] = output_data_6_c;
output_base_ptr_7[c_block + c + depth * (input_width - 1) * 8 +
depth * 4 + depth * p] = output_data_7_c;
}
accum_0[c] = accum_0_c;
accum_1[c] = accum_1_c;
accum_2[c] = accum_2_c;
accum_3[c] = accum_3_c;
accum_4[c] = accum_4_c;
accum_5[c] = accum_5_c;
accum_6[c] = accum_6_c;
accum_7[c] = accum_7_c;
}
}
output_base_ptr += output_row_stride * 8;
input_base_ptr += input_row_stride;
}
for (int c_block = 0; c_block < depth; c_block += 8) {
uint8 output_data[8];
uint16 accum[8];
// Bottom-left margin corner.
for (int c = 0; c < 8; ++c) {
output_data[c] = input_base_ptr[c_block + c];
output_base_ptr[c_block + c] = output_data[c];
output_base_ptr[c_block + c + depth] = output_data[c];
output_base_ptr[c_block + c + depth * 2] = output_data[c];
output_base_ptr[c_block + c + depth * 3] = output_data[c];
// Accumulate in 8.8 representation, pre-adding 0.5 for later rounding.
accum[c] =
(output_data[c] << 8) + 128; // 128 = 0.5 in 8.8 representation.
}
// Bottom-centre margin.
uint16 wdelta[8];
uint16 wdelta_twice[8];
for (int j = 0; j < (input_width - 1); ++j) {
for (int c = 0; c < 8; ++c) {
wdelta[c] = static_cast<uint16>(
input_base_ptr[c_block + c + depth * (j + 1)] -
input_base_ptr[c_block + c + depth * j])
<< 4;
wdelta_twice[c] = wdelta[c] << 1;
accum[c] += wdelta[c];
output_base_ptr[c_block + c + depth * j * 8 + depth * 4] =
accum[c] >> 8;
for (int p = 1; p < 8; ++p) {
accum[c] += wdelta_twice[c];
output_base_ptr[c_block + c + depth * j * 8 + depth * p +
depth * 4] = accum[c] >> 8;
}
accum[c] += wdelta[c];
}
}
// Bottom-right margin corner.
for (int c = 0; c < 8; ++c) {
// Accumulations have pre-added 0.5 for rounding, but that is just
// discarded and this just avoids re-loading.
output_data[c] = accum[c] >> 8;
TFLITE_DCHECK_EQ(
output_data[c],
input_base_ptr[c_block + c + depth * (input_width - 1)]);
output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
depth * 4] = output_data[c];
output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
depth * 4 + depth] = output_data[c];
output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
depth * 4 + depth * 2] = output_data[c];
output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
depth * 4 + depth * 3] = output_data[c];
}
}
// Fill out remainder of bottom margin.
std::memcpy(output_base_ptr + output_row_stride, output_base_ptr,
output_row_stride * sizeof(uint8));
std::memcpy(output_base_ptr + output_row_stride * 2, output_base_ptr,
output_row_stride * sizeof(uint8));
std::memcpy(output_base_ptr + output_row_stride * 3, output_base_ptr,
output_row_stride * sizeof(uint8));
}
} // NOLINT(readability/fn_size)
} // namespace resize_bilinear
#ifdef USE_NEON
inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
@ -474,7 +940,7 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
const int32* output_size_data,
const RuntimeShape& unextended_output_shape,
uint8* output_data) {
ruy::profiler::ScopeLabel label("ResizeBilinear");
ruy::profiler::ScopeLabel label("ResizeBilinearUint8");
// If half_pixel_centers is True, align_corners must be False.
TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
@ -493,6 +959,22 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
int32 output_height = output_size_data[0];
int32 output_width = output_size_data[1];
if (!op_params.align_corners && op_params.half_pixel_centers &&
((depth % 8) == 0)) {
const int32 scale = output_height / input_height;
// Restricting the minimum output dimensions may not be necessary, but
// ensures that kernels can use unrolling with minimal code size.
if ((output_height >= 8) && (output_width >= 8) &&
((input_height * scale) == output_height) &&
((input_width * scale) == output_width)) {
if (scale == 8) {
resize_bilinear::ResizeBilinear888Uint8(
batches, input_height, input_width, depth, input_data, output_data);
return;
}
}
}
float height_scale =
(op_params.align_corners && output_height > 1)
? (static_cast<float>(input_height - 1) / (output_height - 1))

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm>
#include <cmath>
#include <list>
#include <type_traits>
#include <typeinfo>
#include <vector>
@ -64,6 +65,13 @@ void TestOneResizeBilinear(const tflite::ResizeBilinearParams& op_params,
op_params, input_dims_inference, input_data.data(), output_size_dims,
output_size_data.data(), output_dims_inference, output_data.data());
bool strict_match = false;
if (std::is_same<T, uint8>::value && ((depth % 8) == 0) &&
((input_width * 8) == output_width) &&
((input_height * 8) == output_height)) {
strict_match = true;
}
double sum_diff = 0;
float max_abs_val = 0;
for (int i = 0; i < output_buffer_size; i++) {
@ -73,10 +81,16 @@ void TestOneResizeBilinear(const tflite::ResizeBilinearParams& op_params,
max_abs_val, std::abs(static_cast<float>(reference_output_data[i])));
}
if (sum_diff != 0.f) {
const float mean_diff = static_cast<float>(sum_diff / output_buffer_size);
const float relative_error = std::abs(mean_diff) / max_abs_val;
ASSERT_LT(relative_error, error_threshold);
if (strict_match) {
if (sum_diff > 0) {
ASSERT_EQ(sum_diff, 0);
}
} else {
if (sum_diff != 0.f) {
const float mean_diff = static_cast<float>(sum_diff / output_buffer_size);
const float relative_error = std::abs(mean_diff) / max_abs_val;
ASSERT_LT(relative_error, error_threshold);
}
}
}