Optimize the custom GEMV paths a little bit.
PiperOrigin-RevId: 248744138
This commit is contained in:
parent
ed958af041
commit
50fd5fb74e
@ -92,6 +92,7 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
||||
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
|
||||
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
|
||||
CpuBackendContext* context) {
|
||||
gemmlowp::ScopedProfilingLabel label("cpu_backend_gemm::Gemm");
|
||||
ValidateParams(lhs_params, rhs_params, dst_params, params);
|
||||
if (dst_params.cols == 1) {
|
||||
// GEMV case: try a custom fast GEMV path.
|
||||
@ -100,6 +101,7 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
||||
return;
|
||||
}
|
||||
}
|
||||
gemmlowp::ScopedProfilingLabel label2("cpu_backend_gemm::Gemm: general GEMM");
|
||||
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||
dst_params, dst_data, params, context);
|
||||
|
@ -144,6 +144,7 @@ bool CustomGemv(
|
||||
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
|
||||
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
|
||||
CpuBackendContext* context) {
|
||||
gemmlowp::ScopedProfilingLabel label("cpu_backend_gemm::Gemm: CustomGemv");
|
||||
using Impl = CustomGemvImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
quantization_flavor>;
|
||||
if (lhs_params.rows < Impl::kKernelRows) {
|
||||
@ -186,8 +187,8 @@ bool CustomGemv(
|
||||
// Some NEON helper functions used by CustomGemvImpl specializations below,
|
||||
// allowing for some type genericity in them.
|
||||
|
||||
inline int16x8x2_t LoadAndSubtractZeroPoint(const std::uint8_t* src,
|
||||
std::uint8_t zero_point) {
|
||||
inline int16x8x2_t Load16AndSubtractZeroPoint(const std::uint8_t* src,
|
||||
std::uint8_t zero_point) {
|
||||
uint8x16_t src_u8 = vld1q_u8(src);
|
||||
int16x8_t src_s16_0 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(src_u8)));
|
||||
int16x8_t src_s16_1 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(src_u8)));
|
||||
@ -198,8 +199,8 @@ inline int16x8x2_t LoadAndSubtractZeroPoint(const std::uint8_t* src,
|
||||
return result;
|
||||
}
|
||||
|
||||
inline int16x8x2_t LoadAndSubtractZeroPoint(const std::int8_t* src,
|
||||
std::int8_t zero_point) {
|
||||
inline int16x8x2_t Load16AndSubtractZeroPoint(const std::int8_t* src,
|
||||
std::int8_t zero_point) {
|
||||
int8x16_t src_s8 = vld1q_s8(src);
|
||||
int16x8_t src_s16_0 = vmovl_s8(vget_low_s8(src_s8));
|
||||
int16x8_t src_s16_1 = vmovl_s8(vget_high_s8(src_s8));
|
||||
@ -210,6 +211,22 @@ inline int16x8x2_t LoadAndSubtractZeroPoint(const std::int8_t* src,
|
||||
return result;
|
||||
}
|
||||
|
||||
inline int16x8_t Load8AndSubtractZeroPoint(const std::uint8_t* src,
|
||||
std::uint8_t zero_point) {
|
||||
uint8x8_t src_u8 = vld1_u8(src);
|
||||
int16x8_t src_s16 = vreinterpretq_s16_u16(vmovl_u8(src_u8));
|
||||
int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
|
||||
return vsubq_s16(src_s16, zero_point_vec);
|
||||
}
|
||||
|
||||
inline int16x8_t Load8AndSubtractZeroPoint(const std::int8_t* src,
|
||||
std::int8_t zero_point) {
|
||||
int8x8_t src_s8 = vld1_s8(src);
|
||||
int16x8_t src_s16 = vmovl_s8(src_s8);
|
||||
int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
|
||||
return vsubq_s16(src_s16, zero_point_vec);
|
||||
}
|
||||
|
||||
inline void ClampAndStore(int32x4_t src, std::uint8_t clamp_min,
|
||||
std::uint8_t clamp_max, std::uint8_t* dst) {
|
||||
// Narrow values down to 16 bit signed.
|
||||
@ -288,11 +305,12 @@ struct CustomGemvImpl<LhsScalar, RhsScalar, std::int32_t, DstScalar,
|
||||
const MatrixParams<RhsScalar>& rhs_params,
|
||||
const MatrixParams<DstScalar>& dst_params,
|
||||
const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params) {
|
||||
// There are no further requirements on the applicability of this kernel,
|
||||
// beyond the left-hand-side matrix having at least kKernelRows rows,
|
||||
// and the type requirements implied in this template partial
|
||||
// specialization.
|
||||
return true;
|
||||
// The kernel processes at least 8 LHS columns at once to fill NEON
|
||||
// registers. The leftovers-handling code at the end works by loading a
|
||||
// partially overlapping final register by walking back by a few (<8) values
|
||||
// to avoid running past the row's end. This relies on there being
|
||||
// at least 8 LHS columns.
|
||||
return lhs_params.cols >= 8;
|
||||
}
|
||||
|
||||
static void Run(
|
||||
@ -311,6 +329,27 @@ struct CustomGemvImpl<LhsScalar, RhsScalar, std::int32_t, DstScalar,
|
||||
// `row`.
|
||||
row = std::min(row, row_end - kKernelRows);
|
||||
const LhsScalar* filter_ptr = lhs_data + row * lhs_params.cols;
|
||||
|
||||
static constexpr int kCacheLineSize = 64;
|
||||
for (int k = 0; k < rhs_params.rows;
|
||||
k += kCacheLineSize / sizeof(RhsScalar)) {
|
||||
optimized_ops_preload_l1_keep(rhs_data + k);
|
||||
}
|
||||
|
||||
// kPreloadAhead is empirically determined.
|
||||
// End-to-end latency (ms) on mobilenet_v2_0.35_96_8bit, 1 thread,
|
||||
// on Qualcomm S855:
|
||||
//
|
||||
// kPreloadAhead | big core | little core
|
||||
// --------------+----------+------------
|
||||
// 64 | 1.26 | 5.45
|
||||
// 128 | 1.23 | 5.01
|
||||
// 256 | 1.18 | 4.9
|
||||
// 512 | 1.18 | 5.45
|
||||
// 1024 | 1.18 | 6.5
|
||||
// no prefetch | 1.25 | 8.1
|
||||
static constexpr int kPreloadAhead = 256;
|
||||
|
||||
// 4 accumulator registers, one for each row being processed.
|
||||
// Each has 4 int32 lanes that corresponds to columns modulo 4, and
|
||||
// will need to be horizontally reduced at the end.
|
||||
@ -322,16 +361,28 @@ struct CustomGemvImpl<LhsScalar, RhsScalar, std::int32_t, DstScalar,
|
||||
// As much as possible, handle 16 columns of the left-hand side matrix
|
||||
// at a time. This allows for decent NEON implementation.
|
||||
for (; in <= lhs_params.cols - 16; in += 16) {
|
||||
const LhsScalar* local_filter_ptr = filter_ptr;
|
||||
int16x8x2_t input_val =
|
||||
LoadAndSubtractZeroPoint(rhs_data + in, rhs_params.zero_point);
|
||||
int16x8x2_t filter_val_0 = LoadAndSubtractZeroPoint(
|
||||
filter_ptr + 0 * lhs_params.cols, lhs_params.zero_point);
|
||||
int16x8x2_t filter_val_1 = LoadAndSubtractZeroPoint(
|
||||
filter_ptr + 1 * lhs_params.cols, lhs_params.zero_point);
|
||||
int16x8x2_t filter_val_2 = LoadAndSubtractZeroPoint(
|
||||
filter_ptr + 2 * lhs_params.cols, lhs_params.zero_point);
|
||||
int16x8x2_t filter_val_3 = LoadAndSubtractZeroPoint(
|
||||
filter_ptr + 3 * lhs_params.cols, lhs_params.zero_point);
|
||||
Load16AndSubtractZeroPoint(rhs_data + in, rhs_params.zero_point);
|
||||
int16x8x2_t filter_val_0 =
|
||||
Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
|
||||
optimized_ops_preload_l1_stream(local_filter_ptr +
|
||||
kPreloadAhead / sizeof(LhsScalar));
|
||||
local_filter_ptr += lhs_params.cols;
|
||||
int16x8x2_t filter_val_1 =
|
||||
Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
|
||||
optimized_ops_preload_l1_stream(local_filter_ptr +
|
||||
kPreloadAhead / sizeof(LhsScalar));
|
||||
local_filter_ptr += lhs_params.cols;
|
||||
int16x8x2_t filter_val_2 =
|
||||
Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
|
||||
optimized_ops_preload_l1_stream(local_filter_ptr +
|
||||
kPreloadAhead / sizeof(LhsScalar));
|
||||
local_filter_ptr += lhs_params.cols;
|
||||
int16x8x2_t filter_val_3 =
|
||||
Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
|
||||
optimized_ops_preload_l1_stream(local_filter_ptr +
|
||||
kPreloadAhead / sizeof(LhsScalar));
|
||||
filter_ptr += 16;
|
||||
acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0.val[0]),
|
||||
vget_low_s16(input_val.val[0]));
|
||||
@ -366,27 +417,109 @@ struct CustomGemvImpl<LhsScalar, RhsScalar, std::int32_t, DstScalar,
|
||||
acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3.val[1]),
|
||||
vget_high_s16(input_val.val[1]));
|
||||
}
|
||||
// Leftovers: fewer than 16 columns remain. Very slow code, could be
|
||||
// improved upon if critical in some application.
|
||||
// Less that 16 values remain. Try to handle 8 more.
|
||||
if (in <= lhs_params.cols - 8) {
|
||||
int16x8_t input_val =
|
||||
Load8AndSubtractZeroPoint(rhs_data + in, rhs_params.zero_point);
|
||||
int16x8_t filter_val_0 = Load8AndSubtractZeroPoint(
|
||||
filter_ptr + 0 * lhs_params.cols, lhs_params.zero_point);
|
||||
int16x8_t filter_val_1 = Load8AndSubtractZeroPoint(
|
||||
filter_ptr + 1 * lhs_params.cols, lhs_params.zero_point);
|
||||
int16x8_t filter_val_2 = Load8AndSubtractZeroPoint(
|
||||
filter_ptr + 2 * lhs_params.cols, lhs_params.zero_point);
|
||||
int16x8_t filter_val_3 = Load8AndSubtractZeroPoint(
|
||||
filter_ptr + 3 * lhs_params.cols, lhs_params.zero_point);
|
||||
filter_ptr += 8;
|
||||
acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0),
|
||||
vget_low_s16(input_val));
|
||||
acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1),
|
||||
vget_low_s16(input_val));
|
||||
acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2),
|
||||
vget_low_s16(input_val));
|
||||
acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3),
|
||||
vget_low_s16(input_val));
|
||||
acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
|
||||
vget_high_s16(input_val));
|
||||
acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
|
||||
vget_high_s16(input_val));
|
||||
acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
|
||||
vget_high_s16(input_val));
|
||||
acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
|
||||
vget_high_s16(input_val));
|
||||
in += 8;
|
||||
}
|
||||
// Less than 8 values remain. Handle the remaining values
|
||||
// in one more copy of the above code handling 8, where we
|
||||
// walk back a few values to be able to load 8 values without
|
||||
// overrunning the buffer. This is where we make use of the requirement
|
||||
// (see IsSupportedGivenSufficientlyManyRows) that there at least
|
||||
// 8 LHS columns.
|
||||
if (in < lhs_params.cols) {
|
||||
int32 buf[16];
|
||||
vst1q_s32(buf + 0, acc0);
|
||||
vst1q_s32(buf + 4, acc1);
|
||||
vst1q_s32(buf + 8, acc2);
|
||||
vst1q_s32(buf + 12, acc3);
|
||||
for (; in < lhs_params.cols; in++) {
|
||||
int lane = (in + 16 - lhs_params.cols) % 4;
|
||||
const int32 input_val = rhs_data[in] - rhs_params.zero_point;
|
||||
for (int k = 0; k < 4; k++) {
|
||||
int32 filter_val = lhs_data[in + (row + k) * lhs_params.cols] -
|
||||
lhs_params.zero_point;
|
||||
buf[lane + 4 * k] += filter_val * input_val;
|
||||
}
|
||||
// `back` is how many entries to walk back by.
|
||||
// Its value is necessarily between 1 and 7.
|
||||
const int back = in + 8 - lhs_params.cols;
|
||||
TFLITE_DCHECK_GE(back, 1);
|
||||
TFLITE_DCHECK_LE(back, 7);
|
||||
// Load 8 values as usual.
|
||||
int16x8_t input_val = Load8AndSubtractZeroPoint(
|
||||
rhs_data + lhs_params.cols - 8, rhs_params.zero_point);
|
||||
const LhsScalar* local_filter_ptr = filter_ptr - back;
|
||||
filter_ptr += lhs_params.cols - in;
|
||||
int16x8_t filter_val_0 =
|
||||
Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
|
||||
local_filter_ptr += lhs_params.cols;
|
||||
int16x8_t filter_val_1 =
|
||||
Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
|
||||
local_filter_ptr += lhs_params.cols;
|
||||
int16x8_t filter_val_2 =
|
||||
Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
|
||||
local_filter_ptr += lhs_params.cols;
|
||||
int16x8_t filter_val_3 =
|
||||
Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
|
||||
// Now zero out the `back` first entries of input_val.
|
||||
// vsetq_lane_s16 takes a literal index, so we need unrolled code.
|
||||
switch (back) {
|
||||
case 7:
|
||||
input_val = vsetq_lane_s16(0, input_val, 6);
|
||||
[[clang::fallthrough]];
|
||||
case 6:
|
||||
input_val = vsetq_lane_s16(0, input_val, 5);
|
||||
[[clang::fallthrough]];
|
||||
case 5:
|
||||
input_val = vsetq_lane_s16(0, input_val, 4);
|
||||
[[clang::fallthrough]];
|
||||
case 4:
|
||||
input_val = vsetq_lane_s16(0, input_val, 3);
|
||||
[[clang::fallthrough]];
|
||||
case 3:
|
||||
input_val = vsetq_lane_s16(0, input_val, 2);
|
||||
[[clang::fallthrough]];
|
||||
case 2:
|
||||
input_val = vsetq_lane_s16(0, input_val, 1);
|
||||
[[clang::fallthrough]];
|
||||
default:
|
||||
input_val = vsetq_lane_s16(0, input_val, 0);
|
||||
}
|
||||
acc0 = vld1q_s32(buf + 0);
|
||||
acc1 = vld1q_s32(buf + 4);
|
||||
acc2 = vld1q_s32(buf + 8);
|
||||
acc3 = vld1q_s32(buf + 12);
|
||||
// Multiply-accumulate 8 values as usual. The `back` first lanes
|
||||
// of filter_val_* are junk, but it doesn't matter since they get
|
||||
// multiplied by the zeros that we just wrote in the corresponding
|
||||
// lanes of input_val.
|
||||
acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0),
|
||||
vget_low_s16(input_val));
|
||||
acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1),
|
||||
vget_low_s16(input_val));
|
||||
acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2),
|
||||
vget_low_s16(input_val));
|
||||
acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3),
|
||||
vget_low_s16(input_val));
|
||||
acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
|
||||
vget_high_s16(input_val));
|
||||
acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
|
||||
vget_high_s16(input_val));
|
||||
acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
|
||||
vget_high_s16(input_val));
|
||||
acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
|
||||
vget_high_s16(input_val));
|
||||
}
|
||||
|
||||
// Horizontally reduce accumulators
|
||||
@ -484,11 +617,12 @@ struct CustomGemvImpl<float, float, float, float,
|
||||
const MatrixParams<float>& rhs_params,
|
||||
const MatrixParams<float>& dst_params,
|
||||
const GemmParams<float, float>& params) {
|
||||
// There are no further requirements on the applicability of this kernel,
|
||||
// beyond the left-hand-side matrix having at least kKernelRows rows,
|
||||
// and the type requirements implied in this template partial
|
||||
// specialization.
|
||||
return true;
|
||||
// The kernel processes 4 LHS columns at once to fill float32x4 registers.
|
||||
// The leftovers-handling code at the end works by loading a partially
|
||||
// overlapping final register by walking back by a few (<4) floats
|
||||
// to avoid running past the row's end. This relies on there being
|
||||
// at least 4 LHS columns.
|
||||
return lhs_params.cols >= 4;
|
||||
}
|
||||
static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data,
|
||||
const MatrixParams<float>& rhs_params, const float* rhs_data,
|
||||
@ -505,6 +639,27 @@ struct CustomGemvImpl<float, float, float, float,
|
||||
// `row`.
|
||||
row = std::min(row, row_end - kKernelRows);
|
||||
const float* filter_ptr = lhs_data + row * lhs_params.cols;
|
||||
|
||||
static constexpr int kCacheLineSize = 64;
|
||||
for (int k = 0; k < rhs_params.rows;
|
||||
k += kCacheLineSize / sizeof(float)) {
|
||||
optimized_ops_preload_l1_keep(rhs_data + k);
|
||||
}
|
||||
|
||||
// kPreloadAhead is empirically determined.
|
||||
// End-to-end latency (ms) on mobilenet_v2_0.35_96_float, 1 thread,
|
||||
// on Qualcomm S855:
|
||||
//
|
||||
// kPreloadAhead | big core | little core
|
||||
// --------------+----------+------------
|
||||
// 64 | 2.4 | 15.2
|
||||
// 128 | 2.15 | 12.9
|
||||
// 256 | 2 | 12.9
|
||||
// 512 | 2.08 | 13.3
|
||||
// 1024 | 2.05 | 14.7
|
||||
// no prefetch | 2.1 | 28
|
||||
static constexpr int kPreloadAhead = 256;
|
||||
|
||||
// 4 accumulator registers, one for each row being processed.
|
||||
// Each has 4 float32 lanes that corresponds to columns modulo 4, and
|
||||
// will need to be horizontally reduced at the end.
|
||||
@ -517,36 +672,71 @@ struct CustomGemvImpl<float, float, float, float,
|
||||
// at a time. This allows for decent NEON implementation.
|
||||
for (; in <= lhs_params.cols - 4; in += 4) {
|
||||
float32x4_t input_val = vld1q_f32(rhs_data + in);
|
||||
float32x4_t filter_val_0 = vld1q_f32(filter_ptr + 0 * lhs_params.cols);
|
||||
float32x4_t filter_val_1 = vld1q_f32(filter_ptr + 1 * lhs_params.cols);
|
||||
float32x4_t filter_val_2 = vld1q_f32(filter_ptr + 2 * lhs_params.cols);
|
||||
float32x4_t filter_val_3 = vld1q_f32(filter_ptr + 3 * lhs_params.cols);
|
||||
const float* local_filter_ptr = filter_ptr;
|
||||
float32x4_t filter_val_0 = vld1q_f32(local_filter_ptr);
|
||||
optimized_ops_preload_l1_stream(local_filter_ptr +
|
||||
kPreloadAhead / sizeof(float));
|
||||
local_filter_ptr += lhs_params.cols;
|
||||
float32x4_t filter_val_1 = vld1q_f32(local_filter_ptr);
|
||||
optimized_ops_preload_l1_stream(local_filter_ptr +
|
||||
kPreloadAhead / sizeof(float));
|
||||
local_filter_ptr += lhs_params.cols;
|
||||
float32x4_t filter_val_2 = vld1q_f32(local_filter_ptr);
|
||||
optimized_ops_preload_l1_stream(local_filter_ptr +
|
||||
kPreloadAhead / sizeof(float));
|
||||
local_filter_ptr += lhs_params.cols;
|
||||
float32x4_t filter_val_3 = vld1q_f32(local_filter_ptr);
|
||||
optimized_ops_preload_l1_stream(local_filter_ptr +
|
||||
kPreloadAhead / sizeof(float));
|
||||
filter_ptr += 4;
|
||||
acc0 = mul_add(acc0, filter_val_0, input_val);
|
||||
acc1 = mul_add(acc1, filter_val_1, input_val);
|
||||
acc2 = mul_add(acc2, filter_val_2, input_val);
|
||||
acc3 = mul_add(acc3, filter_val_3, input_val);
|
||||
}
|
||||
// Leftovers: fewer than 4 columns remain. Very slow code, could be
|
||||
// improved upon if critical in some application.
|
||||
// Less than 4 values remain. Handle the remaining values
|
||||
// in one more copy of the above code handling 4, where we
|
||||
// walk back a few values to be able to load 4 values without
|
||||
// overrunning the buffer. This is where we make use of the requirement
|
||||
// (see IsSupportedGivenSufficientlyManyRows) that there at least
|
||||
// 4 LHS columns.
|
||||
if (in < lhs_params.cols) {
|
||||
float buf[16];
|
||||
vst1q_f32(buf + 0, acc0);
|
||||
vst1q_f32(buf + 4, acc1);
|
||||
vst1q_f32(buf + 8, acc2);
|
||||
vst1q_f32(buf + 12, acc3);
|
||||
for (; in < lhs_params.cols; in++) {
|
||||
int lane = (in + 4 - lhs_params.cols) % 4;
|
||||
const float input_val = rhs_data[in];
|
||||
for (int k = 0; k < 4; k++) {
|
||||
float filter_val = lhs_data[in + (row + k) * lhs_params.cols];
|
||||
buf[lane + 4 * k] += filter_val * input_val;
|
||||
}
|
||||
// `back` is how many entries to walk back by.
|
||||
// Its value is necessarily between 1 and 3.
|
||||
const int back = in + 4 - lhs_params.cols;
|
||||
TFLITE_DCHECK_GE(back, 1);
|
||||
TFLITE_DCHECK_LE(back, 3);
|
||||
// Load 4 values as usual.
|
||||
float32x4_t input_val = vld1q_f32(rhs_data + lhs_params.cols - 4);
|
||||
const float* local_filter_ptr = filter_ptr - back;
|
||||
filter_ptr += lhs_params.cols - in;
|
||||
float32x4_t filter_val_0 = vld1q_f32(local_filter_ptr);
|
||||
local_filter_ptr += lhs_params.cols;
|
||||
float32x4_t filter_val_1 = vld1q_f32(local_filter_ptr);
|
||||
local_filter_ptr += lhs_params.cols;
|
||||
float32x4_t filter_val_2 = vld1q_f32(local_filter_ptr);
|
||||
local_filter_ptr += lhs_params.cols;
|
||||
float32x4_t filter_val_3 = vld1q_f32(local_filter_ptr);
|
||||
// Now zero out the `back` first entries of input_val.
|
||||
// vsetq_lane_f32 takes a literal index, so we need unrolled code.
|
||||
switch (back) {
|
||||
case 3:
|
||||
input_val = vsetq_lane_f32(0, input_val, 2);
|
||||
[[clang::fallthrough]];
|
||||
case 2:
|
||||
input_val = vsetq_lane_f32(0, input_val, 1);
|
||||
[[clang::fallthrough]];
|
||||
default:
|
||||
input_val = vsetq_lane_f32(0, input_val, 0);
|
||||
}
|
||||
acc0 = vld1q_f32(buf + 0);
|
||||
acc1 = vld1q_f32(buf + 4);
|
||||
acc2 = vld1q_f32(buf + 8);
|
||||
acc3 = vld1q_f32(buf + 12);
|
||||
// Multiply-accumulate 4 values as usual. The `back` first lanes
|
||||
// of filter_val_* are junk, but it doesn't matter since they get
|
||||
// multiplied by the zeros that we just wrote in the corresponding
|
||||
// lanes of input_val.
|
||||
acc0 = mul_add(acc0, filter_val_0, input_val);
|
||||
acc1 = mul_add(acc1, filter_val_1, input_val);
|
||||
acc2 = mul_add(acc2, filter_val_2, input_val);
|
||||
acc3 = mul_add(acc3, filter_val_3, input_val);
|
||||
}
|
||||
|
||||
// Horizontally reduce accumulators
|
||||
|
@ -1205,6 +1205,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
uint8* output_data, const RuntimeShape& im2col_shape,
|
||||
uint8* im2col_data, CpuBackendContext* cpu_backend_context) {
|
||||
gemmlowp::ScopedProfilingLabel label("Conv/8bit");
|
||||
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
const int dilation_width_factor = params.dilation_width_factor;
|
||||
|
Loading…
Reference in New Issue
Block a user