From 5c4b601d2d29b7e1616194463d8444ed8b3977cb Mon Sep 17 00:00:00 2001 From: Robert David Date: Wed, 20 Jan 2021 16:52:59 -0800 Subject: [PATCH] Define TFLITE_UNLIKELY using __builtin_expect if available. Annotate all postamble loops in neon_tensor_utils with TFLITE_UNLIKELY. PiperOrigin-RevId: 352904916 Change-Id: I2067a32af737cf8fc07850ed7e96e3c0fbc5809a --- .../internal/optimized/neon_tensor_utils.cc | 88 +++++++++++-------- 1 file changed, 50 insertions(+), 38 deletions(-) diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index 565ccb4eec5..405cde0373a 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -45,6 +45,21 @@ limitations under the License. #endif #endif +// Note: This is the same as ABSL_HAVE_BUILTIN, but can't include the header. +#ifdef __has_builtin +#define TFLITE_HAS_BUILTIN(x) __has_builtin(x) +#else +#define TFLITE_HAS_BUILTIN(x) 0 +#endif + +// Note: This is the same as ABSL_PREDICT_FALSE, but can't include the header. +#if TFLITE_HAS_BUILTIN(__builtin_expect) || \ + (defined(__GNUC__) && !defined(__clang__)) +#define TFLITE_UNLIKELY(x) (__builtin_expect(false || (x), false)) +#else +#define TFLITE_UNLIKELY(x) (x) +#endif + namespace tflite { namespace tensor_utils { namespace { @@ -209,7 +224,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, // Add the 4 intermediate sum values to get the final dot-prod value for // this column. *result_in_batch += AccumulateNeonLane(acc_32x4); - for (; c < m_cols; c++) { + for (; TFLITE_UNLIKELY(c < m_cols); c++) { *result_in_batch += matrix_row[c] * vector_in_batch[c]; } matrix_row += m_cols; @@ -802,8 +817,7 @@ void NeonMatrixBatchVectorMultiplyImpl(const int8_t* input, const int32_t* bias, } // for col // Half iteration dealing only 8 elements - // TODO(raziel): if (ABSL_PREDICT_FALSE(col < postamble_start)) - if (col < postamble_start) { + if (TFLITE_UNLIKELY(col < postamble_start)) { // Load 8 8-bit values from the row and column each to operate on. // Here the assumption is that each buffer is 4-bytes aligned. // Otherwise, performance may suffer significantly. @@ -819,8 +833,7 @@ void NeonMatrixBatchVectorMultiplyImpl(const int8_t* input, const int32_t* bias, // this row. int32_t dotprod = AccumulateNeonLane(dotprod_32x4); // Postamble loop. - // TODO(raziel): if (ABSL_PREDICT_FALSE(col < m_cols)) - for (; col < n_input; ++col) { + for (; TFLITE_UNLIKELY(col < n_input); ++col) { dotprod += row_ptr[col] * aligned_vec[col]; } // for col @@ -874,7 +887,7 @@ inline void NeonMatrixBatchVectorAccumulateImpl( vcombine_s16(vqmovn_s32(temp_val.val[0]), vqmovn_s32(temp_val.val[1])); vst1q_s16(output + i, result); } - for (; i < total_size; ++i) { + for (; TFLITE_UNLIKELY(i < total_size); ++i) { int32_t temp = MultiplyByQuantizedMultiplier(scratch[i], multiplier, shift); temp += output_zp; temp += output[i]; @@ -948,7 +961,7 @@ inline void NeonMatrixBatchVectorAccumulateImpl( vcombine_s8(vqmovn_s16(result_1), vqmovn_s16(result_2)); vst1q_s8(output + i, result); } - for (; i < total_size; ++i) { + for (; TFLITE_UNLIKELY(i < total_size); ++i) { int32_t temp = MultiplyByQuantizedMultiplier(scratch[i], multiplier, shift); temp += output_zp; temp += output[i]; @@ -1128,8 +1141,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix, } // for col // Half iteration dealing only 8 elements - // TODO(raziel): if (ABSL_PREDICT_FALSE(col < postamble_start)) - if (col < postamble_start) { + if (TFLITE_UNLIKELY(col < postamble_start)) { // Load 8 8-bit values from the row and column each to operate on. // Here the assumption is that each buffer is 4-bytes aligned. // Otherwise, performance may suffer significantly. @@ -1145,8 +1157,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix, // this row. int32_t dotprod = AccumulateNeonLane(dotprod_32x4); // Postamble loop. - // TODO(raziel): if (ABSL_PREDICT_FALSE(col < m_cols)) - for (; col < m_cols; ++col) { + for (; TFLITE_UNLIKELY(col < m_cols); ++col) { dotprod += row_ptr[col] * aligned_vec[col]; } // for col @@ -1193,7 +1204,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix, vst1q_f32(result + 4, result1); } scratch += i; - for (; i < total_size; i++) { + for (; TFLITE_UNLIKELY(i < total_size); i++) { const float batch_scaling_factor = scaling_factors[i / m_rows]; int32_t x = *(scratch++); *result += x * batch_scaling_factor; @@ -1221,7 +1232,7 @@ void NeonMatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar, row_sum = vpadalq_s16(row_sum, temp); } int32_t sum = AccumulateNeonLane(row_sum); - for (; j < n_col; ++j) { + for (; TFLITE_UNLIKELY(j < n_col); ++j) { sum += *(row_ptr + j); } output[i] += sum * scalar; @@ -1322,7 +1333,7 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl( } // for col // Half iteration dealing only 8 elements - if (col < postamble_start) { + if (TFLITE_UNLIKELY(col < postamble_start)) { // Load 8 8-bit values from the row and column each to operate on. // Here the assumption is that each buffer is 4-bytes aligned. // Otherwise, performance may suffer significantly. @@ -1338,7 +1349,7 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl( int32_t dotprod = AccumulateNeonLane(dotprod_32x4); // Postamble loop. - for (; col < m_cols; ++col) { + for (; TFLITE_UNLIKELY(col < m_cols); ++col) { dotprod += row_ptr[col] * aligned_vec[col]; } // for col dotprod -= row_sums_ptr[row] * batch_input_offset; @@ -1432,7 +1443,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate( } scratch_ptr += i; - for (; i < total_size; i++) { + for (; TFLITE_UNLIKELY(i < total_size); i++) { float batch_scaling_factor = scaling_factors[i / m_rows]; if (per_channel_scale) { batch_scaling_factor *= per_channel_scale[i % m_rows]; @@ -1503,7 +1514,7 @@ void NeonApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights, sum_sq += static_cast( AccumulateNeonLane(vmulq_s32(val_s32_1, val_s32_1))); } - for (; j < n_input; ++j) { + for (; TFLITE_UNLIKELY(j < n_input); ++j) { const int32 index = i * n_input + j; int32 val = static_cast(input[index]); sum += val; @@ -1589,7 +1600,7 @@ void NeonApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights, vst1_s16(output + index + 8, vqmovn_s32(val5_s32.val[2])); vst1_s16(output + index + 12, vqmovn_s32(val5_s32.val[3])); } - for (; j < n_input; ++j) { + for (; TFLITE_UNLIKELY(j < n_input); ++j) { const int32 index = i * n_input + j; int32 val = static_cast(input[index]); int32 shifted = 1024 * val - mean; @@ -1727,7 +1738,7 @@ void NeonCwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch, const int16x8_t result = vcombine_s16(vmovn_s32(x_0), vmovn_s32(x_1)); vst1q_s16(output + index, result); } - for (; i < n_input; ++i) { + for (; TFLITE_UNLIKELY(i < n_input); ++i) { const int index = batch * n_input + i; const int16_t a = input_1[index]; const int16_t b = input_2[index]; @@ -1780,7 +1791,7 @@ void NeonCwiseMul(const int16_t* input_1, const int16_t* input_2, vcombine_s16(vmovn_s32(temp_val.val[0]), vmovn_s32(temp_val.val[1])); vst1_s8(output + index, vmovn_s16(result)); } - for (; i < n_input; ++i) { + for (; TFLITE_UNLIKELY(i < n_input); ++i) { const int index = batch * n_input + i; const int16_t a = input_1[index]; const int16_t b = input_2[index]; @@ -1814,7 +1825,7 @@ void NeonCwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch, vst1_s16(output + index, vqmovn_s32(sum_0)); vst1_s16(output + index + 4, vqmovn_s32(sum_1)); } - for (; i < n_input; ++i) { + for (; TFLITE_UNLIKELY(i < n_input); ++i) { const int index = batch * n_input + i; int32_t sum = input_1[index] + input_2[index]; const int32 sum_clamped = std::min(int16_max, std::max(int16_min, sum)); @@ -1839,7 +1850,7 @@ void NeonCwiseClipping(float* vector, const int v_size, // Save to output. vst1q_f32(vector + i, v_f32x4); } - for (; i < v_size; i++) { + for (; TFLITE_UNLIKELY(i < v_size); i++) { vector[i] = std::max(std::min(clipping_value, vector[i]), -clipping_value); } } @@ -1861,7 +1872,7 @@ void NeonCwiseClipping(int16_t* vector, const int v_size, vst1q_s16(vector + i, val_0); vst1q_s16(vector + i + kInt16ValuesPerNeonVector, val_1); } - for (; i < v_size; i++) { + for (; TFLITE_UNLIKELY(i < v_size); i++) { vector[i] = std::max(std::min(clipping_value, vector[i]), static_cast(-clipping_value)); } @@ -1884,7 +1895,7 @@ void NeonCwiseClipping(int8_t* vector, const int v_size, vst1q_s8(vector + i, val_0); vst1q_s8(vector + i + kInt8ValuesPerNeonVector, val_1); } - for (; i < v_size; i++) { + for (; TFLITE_UNLIKELY(i < v_size); i++) { vector[i] = std::max(std::min(clipping_value, vector[i]), static_cast(-clipping_value)); } @@ -2049,7 +2060,7 @@ void NeonSub1Vector(const float* vector, int v_size, float* result) { // Save to output. vst1q_f32(result + v, result_f32x4); } - for (; v < v_size; v++) { + for (; TFLITE_UNLIKELY(v < v_size); v++) { result[v] = 1.0f - vector[v]; } } @@ -2066,7 +2077,7 @@ void NeonSub1Vector(const int16_t* vector, int v_size, int16_t* result) { const int16x8_t sub1_result = veorq_s16(one_dup, input); vst1q_s16(result + i, sub1_result); } - for (; i < v_size; i++) { + for (; TFLITE_UNLIKELY(i < v_size); i++) { result[i] = kOne ^ vector[i]; } } @@ -2121,7 +2132,7 @@ bool NeonIsZeroVector(const float* vector, int v_size) { if (!IsAllZero(v_f32x4)) return false; } // Postamble loop - for (; v < v_size; ++v) { + for (; TFLITE_UNLIKELY(v < v_size); ++v) { if (vector[v] != 0.0) return false; } return true; @@ -2140,7 +2151,7 @@ bool NeonIsZeroVector(const int8_t* vector, int v_size) { if (!IsAllZero(v_s8x16)) return false; } // Postamble loop - for (; v < v_size; ++v) { + for (; TFLITE_UNLIKELY(v < v_size); ++v) { if (vector[v] != 0) return false; } return true; @@ -2190,7 +2201,8 @@ void NeonVectorScalarMultiply(const int8_t* vector, const int v_size, vst1q_f32(result + v + 12, v3_f32x4); } - if (v_size - postamble_start >= (kInt8ValuesPerNeonVector >> 1)) { + if (TFLITE_UNLIKELY(v_size - postamble_start >= + (kInt8ValuesPerNeonVector >> 1))) { // Load eight int8 values, if there is at least eight remaining. const int8x8_t v_i8x8 = vld1_s8(vector + v); // Convert them to int16 first. @@ -2211,7 +2223,7 @@ void NeonVectorScalarMultiply(const int8_t* vector, const int v_size, } // Postamble loop. - for (; v < v_size; v++) { + for (; TFLITE_UNLIKELY(v < v_size); v++) { result[v] = scale * vector[v]; } } @@ -2263,7 +2275,7 @@ inline void NeonMinMax(const float* values, const int size, float* min, rmax = std::max(rmax, vget_lane_f32(max_f32x2, 0)); #endif // __aarch64__ } - if (i < size) { + if (TFLITE_UNLIKELY(i < size)) { const auto minmax = std::minmax_element(values + postamble_start, values + size); rmin = std::min(rmin, *minmax.first); @@ -2335,7 +2347,7 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size, vst1_s8(&quantized_values[i], min_s8x8); } - for (; i < size; ++i) { + for (; TFLITE_UNLIKELY(i < size); ++i) { const int32 quantized_value = static_cast(TfLiteRound(scaling_factor_inv * values[i])); quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value)); @@ -2418,7 +2430,7 @@ void NeonAsymmetricQuantizeFloats(const float* values, const int size, vst1_s8(&quantized_values[i], min_s8x8); } - for (; i < size; ++i) { + for (; TFLITE_UNLIKELY(i < size); ++i) { const int32 quantized_value = static_cast( *offset + TfLiteRound(scaling_factor_inv * values[i])); quantized_values[i] = @@ -2444,7 +2456,7 @@ float NeonVectorVectorDotProduct(const float* vector1, const float* vector2, } float result = AccumulateNeonLane(acc_32x4); // Postamble loop. - for (; v < v_size; v++) { + for (; TFLITE_UNLIKELY(v < v_size); v++) { result += vector1[v] * vector2[v]; } return result; @@ -2466,7 +2478,7 @@ void NeonReductionSumVector(const float* input_vector, float* output_vector, } float sum = AccumulateNeonLane(sum_f32x4); // Postamble loop. - for (; r < reduction_size; r++) { + for (; TFLITE_UNLIKELY(r < reduction_size); r++) { sum += input_vector[r]; } output_vector[o] = sum; @@ -2487,13 +2499,13 @@ void NeonReductionSumVector(const int8_t* input_vector, int32_t* output_vector, const int8x16_t s2_8x16 = vld1q_s8(input_vector + r); sum_32x4 = vpadalq_s16(sum_32x4, vpaddlq_s8(s2_8x16)); } - if (r < postamble_start) { + if (TFLITE_UNLIKELY(r < postamble_start)) { const int8x8_t s2_8x8 = vld1_s8(input_vector + r); sum_32x4 = vpadalq_s16(sum_32x4, vmovl_s8(s2_8x8)); r += (kInt8ValuesPerNeonVector >> 1); } int32_t sum = AccumulateNeonLane(sum_32x4); - for (; r < reduction_size; ++r) { + for (; TFLITE_UNLIKELY(r < reduction_size); ++r) { sum += input_vector[r]; } output_vector[o] = sum; @@ -2551,7 +2563,7 @@ void NeonVectorBatchVectorCwiseProductAccumulate( result += 16; } - for (; v < v_size; v++) { + for (; TFLITE_UNLIKELY(v < v_size); v++) { int32_t prod = vector[v] * *batch_vector++; prod = MultiplyByQuantizedMultiplier(prod, multiplier, shift); int32_t output = prod + *result;