Add activation function to per channel weights-only quantized conv reference implementation and update per channel weights-only quantized conv to use optimized asymmetric quantize floats method.
PiperOrigin-RevId: 278751385 Change-Id: I3afc689633e6e2f5bf53648f457fe02356e96b2d
This commit is contained in:
parent
399e1c2718
commit
28d1ad34bb
@ -668,20 +668,10 @@ void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
const int offset = b * input_size;
|
||||
auto tensor_data = GetTensorData<float>(input) + offset;
|
||||
auto minmax = std::minmax_element(tensor_data, tensor_data + input_size);
|
||||
double min_value = *minmax.first;
|
||||
double max_value = *minmax.second;
|
||||
min_value = 0.0 < min_value ? 0.0 : min_value;
|
||||
max_value = 0.0 > max_value ? 0.0 : max_value;
|
||||
QuantizationParams quantization_params =
|
||||
ChooseQuantizationParams<int8>(min_value, max_value);
|
||||
input_offset_ptr[b] = quantization_params.zero_point;
|
||||
scaling_factors_ptr[b] = quantization_params.scale;
|
||||
tensor_utils::AsymmetricQuantizeFloats(
|
||||
GetTensorData<float>(input) + offset, input_size,
|
||||
quantized_input_ptr_batch + offset, scaling_factors_ptr[b],
|
||||
input_offset_ptr[b]);
|
||||
quantized_input_ptr_batch + offset, &scaling_factors_ptr[b],
|
||||
&input_offset_ptr[b]);
|
||||
}
|
||||
|
||||
int8_t* im2col_ptr = nullptr;
|
||||
|
@ -1966,6 +1966,41 @@ inline int32x4_t RoundToNearest(const float32x4_t input) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void NeonMinMax(const float* values, const int size, float* min,
|
||||
float* max) {
|
||||
const int postamble_start = RoundDownVectors<kFloatValuesPerNeonVector>(size);
|
||||
double rmin = 0.0, rmax = 0.0;
|
||||
int i = 0;
|
||||
if (postamble_start) {
|
||||
float32x4_t min_f32x4 = vld1q_f32(values);
|
||||
float32x4_t max_f32x4 = min_f32x4;
|
||||
for (i = kFloatValuesPerNeonVector; i < postamble_start;
|
||||
i += kFloatValuesPerNeonVector) {
|
||||
const float32x4_t value0_f32x4 = vld1q_f32(&values[i]);
|
||||
min_f32x4 = vminq_f32(min_f32x4, value0_f32x4);
|
||||
max_f32x4 = vmaxq_f32(max_f32x4, value0_f32x4);
|
||||
}
|
||||
float32x2_t min_f32x2 =
|
||||
vmin_f32(vget_low_f32(min_f32x4), vget_high_f32(min_f32x4));
|
||||
float32x2_t max_f32x2 =
|
||||
vmax_f32(vget_low_f32(max_f32x4), vget_high_f32(max_f32x4));
|
||||
min_f32x2 = vpmin_f32(min_f32x2, min_f32x2);
|
||||
const float fmin = vget_lane_f32(min_f32x2, 0);
|
||||
rmin = rmin < fmin ? rmin : fmin;
|
||||
max_f32x2 = vpmax_f32(max_f32x2, max_f32x2);
|
||||
const float fmax = vget_lane_f32(max_f32x2, 0);
|
||||
rmax = rmax > fmax ? rmax : fmax;
|
||||
*min = rmin;
|
||||
*max = rmax;
|
||||
}
|
||||
if (i < size) {
|
||||
const auto minmax =
|
||||
std::minmax_element(values + postamble_start, values + size);
|
||||
*min = rmin < *minmax.first ? rmin : *minmax.first;
|
||||
*max = rmax > *minmax.second ? rmax : *minmax.second;
|
||||
}
|
||||
}
|
||||
|
||||
void NeonSymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float* min,
|
||||
float* max, float* scaling_factor) {
|
||||
@ -2036,16 +2071,48 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size,
|
||||
|
||||
void NeonAsymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values,
|
||||
float scaling_factor, int32_t offset) {
|
||||
float* scaling_factor, int32_t* offset) {
|
||||
float rmin = 0.0, rmax = 0.0;
|
||||
NeonMinMax(values, size, &rmin, &rmax);
|
||||
|
||||
const int32_t kMinScale = -128;
|
||||
const int32_t kMaxScale = 127;
|
||||
const float scaling_factor_inv =
|
||||
scaling_factor == 0 ? 0 : 1.0 / scaling_factor;
|
||||
const double qmin_double = kMinScale;
|
||||
const double qmax_double = kMaxScale;
|
||||
if (rmin == rmax) {
|
||||
*scaling_factor = 0;
|
||||
*offset = 0;
|
||||
} else {
|
||||
const double scale = (rmax - rmin) / (qmax_double - qmin_double);
|
||||
const double zero_point_from_min = qmin_double - rmin / scale;
|
||||
const double zero_point_from_max = qmax_double - rmax / scale;
|
||||
const double zero_point_from_min_error =
|
||||
std::abs(qmin_double) + std::abs(rmin / scale);
|
||||
const double zero_point_from_max_error =
|
||||
std::abs(qmax_double) + std::abs(rmax / scale);
|
||||
const double zero_point_double =
|
||||
zero_point_from_min_error < zero_point_from_max_error
|
||||
? zero_point_from_min
|
||||
: zero_point_from_max;
|
||||
int8 nudged_zero_point = 0;
|
||||
if (zero_point_double < qmin_double) {
|
||||
nudged_zero_point = kMinScale;
|
||||
} else if (zero_point_double > qmax_double) {
|
||||
nudged_zero_point = kMaxScale;
|
||||
} else {
|
||||
nudged_zero_point = static_cast<int8>(round(zero_point_double));
|
||||
}
|
||||
*scaling_factor = scale;
|
||||
*offset = nudged_zero_point;
|
||||
}
|
||||
|
||||
const int postamble_start = size & ~(2 * kFloatValuesPerNeonVector - 1);
|
||||
const float scaling_factor_inv =
|
||||
*scaling_factor == 0 ? 0 : 1.0 / *scaling_factor;
|
||||
const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv);
|
||||
const int32x4_t scale_i32x4 = vmovq_n_s32(kMaxScale);
|
||||
const int32x4_t neg_scale_i32x4 = vmovq_n_s32(kMinScale);
|
||||
const int32x4_t offset_i32x4 = vmovq_n_s32(offset);
|
||||
const int32x4_t offset_i32x4 = vmovq_n_s32(*offset);
|
||||
|
||||
int i = 0;
|
||||
for (; i < postamble_start; i += 2 * kFloatValuesPerNeonVector) {
|
||||
@ -2077,7 +2144,7 @@ void NeonAsymmetricQuantizeFloats(const float* values, const int size,
|
||||
|
||||
for (; i < size; ++i) {
|
||||
const int32 quantized_value = static_cast<int32>(
|
||||
offset + TfLiteRound(scaling_factor_inv * values[i]));
|
||||
*offset + TfLiteRound(scaling_factor_inv * values[i]));
|
||||
quantized_values[i] =
|
||||
std::min(kMaxScale, std::max(kMinScale, quantized_value));
|
||||
}
|
||||
|
@ -245,8 +245,8 @@ void SymmetricQuantizeFloats(const float* values, const int size,
|
||||
}
|
||||
|
||||
void AsymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float scaling_factor,
|
||||
int32_t offset) {
|
||||
int8_t* quantized_values, float* scaling_factor,
|
||||
int32_t* offset) {
|
||||
NEON_OR_PORTABLE(AsymmetricQuantizeFloats, values, size, quantized_values,
|
||||
scaling_factor, offset);
|
||||
}
|
||||
|
@ -172,7 +172,7 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size,
|
||||
// Asymmetric quantizer.
|
||||
void NeonAsymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values,
|
||||
float scaling_factor, int32_t offset);
|
||||
float* scaling_factor, int32_t* offset);
|
||||
|
||||
// Shift left a vector in place with v_size size.
|
||||
void NeonVectorShiftLeft(float* vector, int v_size, float shift_value);
|
||||
|
@ -251,8 +251,8 @@ void SymmetricQuantizeFloats(const float* values, const int size,
|
||||
}
|
||||
|
||||
void AsymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float scaling_factor,
|
||||
int32_t offset) {
|
||||
int8_t* quantized_values, float* scaling_factor,
|
||||
int32_t* offset) {
|
||||
NEON_OR_PORTABLE(AsymmetricQuantizeFloats, values, size, quantized_values,
|
||||
scaling_factor, offset);
|
||||
}
|
||||
|
@ -197,7 +197,8 @@ inline void HybridConvPerChannel(
|
||||
const int dilation_height_factor = params.dilation_height_factor;
|
||||
const int pad_width = params.padding_values.width;
|
||||
const int pad_height = params.padding_values.height;
|
||||
|
||||
const float output_activation_min = params.float_activation_min;
|
||||
const float output_activation_max = params.float_activation_max;
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||
@ -246,7 +247,8 @@ inline void HybridConvPerChannel(
|
||||
acc_float += bias_data[out_channel];
|
||||
}
|
||||
output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
|
||||
acc_float;
|
||||
ActivationFunctionWithMinMax(acc_float, output_activation_min,
|
||||
output_activation_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -96,14 +96,46 @@ void PortableSymmetricQuantizeFloats(const float* values, const int size,
|
||||
|
||||
void PortableAsymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values,
|
||||
float scaling_factor, int32_t offset) {
|
||||
float* scaling_factor, int32_t* offset) {
|
||||
const int32_t kMinScale = -128;
|
||||
const int32_t kMaxScale = 127;
|
||||
const double qmin_double = kMinScale;
|
||||
const double qmax_double = kMaxScale;
|
||||
float rmin = 0.0, rmax = 0.0;
|
||||
const auto minmax = std::minmax_element(values, values + size);
|
||||
rmin = rmin < *minmax.first ? rmin : *minmax.first;
|
||||
rmax = rmax > *minmax.second ? rmax : *minmax.second;
|
||||
if (rmin == rmax) {
|
||||
*scaling_factor = 0;
|
||||
*offset = 0;
|
||||
} else {
|
||||
const double scale = (rmax - rmin) / (qmax_double - qmin_double);
|
||||
const double zero_point_from_min = qmin_double - rmin / scale;
|
||||
const double zero_point_from_max = qmax_double - rmax / scale;
|
||||
const double zero_point_from_min_error =
|
||||
std::abs(qmin_double) + std::abs(rmin / scale);
|
||||
const double zero_point_from_max_error =
|
||||
std::abs(qmax_double) + std::abs(rmax / scale);
|
||||
const double zero_point_double =
|
||||
zero_point_from_min_error < zero_point_from_max_error
|
||||
? zero_point_from_min
|
||||
: zero_point_from_max;
|
||||
int8 nudged_zero_point = 0;
|
||||
if (zero_point_double < qmin_double) {
|
||||
nudged_zero_point = kMinScale;
|
||||
} else if (zero_point_double > qmax_double) {
|
||||
nudged_zero_point = kMaxScale;
|
||||
} else {
|
||||
nudged_zero_point = static_cast<int8>(round(zero_point_double));
|
||||
}
|
||||
*scaling_factor = scale;
|
||||
*offset = nudged_zero_point;
|
||||
}
|
||||
const float scaling_factor_inv =
|
||||
scaling_factor == 0 ? 0 : 1.0 / scaling_factor;
|
||||
*scaling_factor == 0 ? 0 : 1.0 / *scaling_factor;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
const int32_t quantized_value = static_cast<int32_t>(
|
||||
TfLiteRound(offset + values[i] * scaling_factor_inv));
|
||||
TfLiteRound(*offset + values[i] * scaling_factor_inv));
|
||||
quantized_values[i] =
|
||||
std::min(kMaxScale, std::max(kMinScale, quantized_value));
|
||||
}
|
||||
|
@ -54,8 +54,8 @@ void SymmetricQuantizeFloats(const float* values, const int size,
|
||||
}
|
||||
|
||||
void AsymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float scaling_factor,
|
||||
int32_t offset) {
|
||||
int8_t* quantized_values, float* scaling_factor,
|
||||
int32_t* offset) {
|
||||
return PortableAsymmetricQuantizeFloats(values, size, quantized_values,
|
||||
scaling_factor, offset);
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ void PortableSymmetricQuantizeFloats(const float* values, const int size,
|
||||
|
||||
void PortableAsymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values,
|
||||
float scaling_factor, int32_t offset);
|
||||
float* scaling_factor, int32_t* offset);
|
||||
|
||||
// Multiply a matrix by a batch vector, and store results in a batch-size
|
||||
// vector.
|
||||
|
@ -52,8 +52,8 @@ void SymmetricQuantizeFloats(const float* values, const int size,
|
||||
float max_value, float* scaling_factor);
|
||||
|
||||
void AsymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float scaling_factor,
|
||||
int32_t offset);
|
||||
int8_t* quantized_values, float* scaling_factor,
|
||||
int32_t* offset);
|
||||
|
||||
// Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch
|
||||
// dimension composed by input vectors independent from each other). The result
|
||||
|
@ -139,9 +139,15 @@ TEST(uKernels, AsymmetricQuantizeFloatsTest) {
|
||||
double max = 1000.0;
|
||||
QuantizationParams quantization_params =
|
||||
ChooseQuantizationParams<int8_t>(min, max);
|
||||
float scale = quantization_params.scale;
|
||||
int32_t offset = quantization_params.zero_point;
|
||||
AsymmetricQuantizeFloats(input, kVectorSize, output,
|
||||
quantization_params.scale, offset);
|
||||
float test_scale;
|
||||
int32_t test_offset;
|
||||
AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale,
|
||||
&test_offset);
|
||||
// EQ won't work due to fpoint.
|
||||
EXPECT_NEAR(test_scale, scale, 1e-6);
|
||||
EXPECT_EQ(test_offset, offset);
|
||||
EXPECT_THAT(output, testing::ElementsAreArray(
|
||||
{-128, -127, -126, -26, -28, -29, -30, -28, 127}));
|
||||
}
|
||||
@ -150,7 +156,12 @@ TEST(uKernels, AsymmetricQuantizeFloatsAllZerosTest) {
|
||||
constexpr int kVectorSize = 9;
|
||||
static float input[kVectorSize] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
|
||||
int8_t output[kVectorSize];
|
||||
AsymmetricQuantizeFloats(input, kVectorSize, output, 0, 0);
|
||||
float test_scale;
|
||||
int32_t test_offset;
|
||||
AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale,
|
||||
&test_offset);
|
||||
EXPECT_EQ(test_scale, 0);
|
||||
EXPECT_EQ(test_offset, 0);
|
||||
EXPECT_THAT(output, testing::ElementsAreArray({0, 0, 0, 0, 0, 0, 0, 0, 0}));
|
||||
}
|
||||
|
||||
@ -164,8 +175,13 @@ TEST(uKernels, AsymmetricQuantizeFloatsZeroRangeTest) {
|
||||
QuantizationParams quantization_params =
|
||||
ChooseQuantizationParams<int8_t>(min, max);
|
||||
int32_t offset = quantization_params.zero_point;
|
||||
AsymmetricQuantizeFloats(input, kVectorSize, output,
|
||||
quantization_params.scale, offset);
|
||||
float scale = quantization_params.scale;
|
||||
float test_scale;
|
||||
int32_t test_offset;
|
||||
AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale,
|
||||
&test_offset);
|
||||
EXPECT_NEAR(test_scale, scale, 1e-6);
|
||||
EXPECT_EQ(test_offset, offset);
|
||||
EXPECT_THAT(output, testing::ElementsAreArray(
|
||||
{127, 127, 127, 127, 127, 127, 127, 127, 127}));
|
||||
}
|
||||
@ -180,8 +196,13 @@ TEST(uKernels, AsymmetricQuantizeFloatsAllAlmostZeroTest) {
|
||||
QuantizationParams quantization_params =
|
||||
ChooseQuantizationParams<int8_t>(min, max);
|
||||
int32_t offset = quantization_params.zero_point;
|
||||
AsymmetricQuantizeFloats(input, kVectorSize, output,
|
||||
quantization_params.scale, offset);
|
||||
float scale = quantization_params.scale;
|
||||
float test_scale;
|
||||
int32_t test_offset;
|
||||
AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale,
|
||||
&test_offset);
|
||||
EXPECT_NEAR(test_scale, scale, 1e-6);
|
||||
EXPECT_EQ(test_offset, offset);
|
||||
EXPECT_THAT(output, testing::ElementsAreArray(
|
||||
{-58, -23, -55, -128, -48, -14, -41, 127, -49}));
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user