Optimize quantized mul.
PiperOrigin-RevId: 312077803 Change-Id: Ib6bbf261834a828590748e2c39ad146bad7d80ae
This commit is contained in:
parent
f40a063d84
commit
50fcac47a2
@ -38,49 +38,81 @@ inline void MulElementwise(int size, const ArithmeticParams& params,
|
||||
TFLITE_DCHECK_GT(params.output_offset, -256);
|
||||
TFLITE_DCHECK_LT(params.output_offset, 256);
|
||||
#ifdef USE_NEON
|
||||
const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
|
||||
const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
|
||||
const auto output_offset_vector = vdupq_n_s16(params.output_offset);
|
||||
const int16x8_t input1_offset_vector = vdupq_n_s16(params.input1_offset);
|
||||
const int16x8_t input2_offset_vector = vdupq_n_s16(params.input2_offset);
|
||||
const int16x8_t output_offset_vector = vdupq_n_s16(params.output_offset);
|
||||
const auto output_activation_min_vector =
|
||||
vdup_n_s8(params.quantized_activation_min);
|
||||
vdupq_n_s8(params.quantized_activation_min);
|
||||
const auto output_activation_max_vector =
|
||||
vdup_n_s8(params.quantized_activation_max);
|
||||
vdupq_n_s8(params.quantized_activation_max);
|
||||
const int left_shift = std::max(0, params.output_shift);
|
||||
const int right_shift = std::max(0, -params.output_shift);
|
||||
const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
|
||||
for (; i <= size - 8; i += 8) {
|
||||
// We load / store 8 at a time, multiplying as two sets of 4 int32s.
|
||||
const auto input1_val_original = vld1_s8(input1_data + i);
|
||||
const auto input2_val_original = vld1_s8(input2_data + i);
|
||||
const auto input1_val_s16 = vmovl_s8(input1_val_original);
|
||||
const auto input2_val_s16 = vmovl_s8(input2_val_original);
|
||||
const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
|
||||
const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
|
||||
for (; i <= size - 16; i += 16) {
|
||||
// We load / store 16 at a time, multiplying as four sets of 4 int32s.
|
||||
const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
|
||||
const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
|
||||
|
||||
const auto input1_val_low = vget_low_s16(input1_val);
|
||||
const auto input1_val_high = vget_high_s16(input1_val);
|
||||
const auto input2_val_low = vget_low_s16(input2_val);
|
||||
const auto input2_val_high = vget_high_s16(input2_val);
|
||||
const int16x8_t input1_val_s16_high =
|
||||
vmovl_s8(vget_high_s8(input1_val_original));
|
||||
const int16x8_t input1_val_s16_low =
|
||||
vmovl_s8(vget_low_s8(input1_val_original));
|
||||
|
||||
auto p1 = vmull_s16(input2_val_low, input1_val_low);
|
||||
auto p2 = vmull_s16(input2_val_high, input1_val_high);
|
||||
const int16x8_t input2_val_s16_high =
|
||||
vmovl_s8(vget_high_s8(input2_val_original));
|
||||
const int16x8_t input2_val_s16_low =
|
||||
vmovl_s8(vget_low_s8(input2_val_original));
|
||||
const int16x8_t input1_val_high =
|
||||
vaddq_s16(input1_val_s16_high, input1_offset_vector);
|
||||
const int16x8_t input2_val_high =
|
||||
vaddq_s16(input2_val_s16_high, input2_offset_vector);
|
||||
const int16x8_t input1_val_low =
|
||||
vaddq_s16(input1_val_s16_low, input1_offset_vector);
|
||||
const int16x8_t input2_val_low =
|
||||
vaddq_s16(input2_val_s16_low, input2_offset_vector);
|
||||
const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high);
|
||||
const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high);
|
||||
const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low);
|
||||
const int16x4_t input1_val_low_low = vget_low_s16(input1_val_low);
|
||||
const int16x4_t input2_val_high_high = vget_high_s16(input2_val_high);
|
||||
const int16x4_t input2_val_high_low = vget_low_s16(input2_val_high);
|
||||
const int16x4_t input2_val_low_high = vget_high_s16(input2_val_low);
|
||||
const int16x4_t input2_val_low_low = vget_low_s16(input2_val_low);
|
||||
|
||||
auto p1 = vmull_s16(input2_val_high_high, input1_val_high_high);
|
||||
auto p2 = vmull_s16(input2_val_high_low, input1_val_high_low);
|
||||
auto p3 = vmull_s16(input2_val_low_high, input1_val_low_high);
|
||||
auto p4 = vmull_s16(input2_val_low_low, input1_val_low_low);
|
||||
|
||||
p1 = vshlq_s32(p1, left_shift_vec);
|
||||
p2 = vshlq_s32(p2, left_shift_vec);
|
||||
p3 = vshlq_s32(p3, left_shift_vec);
|
||||
p4 = vshlq_s32(p4, left_shift_vec);
|
||||
|
||||
p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
|
||||
p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
|
||||
p3 = vqrdmulhq_n_s32(p3, params.output_multiplier);
|
||||
p4 = vqrdmulhq_n_s32(p4, params.output_multiplier);
|
||||
using gemmlowp::RoundingDivideByPOT;
|
||||
p1 = RoundingDivideByPOT(p1, right_shift);
|
||||
p2 = RoundingDivideByPOT(p2, right_shift);
|
||||
p3 = RoundingDivideByPOT(p3, right_shift);
|
||||
p4 = RoundingDivideByPOT(p4, right_shift);
|
||||
|
||||
const auto p1_narrowed = vqmovn_s32(p1);
|
||||
const auto p2_narrowed = vqmovn_s32(p2);
|
||||
const auto p =
|
||||
vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
|
||||
const auto clamped =
|
||||
vmax_s8(output_activation_min_vector,
|
||||
vmin_s8(output_activation_max_vector, vqmovn_s16(p)));
|
||||
vst1_s8(output_data + i, clamped);
|
||||
const auto p3_narrowed = vqmovn_s32(p3);
|
||||
const auto p4_narrowed = vqmovn_s32(p4);
|
||||
|
||||
const int16x8_t p_part1 =
|
||||
vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector);
|
||||
const int16x8_t p_part2 =
|
||||
vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector);
|
||||
const int8x16_t p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1));
|
||||
|
||||
const auto clamped = vmaxq_s8(output_activation_min_vector,
|
||||
vminq_s8(output_activation_max_vector, p));
|
||||
vst1q_s8(output_data + i, clamped);
|
||||
}
|
||||
#endif // NEON
|
||||
|
||||
@ -117,40 +149,63 @@ inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
|
||||
const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
|
||||
const auto output_offset_vector = vdupq_n_s16(params.output_offset);
|
||||
const auto output_activation_min_vector =
|
||||
vdup_n_s8(params.quantized_activation_min);
|
||||
vdupq_n_s8(params.quantized_activation_min);
|
||||
const auto output_activation_max_vector =
|
||||
vdup_n_s8(params.quantized_activation_max);
|
||||
vdupq_n_s8(params.quantized_activation_max);
|
||||
const int left_shift = std::max(0, params.output_shift);
|
||||
const int right_shift = std::max(0, -params.output_shift);
|
||||
const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
|
||||
for (; i <= size - 8; i += 8) {
|
||||
// We load / store 8 at a time, multiplying as two sets of 4 int32s.
|
||||
const auto input2_val_original = vld1_s8(input2_data + i);
|
||||
const auto input2_val_s16 = vmovl_s8(input2_val_original);
|
||||
const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
|
||||
for (; i <= size - 16; i += 16) {
|
||||
// We load / store 16 at a time, multiplying as four sets of 4 int32s.
|
||||
const auto input2_val_original = vld1q_s8(input2_data + i);
|
||||
const auto input2_val_s16_high =
|
||||
vmovl_s8(vget_high_s8(input2_val_original));
|
||||
const auto input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original));
|
||||
|
||||
const auto input2_val_low = vget_low_s16(input2_val);
|
||||
const auto input2_val_high = vget_high_s16(input2_val);
|
||||
const auto input2_val_high =
|
||||
vaddq_s16(input2_val_s16_high, input2_offset_vector);
|
||||
const auto input2_val_low =
|
||||
vaddq_s16(input2_val_s16_low, input2_offset_vector);
|
||||
|
||||
auto p1 = vmull_n_s16(input2_val_low, input1_val);
|
||||
auto p2 = vmull_n_s16(input2_val_high, input1_val);
|
||||
const auto input2_val_low_low = vget_low_s16(input2_val_low);
|
||||
const auto input2_val_low_high = vget_high_s16(input2_val_low);
|
||||
const auto input2_val_high_low = vget_low_s16(input2_val_high);
|
||||
const auto input2_val_high_high = vget_high_s16(input2_val_high);
|
||||
|
||||
auto p1 = vmull_n_s16(input2_val_high_high, input1_val);
|
||||
auto p2 = vmull_n_s16(input2_val_high_low, input1_val);
|
||||
auto p3 = vmull_n_s16(input2_val_low_high, input1_val);
|
||||
auto p4 = vmull_n_s16(input2_val_low_low, input1_val);
|
||||
|
||||
p1 = vshlq_s32(p1, left_shift_vec);
|
||||
p2 = vshlq_s32(p2, left_shift_vec);
|
||||
p3 = vshlq_s32(p3, left_shift_vec);
|
||||
p4 = vshlq_s32(p4, left_shift_vec);
|
||||
|
||||
p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
|
||||
p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
|
||||
p3 = vqrdmulhq_n_s32(p3, params.output_multiplier);
|
||||
p4 = vqrdmulhq_n_s32(p4, params.output_multiplier);
|
||||
using gemmlowp::RoundingDivideByPOT;
|
||||
p1 = RoundingDivideByPOT(p1, right_shift);
|
||||
p2 = RoundingDivideByPOT(p2, right_shift);
|
||||
p3 = RoundingDivideByPOT(p3, right_shift);
|
||||
p4 = RoundingDivideByPOT(p4, right_shift);
|
||||
|
||||
const auto p1_narrowed = vqmovn_s32(p1);
|
||||
const auto p2_narrowed = vqmovn_s32(p2);
|
||||
const auto p =
|
||||
vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
|
||||
const auto clamped =
|
||||
vmax_s8(output_activation_min_vector,
|
||||
vmin_s8(output_activation_max_vector, vqmovn_s16(p)));
|
||||
vst1_s8(output_data + i, clamped);
|
||||
const auto p3_narrowed = vqmovn_s32(p3);
|
||||
const auto p4_narrowed = vqmovn_s32(p4);
|
||||
|
||||
const int16x8_t p_part1 =
|
||||
vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector);
|
||||
const int16x8_t p_part2 =
|
||||
vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector);
|
||||
const int8x16_t p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1));
|
||||
|
||||
const auto clamped = vmaxq_s8(output_activation_min_vector,
|
||||
vminq_s8(output_activation_max_vector, p));
|
||||
vst1q_s8(output_data + i, clamped);
|
||||
}
|
||||
#endif // NEON
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user