Optimize the custom GEMV paths a little bit.

PiperOrigin-RevId: 248744138
This commit is contained in:
Benoit Jacob 2019-05-17 10:27:35 -07:00 committed by TensorFlower Gardener
parent ed958af041
commit 50fd5fb74e
3 changed files with 257 additions and 64 deletions

View File

@ -92,6 +92,7 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data, const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params, const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
CpuBackendContext* context) { CpuBackendContext* context) {
gemmlowp::ScopedProfilingLabel label("cpu_backend_gemm::Gemm");
ValidateParams(lhs_params, rhs_params, dst_params, params); ValidateParams(lhs_params, rhs_params, dst_params, params);
if (dst_params.cols == 1) { if (dst_params.cols == 1) {
// GEMV case: try a custom fast GEMV path. // GEMV case: try a custom fast GEMV path.
@ -100,6 +101,7 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
return; return;
} }
} }
gemmlowp::ScopedProfilingLabel label2("cpu_backend_gemm::Gemm: general GEMM");
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar, GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data, quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
dst_params, dst_data, params, context); dst_params, dst_data, params, context);

View File

@ -144,6 +144,7 @@ bool CustomGemv(
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data, const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params, const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
CpuBackendContext* context) { CpuBackendContext* context) {
gemmlowp::ScopedProfilingLabel label("cpu_backend_gemm::Gemm: CustomGemv");
using Impl = CustomGemvImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar, using Impl = CustomGemvImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
quantization_flavor>; quantization_flavor>;
if (lhs_params.rows < Impl::kKernelRows) { if (lhs_params.rows < Impl::kKernelRows) {
@ -186,8 +187,8 @@ bool CustomGemv(
// Some NEON helper functions used by CustomGemvImpl specializations below, // Some NEON helper functions used by CustomGemvImpl specializations below,
// allowing for some type genericity in them. // allowing for some type genericity in them.
inline int16x8x2_t LoadAndSubtractZeroPoint(const std::uint8_t* src, inline int16x8x2_t Load16AndSubtractZeroPoint(const std::uint8_t* src,
std::uint8_t zero_point) { std::uint8_t zero_point) {
uint8x16_t src_u8 = vld1q_u8(src); uint8x16_t src_u8 = vld1q_u8(src);
int16x8_t src_s16_0 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(src_u8))); int16x8_t src_s16_0 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(src_u8)));
int16x8_t src_s16_1 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(src_u8))); int16x8_t src_s16_1 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(src_u8)));
@ -198,8 +199,8 @@ inline int16x8x2_t LoadAndSubtractZeroPoint(const std::uint8_t* src,
return result; return result;
} }
inline int16x8x2_t LoadAndSubtractZeroPoint(const std::int8_t* src, inline int16x8x2_t Load16AndSubtractZeroPoint(const std::int8_t* src,
std::int8_t zero_point) { std::int8_t zero_point) {
int8x16_t src_s8 = vld1q_s8(src); int8x16_t src_s8 = vld1q_s8(src);
int16x8_t src_s16_0 = vmovl_s8(vget_low_s8(src_s8)); int16x8_t src_s16_0 = vmovl_s8(vget_low_s8(src_s8));
int16x8_t src_s16_1 = vmovl_s8(vget_high_s8(src_s8)); int16x8_t src_s16_1 = vmovl_s8(vget_high_s8(src_s8));
@ -210,6 +211,22 @@ inline int16x8x2_t LoadAndSubtractZeroPoint(const std::int8_t* src,
return result; return result;
} }
inline int16x8_t Load8AndSubtractZeroPoint(const std::uint8_t* src,
std::uint8_t zero_point) {
uint8x8_t src_u8 = vld1_u8(src);
int16x8_t src_s16 = vreinterpretq_s16_u16(vmovl_u8(src_u8));
int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
return vsubq_s16(src_s16, zero_point_vec);
}
inline int16x8_t Load8AndSubtractZeroPoint(const std::int8_t* src,
std::int8_t zero_point) {
int8x8_t src_s8 = vld1_s8(src);
int16x8_t src_s16 = vmovl_s8(src_s8);
int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
return vsubq_s16(src_s16, zero_point_vec);
}
inline void ClampAndStore(int32x4_t src, std::uint8_t clamp_min, inline void ClampAndStore(int32x4_t src, std::uint8_t clamp_min,
std::uint8_t clamp_max, std::uint8_t* dst) { std::uint8_t clamp_max, std::uint8_t* dst) {
// Narrow values down to 16 bit signed. // Narrow values down to 16 bit signed.
@ -288,11 +305,12 @@ struct CustomGemvImpl<LhsScalar, RhsScalar, std::int32_t, DstScalar,
const MatrixParams<RhsScalar>& rhs_params, const MatrixParams<RhsScalar>& rhs_params,
const MatrixParams<DstScalar>& dst_params, const MatrixParams<DstScalar>& dst_params,
const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params) { const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params) {
// There are no further requirements on the applicability of this kernel, // The kernel processes at least 8 LHS columns at once to fill NEON
// beyond the left-hand-side matrix having at least kKernelRows rows, // registers. The leftovers-handling code at the end works by loading a
// and the type requirements implied in this template partial // partially overlapping final register by walking back by a few (<8) values
// specialization. // to avoid running past the row's end. This relies on there being
return true; // at least 8 LHS columns.
return lhs_params.cols >= 8;
} }
static void Run( static void Run(
@ -311,6 +329,27 @@ struct CustomGemvImpl<LhsScalar, RhsScalar, std::int32_t, DstScalar,
// `row`. // `row`.
row = std::min(row, row_end - kKernelRows); row = std::min(row, row_end - kKernelRows);
const LhsScalar* filter_ptr = lhs_data + row * lhs_params.cols; const LhsScalar* filter_ptr = lhs_data + row * lhs_params.cols;
static constexpr int kCacheLineSize = 64;
for (int k = 0; k < rhs_params.rows;
k += kCacheLineSize / sizeof(RhsScalar)) {
optimized_ops_preload_l1_keep(rhs_data + k);
}
// kPreloadAhead is empirically determined.
// End-to-end latency (ms) on mobilenet_v2_0.35_96_8bit, 1 thread,
// on Qualcomm S855:
//
// kPreloadAhead | big core | little core
// --------------+----------+------------
// 64 | 1.26 | 5.45
// 128 | 1.23 | 5.01
// 256 | 1.18 | 4.9
// 512 | 1.18 | 5.45
// 1024 | 1.18 | 6.5
// no prefetch | 1.25 | 8.1
static constexpr int kPreloadAhead = 256;
// 4 accumulator registers, one for each row being processed. // 4 accumulator registers, one for each row being processed.
// Each has 4 int32 lanes that corresponds to columns modulo 4, and // Each has 4 int32 lanes that corresponds to columns modulo 4, and
// will need to be horizontally reduced at the end. // will need to be horizontally reduced at the end.
@ -322,16 +361,28 @@ struct CustomGemvImpl<LhsScalar, RhsScalar, std::int32_t, DstScalar,
// As much as possible, handle 16 columns of the left-hand side matrix // As much as possible, handle 16 columns of the left-hand side matrix
// at a time. This allows for decent NEON implementation. // at a time. This allows for decent NEON implementation.
for (; in <= lhs_params.cols - 16; in += 16) { for (; in <= lhs_params.cols - 16; in += 16) {
const LhsScalar* local_filter_ptr = filter_ptr;
int16x8x2_t input_val = int16x8x2_t input_val =
LoadAndSubtractZeroPoint(rhs_data + in, rhs_params.zero_point); Load16AndSubtractZeroPoint(rhs_data + in, rhs_params.zero_point);
int16x8x2_t filter_val_0 = LoadAndSubtractZeroPoint( int16x8x2_t filter_val_0 =
filter_ptr + 0 * lhs_params.cols, lhs_params.zero_point); Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
int16x8x2_t filter_val_1 = LoadAndSubtractZeroPoint( optimized_ops_preload_l1_stream(local_filter_ptr +
filter_ptr + 1 * lhs_params.cols, lhs_params.zero_point); kPreloadAhead / sizeof(LhsScalar));
int16x8x2_t filter_val_2 = LoadAndSubtractZeroPoint( local_filter_ptr += lhs_params.cols;
filter_ptr + 2 * lhs_params.cols, lhs_params.zero_point); int16x8x2_t filter_val_1 =
int16x8x2_t filter_val_3 = LoadAndSubtractZeroPoint( Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
filter_ptr + 3 * lhs_params.cols, lhs_params.zero_point); optimized_ops_preload_l1_stream(local_filter_ptr +
kPreloadAhead / sizeof(LhsScalar));
local_filter_ptr += lhs_params.cols;
int16x8x2_t filter_val_2 =
Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
optimized_ops_preload_l1_stream(local_filter_ptr +
kPreloadAhead / sizeof(LhsScalar));
local_filter_ptr += lhs_params.cols;
int16x8x2_t filter_val_3 =
Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
optimized_ops_preload_l1_stream(local_filter_ptr +
kPreloadAhead / sizeof(LhsScalar));
filter_ptr += 16; filter_ptr += 16;
acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0.val[0]), acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0.val[0]),
vget_low_s16(input_val.val[0])); vget_low_s16(input_val.val[0]));
@ -366,27 +417,109 @@ struct CustomGemvImpl<LhsScalar, RhsScalar, std::int32_t, DstScalar,
acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3.val[1]), acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3.val[1]),
vget_high_s16(input_val.val[1])); vget_high_s16(input_val.val[1]));
} }
// Leftovers: fewer than 16 columns remain. Very slow code, could be // Less that 16 values remain. Try to handle 8 more.
// improved upon if critical in some application. if (in <= lhs_params.cols - 8) {
int16x8_t input_val =
Load8AndSubtractZeroPoint(rhs_data + in, rhs_params.zero_point);
int16x8_t filter_val_0 = Load8AndSubtractZeroPoint(
filter_ptr + 0 * lhs_params.cols, lhs_params.zero_point);
int16x8_t filter_val_1 = Load8AndSubtractZeroPoint(
filter_ptr + 1 * lhs_params.cols, lhs_params.zero_point);
int16x8_t filter_val_2 = Load8AndSubtractZeroPoint(
filter_ptr + 2 * lhs_params.cols, lhs_params.zero_point);
int16x8_t filter_val_3 = Load8AndSubtractZeroPoint(
filter_ptr + 3 * lhs_params.cols, lhs_params.zero_point);
filter_ptr += 8;
acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0),
vget_low_s16(input_val));
acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1),
vget_low_s16(input_val));
acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2),
vget_low_s16(input_val));
acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3),
vget_low_s16(input_val));
acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
vget_high_s16(input_val));
acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
vget_high_s16(input_val));
acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
vget_high_s16(input_val));
acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
vget_high_s16(input_val));
in += 8;
}
// Less than 8 values remain. Handle the remaining values
// in one more copy of the above code handling 8, where we
// walk back a few values to be able to load 8 values without
// overrunning the buffer. This is where we make use of the requirement
// (see IsSupportedGivenSufficientlyManyRows) that there at least
// 8 LHS columns.
if (in < lhs_params.cols) { if (in < lhs_params.cols) {
int32 buf[16]; // `back` is how many entries to walk back by.
vst1q_s32(buf + 0, acc0); // Its value is necessarily between 1 and 7.
vst1q_s32(buf + 4, acc1); const int back = in + 8 - lhs_params.cols;
vst1q_s32(buf + 8, acc2); TFLITE_DCHECK_GE(back, 1);
vst1q_s32(buf + 12, acc3); TFLITE_DCHECK_LE(back, 7);
for (; in < lhs_params.cols; in++) { // Load 8 values as usual.
int lane = (in + 16 - lhs_params.cols) % 4; int16x8_t input_val = Load8AndSubtractZeroPoint(
const int32 input_val = rhs_data[in] - rhs_params.zero_point; rhs_data + lhs_params.cols - 8, rhs_params.zero_point);
for (int k = 0; k < 4; k++) { const LhsScalar* local_filter_ptr = filter_ptr - back;
int32 filter_val = lhs_data[in + (row + k) * lhs_params.cols] - filter_ptr += lhs_params.cols - in;
lhs_params.zero_point; int16x8_t filter_val_0 =
buf[lane + 4 * k] += filter_val * input_val; Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
} local_filter_ptr += lhs_params.cols;
int16x8_t filter_val_1 =
Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
local_filter_ptr += lhs_params.cols;
int16x8_t filter_val_2 =
Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
local_filter_ptr += lhs_params.cols;
int16x8_t filter_val_3 =
Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
// Now zero out the `back` first entries of input_val.
// vsetq_lane_s16 takes a literal index, so we need unrolled code.
switch (back) {
case 7:
input_val = vsetq_lane_s16(0, input_val, 6);
[[clang::fallthrough]];
case 6:
input_val = vsetq_lane_s16(0, input_val, 5);
[[clang::fallthrough]];
case 5:
input_val = vsetq_lane_s16(0, input_val, 4);
[[clang::fallthrough]];
case 4:
input_val = vsetq_lane_s16(0, input_val, 3);
[[clang::fallthrough]];
case 3:
input_val = vsetq_lane_s16(0, input_val, 2);
[[clang::fallthrough]];
case 2:
input_val = vsetq_lane_s16(0, input_val, 1);
[[clang::fallthrough]];
default:
input_val = vsetq_lane_s16(0, input_val, 0);
} }
acc0 = vld1q_s32(buf + 0); // Multiply-accumulate 8 values as usual. The `back` first lanes
acc1 = vld1q_s32(buf + 4); // of filter_val_* are junk, but it doesn't matter since they get
acc2 = vld1q_s32(buf + 8); // multiplied by the zeros that we just wrote in the corresponding
acc3 = vld1q_s32(buf + 12); // lanes of input_val.
acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0),
vget_low_s16(input_val));
acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1),
vget_low_s16(input_val));
acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2),
vget_low_s16(input_val));
acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3),
vget_low_s16(input_val));
acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
vget_high_s16(input_val));
acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
vget_high_s16(input_val));
acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
vget_high_s16(input_val));
acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
vget_high_s16(input_val));
} }
// Horizontally reduce accumulators // Horizontally reduce accumulators
@ -484,11 +617,12 @@ struct CustomGemvImpl<float, float, float, float,
const MatrixParams<float>& rhs_params, const MatrixParams<float>& rhs_params,
const MatrixParams<float>& dst_params, const MatrixParams<float>& dst_params,
const GemmParams<float, float>& params) { const GemmParams<float, float>& params) {
// There are no further requirements on the applicability of this kernel, // The kernel processes 4 LHS columns at once to fill float32x4 registers.
// beyond the left-hand-side matrix having at least kKernelRows rows, // The leftovers-handling code at the end works by loading a partially
// and the type requirements implied in this template partial // overlapping final register by walking back by a few (<4) floats
// specialization. // to avoid running past the row's end. This relies on there being
return true; // at least 4 LHS columns.
return lhs_params.cols >= 4;
} }
static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data, static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data,
const MatrixParams<float>& rhs_params, const float* rhs_data, const MatrixParams<float>& rhs_params, const float* rhs_data,
@ -505,6 +639,27 @@ struct CustomGemvImpl<float, float, float, float,
// `row`. // `row`.
row = std::min(row, row_end - kKernelRows); row = std::min(row, row_end - kKernelRows);
const float* filter_ptr = lhs_data + row * lhs_params.cols; const float* filter_ptr = lhs_data + row * lhs_params.cols;
static constexpr int kCacheLineSize = 64;
for (int k = 0; k < rhs_params.rows;
k += kCacheLineSize / sizeof(float)) {
optimized_ops_preload_l1_keep(rhs_data + k);
}
// kPreloadAhead is empirically determined.
// End-to-end latency (ms) on mobilenet_v2_0.35_96_float, 1 thread,
// on Qualcomm S855:
//
// kPreloadAhead | big core | little core
// --------------+----------+------------
// 64 | 2.4 | 15.2
// 128 | 2.15 | 12.9
// 256 | 2 | 12.9
// 512 | 2.08 | 13.3
// 1024 | 2.05 | 14.7
// no prefetch | 2.1 | 28
static constexpr int kPreloadAhead = 256;
// 4 accumulator registers, one for each row being processed. // 4 accumulator registers, one for each row being processed.
// Each has 4 float32 lanes that corresponds to columns modulo 4, and // Each has 4 float32 lanes that corresponds to columns modulo 4, and
// will need to be horizontally reduced at the end. // will need to be horizontally reduced at the end.
@ -517,36 +672,71 @@ struct CustomGemvImpl<float, float, float, float,
// at a time. This allows for decent NEON implementation. // at a time. This allows for decent NEON implementation.
for (; in <= lhs_params.cols - 4; in += 4) { for (; in <= lhs_params.cols - 4; in += 4) {
float32x4_t input_val = vld1q_f32(rhs_data + in); float32x4_t input_val = vld1q_f32(rhs_data + in);
float32x4_t filter_val_0 = vld1q_f32(filter_ptr + 0 * lhs_params.cols); const float* local_filter_ptr = filter_ptr;
float32x4_t filter_val_1 = vld1q_f32(filter_ptr + 1 * lhs_params.cols); float32x4_t filter_val_0 = vld1q_f32(local_filter_ptr);
float32x4_t filter_val_2 = vld1q_f32(filter_ptr + 2 * lhs_params.cols); optimized_ops_preload_l1_stream(local_filter_ptr +
float32x4_t filter_val_3 = vld1q_f32(filter_ptr + 3 * lhs_params.cols); kPreloadAhead / sizeof(float));
local_filter_ptr += lhs_params.cols;
float32x4_t filter_val_1 = vld1q_f32(local_filter_ptr);
optimized_ops_preload_l1_stream(local_filter_ptr +
kPreloadAhead / sizeof(float));
local_filter_ptr += lhs_params.cols;
float32x4_t filter_val_2 = vld1q_f32(local_filter_ptr);
optimized_ops_preload_l1_stream(local_filter_ptr +
kPreloadAhead / sizeof(float));
local_filter_ptr += lhs_params.cols;
float32x4_t filter_val_3 = vld1q_f32(local_filter_ptr);
optimized_ops_preload_l1_stream(local_filter_ptr +
kPreloadAhead / sizeof(float));
filter_ptr += 4; filter_ptr += 4;
acc0 = mul_add(acc0, filter_val_0, input_val); acc0 = mul_add(acc0, filter_val_0, input_val);
acc1 = mul_add(acc1, filter_val_1, input_val); acc1 = mul_add(acc1, filter_val_1, input_val);
acc2 = mul_add(acc2, filter_val_2, input_val); acc2 = mul_add(acc2, filter_val_2, input_val);
acc3 = mul_add(acc3, filter_val_3, input_val); acc3 = mul_add(acc3, filter_val_3, input_val);
} }
// Leftovers: fewer than 4 columns remain. Very slow code, could be // Less than 4 values remain. Handle the remaining values
// improved upon if critical in some application. // in one more copy of the above code handling 4, where we
// walk back a few values to be able to load 4 values without
// overrunning the buffer. This is where we make use of the requirement
// (see IsSupportedGivenSufficientlyManyRows) that there at least
// 4 LHS columns.
if (in < lhs_params.cols) { if (in < lhs_params.cols) {
float buf[16]; // `back` is how many entries to walk back by.
vst1q_f32(buf + 0, acc0); // Its value is necessarily between 1 and 3.
vst1q_f32(buf + 4, acc1); const int back = in + 4 - lhs_params.cols;
vst1q_f32(buf + 8, acc2); TFLITE_DCHECK_GE(back, 1);
vst1q_f32(buf + 12, acc3); TFLITE_DCHECK_LE(back, 3);
for (; in < lhs_params.cols; in++) { // Load 4 values as usual.
int lane = (in + 4 - lhs_params.cols) % 4; float32x4_t input_val = vld1q_f32(rhs_data + lhs_params.cols - 4);
const float input_val = rhs_data[in]; const float* local_filter_ptr = filter_ptr - back;
for (int k = 0; k < 4; k++) { filter_ptr += lhs_params.cols - in;
float filter_val = lhs_data[in + (row + k) * lhs_params.cols]; float32x4_t filter_val_0 = vld1q_f32(local_filter_ptr);
buf[lane + 4 * k] += filter_val * input_val; local_filter_ptr += lhs_params.cols;
} float32x4_t filter_val_1 = vld1q_f32(local_filter_ptr);
local_filter_ptr += lhs_params.cols;
float32x4_t filter_val_2 = vld1q_f32(local_filter_ptr);
local_filter_ptr += lhs_params.cols;
float32x4_t filter_val_3 = vld1q_f32(local_filter_ptr);
// Now zero out the `back` first entries of input_val.
// vsetq_lane_f32 takes a literal index, so we need unrolled code.
switch (back) {
case 3:
input_val = vsetq_lane_f32(0, input_val, 2);
[[clang::fallthrough]];
case 2:
input_val = vsetq_lane_f32(0, input_val, 1);
[[clang::fallthrough]];
default:
input_val = vsetq_lane_f32(0, input_val, 0);
} }
acc0 = vld1q_f32(buf + 0); // Multiply-accumulate 4 values as usual. The `back` first lanes
acc1 = vld1q_f32(buf + 4); // of filter_val_* are junk, but it doesn't matter since they get
acc2 = vld1q_f32(buf + 8); // multiplied by the zeros that we just wrote in the corresponding
acc3 = vld1q_f32(buf + 12); // lanes of input_val.
acc0 = mul_add(acc0, filter_val_0, input_val);
acc1 = mul_add(acc1, filter_val_1, input_val);
acc2 = mul_add(acc2, filter_val_2, input_val);
acc3 = mul_add(acc3, filter_val_3, input_val);
} }
// Horizontally reduce accumulators // Horizontally reduce accumulators

View File

@ -1205,6 +1205,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
uint8* output_data, const RuntimeShape& im2col_shape, uint8* output_data, const RuntimeShape& im2col_shape,
uint8* im2col_data, CpuBackendContext* cpu_backend_context) { uint8* im2col_data, CpuBackendContext* cpu_backend_context) {
gemmlowp::ScopedProfilingLabel label("Conv/8bit"); gemmlowp::ScopedProfilingLabel label("Conv/8bit");
const int stride_width = params.stride_width; const int stride_width = params.stride_width;
const int stride_height = params.stride_height; const int stride_height = params.stride_height;
const int dilation_width_factor = params.dilation_width_factor; const int dilation_width_factor = params.dilation_width_factor;