diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc index 8b730cd9069..3e48d95a082 100644 --- a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc @@ -145,9 +145,29 @@ inline void DispatchDepthwiseConv( break; #endif } - case DepthwiseConvImplementation::kUseNeon3x3DotProduct: - // TODO(b/118426582) Placeholder for future dispatches. + case DepthwiseConvImplementation::kUseNeon3x3DotProduct: { +#if defined(__ARM_FEATURE_DOTPROD) && defined(__aarch64__) + DotProduct3x3KernelType kernel_type = + optimized_ops::depthwise_conv::CategorizeDotProductKernel( + input_shape, filter_shape, params); + + ASSERT_TRUE( + kernel_type == DotProduct3x3KernelType::kPlain || + kernel_type == DotProduct3x3KernelType::kStride2 || + kernel_type == + DotProduct3x3KernelType::kWithDepthMultiplicationStride1 || + kernel_type == + DotProduct3x3KernelType::kWithDepthMultiplicationStride2) + << "Kernel type = " << static_cast(kernel_type); + + optimized_ops::depthwise_conv::DepthwiseConvDotProduct3x3< + DepthwiseConvImplementation::kUseNeon3x3DotProduct>( + params, input_shape, input_data, filter_shape, filter_data, + bias_shape, bias_data, output_shape, output_data); + return; +#endif break; + } case DepthwiseConvImplementation::kUseCModel3x3DotProduct: { DotProduct3x3KernelType kernel_type = optimized_ops::depthwise_conv::CategorizeDotProductKernel( @@ -181,7 +201,6 @@ inline void DispatchDepthwiseConv( return; } case DepthwiseConvImplementation::kUseUnwound3x3DotProduct: { - using optimized_ops::depthwise_conv::DotProduct3x3KernelType; DotProduct3x3KernelType kernel_type = optimized_ops::depthwise_conv::CategorizeDotProductKernel( input_shape, filter_shape, params); @@ -200,7 +219,6 @@ inline void DispatchDepthwiseConv( } case DepthwiseConvImplementation::kUseIntrinsics3x3DotProduct: { #if defined(USE_NEON) - using optimized_ops::depthwise_conv::DotProduct3x3KernelType; DotProduct3x3KernelType kernel_type = optimized_ops::depthwise_conv::CategorizeDotProductKernel( input_shape, filter_shape, params); @@ -794,5 +812,21 @@ INSTANTIATE_TEST_SUITE_P( TestParam::TestNameSuffix); #endif +#if defined(__ARM_FEATURE_DOTPROD) && defined(__aarch64__) +INSTANTIATE_TEST_SUITE_P( + NeonAsm, DepthwiseConvTest, + testing::Combine( + Values(DepthwiseConvImplementation:: + kUseNeon3x3DotProduct), // forced_invocation + Values(1000), // tests_to_run + Bool(), // test_stride + Bool(), // test_pad + Bool(), // test_depth_multiplier + Values(DepthwiseConvOutputRounding::kUpward), // output_rounding + Values(false) // loose_tolerance + ), + TestParam::TestNameSuffix); +#endif + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index 18011c1cc91..50de905db17 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -473,7 +473,8 @@ struct DepthwiseConvWindow<8, 1, 1> { // Set "constant" registers. These registers may be replaced with temp // values from time to time when there are not enough NEON registers. // We use x9--x15 general purpose registers as they are caller-saved - // temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT + // temporary registers (see + // http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT "ldr w9, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" "ldr x3, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" "cmp %w[output_window_height], #2\n" @@ -3752,8 +3753,7 @@ struct KernelMacroBlock { // implementation rather than conforming to style. }; -#ifdef USE_NEON -#ifdef __aarch64__ +#if defined(USE_NEON) && defined(__aarch64__) // Experiments suggest that a modest performance improvement is seen, at least // on 855 chipset big cores, with cache hints. inline void PreloadInputBlock( @@ -3781,8 +3781,3187 @@ inline void PreloadInputBlock( } } } -#endif -#endif // USE_NEON + +template <> +struct ProcessPerDepth { + static void ProcessPerDepthIntrinsics( + const uint8* filter_data, const int32* bias_data, + int8* shuffled_filter_data, int32* adjusted_bias_data, + const DepthwiseConvDotProdParams* function_params) { + const int depth = function_params->output_depth; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int bias_increment = function_params->bias_increment; + + constexpr int kSymmetricZeroPoint = 128; + constexpr uint8 kSignBit = 0x80; + const int32 input_offset = function_params->input_offset; + TFLITE_DCHECK_GE(input_offset, -255); + TFLITE_DCHECK_LE(input_offset, 0); + const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + const int8x16_t ones_vector = vdupq_n_s8(1); + + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_0_b; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_1_b; + int8x16_t filter_reg_2_a; + int8x16_t filter_reg_2_b; + + // Register pairs for each height. + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + + const uint8* filter_block = filter_data; + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + // Filter data is provided as filter_block[3][3][depth/8][2][4]. + // height 3, width 3, micro-blocks, sub-block 0 or 1, depth 4. + // filter_bank[3][2][4][4]; Sub-block, height 3, depth 4, width 4. + + // Load zero-point into effective position of zero-padding of filter + // (register B, upper part). + filter_reg_0_b = vdupq_n_u8(kSignBit); + filter_reg_1_b = vdupq_n_u8(kSignBit); + filter_reg_2_b = vdupq_n_u8(kSignBit); + + const uint8* filter_block_ptr = filter_block; + filter_reg_0_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_a, 0); + filter_block_ptr += depth; + filter_reg_0_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_b, 0); + filter_block_ptr += depth; + filter_reg_0_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_a, 1); + filter_block_ptr += depth; + filter_reg_1_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_a, 0); + filter_block_ptr += depth; + filter_reg_1_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_b, 0); + filter_block_ptr += depth; + filter_reg_1_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_a, 1); + filter_block_ptr += depth; + filter_reg_2_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_a, 0); + filter_block_ptr += depth; + filter_reg_2_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_b, 0); + filter_block_ptr += depth; + filter_reg_2_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_a, 1); + + filter_reg_0_a = veorq_s8(filter_reg_0_a, sign_bit); + filter_reg_0_b = veorq_s8(filter_reg_0_b, sign_bit); + filter_reg_1_a = veorq_s8(filter_reg_1_a, sign_bit); + filter_reg_1_b = veorq_s8(filter_reg_1_b, sign_bit); + filter_reg_2_a = veorq_s8(filter_reg_2_a, sign_bit); + filter_reg_2_b = veorq_s8(filter_reg_2_b, sign_bit); + + vzipq_s8_in_place(&filter_reg_0_a, &filter_reg_0_b); + vzipq_s8_in_place(&filter_reg_1_a, &filter_reg_1_b); + vzipq_s8_in_place(&filter_reg_2_a, &filter_reg_2_b); + vzipq_s8x2_in_place(&filter_reg_0_a, &filter_reg_0_b); + vzipq_s8x2_in_place(&filter_reg_1_a, &filter_reg_1_b); + vzipq_s8x2_in_place(&filter_reg_2_a, &filter_reg_2_b); + + vst1q_s8(shuffled_filter_data, filter_reg_0_a); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_0_b); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_1_a); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_1_b); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_2_a); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_2_b); + shuffled_filter_data += 16; + + int32x4_t adjusted_bias_data_a = vld1q_s32(bias_data); + bias_data += bias_increment; + int32x4_t adjusted_bias_data_b = vld1q_s32(bias_data); + bias_data += bias_increment; + // For instance, if input_offset == 128, no adjustment is needed. + + int32x4_t filter_sum_a = vdupq_n_s32(0); + filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_0_a, ones_vector); + filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_1_a, ones_vector); + filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_2_a, ones_vector); + int32x4_t filter_sum_b = vdupq_n_s32(0); + filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_0_b, ones_vector); + filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_1_b, ones_vector); + filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_2_b, ones_vector); + + adjusted_bias_data_a = vmlaq_n_s32(adjusted_bias_data_a, filter_sum_a, + input_offset_difference); + adjusted_bias_data_b = vmlaq_n_s32(adjusted_bias_data_b, filter_sum_b, + input_offset_difference); + + vst1q_s32(adjusted_bias_data, adjusted_bias_data_a); + adjusted_bias_data += 4; + vst1q_s32(adjusted_bias_data, adjusted_bias_data_b); + adjusted_bias_data += 4; + + filter_block += 8; + } + } + + static inline void Run(const uint8* filter_data, const int32* bias_data, + int8* shuffled_filter_data, int32* adjusted_bias_data, + const DepthwiseConvDotProdParams* function_params) { + ProcessPerDepthIntrinsics(filter_data, bias_data, shuffled_filter_data, + adjusted_bias_data, function_params); + } +}; + +template <> +struct PackMacroBlock { + static inline void PackMacroBlockNeon( + const uint8* input_block_data, int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + TFLITE_DCHECK_EQ(function_params->padding_bottom, 0); + TFLITE_DCHECK_EQ(function_params->padding_top, 0); + TFLITE_DCHECK_EQ(function_params->padding_left, 0); + TFLITE_DCHECK_EQ(function_params->padding_right, 0); + const int workspace_height_stride = + function_params->workspace_height_stride; + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + const int input_depth = function_params->input_depth; + + static const uint8 perm_data[64] = { + 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51, // + 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55, + 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59, + 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63}; + + TFLITE_DCHECK_GE(depth_micro_repeats, 0); + constexpr uint8 kSignBit = 0x80; + const int micro_block_size = 4 * 8; + const int depth_advance = width_overall_micro_repeats * micro_block_size; + const int width_advance = + micro_block_size * + (1 - depth_micro_repeats * width_overall_micro_repeats); + const int height_advance = workspace_height_stride - + width_overall_micro_repeats * micro_block_size; + const int input_depth_skip = 4 * input_depth - 8 * depth_micro_repeats; + + // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON + // code. Note the blocks of 4x4 are still interleaved down the depth. + int8x16_t work_reg_a; + int8x16_t work_reg_b; + const int8x16_t perm_data_0 = vld1q_u8(perm_data); + const int8x16_t perm_data_1 = vld1q_u8(perm_data + 16); + const int8x16_t perm_data_2 = vld1q_u8(perm_data + 32); + const int8x16_t perm_data_3 = vld1q_u8(perm_data + 48); + + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + + // Work through one slice, by row, at a time. + int8* scratch_data_0 = scratch_block_data; + + for (int k_height = 0; k_height < block_height; ++k_height) { + const uint8* input_data_0 = input_block_data; + const uint8* input_data_1 = input_block_data + input_depth; + const uint8* input_data_2 = input_block_data + 2 * input_depth; + const uint8* input_data_3 = input_block_data + 3 * input_depth; + + // Traverse the width one point at a time, but the depth in (micro) blocks + // of size 8. + // + // The depth and width margins, which are filled with "zeros", may be + // larger than is strictly needed to calculate output. This is because the + // conv calculation is performed across complete micro blocks. + for (int j_width = 0; j_width < input_width_micro_repeats; ++j_width) { + int i_depth = 0; + for (; i_depth < depth_micro_repeats - 1; i_depth += 2) { + int8x16x4_t input_data; + input_data.val[0] = vld1q_u8(input_data_0); + input_data.val[1] = vld1q_u8(input_data_1); + input_data.val[2] = vld1q_u8(input_data_2); + input_data.val[3] = vld1q_u8(input_data_3); + input_data_1 += 16; + input_data_0 += 16; + + int8x16_t tmp_0 = vqtbl4q_s8(input_data, perm_data_0); + int8x16_t tmp_1 = vqtbl4q_s8(input_data, perm_data_1); + work_reg_a = veorq_s8(tmp_0, sign_bit); + work_reg_b = veorq_s8(tmp_1, sign_bit); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_2 += 16; + input_data_3 += 16; + + tmp_0 = vqtbl4q_s8(input_data, perm_data_2); + tmp_1 = vqtbl4q_s8(input_data, perm_data_3); + work_reg_a = veorq_s8(tmp_0, sign_bit); + work_reg_b = veorq_s8(tmp_1, sign_bit); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + } + for (; i_depth < depth_micro_repeats; ++i_depth) { + int8x16x4_t input_data; + input_data.val[0] = + vld1q_lane_s8x8(input_data_0, input_data.val[0], 0); + input_data.val[1] = + vld1q_lane_s8x8(input_data_1, input_data.val[1], 0); + input_data.val[2] = + vld1q_lane_s8x8(input_data_2, input_data.val[2], 0); + input_data.val[3] = + vld1q_lane_s8x8(input_data_3, input_data.val[3], 0); + input_data_1 += 8; + input_data_0 += 8; + + int8x16_t tmp_0 = vqtbl4q_s8(input_data, perm_data_0); + int8x16_t tmp_1 = vqtbl4q_s8(input_data, perm_data_1); + work_reg_a = veorq_s8(tmp_0, sign_bit); + work_reg_b = veorq_s8(tmp_1, sign_bit); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } + if (width_overall_micro_repeats > input_width_micro_repeats) { + TFLITE_DCHECK_EQ(width_overall_micro_repeats, + input_width_micro_repeats + 1); + TFLITE_DCHECK_GT(residual_width, 0); + TFLITE_DCHECK_LT(residual_width, 4); + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vdupq_n_u8(kSignBit); + work_reg_a = vld1q_lane_s8x8(input_data_0, work_reg_a, 0); + work_reg_b = vdupq_n_u8(kSignBit); + if (residual_width > 1) { + work_reg_b = + vld1q_lane_s8x8(input_data_0 + input_depth, work_reg_b, 0); + if (residual_width == 3) { + work_reg_a = vld1q_lane_s8x8(input_data_0 + 2 * input_depth, + work_reg_a, 1); + } + } + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_0 += 8; + input_data_1 += 8; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } + scratch_data_0 += height_advance; + input_block_data += input_height_stride; + } + TFLITE_DCHECK_EQ( + scratch_data_0, + scratch_block_data + block_height * workspace_height_stride); + } + + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + PreloadInputBlock(input_block_data, function_params); + PackMacroBlockNeon(input_block_data, scratch_block_data, function_params); + } +}; + +template <> +struct PackMacroBlock { + static inline void PackMacroBlockNeon( + int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + constexpr uint8 kSignBit = 0x80; + + const int workspace_height_stride = + function_params->workspace_height_stride; + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + const int input_depth = function_params->input_depth; + + const int padding_left = function_params->padding_left; + const int padding_right = function_params->padding_right; + const int padding_top = function_params->padding_top; + const int padding_bottom = function_params->padding_bottom; + + TFLITE_DCHECK_GT(depth_micro_repeats, 0); + constexpr int kSymmetricZeroPoint = 128; + + const int micro_block_size = 4 * 8; + const int depth_advance = width_overall_micro_repeats * micro_block_size; + const int width_advance = + micro_block_size * + (1 - depth_micro_repeats * width_overall_micro_repeats); + const int height_advance = workspace_height_stride - + width_overall_micro_repeats * micro_block_size; + const int input_depth_skip = 4 * input_depth - 8 * depth_micro_repeats; + + const bool leading_width_padding = + padding_left > 0 && width_block_number == 0; + const bool trailing_width_padding = + padding_right > 0 && + width_block_number == (function_params->width_macro_count - 1); + const bool leading_height_padding = + padding_top > 0 && height_block_number < 0; + const bool trailing_height_padding = + padding_bottom > 0 && + height_block_number == (function_params->height_macro_count - 1); + + const int32 input_offset = function_params->input_offset; + const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + + // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON + // code. Note the blocks of 4x4 are still interleaved down the depth. + int8x16_t work_reg_a; + int8x16_t work_reg_b; + + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + + // Work through one slice, by row, at a time. + int8* scratch_data_0 = scratch_block_data; + + int copy_block_height = block_height; + if (leading_height_padding) { + copy_block_height -= 1; + memset(scratch_data_0, -input_offset_difference, workspace_height_stride); + scratch_data_0 += workspace_height_stride; + input_block_data += input_height_stride; + } + if (trailing_height_padding) { + copy_block_height -= 1; + } + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + const uint8* input_data_0 = input_block_data; + const uint8* input_data_1 = input_block_data + input_depth; + const uint8* input_data_2 = input_block_data + 2 * input_depth; + const uint8* input_data_3 = input_block_data + 3 * input_depth; + + // Traverse the width one point at a time, but the depth in (micro) blocks + // of size 8. + // + // The depth and width margins, which are filled with "zeros", may be + // larger than is strictly needed to calculate output. This is because the + // conv calculation is performed across complete micro blocks. + for (int j_width = 0; j_width < width_overall_micro_repeats; ++j_width) { + // Figure out division of work (available input vs zero-ed). + int adjusted_residual_width = + j_width == (input_width_micro_repeats) ? residual_width : 4; + + if (trailing_width_padding && + j_width == (width_overall_micro_repeats - 1)) { + adjusted_residual_width -= 1; + } + int start_width = 0; + if (leading_width_padding && j_width == 0) { + start_width = 1; + } + if (start_width == 0) { + if (adjusted_residual_width == 4) { + // Load, then zero. + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vld1q_lane_s8x8(input_data_2, work_reg_a, 1); + work_reg_b = vld1q_lane_s8x8(input_data_3, work_reg_b, 1); + work_reg_b = vld1q_lane_s8x8(input_data_1, work_reg_b, 0); + input_data_1 += 8; + work_reg_a = vld1q_lane_s8x8(input_data_0, work_reg_a, 0); + input_data_0 += 8; + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + scratch_data_0 += 16; + vst1q_s8(scratch_data_0, work_reg_b); + + scratch_data_0 += depth_advance - 16; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } else { + TFLITE_DCHECK_LT(adjusted_residual_width, 4); + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vdupq_n_u8(-input_offset); + work_reg_b = vdupq_n_u8(-input_offset); + if (adjusted_residual_width > 0) { + work_reg_a = vld1q_lane_s8x8(input_data_0, work_reg_a, 0); + if (adjusted_residual_width > 1) { + work_reg_b = vld1q_lane_s8x8(input_data_0 + input_depth, + work_reg_b, 0); + if (adjusted_residual_width == 3) { + work_reg_a = vld1q_lane_s8x8(input_data_0 + 2 * input_depth, + work_reg_a, 1); + } + } + } + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_0 += 8; + input_data_1 += 8; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } + } else { + if (adjusted_residual_width == 4) { + // Load, then zero. + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vdupq_n_u8(-input_offset); + work_reg_a = vld1q_lane_s8x8(input_data_2, work_reg_a, 1); + work_reg_b = vld1q_lane_s8x8(input_data_3, work_reg_b, 1); + work_reg_b = vld1q_lane_s8x8(input_data_1, work_reg_b, 0); + input_data_1 += 8; + // Skip loading first column. + input_data_0 += 8; + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + scratch_data_0 += 16; + vst1q_s8(scratch_data_0, work_reg_b); + + scratch_data_0 += depth_advance - 16; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } else { + TFLITE_DCHECK_LT(adjusted_residual_width, 4); + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vdupq_n_u8(-input_offset); + // Skip loading first column. + work_reg_b = vdupq_n_u8(-input_offset); + if (adjusted_residual_width > 1) { + work_reg_b = + vld1q_lane_s8x8(input_data_0 + input_depth, work_reg_b, 0); + if (adjusted_residual_width == 3) { + work_reg_a = vld1q_lane_s8x8(input_data_0 + 2 * input_depth, + work_reg_a, 1); + } + } + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_0 += 8; + input_data_1 += 8; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } + } + } + scratch_data_0 += height_advance; + input_block_data += input_height_stride; + } + + if (trailing_height_padding) { + memset(scratch_data_0, -input_offset_difference, workspace_height_stride); + scratch_data_0 += workspace_height_stride; + } + + TFLITE_DCHECK_EQ( + scratch_data_0, + scratch_block_data + block_height * workspace_height_stride); + } + + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + PreloadInputBlock(input_block_data, function_params); + PackMacroBlockNeon(height_block_number, width_block_number, + input_block_data, scratch_block_data, function_params); + } +}; + +template <> +struct PackMacroBlock { + static inline void PackMacroBlockNeon( + int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + + const int padding_left = function_params->padding_left; + const int padding_right = function_params->padding_right; + const int padding_top = function_params->padding_top; + const int padding_bottom = function_params->padding_bottom; + + constexpr int kSymmetricZeroPoint = 128; + + TFLITE_DCHECK_GE(workspace_height_stride, 4 * width_overall_micro_repeats); + + const bool leading_width_padding = + padding_left > 0 && width_block_number == 0; + const bool trailing_width_padding = + padding_right > 0 && + width_block_number == (function_params->width_macro_count - 1); + const bool leading_height_padding = + padding_top > 0 && height_block_number < 0; + const bool trailing_height_padding = + padding_bottom > 0 && + height_block_number == (function_params->height_macro_count - 1); + + const int32 input_offset = function_params->input_offset; + const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + + // Work through one slice, by row, at a time. + int8* scratch_data_base = scratch_block_data; + + int copy_block_height = block_height; + if (leading_height_padding) { + copy_block_height -= 1; + memset(scratch_data_base, -input_offset_difference, + workspace_height_stride + kWorkspaceExtension); + scratch_data_base += workspace_height_stride; + input_block_data += input_height_stride; + } + if (trailing_height_padding) { + copy_block_height -= 1; + } + + int adjusted_residual_width = + input_width_micro_repeats < width_overall_micro_repeats ? residual_width + : 4; + + if (trailing_width_padding) { + adjusted_residual_width -= 1; + } + int start_width = 0; + if (leading_width_padding) { + start_width = 1; + input_block_data += 1; + } + + const int copy_size = (width_overall_micro_repeats - 1) * 4 + + adjusted_residual_width - start_width; + // Adjusted so that later conditionals are simplified. + const int copy_size_adjusted = + trailing_width_padding ? copy_size + 1 : copy_size; + + TFLITE_DCHECK_LE( + copy_size, + input_height_stride - width_block_number * input_width_micro_repeats); + // We may drop up to stride-1 of trailing input. + TFLITE_DCHECK_GE(copy_size, input_height_stride - 1); + + int scratch_data_offset = 0; + int input_block_offset = 0; + + constexpr uint8 kSignBit = 0x80; + + // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON + // code. Note the blocks of 4x4 are still interleaved down the depth. + int8x16_t work_reg; + int8x8_t half_work_reg; + int8x8_t padding_mask; + + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + const uint8x16_t padding_reg = vdupq_n_u8(-input_offset); + padding_mask = vdup_n_s8(-1); + half_work_reg = vdup_n_s8(0); + + if (copy_size >= 16) { + const int copy_remaining = (copy_size + start_width) & 0x7; + padding_mask = vshl_u64(padding_mask, vdup_n_s64(8 * copy_remaining)); + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_data_base + scratch_data_offset; + + int copy_done = 0; + + // The surrounding condition ensures that we always need at least one + // iteration of the main copy loop. In the case of leading width + // padding, we unroll this specially. + if (leading_width_padding) { + work_reg = vld1q_u8(input_block_data + input_block_offset); + work_reg = vextq_s8(padding_reg, work_reg, 15); + work_reg = veorq_s8(work_reg, sign_bit); + vst1q_s8(scratch_data, work_reg); + copy_done += 15; + } + + // Main copy loop. + for (; (copy_done + 16) <= copy_size; copy_done += 16) { + work_reg = + vld1q_u8(input_block_data + input_block_offset + copy_done); + work_reg = veorq_s8(work_reg, sign_bit); + TFLITE_DCHECK_EQ((start_width + copy_done) % 16, 0); + vst1q_s8(scratch_data + start_width + copy_done, work_reg); + } + + if (copy_done + 8 <= copy_size) { + half_work_reg = + vld1_u8(input_block_data + input_block_offset + copy_done); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ((start_width + copy_done) % 8, 0); + vst1_s8(scratch_data + start_width + copy_done, half_work_reg); + copy_done += 8; + } + + TFLITE_DCHECK_EQ(copy_remaining, copy_size - copy_done); + // Total amount + // = copy_size - copy_done + 4 - adjusted_residual_width + // = width_overall_micro_repeats * 4 - start_width - copy_done. + // Undone micro blocks + // = width_overall_micro_repeats - (start_width + copy_done) / 4. + + // Conditional is (copy_remaining > 0 || trailing_width_padding). + if (copy_done < copy_size_adjusted) { + // Employ overlapping-load strategy in order to load full register, + // but use only part. + // This has the advantage of resulting in zeros after shifting. + half_work_reg = + vld1_u8(input_block_data + input_block_offset + copy_size - 8); + + half_work_reg = + vshl_u64(half_work_reg, vdup_n_s64(-8 * (8 - copy_remaining))); + half_work_reg = + vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ((start_width + copy_done) % 8, 0); + vst1_s8(scratch_data + start_width + copy_done, half_work_reg); + } + + // Trailing guard. + vst1_s8(scratch_data + start_width + copy_done, half_work_reg); + vst1_s8(scratch_data + start_width + copy_done + 8, half_work_reg); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else if (copy_size >= 4) { + const int copy_remaining = (copy_size + start_width) & 0x3; + padding_mask = vshl_u64(padding_mask, vdup_n_s64(8 * copy_remaining)); + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_data_base + scratch_data_offset; + + int copy_done = 0; + + // The surrounding condition ensures that we always need at least one + // iteration of the main copy loop. In the case of leading width + // padding, we unroll this specially. + if (leading_width_padding) { + half_work_reg = vld1_lane_8x4(input_block_data + input_block_offset, + half_work_reg, 0); + half_work_reg = vext_s8(vget_low_s8(padding_reg), half_work_reg, 7); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + vst1_lane_8x4(scratch_data, half_work_reg, 0); + copy_done += 3; + } + + // Main copy loop. + for (; (copy_done + 4) <= copy_size; copy_done += 4) { + // Important! Most compilation configurations will compile and run + // without the reinterpret_cast. Sanitizers may fail silently on + // lane-loading, with a obscure bug or mis-feature probably in + // unhygienic macro expansion. + half_work_reg = + vld1_lane_8x4(input_block_data + input_block_offset + copy_done, + half_work_reg, 0); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ((start_width + copy_done) % 4, 0); + vst1_lane_8x4(scratch_data + start_width + copy_done, half_work_reg, + 0); + } + + TFLITE_DCHECK_EQ(copy_remaining, copy_size - copy_done); + // Total amount + // = copy_size - copy_done + 4 - adjusted_residual_width + // = width_overall_micro_repeats * 4 - start_width - copy_done. + // Undone micro blocks + // = width_overall_micro_repeats - (start_width + copy_done) / 4. + + // Conditional is (copy_remaining > 0 || trailing_width_padding). + if (copy_done < copy_size_adjusted) { + TFLITE_DCHECK_LT(copy_remaining, 4); + // Employ overlapping-load strategy in order to load full register, + // but use only part. + // This has the advantage of resulting in zeros after shifting. + half_work_reg = vld1_lane_8x4( + input_block_data + input_block_offset + copy_size - 4, + half_work_reg, 0); + + half_work_reg = + vshl_u64(half_work_reg, vdup_n_s64(-8 * (4 - copy_remaining))); + half_work_reg = + vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ((start_width + copy_done) % 4, 0); + vst1_lane_8x4(scratch_data + start_width + copy_done, half_work_reg, + 0); + copy_done += 4; + } + // Trailing guard. + vst1_lane_8x4(scratch_data + start_width + copy_done, half_work_reg, 0); + vst1_lane_8x4(scratch_data + start_width + copy_done + 4, half_work_reg, + 0); + vst1_lane_8x4(scratch_data + start_width + copy_done + 8, half_work_reg, + 0); + vst1_lane_8x4(scratch_data + start_width + copy_done + 12, + half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else if (width_overall_micro_repeats == 2) { + // Special case of 1 + 3 + 1, padding + copy + padding. + // This is rarely executed in practice. + TFLITE_DCHECK_EQ(copy_size, 3); + TFLITE_DCHECK_EQ(start_width, 1); + TFLITE_DCHECK(leading_width_padding); + TFLITE_DCHECK(trailing_width_padding); + // ASM should use MOVI 64-bit set. + padding_mask = vcreate_u64(~0xffffff00L); + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + half_work_reg = vld1_lane_s8(reinterpret_cast( + input_block_data + input_block_offset), + half_work_reg, 1); + half_work_reg = + vld1_lane_s8(reinterpret_cast(input_block_data + + input_block_offset + 1), + half_work_reg, 2); + half_work_reg = + vld1_lane_s8(reinterpret_cast(input_block_data + + input_block_offset + 2), + half_work_reg, 3); + half_work_reg = + vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(scratch_data_offset % 8, 0); + vst1_s8(scratch_data_base + scratch_data_offset, half_work_reg); + + // Trailing guard. + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 4, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 8, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 12, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 16, + half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else { + TFLITE_DCHECK_EQ(width_overall_micro_repeats, 1); + const int copy_remaining = (copy_size + start_width) & 0x3; + padding_mask = vshl_u64(padding_mask, vdup_n_s64(8 * copy_remaining)); + if (leading_width_padding) { + padding_mask = vset_lane_u8(255, padding_mask, 0); + } + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + for (int i = 0; i < copy_size; ++i) { + half_work_reg = vshl_n_u64(half_work_reg, 8); + half_work_reg = vld1_lane_s8( + reinterpret_cast( + input_block_data + input_block_offset + copy_size - 1 - i), + half_work_reg, 0); + } + if (leading_width_padding) { + half_work_reg = vshl_n_s64(half_work_reg, 8); + } + half_work_reg = + vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(scratch_data_offset % 4, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset, half_work_reg, + 0); + + // Trailing guard. + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 4, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 8, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 12, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 16, + half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } + + scratch_data_base += copy_block_height * workspace_height_stride; + + if (trailing_height_padding) { + memset(scratch_data_base, -input_offset_difference, + workspace_height_stride + kWorkspaceExtension); + scratch_data_base += workspace_height_stride; + } + + TFLITE_DCHECK_EQ( + scratch_data_base, + scratch_block_data + block_height * workspace_height_stride); + } + + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + PreloadInputBlock(input_block_data, function_params); + PackMacroBlockNeon(height_block_number, width_block_number, + input_block_data, scratch_block_data, function_params); + } +}; + +template <> +struct PackMacroBlock { + static inline void PackMacroBlockNeon( + int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + + TFLITE_DCHECK_EQ(function_params->padding_left, 0); + TFLITE_DCHECK_EQ(function_params->padding_right, 0); + TFLITE_DCHECK_EQ(function_params->padding_top, 0); + TFLITE_DCHECK_EQ(function_params->padding_bottom, 0); + + TFLITE_DCHECK_GE(workspace_height_stride, 4 * width_overall_micro_repeats); + + // Work through one slice, by row, at a time. + int8* scratch_data_base = scratch_block_data; + + const int copy_block_height = block_height; + + int adjusted_residual_width = + input_width_micro_repeats < width_overall_micro_repeats ? residual_width + : 4; + + const int copy_size = + (width_overall_micro_repeats - 1) * 4 + adjusted_residual_width; + + TFLITE_DCHECK_LE( + copy_size, + input_height_stride - width_block_number * input_width_micro_repeats); + // We may drop up to stride-1 of trailing input. + TFLITE_DCHECK_GE(copy_size, input_height_stride - 1); + + int scratch_data_offset = 0; + int input_block_offset = 0; + + constexpr uint8 kSignBit = 0x80; + + // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON + // code. Note the blocks of 4x4 are still interleaved down the depth. + int8x16_t work_reg; + int8x8_t half_work_reg; + + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + half_work_reg = vdup_n_s8(0); + + if (copy_size >= 16) { + const int copy_remaining = copy_size & 0x7; + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_data_base + scratch_data_offset; + + int copy_done = 0; + + // Main copy loop. + for (; (copy_done + 16) <= copy_size; copy_done += 16) { + work_reg = + vld1q_u8(input_block_data + input_block_offset + copy_done); + work_reg = veorq_s8(work_reg, sign_bit); + TFLITE_DCHECK_EQ(copy_done % 16, 0); + vst1q_s8(scratch_data + copy_done, work_reg); + } + + if (copy_done + 8 <= copy_size) { + half_work_reg = + vld1_u8(input_block_data + input_block_offset + copy_done); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(copy_done % 8, 0); + vst1_s8(scratch_data + copy_done, half_work_reg); + copy_done += 8; + } + + TFLITE_DCHECK_EQ(copy_remaining, copy_size - copy_done); + // Total amount + // = copy_size - copy_done + 4 - adjusted_residual_width + // = width_overall_micro_repeats * 4 - start_width - copy_done. + // Undone micro blocks + // = width_overall_micro_repeats - (start_width + copy_done) / 4. + + // Conditional is (copy_remaining > 0 || trailing_width_padding). + if (copy_done < copy_size) { + // Employ overlapping-load strategy in order to load full register, + // but use only part. + // This has the advantage of resulting in zeros after shifting. + half_work_reg = + vld1_u8(input_block_data + input_block_offset + copy_size - 8); + + half_work_reg = + vshl_u64(half_work_reg, vdup_n_s64(-8 * (8 - copy_remaining))); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(copy_done % 8, 0); + vst1_s8(scratch_data + copy_done, half_work_reg); + copy_done += 8; + } + + // Trailing guard. + vst1_s8(scratch_data + copy_done, half_work_reg); + vst1_s8(scratch_data + copy_done + 8, half_work_reg); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else if (copy_size >= 4) { + const int copy_remaining = copy_size & 0x3; + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_data_base + scratch_data_offset; + + int copy_done = 0; + + // Main copy loop. + for (; (copy_done + 4) <= copy_size; copy_done += 4) { + // Important! Most compilation configurations will compile and run + // without the reinterpret_cast. Sanitizers may fail silently on + // lane-loading, with a obscure bug or mis-feature probably in + // unhygienic macro expansion. + half_work_reg = + vld1_lane_8x4(input_block_data + input_block_offset + copy_done, + half_work_reg, 0); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(copy_done % 4, 0); + vst1_lane_8x4(scratch_data + copy_done, half_work_reg, 0); + } + + TFLITE_DCHECK_EQ(copy_remaining, copy_size - copy_done); + // Total amount + // = copy_size - copy_done + 4 - adjusted_residual_width + // = width_overall_micro_repeats * 4 - start_width - copy_done. + // Undone micro blocks + // = width_overall_micro_repeats - (start_width + copy_done) / 4. + + // Conditional is (copy_remaining > 0 || trailing_width_padding). + if (copy_done < copy_size) { + TFLITE_DCHECK_LT(copy_remaining, 4); + // Employ overlapping-load strategy in order to load full register, + // but use only part. + // This has the advantage of resulting in zeros after shifting. + half_work_reg = vld1_lane_8x4( + input_block_data + input_block_offset + copy_size - 4, + half_work_reg, 0); + + half_work_reg = + vshl_u64(half_work_reg, vdup_n_s64(-8 * (4 - copy_remaining))); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(copy_done % 4, 0); + vst1_lane_8x4(scratch_data + copy_done, half_work_reg, 0); + copy_done += 4; + } + // Trailing guard. + vst1_lane_8x4(scratch_data + copy_done, half_work_reg, 0); + vst1_lane_8x4(scratch_data + copy_done + 4, half_work_reg, 0); + vst1_lane_8x4(scratch_data + copy_done + 8, half_work_reg, 0); + vst1_lane_8x4(scratch_data + copy_done + 12, half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else { + TFLITE_DCHECK_EQ(width_overall_micro_repeats, 1); + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + for (int i = 0; i < copy_size; ++i) { + half_work_reg = vshl_n_u64(half_work_reg, 8); + half_work_reg = vld1_lane_s8( + reinterpret_cast( + input_block_data + input_block_offset + copy_size - 1 - i), + half_work_reg, 0); + } + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(scratch_data_offset % 4, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset, half_work_reg, + 0); + + // Trailing guard. + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 4, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 8, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 12, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 16, + half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } + + scratch_data_base += copy_block_height * workspace_height_stride; + + TFLITE_DCHECK_EQ( + scratch_data_base, + scratch_block_data + block_height * workspace_height_stride); + } + + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + PreloadInputBlock(input_block_data, function_params); + PackMacroBlockNeon(height_block_number, width_block_number, + input_block_data, scratch_block_data, function_params); + } +}; + +template <> +struct KernelMacroBlock { + static inline void KernelMacroBlockNeon( + const int8* scratch_block_data, const int8* filter_workspace, + const int32* bias_data, uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int input_width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int depth = function_params->input_depth; + + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + const int bias_increment = function_params->bias_increment; + + TFLITE_DCHECK(depth_micro_repeats > 0); + const int width_micro_stride = 4 * 8; + const int depth_micro_stride = + width_micro_stride * input_width_overall_micro_repeats; + + const int32 output_activation_min = + function_params->quantized_activation_min; + const int32 output_activation_max = + function_params->quantized_activation_max; + const int32 output_multiplier = function_params->output_multiplier; + const int32 output_shift = function_params->output_shift; + const int32 output_offset = function_params->output_offset; + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LT(output_activation_min, 256); + TFLITE_DCHECK_GE(output_activation_max, 0); + TFLITE_DCHECK_LT(output_activation_max, 256); + TFLITE_DCHECK_GE(output_offset, -32878); + TFLITE_DCHECK_LT(output_offset, 32768); + + const int16x8_t output_offset_vec = + vdupq_n_s16(static_cast(output_offset)); + const uint8x16_t output_activation_min_vec = + vdupq_n_u8(static_cast(output_activation_min)); + const uint8x16_t output_activation_max_vec = + vdupq_n_u8(static_cast(output_activation_max)); + + const int8* input_data_depthwise = scratch_block_data; + uint8* output_data_depthwise = output_block_data; + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_0_b; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_1_b; + int8x16_t filter_reg_2_a; + int8x16_t filter_reg_2_b; + int8x16_t filter_reg_0_a_shifted; + int8x16_t filter_reg_1_a_shifted; + int8x16_t filter_reg_2_a_shifted; + + filter_reg_0_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_0_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + + filter_reg_0_a_shifted = vshlq_n_u32(filter_reg_0_a, 8); + filter_reg_1_a_shifted = vshlq_n_u32(filter_reg_1_a, 8); + filter_reg_2_a_shifted = vshlq_n_u32(filter_reg_2_a, 8); + + if (block_height == 4) { + for (int s = 0; s < 2; ++s) { + // Work through one slice, by row, at a time. + const int8* input_data_base = input_data_depthwise + 2 * 8 * s; + uint8* output_data_base = output_data_depthwise + 4 * s; + + const int8* next_input_data = input_data_base; + uint8* output_data = output_data_base; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + bias_data += bias_increment; + + // Load first sub-micro block of data into operational banks. + int8x16_t left_bank_0_reg = vld1q_s8(next_input_data); + int8x16_t left_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + int8x16_t left_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + int8x16_t left_bank_3_reg = + vld1q_s8(next_input_data + 3 * workspace_height_stride); + int8x16_t left_bank_4_reg = + vld1q_s8(next_input_data + 4 * workspace_height_stride); + int8x16_t left_bank_5_reg = + vld1q_s8(next_input_data + 5 * workspace_height_stride); + + int32x4_t acc0; + int32x4_t acc1; + int32x4_t acc2; + int32x4_t acc3; + + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a, left_bank_2_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a, left_bank_3_reg); + + for (int i_width = 0; i_width < output_width_micro_repeats; + ++i_width) { + next_input_data += width_micro_stride; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + output_data += depth; + } + + // Load next sub-micro block of data. + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + int8x16_t right_bank_3_reg; + int8x16_t right_bank_4_reg; + int8x16_t right_bank_5_reg; + // Logic: (i_width == output_width_micro_repeats) && + // ((residual_width - 1) * stride_val < 2) + const bool no_right_block = + i_width == output_width_micro_repeats && residual_width < 3; + + if (no_right_block) { + // Only needed for santizer checks. + right_bank_0_reg = vdupq_n_s8(0); + right_bank_1_reg = vdupq_n_s8(0); + right_bank_2_reg = vdupq_n_s8(0); + right_bank_3_reg = vdupq_n_s8(0); + right_bank_4_reg = vdupq_n_s8(0); + right_bank_5_reg = vdupq_n_s8(0); + } else { + right_bank_0_reg = vld1q_s8(next_input_data); + right_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + right_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + right_bank_3_reg = + vld1q_s8(next_input_data + 3 * workspace_height_stride); + right_bank_4_reg = + vld1q_s8(next_input_data + 4 * workspace_height_stride); + right_bank_5_reg = + vld1q_s8(next_input_data + 5 * workspace_height_stride); + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a_shifted, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a_shifted, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a_shifted, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a_shifted, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a_shifted, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a_shifted, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a_shifted, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a_shifted, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a_shifted, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a_shifted, left_bank_3_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a_shifted, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a_shifted, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + left_bank_0_reg = vrev32q_u16(left_bank_0_reg); + left_bank_1_reg = vrev32q_u16(left_bank_1_reg); + left_bank_2_reg = vrev32q_u16(left_bank_2_reg); + left_bank_3_reg = vrev32q_u16(left_bank_3_reg); + left_bank_4_reg = vrev32q_u16(left_bank_4_reg); + left_bank_5_reg = vrev32q_u16(left_bank_5_reg); + vtrn1_s8x2_in_place(&left_bank_0_reg, &right_bank_0_reg); + vtrn1_s8x2_in_place(&left_bank_1_reg, &right_bank_1_reg); + vtrn1_s8x2_in_place(&left_bank_2_reg, &right_bank_2_reg); + vtrn1_s8x2_in_place(&left_bank_3_reg, &right_bank_3_reg); + vtrn1_s8x2_in_place(&left_bank_4_reg, &right_bank_4_reg); + vtrn1_s8x2_in_place(&left_bank_5_reg, &right_bank_5_reg); + + output_data += depth; + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a, left_bank_3_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + output_data += depth; + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a_shifted, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a_shifted, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a_shifted, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a_shifted, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a_shifted, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a_shifted, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a_shifted, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a_shifted, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a_shifted, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a_shifted, left_bank_3_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a_shifted, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a_shifted, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + left_bank_0_reg = right_bank_0_reg; + left_bank_1_reg = right_bank_1_reg; + left_bank_2_reg = right_bank_2_reg; + left_bank_3_reg = right_bank_3_reg; + left_bank_4_reg = right_bank_4_reg; + left_bank_5_reg = right_bank_5_reg; + + output_data += depth; + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a, left_bank_2_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a, left_bank_3_reg); + } + } + + if (residual_width > 0) { + next_input_data += width_micro_stride; + const int output_width = residual_width; + + // Load next sub-micro block of data. + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + int8x16_t right_bank_3_reg; + int8x16_t right_bank_4_reg; + int8x16_t right_bank_5_reg; + // Logic: (output_width - 1) * stride_val < 2. + const bool no_right_block = output_width < 3; + + if (no_right_block) { + // Only needed for santizer checks. + right_bank_0_reg = vdupq_n_s8(0); + right_bank_1_reg = vdupq_n_s8(0); + right_bank_2_reg = vdupq_n_s8(0); + right_bank_3_reg = vdupq_n_s8(0); + right_bank_4_reg = vdupq_n_s8(0); + right_bank_5_reg = vdupq_n_s8(0); + } else { + right_bank_0_reg = vld1q_s8(next_input_data); + right_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + right_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + right_bank_3_reg = + vld1q_s8(next_input_data + 3 * workspace_height_stride); + right_bank_4_reg = + vld1q_s8(next_input_data + 4 * workspace_height_stride); + right_bank_5_reg = + vld1q_s8(next_input_data + 5 * workspace_height_stride); + } + + // Iterate over input width shifts within 4x4 blocks. + for (int x = 0; x < output_width; ++x) { + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + biregister_rotate_8(&left_bank_0_reg, &right_bank_0_reg); + biregister_rotate_8(&left_bank_1_reg, &right_bank_1_reg); + biregister_rotate_8(&left_bank_2_reg, &right_bank_2_reg); + biregister_rotate_8(&left_bank_3_reg, &right_bank_3_reg); + biregister_rotate_8(&left_bank_4_reg, &right_bank_4_reg); + biregister_rotate_8(&left_bank_5_reg, &right_bank_5_reg); + + output_data += depth; + + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a, left_bank_2_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a, left_bank_3_reg); + } + } + input_data_base += 4 * workspace_height_stride; + output_data_base += 4 * output_height_stride; + + // Move to next sub-block: advance to second set of filters, to new + // bias. + filter_reg_0_a = filter_reg_0_b; + filter_reg_1_a = filter_reg_1_b; + filter_reg_2_a = filter_reg_2_b; + filter_reg_0_a_shifted = vshlq_n_u32(filter_reg_0_a, 8); + filter_reg_1_a_shifted = vshlq_n_u32(filter_reg_1_a, 8); + filter_reg_2_a_shifted = vshlq_n_u32(filter_reg_2_a, 8); + } + } else { + for (int s = 0; s < 2; ++s) { + // Work through one slice, by row, at a time. + const int8* input_data_base = input_data_depthwise + 2 * 8 * s; + uint8* output_data_base = output_data_depthwise + 4 * s; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + bias_data += bias_increment; + + for (int k_height = 0; k_height < block_height; ++k_height) { + const int8* next_input_data = input_data_base; + uint8* output_data = output_data_base; + + // Load first sub-micro block of data into operational banks. + int8x16_t left_bank_0_reg = vld1q_s8(next_input_data); + int8x16_t left_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + int8x16_t left_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + next_input_data += width_micro_stride; + const int output_width = + i_width == output_width_micro_repeats ? residual_width : 4; + + // Load next sub-micro block of data. + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + // Logic: (output_width - 1) * stride_val < 2. + const bool no_right_block = output_width < 3; + + if (no_right_block) { + // Only needed for santizer checks. + right_bank_0_reg = vdupq_n_s8(0); + right_bank_1_reg = vdupq_n_s8(0); + right_bank_2_reg = vdupq_n_s8(0); + } else { + right_bank_0_reg = vld1q_s8(next_input_data); + right_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + right_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + } + // Load next sub-micro block of data. + + // Iterate over input width shifts within 4x4 blocks. + for (int x = 0; x < output_width; ++x) { + int32x4_t acc = adjusted_bias_data; + acc = vdotq_s32(acc, filter_reg_0_a, left_bank_0_reg); + acc = vdotq_s32(acc, filter_reg_1_a, left_bank_1_reg); + acc = vdotq_s32(acc, filter_reg_2_a, left_bank_2_reg); + + // Fixed-point multiplication. + acc = vqrdmulhq_n_s32(acc, output_multiplier); + acc = DivideByPOT::Run( + acc, -output_shift); + // Add the output offset. + // Note that we need to fill the top half with vcombine, but can + // drop the instruction in ASM code. + int16x8_t acc_s16_0_0 = + vcombine_s16(vqmovn_s32(acc), vqmovn_s32(acc)); + acc_s16_0_0 = vqaddq_s16(acc_s16_0_0, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8_0_0 = vqmovun_s16(acc_s16_0_0); + acc_u8_0_0 = + vmax_u8(acc_u8_0_0, vget_low_u8(output_activation_min_vec)); + acc_u8_0_0 = + vmin_u8(acc_u8_0_0, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data, acc_u8_0_0, 0); + + biregister_rotate_8(&left_bank_0_reg, &right_bank_0_reg); + biregister_rotate_8(&left_bank_1_reg, &right_bank_1_reg); + biregister_rotate_8(&left_bank_2_reg, &right_bank_2_reg); + + output_data += depth; + } + } + input_data_base += workspace_height_stride; + output_data_base += output_height_stride; + } + + // Move to next sub-block: advance to second set of filters. + filter_reg_0_a = filter_reg_0_b; + filter_reg_1_a = filter_reg_1_b; + filter_reg_2_a = filter_reg_2_b; + } + } + input_data_depthwise += depth_micro_stride; + output_data_depthwise += 8; + } + } // NOLINT(readability/fn_size) Manually unrolled. + + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + KernelMacroBlockNeon(scratch_block_data, filter_workspace, bias_data, + output_block_data, function_params); + } +}; + +template <> +struct KernelMacroBlock { + static inline void KernelMacroBlockNeon( + const int8* scratch_block_data, const int8* filter_workspace, + const int32* bias_data, uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int input_width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int depth = function_params->input_depth; + const int stride_val = function_params->stride; + const int four_over_stride = function_params->four_over_stride; + + const int workspace_width_micro_repeats = + function_params->workspace_width_micro_repeats; + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + const int bias_increment = function_params->bias_increment; + + TFLITE_DCHECK(depth_micro_repeats > 0); + const int width_micro_stride = 4 * 8; + const int depth_micro_stride = + width_micro_stride * input_width_overall_micro_repeats; + + const int32 output_activation_min = + function_params->quantized_activation_min; + const int32 output_activation_max = + function_params->quantized_activation_max; + const int32 output_multiplier = function_params->output_multiplier; + const int32 output_shift = function_params->output_shift; + const int32 output_offset = function_params->output_offset; + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LT(output_activation_min, 256); + TFLITE_DCHECK_GE(output_activation_max, 0); + TFLITE_DCHECK_LT(output_activation_max, 256); + TFLITE_DCHECK_GE(output_offset, -32878); + TFLITE_DCHECK_LT(output_offset, 32768); + + // This version only does min/max on 64 bits. + const int16x8_t output_offset_vec = + vdupq_n_s16(static_cast(output_offset)); + const uint8x8_t output_activation_min_vec = + vdup_n_u8(static_cast(output_activation_min)); + const uint8x8_t output_activation_max_vec = + vdup_n_u8(static_cast(output_activation_max)); + + constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4; + + TFLITE_DCHECK_EQ(stride_val, 2); + TFLITE_DCHECK_LE(block_height, 2); + + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + const int8* filter_block = + filter_workspace + shuffled_filter_increment * j_depth; + + if (block_height == 2) { + for (int s = 0; s < 2; ++s) { + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_2_a; + + filter_reg_0_a = vld1q_s8(filter_block + s * 16); + filter_reg_1_a = vld1q_s8(filter_block + s * 16 + 32); + filter_reg_2_a = vld1q_s8(filter_block + s * 16 + 64); + + const int8* scratch_data = + scratch_block_data + depth_micro_stride * j_depth; + uint8* output_data = output_block_data + 8 * j_depth; + const int8* input_data_0 = scratch_data + s * 2 * 8; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + + // Load first sub-micro block of data into operational banks. + int8x16_t left_bank_0_reg = vld1q_s8(input_data_0); + int8x16_t left_bank_1_reg = + vld1q_s8(input_data_0 + workspace_height_stride); + int8x16_t left_bank_2_reg = + vld1q_s8(input_data_0 + 2 * workspace_height_stride); + int8x16_t left_bank_3_reg = + vld1q_s8(input_data_0 + 3 * workspace_height_stride); + int8x16_t left_bank_4_reg = + vld1q_s8(input_data_0 + 4 * workspace_height_stride); + + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + int8x16_t right_bank_3_reg; + int8x16_t right_bank_4_reg; + + int32x4_t acc0; + int32x4_t acc1; + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = i_width == output_width_micro_repeats + ? residual_width + : four_over_stride; + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = + input_data_0 + width_micro_stride * i_width; + const bool no_right_block = i_width == output_width_micro_repeats && + output_width_overall_micro_repeats == + workspace_width_micro_repeats; + + if (!no_right_block) { + // Load next sub-micro block of data. + right_bank_0_reg = vld1q_s8(input_data + width_micro_stride); + right_bank_1_reg = vld1q_s8(input_data + width_micro_stride + + workspace_height_stride); + right_bank_2_reg = vld1q_s8(input_data + width_micro_stride + + 2 * workspace_height_stride); + right_bank_3_reg = vld1q_s8(input_data + width_micro_stride + + 3 * workspace_height_stride); + right_bank_4_reg = vld1q_s8(input_data + width_micro_stride + + 4 * workspace_height_stride); + } + + uint8* output_data_base = output_data + depth * 2 * i_width + 4 * s; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_3_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_4_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1); + acc_u8 = vmax_u8(acc_u8, output_activation_min_vec); + acc_u8 = vmin_u8(acc_u8, output_activation_max_vec); + + vst1_lane_8x4(output_data_base, acc_u8, 0); + vst1_lane_8x4(output_data_base + output_height_stride, acc_u8, 1); + + left_bank_0_reg = vrev32q_u16(left_bank_0_reg); + left_bank_1_reg = vrev32q_u16(left_bank_1_reg); + left_bank_2_reg = vrev32q_u16(left_bank_2_reg); + left_bank_3_reg = vrev32q_u16(left_bank_3_reg); + left_bank_4_reg = vrev32q_u16(left_bank_4_reg); + vtrn1_s8x2_in_place(&left_bank_0_reg, &right_bank_0_reg); + vtrn1_s8x2_in_place(&left_bank_1_reg, &right_bank_1_reg); + vtrn1_s8x2_in_place(&left_bank_2_reg, &right_bank_2_reg); + vtrn1_s8x2_in_place(&left_bank_3_reg, &right_bank_3_reg); + vtrn1_s8x2_in_place(&left_bank_4_reg, &right_bank_4_reg); + } + + if (output_width > 1) { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_3_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_4_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1); + acc_u8 = vmax_u8(acc_u8, output_activation_min_vec); + acc_u8 = vmin_u8(acc_u8, output_activation_max_vec); + + vst1_lane_8x4(output_data_base + depth, acc_u8, 0); + vst1_lane_8x4(output_data_base + depth + output_height_stride, + acc_u8, 1); + + left_bank_0_reg = right_bank_0_reg; + left_bank_1_reg = right_bank_1_reg; + left_bank_2_reg = right_bank_2_reg; + left_bank_3_reg = right_bank_3_reg; + left_bank_4_reg = right_bank_4_reg; + } + } + bias_data += bias_increment; + } + } else { + for (int s = 0; s < 2; ++s) { + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_2_a; + + filter_reg_0_a = vld1q_s8(filter_block + s * 16); + filter_reg_1_a = vld1q_s8(filter_block + s * 16 + 32); + filter_reg_2_a = vld1q_s8(filter_block + s * 16 + 64); + + const int8* scratch_data = + scratch_block_data + depth_micro_stride * j_depth; + uint8* output_data = output_block_data + 8 * j_depth; + const int8* input_data_0 = scratch_data + s * 2 * 8; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + + // Load first sub-micro block of data into operational banks. + int8x16_t left_bank_0_reg = vld1q_s8(input_data_0); + int8x16_t left_bank_1_reg = + vld1q_s8(input_data_0 + workspace_height_stride); + int8x16_t left_bank_2_reg = + vld1q_s8(input_data_0 + 2 * workspace_height_stride); + + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + + int32x4_t acc0; + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = i_width == output_width_micro_repeats + ? residual_width + : four_over_stride; + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = + input_data_0 + width_micro_stride * i_width; + const bool no_right_block = i_width == output_width_micro_repeats && + output_width_overall_micro_repeats == + workspace_width_micro_repeats; + + if (!no_right_block) { + // Load next sub-micro block of data. + right_bank_0_reg = vld1q_s8(input_data + width_micro_stride); + right_bank_1_reg = vld1q_s8(input_data + width_micro_stride + + workspace_height_stride); + right_bank_2_reg = vld1q_s8(input_data + width_micro_stride + + 2 * workspace_height_stride); + } + + uint8* output_data_base = output_data + depth * 2 * i_width + 4 * s; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc0)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1); + acc_u8 = vmax_u8(acc_u8, output_activation_min_vec); + acc_u8 = vmin_u8(acc_u8, output_activation_max_vec); + + vst1_lane_8x4(output_data_base, acc_u8, 0); + + left_bank_0_reg = vrev32q_u16(left_bank_0_reg); + left_bank_1_reg = vrev32q_u16(left_bank_1_reg); + left_bank_2_reg = vrev32q_u16(left_bank_2_reg); + vtrn1_s8x2_in_place(&left_bank_0_reg, &right_bank_0_reg); + vtrn1_s8x2_in_place(&left_bank_1_reg, &right_bank_1_reg); + vtrn1_s8x2_in_place(&left_bank_2_reg, &right_bank_2_reg); + } + + if (output_width > 1) { + acc0 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc0)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1); + acc_u8 = vmax_u8(acc_u8, output_activation_min_vec); + acc_u8 = vmin_u8(acc_u8, output_activation_max_vec); + + vst1_lane_8x4(output_data_base + depth, acc_u8, 0); + + left_bank_0_reg = right_bank_0_reg; + left_bank_1_reg = right_bank_1_reg; + left_bank_2_reg = right_bank_2_reg; + } + } + bias_data += bias_increment; + } + } + } + } // NOLINT(readability/fn_size) Manually unrolled. + + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + KernelMacroBlockNeon(scratch_block_data, filter_workspace, bias_data, + output_block_data, function_params); + } +}; + +template <> +struct KernelMacroBlock { + static inline void KernelMacroBlockNeon( + const int8* scratch_block_data, const int8* filter_workspace, + const int32* bias_data, uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + TFLITE_DCHECK_EQ(function_params->stride, 1); + const int workspace_height_stride = + function_params->workspace_height_stride; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int output_depth = function_params->output_depth; + + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + const int bias_increment = function_params->bias_increment; + + TFLITE_DCHECK(depth_micro_repeats > 0); + + TFLITE_DCHECK_EQ(bias_increment, 4); + + const int32 output_activation_min = + function_params->quantized_activation_min; + const int32 output_activation_max = + function_params->quantized_activation_max; + const int32 output_multiplier = function_params->output_multiplier; + const int32 output_shift = function_params->output_shift; + const int32 output_offset = function_params->output_offset; + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LT(output_activation_min, 256); + TFLITE_DCHECK_GE(output_activation_max, 0); + TFLITE_DCHECK_LT(output_activation_max, 256); + TFLITE_DCHECK_GE(output_offset, -32878); + TFLITE_DCHECK_LT(output_offset, 32768); + + const int16x8_t output_offset_vec = + vdupq_n_s16(static_cast(output_offset)); + const uint8x16_t output_activation_min_vec = + vdupq_n_u8(static_cast(output_activation_min)); + const uint8x16_t output_activation_max_vec = + vdupq_n_u8(static_cast(output_activation_max)); + + uint8* output_data_depthwise = output_block_data; + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_0_b; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_1_b; + int8x16_t filter_reg_2_a; + int8x16_t filter_reg_2_b; + int8x16_t filter_reg_0_a_shifted; + int8x16_t filter_reg_1_a_shifted; + int8x16_t filter_reg_2_a_shifted; + + filter_reg_0_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_0_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + + filter_reg_0_a_shifted = vshlq_n_u32(filter_reg_0_a, 8); + filter_reg_1_a_shifted = vshlq_n_u32(filter_reg_1_a, 8); + filter_reg_2_a_shifted = vshlq_n_u32(filter_reg_2_a, 8); + + if (block_height == 4) { + for (int s = 0; s < 2; ++s) { + // Work through one slice, by row, at a time. + uint8* output_data_base = output_data_depthwise + 4 * s; + + const int8* next_input_data = scratch_block_data; + uint8* output_data = output_data_base; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + bias_data += bias_increment; + + int8x16_t input_bank_a_reg; // left 0, right 0, left 1, right 1. + int8x16_t input_bank_b_reg; // left 2, right 2, left 3, right 3. + int8x16_t input_bank_c_reg; // left 4, right 4, left 5, right 5. + + // Load first sub-micro block of data into operational banks. + input_bank_a_reg = + vld1q_dup_s8x4(next_input_data); // Load lane 0, avoiding + // uninitialized variable. + input_bank_a_reg = vld1q_lane_8x4( + next_input_data + workspace_height_stride, input_bank_a_reg, 2); + input_bank_b_reg = vld1q_dup_s8x4( + next_input_data + + 2 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 3 * workspace_height_stride, + input_bank_b_reg, 2); + input_bank_c_reg = vld1q_dup_s8x4( + next_input_data + + 4 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 5 * workspace_height_stride, + input_bank_c_reg, 2); + + int32x4_t acc0; + int32x4_t acc1; + int32x4_t acc2; + int32x4_t acc3; + + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a, input_bank_b_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a, input_bank_b_reg, 2); + + for (int i_width = 0; i_width < output_width_micro_repeats; + ++i_width) { + next_input_data += 4; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, + 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a, input_bank_c_reg, + 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + output_data += output_depth; + } + // Load next sub-micro block of data. + input_bank_a_reg = + vld1q_lane_8x4(next_input_data, input_bank_a_reg, 1); + input_bank_a_reg = vld1q_lane_8x4( + next_input_data + workspace_height_stride, input_bank_a_reg, 3); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 2 * workspace_height_stride, + input_bank_b_reg, 1); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 3 * workspace_height_stride, + input_bank_b_reg, 3); + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 4 * workspace_height_stride, + input_bank_c_reg, 1); + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 5 * workspace_height_stride, + input_bank_c_reg, 3); + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a_shifted, + input_bank_a_reg, 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a_shifted, + input_bank_a_reg, 2); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a_shifted, + input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a_shifted, + input_bank_a_reg, 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a_shifted, + input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a_shifted, + input_bank_b_reg, 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a_shifted, + input_bank_b_reg, 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a_shifted, + input_bank_b_reg, 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a_shifted, + input_bank_c_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a_shifted, + input_bank_b_reg, 2); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a_shifted, + input_bank_c_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a_shifted, + input_bank_c_reg, 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 16); + + output_data += output_depth; + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, + 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, + 2); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, + 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, + 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a, input_bank_b_reg, + 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a, input_bank_b_reg, + 2); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a, input_bank_c_reg, + 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + output_data += output_depth; + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a_shifted, + input_bank_a_reg, 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a_shifted, + input_bank_a_reg, 2); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a_shifted, + input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a_shifted, + input_bank_a_reg, 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a_shifted, + input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a_shifted, + input_bank_b_reg, 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a_shifted, + input_bank_b_reg, 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a_shifted, + input_bank_b_reg, 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a_shifted, + input_bank_c_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a_shifted, + input_bank_b_reg, 2); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a_shifted, + input_bank_c_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a_shifted, + input_bank_c_reg, 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 16); + + output_data += output_depth; + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, + 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, + 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a, input_bank_b_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a, input_bank_b_reg, + 2); + } + } + + if (residual_width > 0) { + next_input_data += 4; + const int output_width = residual_width; + + // Load next sub-micro block of data. + input_bank_a_reg = + vld1q_lane_8x4(next_input_data, input_bank_a_reg, 1); + input_bank_a_reg = vld1q_lane_8x4( + next_input_data + workspace_height_stride, input_bank_a_reg, 3); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 2 * workspace_height_stride, + input_bank_b_reg, 1); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 3 * workspace_height_stride, + input_bank_b_reg, 3); + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 4 * workspace_height_stride, + input_bank_c_reg, 1); + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 5 * workspace_height_stride, + input_bank_c_reg, 3); + + // Iterate over input width shifts within 4x4 blocks. + for (int x = 0; x < output_width; ++x) { + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, + 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a, input_bank_c_reg, + 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 8); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 8); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 8); + + output_data += output_depth; + + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, + 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, + 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a, input_bank_b_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a, input_bank_b_reg, + 2); + } + } + // scratch_block_data += 4 * workspace_height_stride; + output_data_base += 4 * output_height_stride; + + // Move to next sub-block: advance to second set of filters, to new + // bias. + filter_reg_0_a = filter_reg_0_b; + filter_reg_1_a = filter_reg_1_b; + filter_reg_2_a = filter_reg_2_b; + filter_reg_0_a_shifted = vshlq_n_u32(filter_reg_0_a, 8); + filter_reg_1_a_shifted = vshlq_n_u32(filter_reg_1_a, 8); + filter_reg_2_a_shifted = vshlq_n_u32(filter_reg_2_a, 8); + } + } else { + // Block height < 4. + for (int s = 0; s < 2; ++s) { + // Work through one slice, by row, at a time. + uint8* output_data_base = output_data_depthwise + 4 * s; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + bias_data += bias_increment; + + for (int k_height = 0; k_height < block_height; ++k_height) { + const int8* next_input_data = + scratch_block_data + k_height * workspace_height_stride; + uint8* output_data = output_data_base; + + int8x16_t input_bank_a_reg; // left 0, right 0, left 1, right 1. + int8x16_t input_bank_b_reg; // left 2, right 2, left 3, right 3. + + // Load first sub-micro block of data into operational banks. + input_bank_a_reg = + vld1q_dup_s8x4(next_input_data); // Load lane 0, avoiding + // uninitialized variable. + input_bank_a_reg = vld1q_lane_8x4( + next_input_data + workspace_height_stride, input_bank_a_reg, 2); + input_bank_b_reg = vld1q_dup_s8x4( + next_input_data + + 2 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + next_input_data += 4; + const int output_width = + i_width == output_width_micro_repeats ? residual_width : 4; + + // Load next sub-micro block of data. + input_bank_a_reg = + vld1q_lane_8x4(next_input_data, input_bank_a_reg, 1); + input_bank_a_reg = + vld1q_lane_8x4(next_input_data + workspace_height_stride, + input_bank_a_reg, 3); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 2 * workspace_height_stride, + input_bank_b_reg, 1); + // Iterate over input width shifts within 4x4 blocks. + for (int x = 0; x < output_width; ++x) { + int32x4_t acc = adjusted_bias_data; + acc = vdotq_four_lane_s32(acc, filter_reg_0_a, input_bank_a_reg, + 0); + acc = vdotq_four_lane_s32(acc, filter_reg_1_a, input_bank_a_reg, + 2); + acc = vdotq_four_lane_s32(acc, filter_reg_2_a, input_bank_b_reg, + 0); + + // Fixed-point multiplication. + acc = vqrdmulhq_n_s32(acc, output_multiplier); + acc = DivideByPOT::Run( + acc, -output_shift); + // Add the output offset. + // Note that we need to fill the top half with vcombine, but can + // drop the instruction in ASM code. + int16x8_t acc_s16_0_0 = + vcombine_s16(vqmovn_s32(acc), vqmovn_s32(acc)); + acc_s16_0_0 = vqaddq_s16(acc_s16_0_0, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8_0_0 = vqmovun_s16(acc_s16_0_0); + acc_u8_0_0 = + vmax_u8(acc_u8_0_0, vget_low_u8(output_activation_min_vec)); + acc_u8_0_0 = + vmin_u8(acc_u8_0_0, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data, acc_u8_0_0, 0); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 8); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 8); + + output_data += output_depth; + } + } + output_data_base += output_height_stride; + } + + // Move to next sub-block: advance to second set of filters. + filter_reg_0_a = filter_reg_0_b; + filter_reg_1_a = filter_reg_1_b; + filter_reg_2_a = filter_reg_2_b; + } + } + output_data_depthwise += 8; + } + } // NOLINT(readability/fn_size) Manually unrolled. + + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + KernelMacroBlockNeon(scratch_block_data, filter_workspace, bias_data, + output_block_data, function_params); + } +}; + +template <> +struct KernelMacroBlock { + static inline void KernelMacroBlockNeon( + const int8* scratch_block_data, const int8* filter_workspace, + const int32* bias_data, uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int output_depth = function_params->output_depth; + const int stride_val = function_params->stride; + const int four_over_stride = function_params->four_over_stride; + + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + const int bias_increment = function_params->bias_increment; + + const int32 output_activation_min = + function_params->quantized_activation_min; + const int32 output_activation_max = + function_params->quantized_activation_max; + const int32 output_multiplier = function_params->output_multiplier; + const int32 output_shift = function_params->output_shift; + const int32 output_offset = function_params->output_offset; + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LT(output_activation_min, 256); + TFLITE_DCHECK_GE(output_activation_max, 0); + TFLITE_DCHECK_LT(output_activation_max, 256); + TFLITE_DCHECK_GE(output_offset, -32878); + TFLITE_DCHECK_LT(output_offset, 32768); + + TFLITE_DCHECK_GE(depth_micro_repeats, 1); + TFLITE_DCHECK_EQ(bias_increment, 4); + + const int16x8_t output_offset_vec = + vdupq_n_s16(static_cast(output_offset)); + const uint8x16_t output_activation_min_vec = + vdupq_n_u8(static_cast(output_activation_min)); + const uint8x16_t output_activation_max_vec = + vdupq_n_u8(static_cast(output_activation_max)); + + for (int j_depth = 0; j_depth < (depth_micro_repeats * 1 + 0); ++j_depth) { + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_0_b; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_1_b; + int8x16_t filter_reg_2_a; + int8x16_t filter_reg_2_b; + + filter_reg_0_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_0_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + + TFLITE_DCHECK_EQ(bias_increment, 4); + const int32x4_t adjusted_bias_data_s_0 = vld1q_s32(bias_data); + bias_data += bias_increment; + const int32x4_t adjusted_bias_data_s_1 = vld1q_s32(bias_data); + bias_data += bias_increment; + + if (block_height == 2) { + const int8* scratch_data = scratch_block_data; + uint8* output_data = output_block_data + 8 * j_depth; + + int8x16_t input_bank_a_reg; // left 0, right 0, left 1, right 1. + int8x16_t input_bank_b_reg; // left 2, right 2, left 3, right 3. + int8x16_t input_bank_c_reg; // left 4, right 4, xxx, xxx. + + // Load first sub-micro block of data into operational banks. + input_bank_a_reg = + vld1q_dup_s8x4(scratch_data); // Load lane 0, avoiding + // uninitialized variable. + input_bank_a_reg = vld1q_lane_8x4( + scratch_data + workspace_height_stride, input_bank_a_reg, 2); + input_bank_b_reg = vld1q_dup_s8x4( + scratch_data + + 2 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + input_bank_b_reg = vld1q_lane_8x4( + scratch_data + 3 * workspace_height_stride, input_bank_b_reg, 2); + input_bank_c_reg = vld1q_dup_s8x4( + scratch_data + + 4 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + + int32x4_t acc0; + int32x4_t acc1; + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = i_width == output_width_micro_repeats + ? residual_width + : four_over_stride; + + TFLITE_DCHECK_LE(output_width, 2); + TFLITE_DCHECK_GE(output_width, 1); + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = scratch_data + 4 + 4 * i_width; + + // Load next sub-micro block of data. + input_bank_a_reg = vld1q_lane_8x4(input_data, input_bank_a_reg, 1); + input_bank_a_reg = vld1q_lane_8x4( + input_data + workspace_height_stride, input_bank_a_reg, 3); + input_bank_b_reg = vld1q_lane_8x4( + input_data + 2 * workspace_height_stride, input_bank_b_reg, 1); + input_bank_b_reg = vld1q_lane_8x4( + input_data + 3 * workspace_height_stride, input_bank_b_reg, 3); + input_bank_c_reg = vld1q_lane_8x4( + input_data + 4 * workspace_height_stride, input_bank_c_reg, 1); + + int16x8_t acc_s16_0_1; + uint8x8_t acc_u8_0_1; + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = adjusted_bias_data_s_0; + acc1 = adjusted_bias_data_s_0; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, 2); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, 2); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_c_reg, 0); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data, acc_u8_0_1, 0); + vst1_lane_8x4(output_data + output_height_stride, acc_u8_0_1, 1); + + acc0 = adjusted_bias_data_s_1; + acc1 = adjusted_bias_data_s_1; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_b, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_b, input_bank_a_reg, 2); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_b, input_bank_b_reg, 2); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_b, input_bank_c_reg, 0); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data + 4, acc_u8_0_1, 0); + vst1_lane_8x4(output_data + 4 + output_height_stride, acc_u8_0_1, + 1); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 16); + + output_data += output_depth; + } + if (output_width == 2) { + acc0 = adjusted_bias_data_s_0; + acc1 = adjusted_bias_data_s_0; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, 2); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, 2); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_c_reg, 0); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data, acc_u8_0_1, 0); + vst1_lane_8x4(output_data + output_height_stride, acc_u8_0_1, 1); + + acc0 = adjusted_bias_data_s_1; + acc1 = adjusted_bias_data_s_1; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_b, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_b, input_bank_a_reg, 2); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_b, input_bank_b_reg, 2); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_b, input_bank_c_reg, 0); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data + 4, acc_u8_0_1, 0); + vst1_lane_8x4(output_data + 4 + output_height_stride, acc_u8_0_1, + 1); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 16); + + output_data += output_depth; + } + } + } else { + TFLITE_DCHECK_EQ(block_height, 1); + // Work through one slice, by row, at a time. + const int8* scratch_data = scratch_block_data; + uint8* output_data = output_block_data + 8 * j_depth; + + // + int8x16_t input_bank_a_reg; // left 0, right 0, left 1, right 1. + int8x16_t input_bank_b_reg; // left 2, right 2, xxx, xxx. + + // Load first sub-micro block of data into operational banks. + input_bank_a_reg = + vld1q_dup_s8x4(scratch_data); // Load lane 0, avoiding + // uninitialized variable. + input_bank_a_reg = vld1q_lane_8x4( + scratch_data + workspace_height_stride, input_bank_a_reg, 2); + input_bank_b_reg = vld1q_dup_s8x4( + scratch_data + + 2 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + + int32x4_t acc0; + int32x4_t acc1; + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = + i_width == output_width_micro_repeats ? residual_width : 2; + + TFLITE_DCHECK_LE(output_width, 2); + TFLITE_DCHECK_GE(output_width, 1); + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = scratch_data + 4 + 4 * i_width; + + // Load next sub-micro block of data. + input_bank_a_reg = vld1q_lane_8x4(input_data, input_bank_a_reg, 1); + input_bank_a_reg = vld1q_lane_8x4( + input_data + workspace_height_stride, input_bank_a_reg, 3); + input_bank_b_reg = vld1q_lane_8x4( + input_data + 2 * workspace_height_stride, input_bank_b_reg, 1); + + int16x8_t acc_s16_0_1; + uint8x8_t acc_u8_0_1; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = adjusted_bias_data_s_0; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, 2); + + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + + // Second sub-block accumulation. + acc1 = adjusted_bias_data_s_1; + + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_b, input_bank_a_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_b, input_bank_a_reg, 2); + + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + // This stores the results for both sub-blocks together. + vst1_u8(output_data, acc_u8_0_1); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + + output_data += output_depth; + } + if (output_width == 2) { + acc0 = adjusted_bias_data_s_0; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, 2); + + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + + // Second sub-block accumulation. + acc1 = adjusted_bias_data_s_1; + + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_b, input_bank_a_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_b, input_bank_a_reg, 2); + + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + // This stores the results for both sub-blocks together. + vst1_u8(output_data, acc_u8_0_1); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + + output_data += output_depth; + } + } + } + } + } + + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + KernelMacroBlockNeon(scratch_block_data, filter_workspace, bias_data, + output_block_data, function_params); + } +}; + +#endif // USE_NEON && __aarch64__ // Top-level implementation function for 3x3 depthwise convolution using NEON // dot-product instructions.