From 50fcac47a2652459a7f9b71255cfa1cf0077447b Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Mon, 18 May 2020 07:49:05 -0700 Subject: [PATCH] Optimize quantized mul. PiperOrigin-RevId: 312077803 Change-Id: Ib6bbf261834a828590748e2c39ad146bad7d80ae --- .../internal/optimized/integer_ops/mul.h | 139 ++++++++++++------ 1 file changed, 97 insertions(+), 42 deletions(-) diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h index 18aeef4c8b5..0d385ec1656 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h @@ -38,49 +38,81 @@ inline void MulElementwise(int size, const ArithmeticParams& params, TFLITE_DCHECK_GT(params.output_offset, -256); TFLITE_DCHECK_LT(params.output_offset, 256); #ifdef USE_NEON - const auto input1_offset_vector = vdupq_n_s16(params.input1_offset); - const auto input2_offset_vector = vdupq_n_s16(params.input2_offset); - const auto output_offset_vector = vdupq_n_s16(params.output_offset); + const int16x8_t input1_offset_vector = vdupq_n_s16(params.input1_offset); + const int16x8_t input2_offset_vector = vdupq_n_s16(params.input2_offset); + const int16x8_t output_offset_vector = vdupq_n_s16(params.output_offset); const auto output_activation_min_vector = - vdup_n_s8(params.quantized_activation_min); + vdupq_n_s8(params.quantized_activation_min); const auto output_activation_max_vector = - vdup_n_s8(params.quantized_activation_max); + vdupq_n_s8(params.quantized_activation_max); const int left_shift = std::max(0, params.output_shift); const int right_shift = std::max(0, -params.output_shift); const int32x4_t left_shift_vec = vdupq_n_s32(left_shift); - for (; i <= size - 8; i += 8) { - // We load / store 8 at a time, multiplying as two sets of 4 int32s. - const auto input1_val_original = vld1_s8(input1_data + i); - const auto input2_val_original = vld1_s8(input2_data + i); - const auto input1_val_s16 = vmovl_s8(input1_val_original); - const auto input2_val_s16 = vmovl_s8(input2_val_original); - const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector); - const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector); + for (; i <= size - 16; i += 16) { + // We load / store 16 at a time, multiplying as four sets of 4 int32s. + const int8x16_t input1_val_original = vld1q_s8(input1_data + i); + const int8x16_t input2_val_original = vld1q_s8(input2_data + i); - const auto input1_val_low = vget_low_s16(input1_val); - const auto input1_val_high = vget_high_s16(input1_val); - const auto input2_val_low = vget_low_s16(input2_val); - const auto input2_val_high = vget_high_s16(input2_val); + const int16x8_t input1_val_s16_high = + vmovl_s8(vget_high_s8(input1_val_original)); + const int16x8_t input1_val_s16_low = + vmovl_s8(vget_low_s8(input1_val_original)); - auto p1 = vmull_s16(input2_val_low, input1_val_low); - auto p2 = vmull_s16(input2_val_high, input1_val_high); + const int16x8_t input2_val_s16_high = + vmovl_s8(vget_high_s8(input2_val_original)); + const int16x8_t input2_val_s16_low = + vmovl_s8(vget_low_s8(input2_val_original)); + const int16x8_t input1_val_high = + vaddq_s16(input1_val_s16_high, input1_offset_vector); + const int16x8_t input2_val_high = + vaddq_s16(input2_val_s16_high, input2_offset_vector); + const int16x8_t input1_val_low = + vaddq_s16(input1_val_s16_low, input1_offset_vector); + const int16x8_t input2_val_low = + vaddq_s16(input2_val_s16_low, input2_offset_vector); + const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high); + const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high); + const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low); + const int16x4_t input1_val_low_low = vget_low_s16(input1_val_low); + const int16x4_t input2_val_high_high = vget_high_s16(input2_val_high); + const int16x4_t input2_val_high_low = vget_low_s16(input2_val_high); + const int16x4_t input2_val_low_high = vget_high_s16(input2_val_low); + const int16x4_t input2_val_low_low = vget_low_s16(input2_val_low); + + auto p1 = vmull_s16(input2_val_high_high, input1_val_high_high); + auto p2 = vmull_s16(input2_val_high_low, input1_val_high_low); + auto p3 = vmull_s16(input2_val_low_high, input1_val_low_high); + auto p4 = vmull_s16(input2_val_low_low, input1_val_low_low); p1 = vshlq_s32(p1, left_shift_vec); p2 = vshlq_s32(p2, left_shift_vec); + p3 = vshlq_s32(p3, left_shift_vec); + p4 = vshlq_s32(p4, left_shift_vec); + p1 = vqrdmulhq_n_s32(p1, params.output_multiplier); p2 = vqrdmulhq_n_s32(p2, params.output_multiplier); + p3 = vqrdmulhq_n_s32(p3, params.output_multiplier); + p4 = vqrdmulhq_n_s32(p4, params.output_multiplier); using gemmlowp::RoundingDivideByPOT; p1 = RoundingDivideByPOT(p1, right_shift); p2 = RoundingDivideByPOT(p2, right_shift); + p3 = RoundingDivideByPOT(p3, right_shift); + p4 = RoundingDivideByPOT(p4, right_shift); const auto p1_narrowed = vqmovn_s32(p1); const auto p2_narrowed = vqmovn_s32(p2); - const auto p = - vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector); - const auto clamped = - vmax_s8(output_activation_min_vector, - vmin_s8(output_activation_max_vector, vqmovn_s16(p))); - vst1_s8(output_data + i, clamped); + const auto p3_narrowed = vqmovn_s32(p3); + const auto p4_narrowed = vqmovn_s32(p4); + + const int16x8_t p_part1 = + vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector); + const int16x8_t p_part2 = + vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector); + const int8x16_t p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1)); + + const auto clamped = vmaxq_s8(output_activation_min_vector, + vminq_s8(output_activation_max_vector, p)); + vst1q_s8(output_data + i, clamped); } #endif // NEON @@ -117,40 +149,63 @@ inline void MulSimpleBroadcast(int size, const ArithmeticParams& params, const auto input2_offset_vector = vdupq_n_s16(params.input2_offset); const auto output_offset_vector = vdupq_n_s16(params.output_offset); const auto output_activation_min_vector = - vdup_n_s8(params.quantized_activation_min); + vdupq_n_s8(params.quantized_activation_min); const auto output_activation_max_vector = - vdup_n_s8(params.quantized_activation_max); + vdupq_n_s8(params.quantized_activation_max); const int left_shift = std::max(0, params.output_shift); const int right_shift = std::max(0, -params.output_shift); const int32x4_t left_shift_vec = vdupq_n_s32(left_shift); - for (; i <= size - 8; i += 8) { - // We load / store 8 at a time, multiplying as two sets of 4 int32s. - const auto input2_val_original = vld1_s8(input2_data + i); - const auto input2_val_s16 = vmovl_s8(input2_val_original); - const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector); + for (; i <= size - 16; i += 16) { + // We load / store 16 at a time, multiplying as four sets of 4 int32s. + const auto input2_val_original = vld1q_s8(input2_data + i); + const auto input2_val_s16_high = + vmovl_s8(vget_high_s8(input2_val_original)); + const auto input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original)); - const auto input2_val_low = vget_low_s16(input2_val); - const auto input2_val_high = vget_high_s16(input2_val); + const auto input2_val_high = + vaddq_s16(input2_val_s16_high, input2_offset_vector); + const auto input2_val_low = + vaddq_s16(input2_val_s16_low, input2_offset_vector); - auto p1 = vmull_n_s16(input2_val_low, input1_val); - auto p2 = vmull_n_s16(input2_val_high, input1_val); + const auto input2_val_low_low = vget_low_s16(input2_val_low); + const auto input2_val_low_high = vget_high_s16(input2_val_low); + const auto input2_val_high_low = vget_low_s16(input2_val_high); + const auto input2_val_high_high = vget_high_s16(input2_val_high); + + auto p1 = vmull_n_s16(input2_val_high_high, input1_val); + auto p2 = vmull_n_s16(input2_val_high_low, input1_val); + auto p3 = vmull_n_s16(input2_val_low_high, input1_val); + auto p4 = vmull_n_s16(input2_val_low_low, input1_val); p1 = vshlq_s32(p1, left_shift_vec); p2 = vshlq_s32(p2, left_shift_vec); + p3 = vshlq_s32(p3, left_shift_vec); + p4 = vshlq_s32(p4, left_shift_vec); + p1 = vqrdmulhq_n_s32(p1, params.output_multiplier); p2 = vqrdmulhq_n_s32(p2, params.output_multiplier); + p3 = vqrdmulhq_n_s32(p3, params.output_multiplier); + p4 = vqrdmulhq_n_s32(p4, params.output_multiplier); using gemmlowp::RoundingDivideByPOT; p1 = RoundingDivideByPOT(p1, right_shift); p2 = RoundingDivideByPOT(p2, right_shift); + p3 = RoundingDivideByPOT(p3, right_shift); + p4 = RoundingDivideByPOT(p4, right_shift); const auto p1_narrowed = vqmovn_s32(p1); const auto p2_narrowed = vqmovn_s32(p2); - const auto p = - vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector); - const auto clamped = - vmax_s8(output_activation_min_vector, - vmin_s8(output_activation_max_vector, vqmovn_s16(p))); - vst1_s8(output_data + i, clamped); + const auto p3_narrowed = vqmovn_s32(p3); + const auto p4_narrowed = vqmovn_s32(p4); + + const int16x8_t p_part1 = + vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector); + const int16x8_t p_part2 = + vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector); + const int8x16_t p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1)); + + const auto clamped = vmaxq_s8(output_activation_min_vector, + vminq_s8(output_activation_max_vector, p)); + vst1q_s8(output_data + i, clamped); } #endif // NEON