From 28d1ad34bb59e3e1631b5807eebc46563ef3382c Mon Sep 17 00:00:00 2001 From: David Rim Date: Tue, 5 Nov 2019 18:24:15 -0800 Subject: [PATCH] Add activation function to per channel weights-only quantized conv reference implementation and update per channel weights-only quantized conv to use optimized asymmetric quantize floats method. PiperOrigin-RevId: 278751385 Change-Id: I3afc689633e6e2f5bf53648f457fe02356e96b2d --- tensorflow/lite/kernels/conv.cc | 14 +--- .../internal/optimized/neon_tensor_utils.cc | 77 +++++++++++++++++-- .../internal/optimized/neon_tensor_utils.h | 4 +- .../optimized/neon_tensor_utils_impl.h | 2 +- .../internal/optimized/sse_tensor_utils.h | 4 +- .../lite/kernels/internal/reference/conv.h | 6 +- .../reference/portable_tensor_utils.cc | 38 ++++++++- .../reference/portable_tensor_utils.h | 4 +- .../reference/portable_tensor_utils_impl.h | 2 +- .../lite/kernels/internal/tensor_utils.h | 4 +- .../kernels/internal/tensor_utils_test.cc | 35 +++++++-- 11 files changed, 151 insertions(+), 39 deletions(-) diff --git a/tensorflow/lite/kernels/conv.cc b/tensorflow/lite/kernels/conv.cc index 7fa89aaeaaf..49090075626 100644 --- a/tensorflow/lite/kernels/conv.cc +++ b/tensorflow/lite/kernels/conv.cc @@ -668,20 +668,10 @@ void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node, for (int b = 0; b < batch_size; ++b) { const int offset = b * input_size; - auto tensor_data = GetTensorData(input) + offset; - auto minmax = std::minmax_element(tensor_data, tensor_data + input_size); - double min_value = *minmax.first; - double max_value = *minmax.second; - min_value = 0.0 < min_value ? 0.0 : min_value; - max_value = 0.0 > max_value ? 0.0 : max_value; - QuantizationParams quantization_params = - ChooseQuantizationParams(min_value, max_value); - input_offset_ptr[b] = quantization_params.zero_point; - scaling_factors_ptr[b] = quantization_params.scale; tensor_utils::AsymmetricQuantizeFloats( GetTensorData(input) + offset, input_size, - quantized_input_ptr_batch + offset, scaling_factors_ptr[b], - input_offset_ptr[b]); + quantized_input_ptr_batch + offset, &scaling_factors_ptr[b], + &input_offset_ptr[b]); } int8_t* im2col_ptr = nullptr; diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index a2c85aa395d..46349219ab4 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -1966,6 +1966,41 @@ inline int32x4_t RoundToNearest(const float32x4_t input) { #endif } +inline void NeonMinMax(const float* values, const int size, float* min, + float* max) { + const int postamble_start = RoundDownVectors(size); + double rmin = 0.0, rmax = 0.0; + int i = 0; + if (postamble_start) { + float32x4_t min_f32x4 = vld1q_f32(values); + float32x4_t max_f32x4 = min_f32x4; + for (i = kFloatValuesPerNeonVector; i < postamble_start; + i += kFloatValuesPerNeonVector) { + const float32x4_t value0_f32x4 = vld1q_f32(&values[i]); + min_f32x4 = vminq_f32(min_f32x4, value0_f32x4); + max_f32x4 = vmaxq_f32(max_f32x4, value0_f32x4); + } + float32x2_t min_f32x2 = + vmin_f32(vget_low_f32(min_f32x4), vget_high_f32(min_f32x4)); + float32x2_t max_f32x2 = + vmax_f32(vget_low_f32(max_f32x4), vget_high_f32(max_f32x4)); + min_f32x2 = vpmin_f32(min_f32x2, min_f32x2); + const float fmin = vget_lane_f32(min_f32x2, 0); + rmin = rmin < fmin ? rmin : fmin; + max_f32x2 = vpmax_f32(max_f32x2, max_f32x2); + const float fmax = vget_lane_f32(max_f32x2, 0); + rmax = rmax > fmax ? rmax : fmax; + *min = rmin; + *max = rmax; + } + if (i < size) { + const auto minmax = + std::minmax_element(values + postamble_start, values + size); + *min = rmin < *minmax.first ? rmin : *minmax.first; + *max = rmax > *minmax.second ? rmax : *minmax.second; + } +} + void NeonSymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, float* min, float* max, float* scaling_factor) { @@ -2036,16 +2071,48 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size, void NeonAsymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, - float scaling_factor, int32_t offset) { + float* scaling_factor, int32_t* offset) { + float rmin = 0.0, rmax = 0.0; + NeonMinMax(values, size, &rmin, &rmax); + const int32_t kMinScale = -128; const int32_t kMaxScale = 127; - const float scaling_factor_inv = - scaling_factor == 0 ? 0 : 1.0 / scaling_factor; + const double qmin_double = kMinScale; + const double qmax_double = kMaxScale; + if (rmin == rmax) { + *scaling_factor = 0; + *offset = 0; + } else { + const double scale = (rmax - rmin) / (qmax_double - qmin_double); + const double zero_point_from_min = qmin_double - rmin / scale; + const double zero_point_from_max = qmax_double - rmax / scale; + const double zero_point_from_min_error = + std::abs(qmin_double) + std::abs(rmin / scale); + const double zero_point_from_max_error = + std::abs(qmax_double) + std::abs(rmax / scale); + const double zero_point_double = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + int8 nudged_zero_point = 0; + if (zero_point_double < qmin_double) { + nudged_zero_point = kMinScale; + } else if (zero_point_double > qmax_double) { + nudged_zero_point = kMaxScale; + } else { + nudged_zero_point = static_cast(round(zero_point_double)); + } + *scaling_factor = scale; + *offset = nudged_zero_point; + } + const int postamble_start = size & ~(2 * kFloatValuesPerNeonVector - 1); + const float scaling_factor_inv = + *scaling_factor == 0 ? 0 : 1.0 / *scaling_factor; const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv); const int32x4_t scale_i32x4 = vmovq_n_s32(kMaxScale); const int32x4_t neg_scale_i32x4 = vmovq_n_s32(kMinScale); - const int32x4_t offset_i32x4 = vmovq_n_s32(offset); + const int32x4_t offset_i32x4 = vmovq_n_s32(*offset); int i = 0; for (; i < postamble_start; i += 2 * kFloatValuesPerNeonVector) { @@ -2077,7 +2144,7 @@ void NeonAsymmetricQuantizeFloats(const float* values, const int size, for (; i < size; ++i) { const int32 quantized_value = static_cast( - offset + TfLiteRound(scaling_factor_inv * values[i])); + *offset + TfLiteRound(scaling_factor_inv * values[i])); quantized_values[i] = std::min(kMaxScale, std::max(kMinScale, quantized_value)); } diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h index d950e1ceeaa..f250f1aa6d0 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -245,8 +245,8 @@ void SymmetricQuantizeFloats(const float* values, const int size, } void AsymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float scaling_factor, - int32_t offset) { + int8_t* quantized_values, float* scaling_factor, + int32_t* offset) { NEON_OR_PORTABLE(AsymmetricQuantizeFloats, values, size, quantized_values, scaling_factor, offset); } diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h index 41cfe6b8dcf..e634673c838 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h @@ -172,7 +172,7 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size, // Asymmetric quantizer. void NeonAsymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, - float scaling_factor, int32_t offset); + float* scaling_factor, int32_t* offset); // Shift left a vector in place with v_size size. void NeonVectorShiftLeft(float* vector, int v_size, float shift_value); diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h index 10ed8c5ea85..bc2676bc7d4 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h @@ -251,8 +251,8 @@ void SymmetricQuantizeFloats(const float* values, const int size, } void AsymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float scaling_factor, - int32_t offset) { + int8_t* quantized_values, float* scaling_factor, + int32_t* offset) { NEON_OR_PORTABLE(AsymmetricQuantizeFloats, values, size, quantized_values, scaling_factor, offset); } diff --git a/tensorflow/lite/kernels/internal/reference/conv.h b/tensorflow/lite/kernels/internal/reference/conv.h index 3ce5a5f9b98..55dd869a4b1 100644 --- a/tensorflow/lite/kernels/internal/reference/conv.h +++ b/tensorflow/lite/kernels/internal/reference/conv.h @@ -197,7 +197,8 @@ inline void HybridConvPerChannel( const int dilation_height_factor = params.dilation_height_factor; const int pad_width = params.padding_values.width; const int pad_height = params.padding_values.height; - + const float output_activation_min = params.float_activation_min; + const float output_activation_max = params.float_activation_max; TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); @@ -246,7 +247,8 @@ inline void HybridConvPerChannel( acc_float += bias_data[out_channel]; } output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] = - acc_float; + ActivationFunctionWithMinMax(acc_float, output_activation_min, + output_activation_max); } } } diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc index a44d00908da..e1cbac94551 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -96,14 +96,46 @@ void PortableSymmetricQuantizeFloats(const float* values, const int size, void PortableAsymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, - float scaling_factor, int32_t offset) { + float* scaling_factor, int32_t* offset) { const int32_t kMinScale = -128; const int32_t kMaxScale = 127; + const double qmin_double = kMinScale; + const double qmax_double = kMaxScale; + float rmin = 0.0, rmax = 0.0; + const auto minmax = std::minmax_element(values, values + size); + rmin = rmin < *minmax.first ? rmin : *minmax.first; + rmax = rmax > *minmax.second ? rmax : *minmax.second; + if (rmin == rmax) { + *scaling_factor = 0; + *offset = 0; + } else { + const double scale = (rmax - rmin) / (qmax_double - qmin_double); + const double zero_point_from_min = qmin_double - rmin / scale; + const double zero_point_from_max = qmax_double - rmax / scale; + const double zero_point_from_min_error = + std::abs(qmin_double) + std::abs(rmin / scale); + const double zero_point_from_max_error = + std::abs(qmax_double) + std::abs(rmax / scale); + const double zero_point_double = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + int8 nudged_zero_point = 0; + if (zero_point_double < qmin_double) { + nudged_zero_point = kMinScale; + } else if (zero_point_double > qmax_double) { + nudged_zero_point = kMaxScale; + } else { + nudged_zero_point = static_cast(round(zero_point_double)); + } + *scaling_factor = scale; + *offset = nudged_zero_point; + } const float scaling_factor_inv = - scaling_factor == 0 ? 0 : 1.0 / scaling_factor; + *scaling_factor == 0 ? 0 : 1.0 / *scaling_factor; for (int i = 0; i < size; ++i) { const int32_t quantized_value = static_cast( - TfLiteRound(offset + values[i] * scaling_factor_inv)); + TfLiteRound(*offset + values[i] * scaling_factor_inv)); quantized_values[i] = std::min(kMaxScale, std::max(kMinScale, quantized_value)); } diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h index 76b4f8bd8bd..0d29b0ec1b1 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h @@ -54,8 +54,8 @@ void SymmetricQuantizeFloats(const float* values, const int size, } void AsymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float scaling_factor, - int32_t offset) { + int8_t* quantized_values, float* scaling_factor, + int32_t* offset) { return PortableAsymmetricQuantizeFloats(values, size, quantized_values, scaling_factor, offset); } diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h index 86d79eac335..a0f7580fd5b 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h @@ -45,7 +45,7 @@ void PortableSymmetricQuantizeFloats(const float* values, const int size, void PortableAsymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, - float scaling_factor, int32_t offset); + float* scaling_factor, int32_t* offset); // Multiply a matrix by a batch vector, and store results in a batch-size // vector. diff --git a/tensorflow/lite/kernels/internal/tensor_utils.h b/tensorflow/lite/kernels/internal/tensor_utils.h index 6a04a56b12b..487035c2f30 100644 --- a/tensorflow/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/lite/kernels/internal/tensor_utils.h @@ -52,8 +52,8 @@ void SymmetricQuantizeFloats(const float* values, const int size, float max_value, float* scaling_factor); void AsymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float scaling_factor, - int32_t offset); + int8_t* quantized_values, float* scaling_factor, + int32_t* offset); // Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch // dimension composed by input vectors independent from each other). The result diff --git a/tensorflow/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/lite/kernels/internal/tensor_utils_test.cc index 10cd61be4f5..c5519fa2016 100644 --- a/tensorflow/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/lite/kernels/internal/tensor_utils_test.cc @@ -139,9 +139,15 @@ TEST(uKernels, AsymmetricQuantizeFloatsTest) { double max = 1000.0; QuantizationParams quantization_params = ChooseQuantizationParams(min, max); + float scale = quantization_params.scale; int32_t offset = quantization_params.zero_point; - AsymmetricQuantizeFloats(input, kVectorSize, output, - quantization_params.scale, offset); + float test_scale; + int32_t test_offset; + AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale, + &test_offset); + // EQ won't work due to fpoint. + EXPECT_NEAR(test_scale, scale, 1e-6); + EXPECT_EQ(test_offset, offset); EXPECT_THAT(output, testing::ElementsAreArray( {-128, -127, -126, -26, -28, -29, -30, -28, 127})); } @@ -150,7 +156,12 @@ TEST(uKernels, AsymmetricQuantizeFloatsAllZerosTest) { constexpr int kVectorSize = 9; static float input[kVectorSize] = {0, 0, 0, 0, 0, 0, 0, 0, 0}; int8_t output[kVectorSize]; - AsymmetricQuantizeFloats(input, kVectorSize, output, 0, 0); + float test_scale; + int32_t test_offset; + AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale, + &test_offset); + EXPECT_EQ(test_scale, 0); + EXPECT_EQ(test_offset, 0); EXPECT_THAT(output, testing::ElementsAreArray({0, 0, 0, 0, 0, 0, 0, 0, 0})); } @@ -164,8 +175,13 @@ TEST(uKernels, AsymmetricQuantizeFloatsZeroRangeTest) { QuantizationParams quantization_params = ChooseQuantizationParams(min, max); int32_t offset = quantization_params.zero_point; - AsymmetricQuantizeFloats(input, kVectorSize, output, - quantization_params.scale, offset); + float scale = quantization_params.scale; + float test_scale; + int32_t test_offset; + AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale, + &test_offset); + EXPECT_NEAR(test_scale, scale, 1e-6); + EXPECT_EQ(test_offset, offset); EXPECT_THAT(output, testing::ElementsAreArray( {127, 127, 127, 127, 127, 127, 127, 127, 127})); } @@ -180,8 +196,13 @@ TEST(uKernels, AsymmetricQuantizeFloatsAllAlmostZeroTest) { QuantizationParams quantization_params = ChooseQuantizationParams(min, max); int32_t offset = quantization_params.zero_point; - AsymmetricQuantizeFloats(input, kVectorSize, output, - quantization_params.scale, offset); + float scale = quantization_params.scale; + float test_scale; + int32_t test_offset; + AsymmetricQuantizeFloats(input, kVectorSize, output, &test_scale, + &test_offset); + EXPECT_NEAR(test_scale, scale, 1e-6); + EXPECT_EQ(test_offset, offset); EXPECT_THAT(output, testing::ElementsAreArray( {-58, -23, -55, -128, -48, -14, -41, 127, -49})); }