Rewrite quantized hardswish for good accuracy in all cases.
This is motivated by MobileNet v3 experiments which exercised corner cases not anticipated by the current implementation (large quantization scales) and which showed that some fine arithmetic details have a significant impact on classification accuracy. Most importantly, bias must be minimized as any increase in bias translate into a degradation of classification accuracy, and that had 2 specific consequences: 1. As HardSwish inherently requires forming the expression (x + 3)/6, it must as a prerequisite step rescale x on a scale where 3 is exactly representable. 2. There are 3 fixed-point multiplications. If we used for all of them the usual rounding fixed-point multiplication primitive/instruction (e.g. NEON SQRDMULH) that that results in significant bias away from zero. This was fixed by suitably combining this usual rounding multiplication with a truncating multiplication (e.g. NEON SQDMULH), one feeding into the other, so the biases (away from zero, and toward zero) cancel each other. A specific test case was added to guard regressions on this front in the unit test, based on the fact that bias (visible at unit-test level) is empirically seen to be a sufficient proxy for classification accuracy (not visible at unit-test level and too expensive to measure even for integration tests). PiperOrigin-RevId: 257904120
This commit is contained in:
parent
42b8511f63
commit
7d4d60c36e
@ -426,6 +426,7 @@ cc_library(
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels/internal:audio_utils",
|
||||
"//tensorflow/lite/kernels/internal:common",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
"//tensorflow/lite/kernels/internal:cpu_check",
|
||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/lite/kernels/internal:optimized",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/softmax.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
@ -178,6 +179,22 @@ void HardSwishFree(TfLiteContext* context, void* buffer) {
|
||||
delete static_cast<HardSwishData*>(buffer);
|
||||
}
|
||||
|
||||
void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32,
|
||||
int16_t* multiplier_int16) {
|
||||
TFLITE_DCHECK_GE(multiplier_int32, 0);
|
||||
static constexpr int32_t kRoundingOffset = 1 << 15;
|
||||
if (multiplier_int32 >=
|
||||
std::numeric_limits<int32_t>::max() - kRoundingOffset) {
|
||||
*multiplier_int16 = std::numeric_limits<int16_t>::max();
|
||||
return;
|
||||
}
|
||||
const int32_t result = (multiplier_int32 + kRoundingOffset) >> 16;
|
||||
TFLITE_DCHECK_LE(result << 16, multiplier_int32 + kRoundingOffset);
|
||||
TFLITE_DCHECK_GT(result << 16, multiplier_int32 - kRoundingOffset);
|
||||
*multiplier_int16 = result;
|
||||
TFLITE_DCHECK_EQ(*multiplier_int16, result);
|
||||
}
|
||||
|
||||
TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_STATUS(GenericPrepare(context, node));
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
@ -186,40 +203,30 @@ TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
HardSwishData* data = static_cast<HardSwishData*>(node->user_data);
|
||||
HardSwishParams* params = &data->params;
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
// TODO(131260336): Maybe pick a better way to select the denominator shift.
|
||||
// Include input shift into the shift.
|
||||
static constexpr int32_t extra_input_shift = 3;
|
||||
// Note: optimized implementations will rely on the ability to perform this
|
||||
// left shift within int16 without overflow. The values being left-shifted
|
||||
// range in [-255, 255] i.e. just under 2^8 in absolute value, and after the
|
||||
// left shift they will still be added the 'three_input' value, which is
|
||||
// safe if they're not greater than 2^14 in absolute value (since 2^15 is
|
||||
// the magnitude of the boundaries of int16 range). 14-8 == 6, so we
|
||||
// require extra_input_shift to be no greater than 6.
|
||||
static_assert(extra_input_shift <= 6, "");
|
||||
const auto in_scale = input->params.scale;
|
||||
params->input_zero_point = input->params.zero_point;
|
||||
const auto out_scale = output->params.scale;
|
||||
const int32_t out_zero_point = output->params.zero_point;
|
||||
// Get 3 and 6 represented in input scale. We avoid intermediate conversion
|
||||
// to the "true" scale, so all operations are done in input scale losslessly
|
||||
// And then converted to the output scale.
|
||||
// However 3 and 6 might not have exact representation in input scale.
|
||||
// We use extra multiplier to avoid precision loss when converting
|
||||
// 3 and 6 from input to output.
|
||||
params->three_input = std::lround((3 << extra_input_shift) / in_scale);
|
||||
params->six_input = std::lround((6 << extra_input_shift) / in_scale);
|
||||
// Compensate for the fact that we multiply two numbers in in_scale
|
||||
// and produce result in output format.
|
||||
// NB: we fold 6 multiplier into the scaling factor here:
|
||||
float from_in_to_out_sq = (in_scale * in_scale / out_scale / 6);
|
||||
params->output_zero_point = output->params.zero_point;
|
||||
const float input_scale = input->params.scale;
|
||||
const float hires_input_scale = (1.0f / 128.0f) * input_scale;
|
||||
const float reluish_scale = 3.0f / 32768.0f;
|
||||
const float output_scale = output->params.scale;
|
||||
|
||||
from_in_to_out_sq /= (1 << extra_input_shift);
|
||||
QuantizeMultiplierSmallerThanOneExp(from_in_to_out_sq, &(params->scale),
|
||||
&(params->shift));
|
||||
const float output_multiplier = hires_input_scale / output_scale;
|
||||
|
||||
params->output_offset = out_zero_point;
|
||||
params->clip_input_shift = extra_input_shift;
|
||||
int32_t output_multiplier_fixedpoint_int32;
|
||||
QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint_int32,
|
||||
¶ms->output_multiplier_exponent);
|
||||
DownScaleInt32ToInt16Multiplier(
|
||||
output_multiplier_fixedpoint_int32,
|
||||
¶ms->output_multiplier_fixedpoint_int16);
|
||||
TF_LITE_ENSURE(context, params->output_multiplier_exponent <= 0);
|
||||
|
||||
const float reluish_multiplier = hires_input_scale / reluish_scale;
|
||||
int32_t reluish_multiplier_fixedpoint_int32;
|
||||
QuantizeMultiplier(reluish_multiplier, &reluish_multiplier_fixedpoint_int32,
|
||||
¶ms->reluish_multiplier_exponent);
|
||||
DownScaleInt32ToInt16Multiplier(
|
||||
reluish_multiplier_fixedpoint_int32,
|
||||
¶ms->reluish_multiplier_fixedpoint_int16);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstdarg>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
@ -124,7 +125,6 @@ class QuantizedActivationsOpModel : public BaseActivationsOpModel {
|
||||
QuantizeAndPopulate<T>(input_, data);
|
||||
}
|
||||
template <typename T>
|
||||
|
||||
std::vector<T> GetOutput() {
|
||||
return ExtractVector<T>(output_);
|
||||
}
|
||||
@ -243,33 +243,88 @@ void TestFloatHardSwish(int size, std::minstd_rand* random_engine) {
|
||||
}
|
||||
|
||||
template <typename QuantizedType>
|
||||
void TestQuantizedHardSwish(TensorType tensor_type, int size,
|
||||
void TestQuantizedHardSwish(TensorType tensor_type, int size, float input_min,
|
||||
float input_max, float output_min, float output_max,
|
||||
std::minstd_rand* random_engine) {
|
||||
std::vector<float> float_input_values;
|
||||
const float kMin = -10.0f;
|
||||
const float kMax = 10.0f;
|
||||
GenerateUniformRandomVector(size, kMin, kMax, random_engine,
|
||||
GenerateUniformRandomVector(size, input_min, input_max, random_engine,
|
||||
&float_input_values);
|
||||
const float kOutMin = -3;
|
||||
const float kOutMax = kMax;
|
||||
std::vector<float> float_ref_output_values;
|
||||
EvalTestReferenceHardSwish(size, float_input_values,
|
||||
&float_ref_output_values);
|
||||
for (float& val : float_ref_output_values) {
|
||||
val = std::min(output_max, std::max(output_min, val));
|
||||
}
|
||||
QuantizedActivationsOpModel m(
|
||||
BuiltinOperator_HARD_SWISH,
|
||||
/*input=*/{tensor_type, {1, 1, 1, size}, kMin, kMax},
|
||||
/*output=*/{tensor_type, {1, 1, 1, size}, kOutMin, kOutMax});
|
||||
/*input=*/{tensor_type, {1, 1, 1, size}, input_min, input_max},
|
||||
/*output=*/{tensor_type, {1, 1, 1, size}, output_min, output_max});
|
||||
m.SetInput<QuantizedType>(float_input_values);
|
||||
|
||||
m.Invoke();
|
||||
const std::vector<float>& dequantized_output =
|
||||
m.GetDequantizedOutput<QuantizedType>();
|
||||
// The numerical error for any 8bit quantized function is at least one half
|
||||
// times the quantization step: 0.5 * (kOutMax - kOutMin) / 256.
|
||||
// To that we add again the quantization step (kOutMax - kOutMin) / 256
|
||||
// to allow for an off-by-one rounding error.
|
||||
const float kTolerance = (kOutMax - kOutMin) * (1.5f / 256.f);
|
||||
EXPECT_THAT(
|
||||
m.GetDequantizedOutput<QuantizedType>(),
|
||||
ElementsAreArray(ArrayFloatNear(float_ref_output_values, kTolerance)));
|
||||
const float kTolerance =
|
||||
std::max(input_max - input_min, output_max - output_min) * (1.5f / 256.f);
|
||||
EXPECT_THAT(dequantized_output, ElementsAreArray(ArrayFloatNear(
|
||||
float_ref_output_values, kTolerance)));
|
||||
}
|
||||
|
||||
template <typename QuantizedType>
|
||||
void TestQuantizedHardSwishBias(TensorType tensor_type, float input_min,
|
||||
float input_max, float output_min,
|
||||
float output_max, float tolerated_bias) {
|
||||
const float quantized_type_range =
|
||||
static_cast<float>(std::numeric_limits<QuantizedType>::max()) -
|
||||
static_cast<float>(std::numeric_limits<QuantizedType>::min());
|
||||
const float input_scale = (input_max - input_min) / quantized_type_range;
|
||||
const float output_scale = (output_max - output_min) / quantized_type_range;
|
||||
const float max_scale = std::max(output_scale, input_scale);
|
||||
|
||||
// In this bias-focused test case, no need for randomly generated input
|
||||
// values.
|
||||
ASSERT_LE(input_min, -3.0f);
|
||||
ASSERT_GE(input_max, 3.0f);
|
||||
const int quantized_input_negative_three =
|
||||
std::round(std::numeric_limits<QuantizedType>::min() +
|
||||
(-3.0f - input_min) / input_scale);
|
||||
const int quantized_input_positive_three =
|
||||
std::round(std::numeric_limits<QuantizedType>::min() +
|
||||
(3.0f - input_min) / input_scale);
|
||||
std::vector<float> float_input_values;
|
||||
for (int i = quantized_input_negative_three;
|
||||
i <= quantized_input_positive_three; i++) {
|
||||
float_input_values.push_back(
|
||||
input_min +
|
||||
(i - std::numeric_limits<QuantizedType>::min()) * input_scale);
|
||||
}
|
||||
const int size = float_input_values.size();
|
||||
std::vector<float> float_ref_output_values;
|
||||
EvalTestReferenceHardSwish(size, float_input_values,
|
||||
&float_ref_output_values);
|
||||
for (float& val : float_ref_output_values) {
|
||||
val = std::min(output_max, std::max(output_min, val));
|
||||
}
|
||||
QuantizedActivationsOpModel m(
|
||||
BuiltinOperator_HARD_SWISH,
|
||||
/*input=*/{tensor_type, {1, 1, 1, size}, input_min, input_max},
|
||||
/*output=*/{tensor_type, {1, 1, 1, size}, output_min, output_max});
|
||||
m.SetInput<QuantizedType>(float_input_values);
|
||||
|
||||
m.Invoke();
|
||||
const std::vector<float>& dequantized_output =
|
||||
m.GetDequantizedOutput<QuantizedType>();
|
||||
|
||||
float sum_diff = 0;
|
||||
for (int i = 0; i < size; i++) {
|
||||
sum_diff += dequantized_output[i] - float_ref_output_values[i];
|
||||
}
|
||||
const float bias = sum_diff / (size * max_scale);
|
||||
EXPECT_LE(std::abs(bias), tolerated_bias);
|
||||
}
|
||||
|
||||
TEST(FloatActivationsOpTest, HardSwish) {
|
||||
@ -281,12 +336,36 @@ TEST(FloatActivationsOpTest, HardSwish) {
|
||||
|
||||
TEST(QuantizedActivationsOpTest, HardSwish) {
|
||||
std::minstd_rand random_engine;
|
||||
for (int size : {1, 2, 3, 4, 10, 20, 30, 40, 100}) {
|
||||
TestQuantizedHardSwish<uint8_t>(TensorType_UINT8, size, &random_engine);
|
||||
TestQuantizedHardSwish<int8_t>(TensorType_INT8, size, &random_engine);
|
||||
std::vector<std::pair<float, float>> minmax_pairs{
|
||||
{0.f, 1.f}, {-2.f, 1.f}, {-5.f, 10.f}, {-40.f, 60.f}};
|
||||
for (const auto& input_minmax : minmax_pairs) {
|
||||
for (const auto& output_minmax : minmax_pairs) {
|
||||
float input_min = input_minmax.first;
|
||||
float input_max = input_minmax.second;
|
||||
float output_min = output_minmax.first;
|
||||
float output_max = output_minmax.second;
|
||||
for (int size : {1, 3, 10, 100}) {
|
||||
TestQuantizedHardSwish<uint8_t>(TensorType_UINT8, size, input_min,
|
||||
input_max, output_min, output_max,
|
||||
&random_engine);
|
||||
TestQuantizedHardSwish<int8_t>(TensorType_INT8, size, input_min,
|
||||
input_max, output_min, output_max,
|
||||
&random_engine);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// See the comment in the reference implementation of quantized HardSwish:
|
||||
// A numerical issue significantly affecting ImageNet classification accuracy
|
||||
// with MobileNet v3 is only observable at the scale of HardSwish unit tests
|
||||
// if we monitor specifically bias. This testcase is extracted from one of the
|
||||
// HardSwish nodes in that MobileNet v3 that exhibited this issue.
|
||||
TEST(QuantizedActivationsOpTest, HardSwishBias) {
|
||||
TestQuantizedHardSwishBias<uint8_t>(TensorType_UINT8, -11.654928f, 25.036512f,
|
||||
-0.3905796f, 24.50887f, 0.035);
|
||||
}
|
||||
|
||||
TEST(FloatActivationsOpTest, Tanh) {
|
||||
FloatActivationsOpModel m(BuiltinOperator_TANH,
|
||||
/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
|
||||
|
@ -5527,125 +5527,168 @@ inline void SaturateAndStore(int16x8_t src, std::int8_t* dst) {
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename QuantizedType>
|
||||
template <typename T>
|
||||
inline void HardSwish(const HardSwishParams& params,
|
||||
const RuntimeShape& input_shape,
|
||||
const QuantizedType* input_data,
|
||||
const RuntimeShape& output_shape,
|
||||
QuantizedType* output_data) {
|
||||
const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("HardSwish/Quantized");
|
||||
// Goal: (x * relu6(x+3))/6
|
||||
const int size = MatchingFlatSize(input_shape, output_shape);
|
||||
const int32_t extra_input_shift = params.clip_input_shift;
|
||||
const auto in_zero_point = params.input_zero_point;
|
||||
const auto three_in = params.three_input;
|
||||
const auto six_in = params.six_input;
|
||||
const auto real_shift = params.shift;
|
||||
const auto scale = params.scale;
|
||||
const auto offset = params.output_offset;
|
||||
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
|
||||
int i = 0;
|
||||
#ifdef USE_NEON
|
||||
const int16x8_t extra_input_shift_vec = vdupq_n_s16(extra_input_shift);
|
||||
const int16x8_t three_in_vec = vdupq_n_s16(three_in);
|
||||
const int16x8_t six_in_vec = vdupq_n_s16(six_in);
|
||||
// The quantization params of this op are designed around a reference
|
||||
// implementation that performs plain integer multiplication, not
|
||||
// fixed-point multiplication. The 16-bit fixed-point multiplications
|
||||
// that we use here, vqrdmulhq_s16, differ from that by an (rounding)
|
||||
// right shift by 15 bits. So in terms of scale and leaving aside
|
||||
// accuracy considerations, we could simply compensate for that by
|
||||
// adding 15 to real_shift. Doing so results in approximately correct results,
|
||||
// but there is high inaccuracy in the low bits. That is because unlike
|
||||
// the integer multiplications done in the reference code, our fixed-point
|
||||
// multiplication are destructive of low bits. In order to have accurate
|
||||
// enough results, we move some of that bit-shifting from being applied to
|
||||
// the result to being applied to one of the operands of these fixed-point
|
||||
// multiplications, before the information in the low bits is destroyed.
|
||||
// Fortunately, one of the operands is by construction smaller than 2^8
|
||||
// in absolute value, so it's safe to left-shift it by 7 bits.
|
||||
static constexpr int left_shift_on_scaled_input = 7;
|
||||
// We now adjust the tweak to real_shift accordingly: instead of adding 15,
|
||||
// we only add (15 - left_shift_on_scaled_input).
|
||||
const int16x8_t real_shift_vec =
|
||||
vdupq_n_s16(15 - left_shift_on_scaled_input + real_shift);
|
||||
const int16x8_t scale_vec = vdupq_n_s16((scale + (1 << 15)) >> 16);
|
||||
const int16x8_t offset_vec = vdupq_n_s16(offset);
|
||||
const int16x8_t zero = vdupq_n_s16(0);
|
||||
for (; i <= size - 32; i += 32) {
|
||||
// This code heavily uses NEON saturating left shifts (vqshl*) with shift
|
||||
// amounts that can be zero, in which case we rely on the correct behavior
|
||||
// of a left shift by zero returning just its first operand unmodified.
|
||||
// Unfortunately, the Intel arm_neon_sse.h implementation of vqshl* is
|
||||
// buggy in the case of zero shift amounts, see b/137199585. That is why
|
||||
// this NEON code path is restricted to true ARM NEON, excluding
|
||||
// arm_neon_sse.h. Anyway, the arm_neon_sse.h implemenation of saturating
|
||||
// left shifts is slow scalar code, so there may not be much benefit in
|
||||
// running that over just plain reference code.
|
||||
//
|
||||
// TODO(b/137199585): revisit when this is fixed.
|
||||
#ifdef __ARM_NEON
|
||||
const int16x8_t positive_reluish_multiplier_exponent_minus_one =
|
||||
vdupq_n_s16(std::max(0, params.reluish_multiplier_exponent - 1));
|
||||
const int16x8_t positive_reluish_multiplier_exponent_last_bit =
|
||||
vdupq_n_s16(params.reluish_multiplier_exponent > 0 ? 1 : 0);
|
||||
const int16x8_t negative_reluish_multiplier_exponent =
|
||||
vdupq_n_s16(std::min(0, params.reluish_multiplier_exponent));
|
||||
const int16x8_t constant_32767 = vdupq_n_s16(32767);
|
||||
const int16x8_t output_multiplier_exponent =
|
||||
vdupq_n_s16(params.output_multiplier_exponent);
|
||||
const int16x8_t output_zero_point = vdupq_n_s16(params.output_zero_point);
|
||||
// 4x unrolled version of the below NEON loop. Read that first.
|
||||
for (; i <= flat_size - 32; i += 32) {
|
||||
using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
|
||||
int16x8x2_t in_0_1 =
|
||||
Load16AndSubtractZeroPoint(input_data + i + 0, in_zero_point);
|
||||
int16x8x2_t in_2_3 =
|
||||
Load16AndSubtractZeroPoint(input_data + i + 16, in_zero_point);
|
||||
int16x8_t in_reluish_0 = vshlq_s16(in_0_1.val[0], extra_input_shift_vec);
|
||||
int16x8_t in_reluish_1 = vshlq_s16(in_0_1.val[1], extra_input_shift_vec);
|
||||
int16x8_t in_reluish_2 = vshlq_s16(in_2_3.val[0], extra_input_shift_vec);
|
||||
int16x8_t in_reluish_3 = vshlq_s16(in_2_3.val[1], extra_input_shift_vec);
|
||||
in_reluish_0 = vaddq_s16(in_reluish_0, three_in_vec);
|
||||
in_reluish_1 = vaddq_s16(in_reluish_1, three_in_vec);
|
||||
in_reluish_2 = vaddq_s16(in_reluish_2, three_in_vec);
|
||||
in_reluish_3 = vaddq_s16(in_reluish_3, three_in_vec);
|
||||
in_reluish_0 = vminq_s16(in_reluish_0, six_in_vec);
|
||||
in_reluish_1 = vminq_s16(in_reluish_1, six_in_vec);
|
||||
in_reluish_2 = vminq_s16(in_reluish_2, six_in_vec);
|
||||
in_reluish_3 = vminq_s16(in_reluish_3, six_in_vec);
|
||||
in_reluish_0 = vmaxq_s16(in_reluish_0, zero);
|
||||
in_reluish_1 = vmaxq_s16(in_reluish_1, zero);
|
||||
in_reluish_2 = vmaxq_s16(in_reluish_2, zero);
|
||||
in_reluish_3 = vmaxq_s16(in_reluish_3, zero);
|
||||
int16x8_t in_scaled_0 = vqrdmulhq_s16(
|
||||
vshlq_n_s16(in_0_1.val[0], left_shift_on_scaled_input), scale_vec);
|
||||
int16x8_t in_scaled_1 = vqrdmulhq_s16(
|
||||
vshlq_n_s16(in_0_1.val[1], left_shift_on_scaled_input), scale_vec);
|
||||
int16x8_t in_scaled_2 = vqrdmulhq_s16(
|
||||
vshlq_n_s16(in_2_3.val[0], left_shift_on_scaled_input), scale_vec);
|
||||
int16x8_t in_scaled_3 = vqrdmulhq_s16(
|
||||
vshlq_n_s16(in_2_3.val[1], left_shift_on_scaled_input), scale_vec);
|
||||
int16x8_t product_0 = vqrdmulhq_s16(in_scaled_0, in_reluish_0);
|
||||
int16x8_t product_1 = vqrdmulhq_s16(in_scaled_1, in_reluish_1);
|
||||
int16x8_t product_2 = vqrdmulhq_s16(in_scaled_2, in_reluish_2);
|
||||
int16x8_t product_3 = vqrdmulhq_s16(in_scaled_3, in_reluish_3);
|
||||
product_0 = vrshlq_s16(product_0, real_shift_vec);
|
||||
product_1 = vrshlq_s16(product_1, real_shift_vec);
|
||||
product_2 = vrshlq_s16(product_2, real_shift_vec);
|
||||
product_3 = vrshlq_s16(product_3, real_shift_vec);
|
||||
SaturateAndStore(vaddq_s16(product_0, offset_vec), output_data + i + 0);
|
||||
SaturateAndStore(vaddq_s16(product_1, offset_vec), output_data + i + 8);
|
||||
SaturateAndStore(vaddq_s16(product_2, offset_vec), output_data + i + 16);
|
||||
SaturateAndStore(vaddq_s16(product_3, offset_vec), output_data + i + 24);
|
||||
const int16x8x2_t input_value_0_1 =
|
||||
Load16AndSubtractZeroPoint(input_data + i, params.input_zero_point);
|
||||
const int16x8x2_t input_value_2_3 = Load16AndSubtractZeroPoint(
|
||||
input_data + i + 16, params.input_zero_point);
|
||||
const int16x8_t input_value_on_hires_input_scale_0 =
|
||||
vshlq_n_s16(input_value_0_1.val[0], 7);
|
||||
const int16x8_t input_value_on_hires_input_scale_1 =
|
||||
vshlq_n_s16(input_value_0_1.val[1], 7);
|
||||
const int16x8_t input_value_on_hires_input_scale_2 =
|
||||
vshlq_n_s16(input_value_2_3.val[0], 7);
|
||||
const int16x8_t input_value_on_hires_input_scale_3 =
|
||||
vshlq_n_s16(input_value_2_3.val[1], 7);
|
||||
const int16x8_t input_value_on_preshift_output_scale_0 =
|
||||
vqrdmulhq_n_s16(input_value_on_hires_input_scale_0,
|
||||
params.output_multiplier_fixedpoint_int16);
|
||||
const int16x8_t input_value_on_preshift_output_scale_1 =
|
||||
vqrdmulhq_n_s16(input_value_on_hires_input_scale_1,
|
||||
params.output_multiplier_fixedpoint_int16);
|
||||
const int16x8_t input_value_on_preshift_output_scale_2 =
|
||||
vqrdmulhq_n_s16(input_value_on_hires_input_scale_2,
|
||||
params.output_multiplier_fixedpoint_int16);
|
||||
const int16x8_t input_value_on_preshift_output_scale_3 =
|
||||
vqrdmulhq_n_s16(input_value_on_hires_input_scale_3,
|
||||
params.output_multiplier_fixedpoint_int16);
|
||||
int16x8_t reluish_value_0 = input_value_on_hires_input_scale_0;
|
||||
int16x8_t reluish_value_1 = input_value_on_hires_input_scale_1;
|
||||
int16x8_t reluish_value_2 = input_value_on_hires_input_scale_2;
|
||||
int16x8_t reluish_value_3 = input_value_on_hires_input_scale_3;
|
||||
reluish_value_0 = vqshlq_s16(
|
||||
reluish_value_0, positive_reluish_multiplier_exponent_minus_one);
|
||||
reluish_value_1 = vqshlq_s16(
|
||||
reluish_value_1, positive_reluish_multiplier_exponent_minus_one);
|
||||
reluish_value_2 = vqshlq_s16(
|
||||
reluish_value_2, positive_reluish_multiplier_exponent_minus_one);
|
||||
reluish_value_3 = vqshlq_s16(
|
||||
reluish_value_3, positive_reluish_multiplier_exponent_minus_one);
|
||||
reluish_value_0 = vqrdmulhq_n_s16(
|
||||
reluish_value_0, params.reluish_multiplier_fixedpoint_int16);
|
||||
reluish_value_1 = vqrdmulhq_n_s16(
|
||||
reluish_value_1, params.reluish_multiplier_fixedpoint_int16);
|
||||
reluish_value_2 = vqrdmulhq_n_s16(
|
||||
reluish_value_2, params.reluish_multiplier_fixedpoint_int16);
|
||||
reluish_value_3 = vqrdmulhq_n_s16(
|
||||
reluish_value_3, params.reluish_multiplier_fixedpoint_int16);
|
||||
reluish_value_0 = vqshlq_s16(reluish_value_0,
|
||||
positive_reluish_multiplier_exponent_last_bit);
|
||||
reluish_value_1 = vqshlq_s16(reluish_value_1,
|
||||
positive_reluish_multiplier_exponent_last_bit);
|
||||
reluish_value_2 = vqshlq_s16(reluish_value_2,
|
||||
positive_reluish_multiplier_exponent_last_bit);
|
||||
reluish_value_3 = vqshlq_s16(reluish_value_3,
|
||||
positive_reluish_multiplier_exponent_last_bit);
|
||||
reluish_value_0 =
|
||||
vrshlq_s16(reluish_value_0, negative_reluish_multiplier_exponent);
|
||||
reluish_value_1 =
|
||||
vrshlq_s16(reluish_value_1, negative_reluish_multiplier_exponent);
|
||||
reluish_value_2 =
|
||||
vrshlq_s16(reluish_value_2, negative_reluish_multiplier_exponent);
|
||||
reluish_value_3 =
|
||||
vrshlq_s16(reluish_value_3, negative_reluish_multiplier_exponent);
|
||||
reluish_value_0 = vrhaddq_s16(reluish_value_0, constant_32767);
|
||||
reluish_value_1 = vrhaddq_s16(reluish_value_1, constant_32767);
|
||||
reluish_value_2 = vrhaddq_s16(reluish_value_2, constant_32767);
|
||||
reluish_value_3 = vrhaddq_s16(reluish_value_3, constant_32767);
|
||||
const int16x8_t preshift_output_value_0 =
|
||||
vqdmulhq_s16(reluish_value_0, input_value_on_preshift_output_scale_0);
|
||||
const int16x8_t preshift_output_value_1 =
|
||||
vqdmulhq_s16(reluish_value_1, input_value_on_preshift_output_scale_1);
|
||||
const int16x8_t preshift_output_value_2 =
|
||||
vqdmulhq_s16(reluish_value_2, input_value_on_preshift_output_scale_2);
|
||||
const int16x8_t preshift_output_value_3 =
|
||||
vqdmulhq_s16(reluish_value_3, input_value_on_preshift_output_scale_3);
|
||||
int16x8_t output_value_0 =
|
||||
vrshlq_s16(preshift_output_value_0, output_multiplier_exponent);
|
||||
int16x8_t output_value_1 =
|
||||
vrshlq_s16(preshift_output_value_1, output_multiplier_exponent);
|
||||
int16x8_t output_value_2 =
|
||||
vrshlq_s16(preshift_output_value_2, output_multiplier_exponent);
|
||||
int16x8_t output_value_3 =
|
||||
vrshlq_s16(preshift_output_value_3, output_multiplier_exponent);
|
||||
output_value_0 = vaddq_s16(output_value_0, output_zero_point);
|
||||
output_value_1 = vaddq_s16(output_value_1, output_zero_point);
|
||||
output_value_2 = vaddq_s16(output_value_2, output_zero_point);
|
||||
output_value_3 = vaddq_s16(output_value_3, output_zero_point);
|
||||
SaturateAndStore(output_value_0, output_data + i);
|
||||
SaturateAndStore(output_value_1, output_data + i + 8);
|
||||
SaturateAndStore(output_value_2, output_data + i + 16);
|
||||
SaturateAndStore(output_value_3, output_data + i + 24);
|
||||
}
|
||||
for (; i <= size - 8; i += 8) {
|
||||
// NEON version of reference_ops::HardSwish. Read that first.
|
||||
for (; i <= flat_size - 8; i += 8) {
|
||||
using cpu_backend_gemm::detail::Load8AndSubtractZeroPoint;
|
||||
// See comments in the float NEON HardSwish implementation.
|
||||
int16x8_t in = Load8AndSubtractZeroPoint(input_data + i, in_zero_point);
|
||||
int16x8_t in_reluish = vshlq_s16(in, extra_input_shift_vec);
|
||||
in_reluish = vaddq_s16(in_reluish, three_in_vec);
|
||||
in_reluish = vminq_s16(in_reluish, six_in_vec);
|
||||
in_reluish = vmaxq_s16(zero, in_reluish);
|
||||
int16x8_t in_scaled =
|
||||
vqrdmulhq_s16(vshlq_n_s16(in, left_shift_on_scaled_input), scale_vec);
|
||||
int16x8_t product = vqrdmulhq_s16(in_scaled, in_reluish);
|
||||
product = vrshlq_s16(product, real_shift_vec);
|
||||
SaturateAndStore(vaddq_s16(product, offset_vec), output_data + i);
|
||||
const int16x8_t input_value =
|
||||
Load8AndSubtractZeroPoint(input_data + i, params.input_zero_point);
|
||||
const int16x8_t input_value_on_hires_input_scale =
|
||||
vshlq_n_s16(input_value, 7);
|
||||
const int16x8_t input_value_on_preshift_output_scale =
|
||||
vqrdmulhq_n_s16(input_value_on_hires_input_scale,
|
||||
params.output_multiplier_fixedpoint_int16);
|
||||
int16x8_t reluish_value = input_value_on_hires_input_scale;
|
||||
reluish_value = vqshlq_s16(reluish_value,
|
||||
positive_reluish_multiplier_exponent_minus_one);
|
||||
reluish_value = vqrdmulhq_n_s16(reluish_value,
|
||||
params.reluish_multiplier_fixedpoint_int16);
|
||||
reluish_value = vqshlq_s16(reluish_value,
|
||||
positive_reluish_multiplier_exponent_last_bit);
|
||||
reluish_value =
|
||||
vrshlq_s16(reluish_value, negative_reluish_multiplier_exponent);
|
||||
reluish_value = vrhaddq_s16(reluish_value, constant_32767);
|
||||
const int16x8_t preshift_output_value =
|
||||
vqdmulhq_s16(reluish_value, input_value_on_preshift_output_scale);
|
||||
int16x8_t output_value =
|
||||
vrshlq_s16(preshift_output_value, output_multiplier_exponent);
|
||||
output_value = vaddq_s16(output_value, output_zero_point);
|
||||
SaturateAndStore(output_value, output_data + i);
|
||||
}
|
||||
#endif
|
||||
for (; i < size; i++) {
|
||||
int32_t v = static_cast<int32>(input_data[i]);
|
||||
v -= in_zero_point; // Make zeros - zero again!
|
||||
|
||||
// Computes x + 3 in input * 2^extra_input_shift scale.
|
||||
//
|
||||
// Note: three_in is in that scale already.
|
||||
const int32_t v3 = (v << extra_input_shift) + three_in;
|
||||
|
||||
// Computes hard-swish up to a final scale
|
||||
v *= std::min(six_in, std::max(0, v3));
|
||||
|
||||
// this converts from x * relu6(x+3) in input into x * relu6(x+3) / 6
|
||||
// in output scale.
|
||||
v = MultiplyByQuantizedMultiplierSmallerThanOneExp(v, scale, real_shift);
|
||||
v += offset;
|
||||
output_data[i] = reference_ops::Saturate<QuantizedType>(v);
|
||||
// TODO(b/137208495): revisit when unit tests cover reference code.
|
||||
// Fall back to reference_ops::HardSwish. In general we have preferred
|
||||
// to duplicate such scalar code rather than call reference code to handle
|
||||
// leftovers, thinking that code duplication was not a big concern.
|
||||
// However, most of our unit tests happen to test only optimized code,
|
||||
// and the quantized HardSwish implementation is nontrivial enough that
|
||||
// I really want test coverage for the reference code.
|
||||
if (i < flat_size) {
|
||||
const RuntimeShape leftover_shape{flat_size - i};
|
||||
reference_ops::HardSwish(params, leftover_shape, input_data + i,
|
||||
leftover_shape, output_data + i);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4580,11 +4580,23 @@ inline void HardSwish(const RuntimeShape& input_shape, const T* input_data,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T Saturate(int32_t v) {
|
||||
return static_cast<T>(std::min(
|
||||
static_cast<int32_t>(std::numeric_limits<T>::max()),
|
||||
std::max(static_cast<int32_t>(std::numeric_limits<T>::min()), v)));
|
||||
inline int16_t SaturatingLeftShift(int16_t value, int amount) {
|
||||
int32_t result = static_cast<int32_t>(value) * (1 << amount);
|
||||
result = std::min<int32_t>(result, std::numeric_limits<int16_t>::max());
|
||||
result = std::max<int32_t>(result, std::numeric_limits<int16_t>::min());
|
||||
return result;
|
||||
}
|
||||
|
||||
// Similar to ARM instruction SQDMULH.
|
||||
// Similar to gemmlowp::SaturatingRoundingDoublingHighMul except
|
||||
// rounding to zero instead of to nearest (SQRDMULH).
|
||||
inline std::int16_t SaturatingDoublingHighMul(std::int16_t a, std::int16_t b) {
|
||||
bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
|
||||
std::int32_t a_32(a);
|
||||
std::int32_t b_32(b);
|
||||
std::int32_t ab_32 = a_32 * b_32;
|
||||
std::int16_t ab_x2_high16 = static_cast<std::int16_t>((ab_32) / (1 << 15));
|
||||
return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -4592,36 +4604,103 @@ inline void HardSwish(const HardSwishParams& params,
|
||||
const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("ReferenceHardSwish/Quantized");
|
||||
// Goal: (x * relu6(x+3))/6
|
||||
const T* in = input_data;
|
||||
T* out = output_data;
|
||||
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
const T* in_end = in + flat_size;
|
||||
const int32_t extra_input_shift = params.clip_input_shift;
|
||||
const auto in_zero_point = params.input_zero_point;
|
||||
const auto three_in = params.three_input;
|
||||
const auto six_in = params.six_input;
|
||||
const auto real_shift = params.shift;
|
||||
const auto scale = params.scale;
|
||||
const auto offset = params.output_offset;
|
||||
|
||||
for (; in < in_end; in++, out++) {
|
||||
int32_t v = static_cast<int32>(*in);
|
||||
v -= in_zero_point; // Make zeros - zero again!
|
||||
|
||||
// Computes x + 3 in input * 2^extra_input_shift scale.
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
const int16_t input_value = input_data[i] - params.input_zero_point;
|
||||
// Left-shift as much as we can without overflow/saturation to put
|
||||
// significant bits in the high bits of our 16-bit fixedpoint values, so
|
||||
// that fixed-point approximate computations below are as accurate as
|
||||
// possible.
|
||||
const int16_t input_value_on_hires_input_scale = input_value << 7;
|
||||
// Compute the input value on essentially the output scale, just not
|
||||
// right-shifted yet. This is the value that we'll use in the (x >= +3)
|
||||
// case, and that in the general case we'll multiply against the "relu-ish"
|
||||
// fixed-point multiplier in [0, 1].
|
||||
const int16_t input_value_on_preshift_output_scale =
|
||||
gemmlowp::SaturatingRoundingDoublingHighMul(
|
||||
input_value_on_hires_input_scale,
|
||||
params.output_multiplier_fixedpoint_int16);
|
||||
// Now compute the "relu-ish multiplier". In the (-3 <= x <= +3) case, that
|
||||
// is just an affine rescaling of x from [-3, 3] to [0, 1]. In the general
|
||||
// case, it is just that plus saturation at the boundaries of [-3, 3].
|
||||
// First, we rescale from [-3, 3] to [-1, 1], saturating.
|
||||
// That is done by rescaling the input value with a fixed-point multiplier
|
||||
// (reluish_multiplier_fixedpoint) and bit-shift such that we represent
|
||||
// that input value on the scale where the real value 3.0f is represented
|
||||
// by the quantized value 32768. (+32768 is actually not representable as
|
||||
// int16, so this saturates at +32767, and that is seen empirically to be
|
||||
// a negligible contribution to numerical error/bias).
|
||||
//
|
||||
// Note: three_in is in that scale already.
|
||||
const int32_t v3 = (v << extra_input_shift) + three_in;
|
||||
|
||||
// Computes hard-swish up to a final scale
|
||||
v *= std::min(six_in, std::max(0, v3));
|
||||
|
||||
// this converts from x * relu6(x+3) in input into x * relu6(x+3) / 6
|
||||
// in output scale.
|
||||
v = MultiplyByQuantizedMultiplierSmallerThanOneExp(v, scale, real_shift);
|
||||
v += offset;
|
||||
*out = Saturate<uint8>(v);
|
||||
// This code is careful to correctly implement any magnitude of multiplier,
|
||||
// involving either a right shift or a left shift, with correct saturation
|
||||
// behavior in the left-shift case. This forces this code to be more
|
||||
// complicated, but is necessary for real applications: a partially
|
||||
// trained quantized MobileNet v3-small model that motivated this code
|
||||
// exhibits some large [min, max] range boundaries, of the order of
|
||||
// magnitude of 10 or 100 depending on layers.
|
||||
//
|
||||
// The next few lines are basically just an ordinary
|
||||
// MultiplyByQuantizedMultiplier, except that we are more careful here
|
||||
// about the fine details of saturation when left-shifting, because here
|
||||
// overflow in left-shift is a common case, not an anomaly as
|
||||
// MultiplyByQuantizedMultiplier assumes.
|
||||
int16_t reluish_value = input_value_on_hires_input_scale;
|
||||
// Shift left, saturating, as much as we can while ensuring that this
|
||||
// saturation will not contribute to the result. That is, left shift amount
|
||||
// reduced by 1.
|
||||
if (params.reluish_multiplier_exponent > 0) {
|
||||
reluish_value = SaturatingLeftShift(
|
||||
reluish_value, params.reluish_multiplier_exponent - 1);
|
||||
}
|
||||
// Apply the fixed-point multiplier, dividing the value by a divisor
|
||||
// ranging in [1, 2].
|
||||
reluish_value = gemmlowp::SaturatingRoundingDoublingHighMul(
|
||||
reluish_value, params.reluish_multiplier_fixedpoint_int16);
|
||||
// Apply the last bit of left-shift. Thus, in the left-shifting case, if
|
||||
// any saturation affects the result, it is happening here --- any
|
||||
// saturation having occurred above is overwritten here, not affecting the
|
||||
// result.
|
||||
if (params.reluish_multiplier_exponent > 0) {
|
||||
reluish_value = SaturatingLeftShift(reluish_value, 1);
|
||||
}
|
||||
// Shift right, in the right-shifting case.
|
||||
if (params.reluish_multiplier_exponent < 0) {
|
||||
reluish_value = gemmlowp::RoundingDivideByPOT(
|
||||
reluish_value, -params.reluish_multiplier_exponent);
|
||||
}
|
||||
// At this point we have rescaled the value into a 16bit fixedpoint
|
||||
// reluish_value in [-1, 1].
|
||||
// We now convert that to a 16bit fixedpoint value in [0, 1].
|
||||
reluish_value = (reluish_value + (1 << 15)) >> 1;
|
||||
// Use of SaturatingDoublingHighMul here is important to cancel the biases
|
||||
// from the above SaturatingRoundingDoublingHighMul.
|
||||
//
|
||||
// On a partially trained MobileNet-v3-small,
|
||||
//
|
||||
// | bias on | ImageNet
|
||||
// | quantized | Top-1
|
||||
// Operation used here | values | accuracy (50k)
|
||||
// --------------------------------------+------------+-----------
|
||||
// SaturatingDoublingHighMul | -0.0024 | 58.920
|
||||
// SaturatingRoundingDoublingHighMul | -0.0067 | 58.064
|
||||
//
|
||||
// In activations_test, this is covered by this testcase:
|
||||
// QuantizedActivationsOpTest.HardSwishBias
|
||||
//
|
||||
const int16_t preshift_output_value = SaturatingDoublingHighMul(
|
||||
reluish_value, input_value_on_preshift_output_scale);
|
||||
// We were so far operating on the pre-shift output scale. Now we finally
|
||||
// apply that output shift, arriving at the final output scale.
|
||||
int16_t output_value = gemmlowp::RoundingDivideByPOT(
|
||||
preshift_output_value, -params.output_multiplier_exponent);
|
||||
output_value += params.output_zero_point;
|
||||
output_value =
|
||||
std::min<int16_t>(output_value, std::numeric_limits<T>::max());
|
||||
output_value =
|
||||
std::max<int16_t>(output_value, std::numeric_limits<T>::min());
|
||||
output_data[i] = output_value;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -873,35 +873,24 @@ struct LocalResponseNormalizationParams {
|
||||
};
|
||||
|
||||
struct HardSwishParams {
|
||||
// uint8 inference params
|
||||
|
||||
// Contains input->params.zero_point
|
||||
int32_t input_zero_point;
|
||||
|
||||
// when computing relu6(x+3), we scale input by using bit-shift
|
||||
// to avoid loss of precision when doing computation in uint8 (and 6 might not
|
||||
// be exactly representable in that scale.
|
||||
// This flag contains number of bits to shift.
|
||||
int32_t clip_input_shift;
|
||||
|
||||
// Added to the final output to bring the output's zero_point in order.
|
||||
int32_t output_offset;
|
||||
|
||||
// Scale that converts x*relu6(x+3) computed in input range
|
||||
// into output range x * relu6(x + 3) / 6.
|
||||
// This takes into account that hardswish is quadratic
|
||||
// so we have in_scale^2/out_scale.
|
||||
// This is the integer nominator pat of the multiplier
|
||||
int32_t scale;
|
||||
|
||||
// this is the denominator 2^shift of the multiplier
|
||||
int shift;
|
||||
|
||||
// 3 in input 0-centered scale
|
||||
int32_t three_input;
|
||||
|
||||
// 6 in input 0-centered scale
|
||||
int32_t six_input;
|
||||
// zero_point of the input activations.
|
||||
int16_t input_zero_point;
|
||||
// zero_point of the output activations.
|
||||
int16_t output_zero_point;
|
||||
// 16bit fixed-point component of the multiplier to apply to go from the
|
||||
// "high-res input scale", which is the input scale multiplied by 2^7, to the
|
||||
// "relu-ish scale", which 3.0/32768.
|
||||
// See the implementation of HardSwishPrepare.
|
||||
int16_t reluish_multiplier_fixedpoint_int16;
|
||||
// exponent/bit-shift component of the aforementioned multiplier.
|
||||
int reluish_multiplier_exponent;
|
||||
// 16bit fixed-point component of the multiplier to apply to go from the
|
||||
// "high-res input scale", which is the input scale multiplied by 2^7, to the
|
||||
// output scale.
|
||||
// See the implementation of HardSwishPrepare.
|
||||
int16_t output_multiplier_fixedpoint_int16;
|
||||
// exponent/bit-shift component of the aforementioned multiplier.
|
||||
int output_multiplier_exponent;
|
||||
};
|
||||
|
||||
struct LogisticParams {
|
||||
|
Loading…
x
Reference in New Issue
Block a user