From 11272973aced68d95d2f47a9fd8110ce3dd82793 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 7 May 2019 20:51:59 -0700 Subject: [PATCH] DepthwiseConv, dot-product kernel asm, copy optimizations (per-depth). PiperOrigin-RevId: 247143866 --- .../depthwiseconv_uint8_transitional.h | 98 ++++++++++--------- 1 file changed, 52 insertions(+), 46 deletions(-) diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_transitional.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_transitional.h index 47d969abf10..b787f3f4e4c 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_transitional.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_transitional.h @@ -293,12 +293,22 @@ struct ProcessPerDepth< const int8x16_t ones_vector = vdupq_n_s8(1); // Simulate NEON-register transposition of subset of filter. - int8x16_t filter_reg_0_a; - int8x16_t filter_reg_0_b; - int8x16_t filter_reg_1_a; - int8x16_t filter_reg_1_b; - int8x16_t filter_reg_2_a; - int8x16_t filter_reg_2_b; + int8x16_t input_0_a; + int8x16_t input_0_b; + int8x16_t input_0_c; + int8x16_t input_1_a; + int8x16_t input_1_b; + int8x16_t input_1_c; + int8x16_t input_2_a; + int8x16_t input_2_b; + int8x16_t input_2_c; + + int8x16_t filter_0_a; + int8x16_t filter_0_b; + int8x16_t filter_1_a; + int8x16_t filter_1_b; + int8x16_t filter_2_a; + int8x16_t filter_2_b; // Register pairs for each height. // Effect subtraction of zero-point = 128 by XOR of sign bit. @@ -310,56 +320,52 @@ struct ProcessPerDepth< // height 3, width 3, micro-blocks, sub-block 0 or 1, depth 4. // filter_bank[3][2][4][4]; Sub-block, height 3, depth 4, width 4. - // Load zero-point into effective position of zero-padding of filter - // (register B, upper part). - filter_reg_0_b = vdupq_n_u8(kSignBit); - filter_reg_1_b = vdupq_n_u8(kSignBit); - filter_reg_2_b = vdupq_n_u8(kSignBit); - const uint8* filter_block_ptr = filter_block; - filter_reg_0_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_a, 0); + input_0_a = vld1q_lane_s8x8(filter_block_ptr, input_0_a, 0); filter_block_ptr += depth; - filter_reg_0_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_b, 0); + input_0_b = vld1q_lane_s8x8(filter_block_ptr, input_0_b, 0); filter_block_ptr += depth; - filter_reg_0_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_a, 1); + input_0_c = vld1q_lane_s8x8(filter_block_ptr, input_0_c, 0); filter_block_ptr += depth; - filter_reg_1_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_a, 0); + input_1_a = vld1q_lane_s8x8(filter_block_ptr, input_1_a, 0); filter_block_ptr += depth; - filter_reg_1_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_b, 0); + input_1_b = vld1q_lane_s8x8(filter_block_ptr, input_1_b, 0); filter_block_ptr += depth; - filter_reg_1_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_a, 1); + input_1_c = vld1q_lane_s8x8(filter_block_ptr, input_1_c, 0); filter_block_ptr += depth; - filter_reg_2_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_a, 0); + input_2_a = vld1q_lane_s8x8(filter_block_ptr, input_2_a, 0); filter_block_ptr += depth; - filter_reg_2_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_b, 0); + input_2_b = vld1q_lane_s8x8(filter_block_ptr, input_2_b, 0); filter_block_ptr += depth; - filter_reg_2_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_a, 1); + input_2_c = vld1q_lane_s8x8(filter_block_ptr, input_2_c, 0); - filter_reg_0_a = veorq_s8(filter_reg_0_a, sign_bit); - filter_reg_0_b = veorq_s8(filter_reg_0_b, sign_bit); - filter_reg_1_a = veorq_s8(filter_reg_1_a, sign_bit); - filter_reg_1_b = veorq_s8(filter_reg_1_b, sign_bit); - filter_reg_2_a = veorq_s8(filter_reg_2_a, sign_bit); - filter_reg_2_b = veorq_s8(filter_reg_2_b, sign_bit); + filter_0_a = vzip1q_s8(input_0_a, input_0_b); + filter_0_b = vzip1q_s8(input_0_c, sign_bit); + filter_1_a = vzip1q_s8(input_1_a, input_1_b); + filter_1_b = vzip1q_s8(input_1_c, sign_bit); + filter_2_a = vzip1q_s8(input_2_a, input_2_b); + filter_2_b = vzip1q_s8(input_2_c, sign_bit); + filter_0_a = veorq_s8(filter_0_a, sign_bit); + filter_0_b = veorq_s8(filter_0_b, sign_bit); + filter_1_a = veorq_s8(filter_1_a, sign_bit); + filter_1_b = veorq_s8(filter_1_b, sign_bit); + filter_2_a = veorq_s8(filter_2_a, sign_bit); + filter_2_b = veorq_s8(filter_2_b, sign_bit); + vzipq_s8x2_in_place(&filter_0_a, &filter_0_b); + vzipq_s8x2_in_place(&filter_1_a, &filter_1_b); + vzipq_s8x2_in_place(&filter_2_a, &filter_2_b); - vzipq_s8_in_place(&filter_reg_0_a, &filter_reg_0_b); - vzipq_s8_in_place(&filter_reg_1_a, &filter_reg_1_b); - vzipq_s8_in_place(&filter_reg_2_a, &filter_reg_2_b); - vzipq_s8x2_in_place(&filter_reg_0_a, &filter_reg_0_b); - vzipq_s8x2_in_place(&filter_reg_1_a, &filter_reg_1_b); - vzipq_s8x2_in_place(&filter_reg_2_a, &filter_reg_2_b); - - vst1q_s8(shuffled_filter_data, filter_reg_0_a); + vst1q_s8(shuffled_filter_data, filter_0_a); shuffled_filter_data += 16; - vst1q_s8(shuffled_filter_data, filter_reg_0_b); + vst1q_s8(shuffled_filter_data, filter_0_b); shuffled_filter_data += 16; - vst1q_s8(shuffled_filter_data, filter_reg_1_a); + vst1q_s8(shuffled_filter_data, filter_1_a); shuffled_filter_data += 16; - vst1q_s8(shuffled_filter_data, filter_reg_1_b); + vst1q_s8(shuffled_filter_data, filter_1_b); shuffled_filter_data += 16; - vst1q_s8(shuffled_filter_data, filter_reg_2_a); + vst1q_s8(shuffled_filter_data, filter_2_a); shuffled_filter_data += 16; - vst1q_s8(shuffled_filter_data, filter_reg_2_b); + vst1q_s8(shuffled_filter_data, filter_2_b); shuffled_filter_data += 16; int32x4_t adjusted_bias_data_a = vld1q_s32(bias_data); @@ -369,13 +375,13 @@ struct ProcessPerDepth< // For instance, if input_offset == 128, no adjustment is needed. int32x4_t filter_sum_a = vdupq_n_s32(0); - filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_0_a, ones_vector); - filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_1_a, ones_vector); - filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_2_a, ones_vector); + filter_sum_a = vdotq_s32(filter_sum_a, filter_0_a, ones_vector); + filter_sum_a = vdotq_s32(filter_sum_a, filter_1_a, ones_vector); + filter_sum_a = vdotq_s32(filter_sum_a, filter_2_a, ones_vector); int32x4_t filter_sum_b = vdupq_n_s32(0); - filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_0_b, ones_vector); - filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_1_b, ones_vector); - filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_2_b, ones_vector); + filter_sum_b = vdotq_s32(filter_sum_b, filter_0_b, ones_vector); + filter_sum_b = vdotq_s32(filter_sum_b, filter_1_b, ones_vector); + filter_sum_b = vdotq_s32(filter_sum_b, filter_2_b, ones_vector); adjusted_bias_data_a = vmlaq_n_s32(adjusted_bias_data_a, filter_sum_a, input_offset_difference);