Fix depthwise flax vector
PiperOrigin-RevId: 253918605
This commit is contained in:
parent
1fef4919bb
commit
ad32377d85
@ -32,9 +32,10 @@ namespace depthwise_conv {
|
|||||||
// 4 8-bit lanes together. So these are treated much like 32-bit loads and
|
// 4 8-bit lanes together. So these are treated much like 32-bit loads and
|
||||||
// 32-bit stores. Stores require 32-bit alignment.
|
// 32-bit stores. Stores require 32-bit alignment.
|
||||||
|
|
||||||
#define vst1_lane_8x4(dst, reg, lane_num) \
|
#define vst1_lane_8x4(dst, reg, lane_num) \
|
||||||
TFLITE_DCHECK_EQ(reinterpret_cast<std::uintptr_t>(dst) % 4, 0); \
|
TFLITE_DCHECK_EQ(reinterpret_cast<std::uintptr_t>(dst) % 4, 0); \
|
||||||
vst1_lane_u32(reinterpret_cast<uint32_t*>(dst), reg, lane_num)
|
vst1_lane_s32(reinterpret_cast<int32_t*>(dst), vreinterpret_s32_s8(reg), \
|
||||||
|
lane_num)
|
||||||
#define vst1q_lane_8x4(dst, reg, lane_num) \
|
#define vst1q_lane_8x4(dst, reg, lane_num) \
|
||||||
TFLITE_DCHECK_EQ(reinterpret_cast<std::uintptr_t>(dst) % 4, 0); \
|
TFLITE_DCHECK_EQ(reinterpret_cast<std::uintptr_t>(dst) % 4, 0); \
|
||||||
vst1q_lane_u32(reinterpret_cast<uint32_t*>(dst), reg, lane_num)
|
vst1q_lane_u32(reinterpret_cast<uint32_t*>(dst), reg, lane_num)
|
||||||
@ -42,10 +43,12 @@ namespace depthwise_conv {
|
|||||||
// Important! Most compilation configurations will compile and run without
|
// Important! Most compilation configurations will compile and run without
|
||||||
// reinterpret_cast. Sanitizers may fail silently on lane-loading, with an
|
// reinterpret_cast. Sanitizers may fail silently on lane-loading, with an
|
||||||
// obscure bug or mis-feature probably in unhygienic macro expansion.
|
// obscure bug or mis-feature probably in unhygienic macro expansion.
|
||||||
#define vld1q_lane_s8x8(src, reg, lane_num) \
|
#define vld1q_lane_s8x8(src, reg, lane_num) \
|
||||||
vld1q_lane_u64(reinterpret_cast<const uint64_t*>(src), reg, lane_num)
|
vreinterpretq_s8_s64(vld1q_lane_s64(reinterpret_cast<const int64_t*>(src), \
|
||||||
#define vld1_lane_8x4(src, reg, lane_num) \
|
vreinterpretq_s64_s8(reg), lane_num))
|
||||||
vld1_lane_s32(reinterpret_cast<const int32*>(src), reg, lane_num)
|
#define vld1_lane_8x4(src, reg, lane_num) \
|
||||||
|
vreinterpret_s8_s32(vld1_lane_s32(reinterpret_cast<const int32*>(src), \
|
||||||
|
vreinterpret_s32_s8(reg), lane_num))
|
||||||
#define vld1q_lane_8x4(src, reg, lane_num) \
|
#define vld1q_lane_8x4(src, reg, lane_num) \
|
||||||
vld1q_lane_s32(reinterpret_cast<const int32*>(src), reg, lane_num)
|
vld1q_lane_s32(reinterpret_cast<const int32*>(src), reg, lane_num)
|
||||||
#define vld1q_dup_s8x4(src) vld1q_dup_s32(reinterpret_cast<const int32*>(src))
|
#define vld1q_dup_s8x4(src) vld1q_dup_s32(reinterpret_cast<const int32*>(src))
|
||||||
@ -5949,13 +5952,14 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
int8x16_t work_reg_b;
|
int8x16_t work_reg_b;
|
||||||
|
|
||||||
// Effect subtraction of zero-point = 128 by XOR of sign bit.
|
// Effect subtraction of zero-point = 128 by XOR of sign bit.
|
||||||
const uint8x16_t sign_bit = vdupq_n_u8(kSignBit);
|
const int8x16_t sign_bit = vdupq_n_s8(kSignBit);
|
||||||
|
|
||||||
// Work through one slice, by row, at a time.
|
// Work through one slice, by row, at a time.
|
||||||
int8* scratch_data_0 = scratch_block_data;
|
int8* scratch_data_0 = scratch_block_data;
|
||||||
|
|
||||||
for (int k_height = 0; k_height < block_height; ++k_height) {
|
for (int k_height = 0; k_height < block_height; ++k_height) {
|
||||||
const uint8* input_data_0 = input_block_data;
|
const int8* input_data_0 =
|
||||||
|
reinterpret_cast<const int8*>(input_block_data);
|
||||||
int8x16_t input_data_a;
|
int8x16_t input_data_a;
|
||||||
int8x16_t input_data_b;
|
int8x16_t input_data_b;
|
||||||
int8x16_t input_data_c;
|
int8x16_t input_data_c;
|
||||||
@ -5978,10 +5982,10 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
|
|
||||||
//
|
//
|
||||||
|
|
||||||
input_data_a = vld1q_u8(input_data_0);
|
input_data_a = vld1q_s8(input_data_0);
|
||||||
input_data_b = vld1q_u8(input_data_0 + 1 * input_depth);
|
input_data_b = vld1q_s8(input_data_0 + 1 * input_depth);
|
||||||
input_data_c = vld1q_u8(input_data_0 + 2 * input_depth);
|
input_data_c = vld1q_s8(input_data_0 + 2 * input_depth);
|
||||||
input_data_d = vld1q_u8(input_data_0 + 3 * input_depth);
|
input_data_d = vld1q_s8(input_data_0 + 3 * input_depth);
|
||||||
input_data_0 += 16;
|
input_data_0 += 16;
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -5997,8 +6001,8 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
work_reg_b_sp = vzip2q_s8(input_data_c, input_data_d);
|
work_reg_b_sp = vzip2q_s8(input_data_c, input_data_d);
|
||||||
vzipq_s8x2_in_place(&work_reg_a_sp, &work_reg_b_sp);
|
vzipq_s8x2_in_place(&work_reg_a_sp, &work_reg_b_sp);
|
||||||
|
|
||||||
input_data_a = vld1q_u8(input_data_0);
|
input_data_a = vld1q_s8(input_data_0);
|
||||||
input_data_b = vld1q_u8(input_data_0 + 1 * input_depth);
|
input_data_b = vld1q_s8(input_data_0 + 1 * input_depth);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
|
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
|
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
|
||||||
vst1q_s8(scratch_data_0, work_reg_a);
|
vst1q_s8(scratch_data_0, work_reg_a);
|
||||||
@ -6009,8 +6013,8 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
work_reg_a_sp = veorq_s8(work_reg_a_sp, sign_bit);
|
work_reg_a_sp = veorq_s8(work_reg_a_sp, sign_bit);
|
||||||
work_reg_b_sp = veorq_s8(work_reg_b_sp, sign_bit);
|
work_reg_b_sp = veorq_s8(work_reg_b_sp, sign_bit);
|
||||||
|
|
||||||
input_data_c = vld1q_u8(input_data_0 + 2 * input_depth);
|
input_data_c = vld1q_s8(input_data_0 + 2 * input_depth);
|
||||||
input_data_d = vld1q_u8(input_data_0 + 3 * input_depth);
|
input_data_d = vld1q_s8(input_data_0 + 3 * input_depth);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
|
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
|
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
|
||||||
vst1q_s8(scratch_data_0, work_reg_a_sp);
|
vst1q_s8(scratch_data_0, work_reg_a_sp);
|
||||||
@ -6082,9 +6086,9 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
TFLITE_DCHECK_GT(residual_width, 0);
|
TFLITE_DCHECK_GT(residual_width, 0);
|
||||||
TFLITE_DCHECK_LT(residual_width, 4);
|
TFLITE_DCHECK_LT(residual_width, 4);
|
||||||
for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) {
|
for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) {
|
||||||
input_data_c = vdupq_n_u8(kSignBit);
|
input_data_c = vdupq_n_s8(kSignBit);
|
||||||
input_data_a = vld1q_lane_s8x8(input_data_0, input_data_a, 0);
|
input_data_a = vld1q_lane_s8x8(input_data_0, input_data_a, 0);
|
||||||
input_data_d = vdupq_n_u8(kSignBit);
|
input_data_d = vdupq_n_s8(kSignBit);
|
||||||
if (residual_width > 1) {
|
if (residual_width > 1) {
|
||||||
input_data_b =
|
input_data_b =
|
||||||
vld1q_lane_s8x8(input_data_0 + input_depth, input_data_b, 0);
|
vld1q_lane_s8x8(input_data_0 + input_depth, input_data_b, 0);
|
||||||
@ -6187,7 +6191,7 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
int8x16_t work_reg_b;
|
int8x16_t work_reg_b;
|
||||||
|
|
||||||
// Effect subtraction of zero-point = 128 by XOR of sign bit.
|
// Effect subtraction of zero-point = 128 by XOR of sign bit.
|
||||||
const uint8x16_t sign_bit = vdupq_n_u8(kSignBit);
|
const int8x16_t sign_bit = vdupq_n_s8(kSignBit);
|
||||||
|
|
||||||
// Work through one slice, by row, at a time.
|
// Work through one slice, by row, at a time.
|
||||||
int8* scratch_data_0 = scratch_block_data;
|
int8* scratch_data_0 = scratch_block_data;
|
||||||
@ -6204,7 +6208,8 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int k_height = 0; k_height < copy_block_height; ++k_height) {
|
for (int k_height = 0; k_height < copy_block_height; ++k_height) {
|
||||||
const uint8* input_data_0 = input_block_data;
|
const int8* input_data_0 =
|
||||||
|
reinterpret_cast<const int8*>(input_block_data);
|
||||||
int8x16_t input_data_a;
|
int8x16_t input_data_a;
|
||||||
int8x16_t input_data_b;
|
int8x16_t input_data_b;
|
||||||
int8x16_t input_data_c;
|
int8x16_t input_data_c;
|
||||||
@ -6241,10 +6246,10 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
|
|
||||||
//
|
//
|
||||||
|
|
||||||
input_data_a = vld1q_u8(input_data_0);
|
input_data_a = vld1q_s8(input_data_0);
|
||||||
input_data_b = vld1q_u8(input_data_0 + 1 * input_depth);
|
input_data_b = vld1q_s8(input_data_0 + 1 * input_depth);
|
||||||
input_data_c = vld1q_u8(input_data_0 + 2 * input_depth);
|
input_data_c = vld1q_s8(input_data_0 + 2 * input_depth);
|
||||||
input_data_d = vld1q_u8(input_data_0 + 3 * input_depth);
|
input_data_d = vld1q_s8(input_data_0 + 3 * input_depth);
|
||||||
input_data_0 += 16;
|
input_data_0 += 16;
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -6260,8 +6265,8 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
work_reg_b_sp = vzip2q_s8(input_data_c, input_data_d);
|
work_reg_b_sp = vzip2q_s8(input_data_c, input_data_d);
|
||||||
vzipq_s8x2_in_place(&work_reg_a_sp, &work_reg_b_sp);
|
vzipq_s8x2_in_place(&work_reg_a_sp, &work_reg_b_sp);
|
||||||
|
|
||||||
input_data_a = vld1q_u8(input_data_0);
|
input_data_a = vld1q_s8(input_data_0);
|
||||||
input_data_b = vld1q_u8(input_data_0 + 1 * input_depth);
|
input_data_b = vld1q_s8(input_data_0 + 1 * input_depth);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
|
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
|
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
|
||||||
vst1q_s8(scratch_data_0, work_reg_a);
|
vst1q_s8(scratch_data_0, work_reg_a);
|
||||||
@ -6272,8 +6277,8 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
work_reg_a_sp = veorq_s8(work_reg_a_sp, sign_bit);
|
work_reg_a_sp = veorq_s8(work_reg_a_sp, sign_bit);
|
||||||
work_reg_b_sp = veorq_s8(work_reg_b_sp, sign_bit);
|
work_reg_b_sp = veorq_s8(work_reg_b_sp, sign_bit);
|
||||||
|
|
||||||
input_data_c = vld1q_u8(input_data_0 + 2 * input_depth);
|
input_data_c = vld1q_s8(input_data_0 + 2 * input_depth);
|
||||||
input_data_d = vld1q_u8(input_data_0 + 3 * input_depth);
|
input_data_d = vld1q_s8(input_data_0 + 3 * input_depth);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
|
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
|
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
|
||||||
vst1q_s8(scratch_data_0, work_reg_a_sp);
|
vst1q_s8(scratch_data_0, work_reg_a_sp);
|
||||||
@ -6341,10 +6346,10 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
} else {
|
} else {
|
||||||
TFLITE_DCHECK_LT(adjusted_residual_width, 4);
|
TFLITE_DCHECK_LT(adjusted_residual_width, 4);
|
||||||
for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) {
|
for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) {
|
||||||
input_data_a = vdupq_n_u8(-input_offset);
|
input_data_a = vdupq_n_s8(-input_offset);
|
||||||
input_data_b = vdupq_n_u8(-input_offset);
|
input_data_b = vdupq_n_s8(-input_offset);
|
||||||
input_data_c = vdupq_n_u8(-input_offset);
|
input_data_c = vdupq_n_s8(-input_offset);
|
||||||
input_data_d = vdupq_n_u8(-input_offset);
|
input_data_d = vdupq_n_s8(-input_offset);
|
||||||
if (adjusted_residual_width > 0) {
|
if (adjusted_residual_width > 0) {
|
||||||
input_data_a = vld1q_lane_s8x8(input_data_0, input_data_a, 0);
|
input_data_a = vld1q_lane_s8x8(input_data_0, input_data_a, 0);
|
||||||
if (adjusted_residual_width > 1) {
|
if (adjusted_residual_width > 1) {
|
||||||
@ -6386,10 +6391,10 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
|
|
||||||
//
|
//
|
||||||
|
|
||||||
input_data_a = vdupq_n_u8(-input_offset);
|
input_data_a = vdupq_n_s8(-input_offset);
|
||||||
input_data_b = vld1q_u8(input_data_0 + 1 * input_depth);
|
input_data_b = vld1q_s8(input_data_0 + 1 * input_depth);
|
||||||
input_data_c = vld1q_u8(input_data_0 + 2 * input_depth);
|
input_data_c = vld1q_s8(input_data_0 + 2 * input_depth);
|
||||||
input_data_d = vld1q_u8(input_data_0 + 3 * input_depth);
|
input_data_d = vld1q_s8(input_data_0 + 3 * input_depth);
|
||||||
input_data_0 += 16;
|
input_data_0 += 16;
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -6405,8 +6410,8 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
work_reg_b_sp = vzip2q_s8(input_data_c, input_data_d);
|
work_reg_b_sp = vzip2q_s8(input_data_c, input_data_d);
|
||||||
vzipq_s8x2_in_place(&work_reg_a_sp, &work_reg_b_sp);
|
vzipq_s8x2_in_place(&work_reg_a_sp, &work_reg_b_sp);
|
||||||
|
|
||||||
input_data_a = vdupq_n_u8(-input_offset);
|
input_data_a = vdupq_n_s8(-input_offset);
|
||||||
input_data_b = vld1q_u8(input_data_0 + 1 * input_depth);
|
input_data_b = vld1q_s8(input_data_0 + 1 * input_depth);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
|
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
|
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
|
||||||
vst1q_s8(scratch_data_0, work_reg_a);
|
vst1q_s8(scratch_data_0, work_reg_a);
|
||||||
@ -6417,8 +6422,8 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
work_reg_a_sp = veorq_s8(work_reg_a_sp, sign_bit);
|
work_reg_a_sp = veorq_s8(work_reg_a_sp, sign_bit);
|
||||||
work_reg_b_sp = veorq_s8(work_reg_b_sp, sign_bit);
|
work_reg_b_sp = veorq_s8(work_reg_b_sp, sign_bit);
|
||||||
|
|
||||||
input_data_c = vld1q_u8(input_data_0 + 2 * input_depth);
|
input_data_c = vld1q_s8(input_data_0 + 2 * input_depth);
|
||||||
input_data_d = vld1q_u8(input_data_0 + 3 * input_depth);
|
input_data_d = vld1q_s8(input_data_0 + 3 * input_depth);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
|
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
|
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
|
||||||
vst1q_s8(scratch_data_0, work_reg_a_sp);
|
vst1q_s8(scratch_data_0, work_reg_a_sp);
|
||||||
@ -6458,7 +6463,7 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
scratch_data_0 += depth_advance;
|
scratch_data_0 += depth_advance;
|
||||||
}
|
}
|
||||||
for (; i_depth < depth_micro_repeats; ++i_depth) {
|
for (; i_depth < depth_micro_repeats; ++i_depth) {
|
||||||
input_data_a = vdupq_n_u8(-input_offset);
|
input_data_a = vdupq_n_s8(-input_offset);
|
||||||
input_data_b = vld1q_lane_s8x8(input_data_0 + 1 * input_depth,
|
input_data_b = vld1q_lane_s8x8(input_data_0 + 1 * input_depth,
|
||||||
input_data_b, 0);
|
input_data_b, 0);
|
||||||
input_data_c = vld1q_lane_s8x8(input_data_0 + 2 * input_depth,
|
input_data_c = vld1q_lane_s8x8(input_data_0 + 2 * input_depth,
|
||||||
@ -6487,10 +6492,10 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
TFLITE_DCHECK_LT(adjusted_residual_width, 4);
|
TFLITE_DCHECK_LT(adjusted_residual_width, 4);
|
||||||
|
|
||||||
for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) {
|
for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) {
|
||||||
input_data_a = vdupq_n_u8(-input_offset);
|
input_data_a = vdupq_n_s8(-input_offset);
|
||||||
input_data_b = vdupq_n_u8(-input_offset);
|
input_data_b = vdupq_n_s8(-input_offset);
|
||||||
input_data_c = vdupq_n_u8(-input_offset);
|
input_data_c = vdupq_n_s8(-input_offset);
|
||||||
input_data_d = vdupq_n_u8(-input_offset);
|
input_data_d = vdupq_n_s8(-input_offset);
|
||||||
// Skip loading first column.
|
// Skip loading first column.
|
||||||
if (adjusted_residual_width > 1) {
|
if (adjusted_residual_width > 1) {
|
||||||
input_data_b = vld1q_lane_s8x8(input_data_0 + input_depth,
|
input_data_b = vld1q_lane_s8x8(input_data_0 + input_depth,
|
||||||
@ -6637,14 +6642,15 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
int8x8_t padding_mask;
|
int8x8_t padding_mask;
|
||||||
|
|
||||||
// Effect subtraction of zero-point = 128 by XOR of sign bit.
|
// Effect subtraction of zero-point = 128 by XOR of sign bit.
|
||||||
const uint8x16_t sign_bit = vdupq_n_u8(kSignBit);
|
const int8x16_t sign_bit = vdupq_n_s8(kSignBit);
|
||||||
const uint8x16_t padding_reg = vdupq_n_u8(-input_offset);
|
const int8x16_t padding_reg = vdupq_n_s8(-input_offset);
|
||||||
padding_mask = vdup_n_s8(-1);
|
padding_mask = vdup_n_s8(-1);
|
||||||
half_work_reg = vdup_n_s8(0);
|
half_work_reg = vdup_n_s8(0);
|
||||||
|
|
||||||
if (copy_size >= 16) {
|
if (copy_size >= 16) {
|
||||||
const int copy_remaining = (copy_size + start_width) & 0x7;
|
const int copy_remaining = (copy_size + start_width) & 0x7;
|
||||||
padding_mask = vshl_u64(padding_mask, vdup_n_s64(8 * copy_remaining));
|
padding_mask = vreinterpret_s8_s64(vshl_s64(
|
||||||
|
vreinterpret_s64_s8(padding_mask), vdup_n_s64(8 * copy_remaining)));
|
||||||
|
|
||||||
for (int k_height = 0; k_height < copy_block_height; ++k_height) {
|
for (int k_height = 0; k_height < copy_block_height; ++k_height) {
|
||||||
// Work through one slice, by row, at a time.
|
// Work through one slice, by row, at a time.
|
||||||
@ -6656,7 +6662,8 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
// iteration of the main copy loop. In the case of leading width
|
// iteration of the main copy loop. In the case of leading width
|
||||||
// padding, we unroll this specially.
|
// padding, we unroll this specially.
|
||||||
if (leading_width_padding) {
|
if (leading_width_padding) {
|
||||||
work_reg = vld1q_u8(input_block_data + input_block_offset);
|
work_reg = vld1q_s8(reinterpret_cast<const int8*>(
|
||||||
|
input_block_data + input_block_offset));
|
||||||
work_reg = vextq_s8(padding_reg, work_reg, 15);
|
work_reg = vextq_s8(padding_reg, work_reg, 15);
|
||||||
work_reg = veorq_s8(work_reg, sign_bit);
|
work_reg = veorq_s8(work_reg, sign_bit);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data);
|
optimized_ops_prefetch_write_l1_keep(scratch_data);
|
||||||
@ -6666,8 +6673,8 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
|
|
||||||
// Main copy loop.
|
// Main copy loop.
|
||||||
for (; (copy_done + 16) <= copy_size; copy_done += 16) {
|
for (; (copy_done + 16) <= copy_size; copy_done += 16) {
|
||||||
work_reg =
|
work_reg = vld1q_s8(reinterpret_cast<const int8*>(
|
||||||
vld1q_u8(input_block_data + input_block_offset + copy_done);
|
input_block_data + input_block_offset + copy_done));
|
||||||
work_reg = veorq_s8(work_reg, sign_bit);
|
work_reg = veorq_s8(work_reg, sign_bit);
|
||||||
TFLITE_DCHECK_EQ((start_width + copy_done) % 16, 0);
|
TFLITE_DCHECK_EQ((start_width + copy_done) % 16, 0);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data + start_width +
|
optimized_ops_prefetch_write_l1_keep(scratch_data + start_width +
|
||||||
@ -6676,8 +6683,8 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (copy_done + 8 <= copy_size) {
|
if (copy_done + 8 <= copy_size) {
|
||||||
half_work_reg =
|
half_work_reg = vld1_s8(reinterpret_cast<const int8*>(
|
||||||
vld1_u8(input_block_data + input_block_offset + copy_done);
|
input_block_data + input_block_offset + copy_done));
|
||||||
half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit));
|
half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit));
|
||||||
TFLITE_DCHECK_EQ((start_width + copy_done) % 8, 0);
|
TFLITE_DCHECK_EQ((start_width + copy_done) % 8, 0);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data + start_width +
|
optimized_ops_prefetch_write_l1_keep(scratch_data + start_width +
|
||||||
@ -6698,13 +6705,14 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
// Employ overlapping-load strategy in order to load full register,
|
// Employ overlapping-load strategy in order to load full register,
|
||||||
// but use only part.
|
// but use only part.
|
||||||
// This has the advantage of resulting in zeros after shifting.
|
// This has the advantage of resulting in zeros after shifting.
|
||||||
half_work_reg =
|
half_work_reg = vld1_s8(reinterpret_cast<const int8*>(
|
||||||
vld1_u8(input_block_data + input_block_offset + copy_size - 8);
|
input_block_data + input_block_offset + copy_size - 8));
|
||||||
|
|
||||||
half_work_reg =
|
half_work_reg = vreinterpret_s8_s64(
|
||||||
vshl_u64(half_work_reg, vdup_n_s64(-8 * (8 - copy_remaining)));
|
vshl_s64(vreinterpret_s64_s8(half_work_reg),
|
||||||
half_work_reg =
|
vdup_n_s64(-8 * (8 - copy_remaining))));
|
||||||
vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg);
|
half_work_reg = vbsl_s8(vreinterpret_u8_s8(padding_mask),
|
||||||
|
vget_low_s8(padding_reg), half_work_reg);
|
||||||
|
|
||||||
half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit));
|
half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit));
|
||||||
TFLITE_DCHECK_EQ((start_width + copy_done) % 8, 0);
|
TFLITE_DCHECK_EQ((start_width + copy_done) % 8, 0);
|
||||||
@ -6726,7 +6734,8 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
}
|
}
|
||||||
} else if (copy_size >= 4) {
|
} else if (copy_size >= 4) {
|
||||||
const int copy_remaining = (copy_size + start_width) & 0x3;
|
const int copy_remaining = (copy_size + start_width) & 0x3;
|
||||||
padding_mask = vshl_u64(padding_mask, vdup_n_s64(8 * copy_remaining));
|
padding_mask = vreinterpret_s8_s64(vshl_s64(
|
||||||
|
vreinterpret_s64_s8(padding_mask), vdup_n_s64(8 * copy_remaining)));
|
||||||
|
|
||||||
for (int k_height = 0; k_height < copy_block_height; ++k_height) {
|
for (int k_height = 0; k_height < copy_block_height; ++k_height) {
|
||||||
// Work through one slice, by row, at a time.
|
// Work through one slice, by row, at a time.
|
||||||
@ -6777,10 +6786,11 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
input_block_data + input_block_offset + copy_size - 4,
|
input_block_data + input_block_offset + copy_size - 4,
|
||||||
half_work_reg, 0);
|
half_work_reg, 0);
|
||||||
|
|
||||||
half_work_reg =
|
half_work_reg = vreinterpret_s8_s64(
|
||||||
vshl_u64(half_work_reg, vdup_n_s64(-8 * (4 - copy_remaining)));
|
vshl_s64(vreinterpret_s64_s8(half_work_reg),
|
||||||
half_work_reg =
|
vdup_n_s64(-8 * (4 - copy_remaining))));
|
||||||
vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg);
|
half_work_reg = vbsl_s8(vreinterpret_u8_s8(padding_mask),
|
||||||
|
vget_low_s8(padding_reg), half_work_reg);
|
||||||
|
|
||||||
half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit));
|
half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit));
|
||||||
TFLITE_DCHECK_EQ((start_width + copy_done) % 4, 0);
|
TFLITE_DCHECK_EQ((start_width + copy_done) % 4, 0);
|
||||||
@ -6815,7 +6825,7 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
TFLITE_DCHECK(trailing_width_padding);
|
TFLITE_DCHECK(trailing_width_padding);
|
||||||
|
|
||||||
for (int k_height = 0; k_height < copy_block_height; ++k_height) {
|
for (int k_height = 0; k_height < copy_block_height; ++k_height) {
|
||||||
half_work_reg = vdup_n_u8(-input_offset);
|
half_work_reg = vdup_n_s8(-input_offset);
|
||||||
half_work_reg = vld1_lane_s8(reinterpret_cast<const int8*>(
|
half_work_reg = vld1_lane_s8(reinterpret_cast<const int8*>(
|
||||||
input_block_data + input_block_offset),
|
input_block_data + input_block_offset),
|
||||||
half_work_reg, 1);
|
half_work_reg, 1);
|
||||||
@ -6854,24 +6864,27 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
} else {
|
} else {
|
||||||
TFLITE_DCHECK_EQ(width_overall_micro_repeats, 1);
|
TFLITE_DCHECK_EQ(width_overall_micro_repeats, 1);
|
||||||
const int copy_remaining = (copy_size + start_width) & 0x3;
|
const int copy_remaining = (copy_size + start_width) & 0x3;
|
||||||
padding_mask = vshl_u64(padding_mask, vdup_n_s64(8 * copy_remaining));
|
padding_mask = vreinterpret_s8_s64(vshl_s64(
|
||||||
|
vreinterpret_s64_s8(padding_mask), vdup_n_s64(8 * copy_remaining)));
|
||||||
if (leading_width_padding) {
|
if (leading_width_padding) {
|
||||||
padding_mask = vset_lane_u8(255, padding_mask, 0);
|
padding_mask = vset_lane_s8(255, padding_mask, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int k_height = 0; k_height < copy_block_height; ++k_height) {
|
for (int k_height = 0; k_height < copy_block_height; ++k_height) {
|
||||||
for (int i = 0; i < copy_size; ++i) {
|
for (int i = 0; i < copy_size; ++i) {
|
||||||
half_work_reg = vshl_n_u64(half_work_reg, 8);
|
half_work_reg = vreinterpret_s8_s64(
|
||||||
|
vshl_n_s64(vreinterpret_s64_s8(half_work_reg), 8));
|
||||||
half_work_reg = vld1_lane_s8(
|
half_work_reg = vld1_lane_s8(
|
||||||
reinterpret_cast<const int8*>(
|
reinterpret_cast<const int8*>(
|
||||||
input_block_data + input_block_offset + copy_size - 1 - i),
|
input_block_data + input_block_offset + copy_size - 1 - i),
|
||||||
half_work_reg, 0);
|
half_work_reg, 0);
|
||||||
}
|
}
|
||||||
if (leading_width_padding) {
|
if (leading_width_padding) {
|
||||||
half_work_reg = vshl_n_s64(half_work_reg, 8);
|
half_work_reg = vreinterpret_s8_s64(
|
||||||
|
vshl_n_s64(vreinterpret_s64_s8(half_work_reg), 8));
|
||||||
}
|
}
|
||||||
half_work_reg =
|
half_work_reg = vbsl_s8(vreinterpret_u8_s8(padding_mask),
|
||||||
vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg);
|
vget_low_s8(padding_reg), half_work_reg);
|
||||||
|
|
||||||
half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit));
|
half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit));
|
||||||
TFLITE_DCHECK_EQ(scratch_data_offset % 4, 0);
|
TFLITE_DCHECK_EQ(scratch_data_offset % 4, 0);
|
||||||
@ -6976,7 +6989,7 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
int8x8_t half_work_reg;
|
int8x8_t half_work_reg;
|
||||||
|
|
||||||
// Effect subtraction of zero-point = 128 by XOR of sign bit.
|
// Effect subtraction of zero-point = 128 by XOR of sign bit.
|
||||||
const uint8x16_t sign_bit = vdupq_n_u8(kSignBit);
|
const int8x16_t sign_bit = vdupq_n_s8(kSignBit);
|
||||||
half_work_reg = vdup_n_s8(0);
|
half_work_reg = vdup_n_s8(0);
|
||||||
|
|
||||||
if (copy_size >= 16) {
|
if (copy_size >= 16) {
|
||||||
@ -6990,8 +7003,8 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
|
|
||||||
// Main copy loop.
|
// Main copy loop.
|
||||||
for (; (copy_done + 16) <= copy_size; copy_done += 16) {
|
for (; (copy_done + 16) <= copy_size; copy_done += 16) {
|
||||||
work_reg =
|
work_reg = vld1q_s8(reinterpret_cast<const int8*>(
|
||||||
vld1q_u8(input_block_data + input_block_offset + copy_done);
|
input_block_data + input_block_offset + copy_done));
|
||||||
work_reg = veorq_s8(work_reg, sign_bit);
|
work_reg = veorq_s8(work_reg, sign_bit);
|
||||||
TFLITE_DCHECK_EQ(copy_done % 16, 0);
|
TFLITE_DCHECK_EQ(copy_done % 16, 0);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data + copy_done);
|
optimized_ops_prefetch_write_l1_keep(scratch_data + copy_done);
|
||||||
@ -6999,8 +7012,8 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (copy_done + 8 <= copy_size) {
|
if (copy_done + 8 <= copy_size) {
|
||||||
half_work_reg =
|
half_work_reg = vld1_s8(reinterpret_cast<const int8*>(
|
||||||
vld1_u8(input_block_data + input_block_offset + copy_done);
|
input_block_data + input_block_offset + copy_done));
|
||||||
half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit));
|
half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit));
|
||||||
TFLITE_DCHECK_EQ(copy_done % 8, 0);
|
TFLITE_DCHECK_EQ(copy_done % 8, 0);
|
||||||
optimized_ops_prefetch_write_l1_keep(scratch_data + copy_done);
|
optimized_ops_prefetch_write_l1_keep(scratch_data + copy_done);
|
||||||
@ -7020,11 +7033,12 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
// Employ overlapping-load strategy in order to load full register,
|
// Employ overlapping-load strategy in order to load full register,
|
||||||
// but use only part.
|
// but use only part.
|
||||||
// This has the advantage of resulting in zeros after shifting.
|
// This has the advantage of resulting in zeros after shifting.
|
||||||
half_work_reg =
|
half_work_reg = vld1_s8(reinterpret_cast<const int8*>(
|
||||||
vld1_u8(input_block_data + input_block_offset + copy_size - 8);
|
input_block_data + input_block_offset + copy_size - 8));
|
||||||
|
|
||||||
half_work_reg =
|
half_work_reg = vreinterpret_s8_s64(
|
||||||
vshl_u64(half_work_reg, vdup_n_s64(-8 * (8 - copy_remaining)));
|
vshl_s64(vreinterpret_s64_s8(half_work_reg),
|
||||||
|
vdup_n_s64(-8 * (8 - copy_remaining))));
|
||||||
|
|
||||||
half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit));
|
half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit));
|
||||||
TFLITE_DCHECK_EQ(copy_done % 8, 0);
|
TFLITE_DCHECK_EQ(copy_done % 8, 0);
|
||||||
@ -7079,8 +7093,9 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
input_block_data + input_block_offset + copy_size - 4,
|
input_block_data + input_block_offset + copy_size - 4,
|
||||||
half_work_reg, 0);
|
half_work_reg, 0);
|
||||||
|
|
||||||
half_work_reg =
|
half_work_reg = vreinterpret_s8_s64(
|
||||||
vshl_u64(half_work_reg, vdup_n_s64(-8 * (4 - copy_remaining)));
|
vshl_s64(vreinterpret_s64_s8(half_work_reg),
|
||||||
|
vdup_n_s64(-8 * (4 - copy_remaining))));
|
||||||
|
|
||||||
half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit));
|
half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit));
|
||||||
TFLITE_DCHECK_EQ(copy_done % 4, 0);
|
TFLITE_DCHECK_EQ(copy_done % 4, 0);
|
||||||
@ -7104,7 +7119,8 @@ struct PackMacroBlock<DepthwiseConvImplementation::kUseNeon3x3DotProduct,
|
|||||||
|
|
||||||
for (int k_height = 0; k_height < copy_block_height; ++k_height) {
|
for (int k_height = 0; k_height < copy_block_height; ++k_height) {
|
||||||
for (int i = 0; i < copy_size; ++i) {
|
for (int i = 0; i < copy_size; ++i) {
|
||||||
half_work_reg = vshl_n_u64(half_work_reg, 8);
|
half_work_reg = vreinterpret_s8_s64(
|
||||||
|
vshl_n_s64(vreinterpret_s64_s8(half_work_reg), 8));
|
||||||
half_work_reg = vld1_lane_s8(
|
half_work_reg = vld1_lane_s8(
|
||||||
reinterpret_cast<const int8*>(
|
reinterpret_cast<const int8*>(
|
||||||
input_block_data + input_block_offset + copy_size - 1 - i),
|
input_block_data + input_block_offset + copy_size - 1 - i),
|
||||||
|
Loading…
Reference in New Issue
Block a user