optimize for int8 add.

PiperOrigin-RevId: 311869888
Change-Id: I4009635592941be39aa5c71e185e3eecbc2ec49c
This commit is contained in:
Renjie Liu 2020-05-16 02:21:00 -07:00 committed by TensorFlower Gardener
parent 74396bcd30
commit fd976b2def

View File

@ -35,58 +35,99 @@ inline void AddElementwise(int size, const ArithmeticParams& params,
TFLITE_DCHECK_GT(params.input2_offset, -256);
TFLITE_DCHECK_LT(params.input1_offset, 256);
TFLITE_DCHECK_LT(params.input2_offset, 256);
#ifdef USE_NEON
const int8x8_t output_activation_min_vector =
vdup_n_s8(params.quantized_activation_min);
const int8x8_t output_activation_max_vector =
vdup_n_s8(params.quantized_activation_max);
for (; i <= size - 8; i += 8) {
const int8x8_t input1_val_original = vld1_s8(input1_data + i);
const int8x8_t input2_val_original = vld1_s8(input2_data + i);
const int16x8_t input1_val_s16 = vmovl_s8(input1_val_original);
const int16x8_t input2_val_s16 = vmovl_s8(input2_val_original);
const int16x8_t input1_val =
vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
const int16x8_t input2_val =
vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
const int16x4_t input1_val_high = vget_high_s16(input1_val);
const int16x4_t input1_val_low = vget_low_s16(input1_val);
const int16x4_t input2_val_high = vget_high_s16(input2_val);
const int16x4_t input2_val_low = vget_low_s16(input2_val);
int32x4_t x11 = vmovl_s16(input1_val_low);
int32x4_t x12 = vmovl_s16(input1_val_high);
int32x4_t x21 = vmovl_s16(input2_val_low);
int32x4_t x22 = vmovl_s16(input2_val_high);
const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
x11 = vshlq_s32(x11, left_shift_dup);
x12 = vshlq_s32(x12, left_shift_dup);
x21 = vshlq_s32(x21, left_shift_dup);
x22 = vshlq_s32(x22, left_shift_dup);
x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
x11 = vshlq_s32(x11, input1_shift_dup);
x12 = vshlq_s32(x12, input1_shift_dup);
x21 = vshlq_s32(x21, input2_shift_dup);
x22 = vshlq_s32(x22, input2_shift_dup);
int32x4_t s1 = vaddq_s32(x11, x21);
int32x4_t s2 = vaddq_s32(x12, x22);
s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
const int8x16_t output_activation_min_vector =
vdupq_n_s8(params.quantized_activation_min);
const int8x16_t output_activation_max_vector =
vdupq_n_s8(params.quantized_activation_max);
const int input1_left_shift = params.left_shift + params.input1_shift;
const int input2_left_shift = params.left_shift + params.input2_shift;
const int32x4_t input1_left_dup = vdupq_n_s32(input1_left_shift);
const int32x4_t input2_left_dup = vdupq_n_s32(input2_left_shift);
for (; i <= size - 16; i += 16) {
const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
const int16x8_t input1_val_s16_high =
vmovl_s8(vget_high_s8(input1_val_original));
const int16x8_t input1_val_s16_low =
vmovl_s8(vget_low_s8(input1_val_original));
const int16x8_t input2_val_s16_high =
vmovl_s8(vget_high_s8(input2_val_original));
const int16x8_t input2_val_s16_low =
vmovl_s8(vget_low_s8(input2_val_original));
const int16x8_t input1_val_high =
vaddq_s16(input1_val_s16_high, vdupq_n_s16(params.input1_offset));
const int16x8_t input2_val_high =
vaddq_s16(input2_val_s16_high, vdupq_n_s16(params.input2_offset));
const int16x8_t input1_val_low =
vaddq_s16(input1_val_s16_low, vdupq_n_s16(params.input1_offset));
const int16x8_t input2_val_low =
vaddq_s16(input2_val_s16_low, vdupq_n_s16(params.input2_offset));
const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high);
const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high);
const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low);
const int16x4_t input1_val_low_low = vget_low_s16(input1_val_low);
const int16x4_t input2_val_high_high = vget_high_s16(input2_val_high);
const int16x4_t input2_val_high_low = vget_low_s16(input2_val_high);
const int16x4_t input2_val_low_high = vget_high_s16(input2_val_low);
const int16x4_t input2_val_low_low = vget_low_s16(input2_val_low);
int32x4_t x111 = vmovl_s16(input1_val_low_low);
int32x4_t x112 = vmovl_s16(input1_val_low_high);
int32x4_t x121 = vmovl_s16(input1_val_high_low);
int32x4_t x122 = vmovl_s16(input1_val_high_high);
int32x4_t x211 = vmovl_s16(input2_val_low_low);
int32x4_t x212 = vmovl_s16(input2_val_low_high);
int32x4_t x221 = vmovl_s16(input2_val_high_low);
int32x4_t x222 = vmovl_s16(input2_val_high_high);
x111 = vshlq_s32(x111, input1_left_dup);
x112 = vshlq_s32(x112, input1_left_dup);
x121 = vshlq_s32(x121, input1_left_dup);
x122 = vshlq_s32(x122, input1_left_dup);
x211 = vshlq_s32(x211, input2_left_dup);
x212 = vshlq_s32(x212, input2_left_dup);
x221 = vshlq_s32(x221, input2_left_dup);
x222 = vshlq_s32(x222, input2_left_dup);
x111 = vqrdmulhq_n_s32(x111, params.input1_multiplier);
x112 = vqrdmulhq_n_s32(x112, params.input1_multiplier);
x121 = vqrdmulhq_n_s32(x121, params.input1_multiplier);
x122 = vqrdmulhq_n_s32(x122, params.input1_multiplier);
x211 = vqrdmulhq_n_s32(x211, params.input2_multiplier);
x212 = vqrdmulhq_n_s32(x212, params.input2_multiplier);
x221 = vqrdmulhq_n_s32(x221, params.input2_multiplier);
x222 = vqrdmulhq_n_s32(x222, params.input2_multiplier);
int32x4_t s11 = vaddq_s32(x111, x211);
int32x4_t s12 = vaddq_s32(x112, x212);
int32x4_t s21 = vaddq_s32(x121, x221);
int32x4_t s22 = vaddq_s32(x122, x222);
s11 = vqrdmulhq_n_s32(s11, params.output_multiplier);
s12 = vqrdmulhq_n_s32(s12, params.output_multiplier);
s21 = vqrdmulhq_n_s32(s21, params.output_multiplier);
s22 = vqrdmulhq_n_s32(s22, params.output_multiplier);
using gemmlowp::RoundingDivideByPOT;
s1 = RoundingDivideByPOT(s1, -params.output_shift);
s2 = RoundingDivideByPOT(s2, -params.output_shift);
const int16x4_t s1_narrowed = vmovn_s32(s1);
const int16x4_t s2_narrowed = vmovn_s32(s2);
const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
vdupq_n_s16(params.output_offset));
const int8x8_t clamped =
vmax_s8(output_activation_min_vector,
vmin_s8(output_activation_max_vector, vqmovn_s16(s)));
vst1_s8(output_data + i, clamped);
s11 = RoundingDivideByPOT(s11, -params.output_shift);
s12 = RoundingDivideByPOT(s12, -params.output_shift);
s21 = RoundingDivideByPOT(s21, -params.output_shift);
s22 = RoundingDivideByPOT(s22, -params.output_shift);
const int16x4_t s11_narrowed = vmovn_s32(s11);
const int16x4_t s12_narrowed = vmovn_s32(s12);
const int16x4_t s21_narrowed = vmovn_s32(s21);
const int16x4_t s22_narrowed = vmovn_s32(s22);
const int16x8_t s1 = vaddq_s16(vcombine_s16(s11_narrowed, s12_narrowed),
vdupq_n_s16(params.output_offset));
const int16x8_t s2 = vaddq_s16(vcombine_s16(s21_narrowed, s22_narrowed),
vdupq_n_s16(params.output_offset));
const int8x16_t s = vcombine_s8(vqmovn_s16(s1), vqmovn_s16(s2));
const int8x16_t clamped =
vmaxq_s8(output_activation_min_vector,
vminq_s8(output_activation_max_vector, s));
vst1q_s8(output_data + i, clamped);
}
#endif // NEON