Define TFLITE_UNLIKELY using __builtin_expect if available.

Annotate all postamble loops in neon_tensor_utils with TFLITE_UNLIKELY.

PiperOrigin-RevId: 352904916
Change-Id: I2067a32af737cf8fc07850ed7e96e3c0fbc5809a
This commit is contained in:
Robert David 2021-01-20 16:52:59 -08:00 committed by TensorFlower Gardener
parent 03acd8ec0b
commit 5c4b601d2d

View File

@ -45,6 +45,21 @@ limitations under the License.
#endif
#endif
// Note: This is the same as ABSL_HAVE_BUILTIN, but can't include the header.
#ifdef __has_builtin
#define TFLITE_HAS_BUILTIN(x) __has_builtin(x)
#else
#define TFLITE_HAS_BUILTIN(x) 0
#endif
// Note: This is the same as ABSL_PREDICT_FALSE, but can't include the header.
#if TFLITE_HAS_BUILTIN(__builtin_expect) || \
(defined(__GNUC__) && !defined(__clang__))
#define TFLITE_UNLIKELY(x) (__builtin_expect(false || (x), false))
#else
#define TFLITE_UNLIKELY(x) (x)
#endif
namespace tflite {
namespace tensor_utils {
namespace {
@ -209,7 +224,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
// Add the 4 intermediate sum values to get the final dot-prod value for
// this column.
*result_in_batch += AccumulateNeonLane(acc_32x4);
for (; c < m_cols; c++) {
for (; TFLITE_UNLIKELY(c < m_cols); c++) {
*result_in_batch += matrix_row[c] * vector_in_batch[c];
}
matrix_row += m_cols;
@ -802,8 +817,7 @@ void NeonMatrixBatchVectorMultiplyImpl(const int8_t* input, const int32_t* bias,
} // for col
// Half iteration dealing only 8 elements
// TODO(raziel): if (ABSL_PREDICT_FALSE(col < postamble_start))
if (col < postamble_start) {
if (TFLITE_UNLIKELY(col < postamble_start)) {
// Load 8 8-bit values from the row and column each to operate on.
// Here the assumption is that each buffer is 4-bytes aligned.
// Otherwise, performance may suffer significantly.
@ -819,8 +833,7 @@ void NeonMatrixBatchVectorMultiplyImpl(const int8_t* input, const int32_t* bias,
// this row.
int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
// Postamble loop.
// TODO(raziel): if (ABSL_PREDICT_FALSE(col < m_cols))
for (; col < n_input; ++col) {
for (; TFLITE_UNLIKELY(col < n_input); ++col) {
dotprod += row_ptr[col] * aligned_vec[col];
} // for col
@ -874,7 +887,7 @@ inline void NeonMatrixBatchVectorAccumulateImpl(
vcombine_s16(vqmovn_s32(temp_val.val[0]), vqmovn_s32(temp_val.val[1]));
vst1q_s16(output + i, result);
}
for (; i < total_size; ++i) {
for (; TFLITE_UNLIKELY(i < total_size); ++i) {
int32_t temp = MultiplyByQuantizedMultiplier(scratch[i], multiplier, shift);
temp += output_zp;
temp += output[i];
@ -948,7 +961,7 @@ inline void NeonMatrixBatchVectorAccumulateImpl(
vcombine_s8(vqmovn_s16(result_1), vqmovn_s16(result_2));
vst1q_s8(output + i, result);
}
for (; i < total_size; ++i) {
for (; TFLITE_UNLIKELY(i < total_size); ++i) {
int32_t temp = MultiplyByQuantizedMultiplier(scratch[i], multiplier, shift);
temp += output_zp;
temp += output[i];
@ -1128,8 +1141,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
} // for col
// Half iteration dealing only 8 elements
// TODO(raziel): if (ABSL_PREDICT_FALSE(col < postamble_start))
if (col < postamble_start) {
if (TFLITE_UNLIKELY(col < postamble_start)) {
// Load 8 8-bit values from the row and column each to operate on.
// Here the assumption is that each buffer is 4-bytes aligned.
// Otherwise, performance may suffer significantly.
@ -1145,8 +1157,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
// this row.
int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
// Postamble loop.
// TODO(raziel): if (ABSL_PREDICT_FALSE(col < m_cols))
for (; col < m_cols; ++col) {
for (; TFLITE_UNLIKELY(col < m_cols); ++col) {
dotprod += row_ptr[col] * aligned_vec[col];
} // for col
@ -1193,7 +1204,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
vst1q_f32(result + 4, result1);
}
scratch += i;
for (; i < total_size; i++) {
for (; TFLITE_UNLIKELY(i < total_size); i++) {
const float batch_scaling_factor = scaling_factors[i / m_rows];
int32_t x = *(scratch++);
*result += x * batch_scaling_factor;
@ -1221,7 +1232,7 @@ void NeonMatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
row_sum = vpadalq_s16(row_sum, temp);
}
int32_t sum = AccumulateNeonLane(row_sum);
for (; j < n_col; ++j) {
for (; TFLITE_UNLIKELY(j < n_col); ++j) {
sum += *(row_ptr + j);
}
output[i] += sum * scalar;
@ -1322,7 +1333,7 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
} // for col
// Half iteration dealing only 8 elements
if (col < postamble_start) {
if (TFLITE_UNLIKELY(col < postamble_start)) {
// Load 8 8-bit values from the row and column each to operate on.
// Here the assumption is that each buffer is 4-bytes aligned.
// Otherwise, performance may suffer significantly.
@ -1338,7 +1349,7 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
// Postamble loop.
for (; col < m_cols; ++col) {
for (; TFLITE_UNLIKELY(col < m_cols); ++col) {
dotprod += row_ptr[col] * aligned_vec[col];
} // for col
dotprod -= row_sums_ptr[row] * batch_input_offset;
@ -1432,7 +1443,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
}
scratch_ptr += i;
for (; i < total_size; i++) {
for (; TFLITE_UNLIKELY(i < total_size); i++) {
float batch_scaling_factor = scaling_factors[i / m_rows];
if (per_channel_scale) {
batch_scaling_factor *= per_channel_scale[i % m_rows];
@ -1503,7 +1514,7 @@ void NeonApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
sum_sq += static_cast<int64_t>(
AccumulateNeonLane(vmulq_s32(val_s32_1, val_s32_1)));
}
for (; j < n_input; ++j) {
for (; TFLITE_UNLIKELY(j < n_input); ++j) {
const int32 index = i * n_input + j;
int32 val = static_cast<int32_t>(input[index]);
sum += val;
@ -1589,7 +1600,7 @@ void NeonApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
vst1_s16(output + index + 8, vqmovn_s32(val5_s32.val[2]));
vst1_s16(output + index + 12, vqmovn_s32(val5_s32.val[3]));
}
for (; j < n_input; ++j) {
for (; TFLITE_UNLIKELY(j < n_input); ++j) {
const int32 index = i * n_input + j;
int32 val = static_cast<int32_t>(input[index]);
int32 shifted = 1024 * val - mean;
@ -1727,7 +1738,7 @@ void NeonCwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch,
const int16x8_t result = vcombine_s16(vmovn_s32(x_0), vmovn_s32(x_1));
vst1q_s16(output + index, result);
}
for (; i < n_input; ++i) {
for (; TFLITE_UNLIKELY(i < n_input); ++i) {
const int index = batch * n_input + i;
const int16_t a = input_1[index];
const int16_t b = input_2[index];
@ -1780,7 +1791,7 @@ void NeonCwiseMul(const int16_t* input_1, const int16_t* input_2,
vcombine_s16(vmovn_s32(temp_val.val[0]), vmovn_s32(temp_val.val[1]));
vst1_s8(output + index, vmovn_s16(result));
}
for (; i < n_input; ++i) {
for (; TFLITE_UNLIKELY(i < n_input); ++i) {
const int index = batch * n_input + i;
const int16_t a = input_1[index];
const int16_t b = input_2[index];
@ -1814,7 +1825,7 @@ void NeonCwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
vst1_s16(output + index, vqmovn_s32(sum_0));
vst1_s16(output + index + 4, vqmovn_s32(sum_1));
}
for (; i < n_input; ++i) {
for (; TFLITE_UNLIKELY(i < n_input); ++i) {
const int index = batch * n_input + i;
int32_t sum = input_1[index] + input_2[index];
const int32 sum_clamped = std::min(int16_max, std::max(int16_min, sum));
@ -1839,7 +1850,7 @@ void NeonCwiseClipping(float* vector, const int v_size,
// Save to output.
vst1q_f32(vector + i, v_f32x4);
}
for (; i < v_size; i++) {
for (; TFLITE_UNLIKELY(i < v_size); i++) {
vector[i] = std::max(std::min(clipping_value, vector[i]), -clipping_value);
}
}
@ -1861,7 +1872,7 @@ void NeonCwiseClipping(int16_t* vector, const int v_size,
vst1q_s16(vector + i, val_0);
vst1q_s16(vector + i + kInt16ValuesPerNeonVector, val_1);
}
for (; i < v_size; i++) {
for (; TFLITE_UNLIKELY(i < v_size); i++) {
vector[i] = std::max(std::min(clipping_value, vector[i]),
static_cast<int16_t>(-clipping_value));
}
@ -1884,7 +1895,7 @@ void NeonCwiseClipping(int8_t* vector, const int v_size,
vst1q_s8(vector + i, val_0);
vst1q_s8(vector + i + kInt8ValuesPerNeonVector, val_1);
}
for (; i < v_size; i++) {
for (; TFLITE_UNLIKELY(i < v_size); i++) {
vector[i] = std::max(std::min(clipping_value, vector[i]),
static_cast<int8_t>(-clipping_value));
}
@ -2049,7 +2060,7 @@ void NeonSub1Vector(const float* vector, int v_size, float* result) {
// Save to output.
vst1q_f32(result + v, result_f32x4);
}
for (; v < v_size; v++) {
for (; TFLITE_UNLIKELY(v < v_size); v++) {
result[v] = 1.0f - vector[v];
}
}
@ -2066,7 +2077,7 @@ void NeonSub1Vector(const int16_t* vector, int v_size, int16_t* result) {
const int16x8_t sub1_result = veorq_s16(one_dup, input);
vst1q_s16(result + i, sub1_result);
}
for (; i < v_size; i++) {
for (; TFLITE_UNLIKELY(i < v_size); i++) {
result[i] = kOne ^ vector[i];
}
}
@ -2121,7 +2132,7 @@ bool NeonIsZeroVector(const float* vector, int v_size) {
if (!IsAllZero(v_f32x4)) return false;
}
// Postamble loop
for (; v < v_size; ++v) {
for (; TFLITE_UNLIKELY(v < v_size); ++v) {
if (vector[v] != 0.0) return false;
}
return true;
@ -2140,7 +2151,7 @@ bool NeonIsZeroVector(const int8_t* vector, int v_size) {
if (!IsAllZero(v_s8x16)) return false;
}
// Postamble loop
for (; v < v_size; ++v) {
for (; TFLITE_UNLIKELY(v < v_size); ++v) {
if (vector[v] != 0) return false;
}
return true;
@ -2190,7 +2201,8 @@ void NeonVectorScalarMultiply(const int8_t* vector, const int v_size,
vst1q_f32(result + v + 12, v3_f32x4);
}
if (v_size - postamble_start >= (kInt8ValuesPerNeonVector >> 1)) {
if (TFLITE_UNLIKELY(v_size - postamble_start >=
(kInt8ValuesPerNeonVector >> 1))) {
// Load eight int8 values, if there is at least eight remaining.
const int8x8_t v_i8x8 = vld1_s8(vector + v);
// Convert them to int16 first.
@ -2211,7 +2223,7 @@ void NeonVectorScalarMultiply(const int8_t* vector, const int v_size,
}
// Postamble loop.
for (; v < v_size; v++) {
for (; TFLITE_UNLIKELY(v < v_size); v++) {
result[v] = scale * vector[v];
}
}
@ -2263,7 +2275,7 @@ inline void NeonMinMax(const float* values, const int size, float* min,
rmax = std::max(rmax, vget_lane_f32(max_f32x2, 0));
#endif // __aarch64__
}
if (i < size) {
if (TFLITE_UNLIKELY(i < size)) {
const auto minmax =
std::minmax_element(values + postamble_start, values + size);
rmin = std::min(rmin, *minmax.first);
@ -2335,7 +2347,7 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size,
vst1_s8(&quantized_values[i], min_s8x8);
}
for (; i < size; ++i) {
for (; TFLITE_UNLIKELY(i < size); ++i) {
const int32 quantized_value =
static_cast<int32>(TfLiteRound(scaling_factor_inv * values[i]));
quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
@ -2418,7 +2430,7 @@ void NeonAsymmetricQuantizeFloats(const float* values, const int size,
vst1_s8(&quantized_values[i], min_s8x8);
}
for (; i < size; ++i) {
for (; TFLITE_UNLIKELY(i < size); ++i) {
const int32 quantized_value = static_cast<int32>(
*offset + TfLiteRound(scaling_factor_inv * values[i]));
quantized_values[i] =
@ -2444,7 +2456,7 @@ float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
}
float result = AccumulateNeonLane(acc_32x4);
// Postamble loop.
for (; v < v_size; v++) {
for (; TFLITE_UNLIKELY(v < v_size); v++) {
result += vector1[v] * vector2[v];
}
return result;
@ -2466,7 +2478,7 @@ void NeonReductionSumVector(const float* input_vector, float* output_vector,
}
float sum = AccumulateNeonLane(sum_f32x4);
// Postamble loop.
for (; r < reduction_size; r++) {
for (; TFLITE_UNLIKELY(r < reduction_size); r++) {
sum += input_vector[r];
}
output_vector[o] = sum;
@ -2487,13 +2499,13 @@ void NeonReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
const int8x16_t s2_8x16 = vld1q_s8(input_vector + r);
sum_32x4 = vpadalq_s16(sum_32x4, vpaddlq_s8(s2_8x16));
}
if (r < postamble_start) {
if (TFLITE_UNLIKELY(r < postamble_start)) {
const int8x8_t s2_8x8 = vld1_s8(input_vector + r);
sum_32x4 = vpadalq_s16(sum_32x4, vmovl_s8(s2_8x8));
r += (kInt8ValuesPerNeonVector >> 1);
}
int32_t sum = AccumulateNeonLane(sum_32x4);
for (; r < reduction_size; ++r) {
for (; TFLITE_UNLIKELY(r < reduction_size); ++r) {
sum += input_vector[r];
}
output_vector[o] = sum;
@ -2551,7 +2563,7 @@ void NeonVectorBatchVectorCwiseProductAccumulate(
result += 16;
}
for (; v < v_size; v++) {
for (; TFLITE_UNLIKELY(v < v_size); v++) {
int32_t prod = vector[v] * *batch_vector++;
prod = MultiplyByQuantizedMultiplier(prod, multiplier, shift);
int32_t output = prod + *result;