DepthwiseConv, dot-product kernel asm, copy optimizations (per-depth).

PiperOrigin-RevId: 247143866
This commit is contained in:
A. Unique TensorFlower 2019-05-07 20:51:59 -07:00 committed by TensorFlower Gardener
parent e1e71dac5d
commit 11272973ac

View File

@ -293,12 +293,22 @@ struct ProcessPerDepth<
const int8x16_t ones_vector = vdupq_n_s8(1);
// Simulate NEON-register transposition of subset of filter.
int8x16_t filter_reg_0_a;
int8x16_t filter_reg_0_b;
int8x16_t filter_reg_1_a;
int8x16_t filter_reg_1_b;
int8x16_t filter_reg_2_a;
int8x16_t filter_reg_2_b;
int8x16_t input_0_a;
int8x16_t input_0_b;
int8x16_t input_0_c;
int8x16_t input_1_a;
int8x16_t input_1_b;
int8x16_t input_1_c;
int8x16_t input_2_a;
int8x16_t input_2_b;
int8x16_t input_2_c;
int8x16_t filter_0_a;
int8x16_t filter_0_b;
int8x16_t filter_1_a;
int8x16_t filter_1_b;
int8x16_t filter_2_a;
int8x16_t filter_2_b;
// Register pairs for each height.
// Effect subtraction of zero-point = 128 by XOR of sign bit.
@ -310,56 +320,52 @@ struct ProcessPerDepth<
// height 3, width 3, micro-blocks, sub-block 0 or 1, depth 4.
// filter_bank[3][2][4][4]; Sub-block, height 3, depth 4, width 4.
// Load zero-point into effective position of zero-padding of filter
// (register B, upper part).
filter_reg_0_b = vdupq_n_u8(kSignBit);
filter_reg_1_b = vdupq_n_u8(kSignBit);
filter_reg_2_b = vdupq_n_u8(kSignBit);
const uint8* filter_block_ptr = filter_block;
filter_reg_0_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_a, 0);
input_0_a = vld1q_lane_s8x8(filter_block_ptr, input_0_a, 0);
filter_block_ptr += depth;
filter_reg_0_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_b, 0);
input_0_b = vld1q_lane_s8x8(filter_block_ptr, input_0_b, 0);
filter_block_ptr += depth;
filter_reg_0_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_a, 1);
input_0_c = vld1q_lane_s8x8(filter_block_ptr, input_0_c, 0);
filter_block_ptr += depth;
filter_reg_1_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_a, 0);
input_1_a = vld1q_lane_s8x8(filter_block_ptr, input_1_a, 0);
filter_block_ptr += depth;
filter_reg_1_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_b, 0);
input_1_b = vld1q_lane_s8x8(filter_block_ptr, input_1_b, 0);
filter_block_ptr += depth;
filter_reg_1_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_a, 1);
input_1_c = vld1q_lane_s8x8(filter_block_ptr, input_1_c, 0);
filter_block_ptr += depth;
filter_reg_2_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_a, 0);
input_2_a = vld1q_lane_s8x8(filter_block_ptr, input_2_a, 0);
filter_block_ptr += depth;
filter_reg_2_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_b, 0);
input_2_b = vld1q_lane_s8x8(filter_block_ptr, input_2_b, 0);
filter_block_ptr += depth;
filter_reg_2_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_a, 1);
input_2_c = vld1q_lane_s8x8(filter_block_ptr, input_2_c, 0);
filter_reg_0_a = veorq_s8(filter_reg_0_a, sign_bit);
filter_reg_0_b = veorq_s8(filter_reg_0_b, sign_bit);
filter_reg_1_a = veorq_s8(filter_reg_1_a, sign_bit);
filter_reg_1_b = veorq_s8(filter_reg_1_b, sign_bit);
filter_reg_2_a = veorq_s8(filter_reg_2_a, sign_bit);
filter_reg_2_b = veorq_s8(filter_reg_2_b, sign_bit);
filter_0_a = vzip1q_s8(input_0_a, input_0_b);
filter_0_b = vzip1q_s8(input_0_c, sign_bit);
filter_1_a = vzip1q_s8(input_1_a, input_1_b);
filter_1_b = vzip1q_s8(input_1_c, sign_bit);
filter_2_a = vzip1q_s8(input_2_a, input_2_b);
filter_2_b = vzip1q_s8(input_2_c, sign_bit);
filter_0_a = veorq_s8(filter_0_a, sign_bit);
filter_0_b = veorq_s8(filter_0_b, sign_bit);
filter_1_a = veorq_s8(filter_1_a, sign_bit);
filter_1_b = veorq_s8(filter_1_b, sign_bit);
filter_2_a = veorq_s8(filter_2_a, sign_bit);
filter_2_b = veorq_s8(filter_2_b, sign_bit);
vzipq_s8x2_in_place(&filter_0_a, &filter_0_b);
vzipq_s8x2_in_place(&filter_1_a, &filter_1_b);
vzipq_s8x2_in_place(&filter_2_a, &filter_2_b);
vzipq_s8_in_place(&filter_reg_0_a, &filter_reg_0_b);
vzipq_s8_in_place(&filter_reg_1_a, &filter_reg_1_b);
vzipq_s8_in_place(&filter_reg_2_a, &filter_reg_2_b);
vzipq_s8x2_in_place(&filter_reg_0_a, &filter_reg_0_b);
vzipq_s8x2_in_place(&filter_reg_1_a, &filter_reg_1_b);
vzipq_s8x2_in_place(&filter_reg_2_a, &filter_reg_2_b);
vst1q_s8(shuffled_filter_data, filter_reg_0_a);
vst1q_s8(shuffled_filter_data, filter_0_a);
shuffled_filter_data += 16;
vst1q_s8(shuffled_filter_data, filter_reg_0_b);
vst1q_s8(shuffled_filter_data, filter_0_b);
shuffled_filter_data += 16;
vst1q_s8(shuffled_filter_data, filter_reg_1_a);
vst1q_s8(shuffled_filter_data, filter_1_a);
shuffled_filter_data += 16;
vst1q_s8(shuffled_filter_data, filter_reg_1_b);
vst1q_s8(shuffled_filter_data, filter_1_b);
shuffled_filter_data += 16;
vst1q_s8(shuffled_filter_data, filter_reg_2_a);
vst1q_s8(shuffled_filter_data, filter_2_a);
shuffled_filter_data += 16;
vst1q_s8(shuffled_filter_data, filter_reg_2_b);
vst1q_s8(shuffled_filter_data, filter_2_b);
shuffled_filter_data += 16;
int32x4_t adjusted_bias_data_a = vld1q_s32(bias_data);
@ -369,13 +375,13 @@ struct ProcessPerDepth<
// For instance, if input_offset == 128, no adjustment is needed.
int32x4_t filter_sum_a = vdupq_n_s32(0);
filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_0_a, ones_vector);
filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_1_a, ones_vector);
filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_2_a, ones_vector);
filter_sum_a = vdotq_s32(filter_sum_a, filter_0_a, ones_vector);
filter_sum_a = vdotq_s32(filter_sum_a, filter_1_a, ones_vector);
filter_sum_a = vdotq_s32(filter_sum_a, filter_2_a, ones_vector);
int32x4_t filter_sum_b = vdupq_n_s32(0);
filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_0_b, ones_vector);
filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_1_b, ones_vector);
filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_2_b, ones_vector);
filter_sum_b = vdotq_s32(filter_sum_b, filter_0_b, ones_vector);
filter_sum_b = vdotq_s32(filter_sum_b, filter_1_b, ones_vector);
filter_sum_b = vdotq_s32(filter_sum_b, filter_2_b, ones_vector);
adjusted_bias_data_a = vmlaq_n_s32(adjusted_bias_data_a, filter_sum_a,
input_offset_difference);