From ad32377d85e58ed071df46d73e45fc28c8a4f258 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Tue, 18 Jun 2019 19:47:05 -0700 Subject: [PATCH] Fix depthwise flax vector PiperOrigin-RevId: 253918605 --- .../depthwiseconv_uint8_3x3_filter.h | 184 ++++++++++-------- 1 file changed, 100 insertions(+), 84 deletions(-) diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index 778bf28b70a..0ef1dcc9b20 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -32,9 +32,10 @@ namespace depthwise_conv { // 4 8-bit lanes together. So these are treated much like 32-bit loads and // 32-bit stores. Stores require 32-bit alignment. -#define vst1_lane_8x4(dst, reg, lane_num) \ - TFLITE_DCHECK_EQ(reinterpret_cast(dst) % 4, 0); \ - vst1_lane_u32(reinterpret_cast(dst), reg, lane_num) +#define vst1_lane_8x4(dst, reg, lane_num) \ + TFLITE_DCHECK_EQ(reinterpret_cast(dst) % 4, 0); \ + vst1_lane_s32(reinterpret_cast(dst), vreinterpret_s32_s8(reg), \ + lane_num) #define vst1q_lane_8x4(dst, reg, lane_num) \ TFLITE_DCHECK_EQ(reinterpret_cast(dst) % 4, 0); \ vst1q_lane_u32(reinterpret_cast(dst), reg, lane_num) @@ -42,10 +43,12 @@ namespace depthwise_conv { // Important! Most compilation configurations will compile and run without // reinterpret_cast. Sanitizers may fail silently on lane-loading, with an // obscure bug or mis-feature probably in unhygienic macro expansion. -#define vld1q_lane_s8x8(src, reg, lane_num) \ - vld1q_lane_u64(reinterpret_cast(src), reg, lane_num) -#define vld1_lane_8x4(src, reg, lane_num) \ - vld1_lane_s32(reinterpret_cast(src), reg, lane_num) +#define vld1q_lane_s8x8(src, reg, lane_num) \ + vreinterpretq_s8_s64(vld1q_lane_s64(reinterpret_cast(src), \ + vreinterpretq_s64_s8(reg), lane_num)) +#define vld1_lane_8x4(src, reg, lane_num) \ + vreinterpret_s8_s32(vld1_lane_s32(reinterpret_cast(src), \ + vreinterpret_s32_s8(reg), lane_num)) #define vld1q_lane_8x4(src, reg, lane_num) \ vld1q_lane_s32(reinterpret_cast(src), reg, lane_num) #define vld1q_dup_s8x4(src) vld1q_dup_s32(reinterpret_cast(src)) @@ -5949,13 +5952,14 @@ struct PackMacroBlock(input_block_data); int8x16_t input_data_a; int8x16_t input_data_b; int8x16_t input_data_c; @@ -5978,10 +5982,10 @@ struct PackMacroBlock 1) { input_data_b = vld1q_lane_s8x8(input_data_0 + input_depth, input_data_b, 0); @@ -6187,7 +6191,7 @@ struct PackMacroBlock(input_block_data); int8x16_t input_data_a; int8x16_t input_data_b; int8x16_t input_data_c; @@ -6241,10 +6246,10 @@ struct PackMacroBlock 0) { input_data_a = vld1q_lane_s8x8(input_data_0, input_data_a, 0); if (adjusted_residual_width > 1) { @@ -6386,10 +6391,10 @@ struct PackMacroBlock 1) { input_data_b = vld1q_lane_s8x8(input_data_0 + input_depth, @@ -6637,14 +6642,15 @@ struct PackMacroBlock= 16) { const int copy_remaining = (copy_size + start_width) & 0x7; - padding_mask = vshl_u64(padding_mask, vdup_n_s64(8 * copy_remaining)); + padding_mask = vreinterpret_s8_s64(vshl_s64( + vreinterpret_s64_s8(padding_mask), vdup_n_s64(8 * copy_remaining))); for (int k_height = 0; k_height < copy_block_height; ++k_height) { // Work through one slice, by row, at a time. @@ -6656,7 +6662,8 @@ struct PackMacroBlock( + input_block_data + input_block_offset)); work_reg = vextq_s8(padding_reg, work_reg, 15); work_reg = veorq_s8(work_reg, sign_bit); optimized_ops_prefetch_write_l1_keep(scratch_data); @@ -6666,8 +6673,8 @@ struct PackMacroBlock( + input_block_data + input_block_offset + copy_done)); work_reg = veorq_s8(work_reg, sign_bit); TFLITE_DCHECK_EQ((start_width + copy_done) % 16, 0); optimized_ops_prefetch_write_l1_keep(scratch_data + start_width + @@ -6676,8 +6683,8 @@ struct PackMacroBlock( + input_block_data + input_block_offset + copy_done)); half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); TFLITE_DCHECK_EQ((start_width + copy_done) % 8, 0); optimized_ops_prefetch_write_l1_keep(scratch_data + start_width + @@ -6698,13 +6705,14 @@ struct PackMacroBlock( + input_block_data + input_block_offset + copy_size - 8)); - half_work_reg = - vshl_u64(half_work_reg, vdup_n_s64(-8 * (8 - copy_remaining))); - half_work_reg = - vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg); + half_work_reg = vreinterpret_s8_s64( + vshl_s64(vreinterpret_s64_s8(half_work_reg), + vdup_n_s64(-8 * (8 - copy_remaining)))); + half_work_reg = vbsl_s8(vreinterpret_u8_s8(padding_mask), + vget_low_s8(padding_reg), half_work_reg); half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); TFLITE_DCHECK_EQ((start_width + copy_done) % 8, 0); @@ -6726,7 +6734,8 @@ struct PackMacroBlock= 4) { const int copy_remaining = (copy_size + start_width) & 0x3; - padding_mask = vshl_u64(padding_mask, vdup_n_s64(8 * copy_remaining)); + padding_mask = vreinterpret_s8_s64(vshl_s64( + vreinterpret_s64_s8(padding_mask), vdup_n_s64(8 * copy_remaining))); for (int k_height = 0; k_height < copy_block_height; ++k_height) { // Work through one slice, by row, at a time. @@ -6777,10 +6786,11 @@ struct PackMacroBlock( input_block_data + input_block_offset), half_work_reg, 1); @@ -6854,24 +6864,27 @@ struct PackMacroBlock( input_block_data + input_block_offset + copy_size - 1 - i), half_work_reg, 0); } if (leading_width_padding) { - half_work_reg = vshl_n_s64(half_work_reg, 8); + half_work_reg = vreinterpret_s8_s64( + vshl_n_s64(vreinterpret_s64_s8(half_work_reg), 8)); } - half_work_reg = - vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg); + half_work_reg = vbsl_s8(vreinterpret_u8_s8(padding_mask), + vget_low_s8(padding_reg), half_work_reg); half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); TFLITE_DCHECK_EQ(scratch_data_offset % 4, 0); @@ -6976,7 +6989,7 @@ struct PackMacroBlock= 16) { @@ -6990,8 +7003,8 @@ struct PackMacroBlock( + input_block_data + input_block_offset + copy_done)); work_reg = veorq_s8(work_reg, sign_bit); TFLITE_DCHECK_EQ(copy_done % 16, 0); optimized_ops_prefetch_write_l1_keep(scratch_data + copy_done); @@ -6999,8 +7012,8 @@ struct PackMacroBlock( + input_block_data + input_block_offset + copy_done)); half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); TFLITE_DCHECK_EQ(copy_done % 8, 0); optimized_ops_prefetch_write_l1_keep(scratch_data + copy_done); @@ -7020,11 +7033,12 @@ struct PackMacroBlock( + input_block_data + input_block_offset + copy_size - 8)); - half_work_reg = - vshl_u64(half_work_reg, vdup_n_s64(-8 * (8 - copy_remaining))); + half_work_reg = vreinterpret_s8_s64( + vshl_s64(vreinterpret_s64_s8(half_work_reg), + vdup_n_s64(-8 * (8 - copy_remaining)))); half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); TFLITE_DCHECK_EQ(copy_done % 8, 0); @@ -7079,8 +7093,9 @@ struct PackMacroBlock( input_block_data + input_block_offset + copy_size - 1 - i),