NEON-optimized HardSwish
PiperOrigin-RevId: 253787700
This commit is contained in:
parent
9fddf47da9
commit
860acc59a0
@ -144,13 +144,21 @@ TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_STATUS(GenericPrepare(context, node));
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
|
||||
if (output->type == kTfLiteUInt8) {
|
||||
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
|
||||
HardSwishData* data = static_cast<HardSwishData*>(node->user_data);
|
||||
HardSwishParams* params = &data->params;
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
// TODO(131260336): Maybe pick a better way to select the denominator shift.
|
||||
// Include input shift into the shift.
|
||||
const int32_t extra_input_shift = 3;
|
||||
static constexpr int32_t extra_input_shift = 3;
|
||||
// Note: optimized implementations will rely on the ability to perform this
|
||||
// left shift within int16 without overflow. The values being left-shifted
|
||||
// range in [-255, 255] i.e. just under 2^8 in absolute value, and after the
|
||||
// left shift they will still be added the 'three_input' value, which is
|
||||
// safe if they're not greater than 2^14 in absolute value (since 2^15 is
|
||||
// the magnitude of the boundaries of int16 range). 14-8 == 6, so we
|
||||
// require extra_input_shift to be no greater than 6.
|
||||
static_assert(extra_input_shift <= 6, "");
|
||||
const auto in_scale = input->params.scale;
|
||||
params->input_zero_point = input->params.zero_point;
|
||||
const auto out_scale = output->params.scale;
|
||||
@ -492,6 +500,7 @@ TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
HardSwishData* data = static_cast<HardSwishData*>(node->user_data);
|
||||
|
||||
@ -499,22 +508,47 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
reference_ops::HardSwish(
|
||||
GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
if (kernel_type == kReference) {
|
||||
reference_ops::HardSwish(
|
||||
GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
} else {
|
||||
optimized_ops::HardSwish(
|
||||
GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
}
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
case kTfLiteUInt8: {
|
||||
HardSwishParams& params = data->params;
|
||||
|
||||
reference_ops::HardSwish<uint8_t>(
|
||||
params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||
if (kernel_type == kReference) {
|
||||
reference_ops::HardSwish(
|
||||
params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||
} else {
|
||||
optimized_ops::HardSwish(
|
||||
params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||
}
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
case kTfLiteInt8: {
|
||||
HardSwishParams& params = data->params;
|
||||
if (kernel_type == kReference) {
|
||||
reference_ops::HardSwish(
|
||||
params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
GetTensorShape(output), GetTensorData<int8_t>(output));
|
||||
} else {
|
||||
optimized_ops::HardSwish(
|
||||
params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
GetTensorShape(output), GetTensorData<int8_t>(output));
|
||||
}
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Only float32, uint8 are supported currently, got %s.",
|
||||
context,
|
||||
"Only float32, uint8 and int8 are supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
@ -1070,9 +1104,19 @@ TfLiteRegistration* Register_LEAKY_RELU() {
|
||||
TfLiteRegistration* Register_HARD_SWISH() {
|
||||
static TfLiteRegistration r = {
|
||||
activations::HardSwishInit, activations::HardSwishFree,
|
||||
activations::HardSwishPrepare, activations::HardSwishEval};
|
||||
activations::HardSwishPrepare,
|
||||
activations::HardSwishEval<activations::kGenericOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_HARD_SWISH_REF() {
|
||||
static TfLiteRegistration r = {
|
||||
activations::HardSwishInit, activations::HardSwishFree,
|
||||
activations::HardSwishPrepare,
|
||||
activations::HardSwishEval<activations::kReference>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstdarg>
|
||||
#include <random>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
@ -89,7 +90,7 @@ class FloatActivationsOpModel : public BaseActivationsOpModel {
|
||||
public:
|
||||
using BaseActivationsOpModel::BaseActivationsOpModel;
|
||||
|
||||
void SetInput(std::initializer_list<float> data) {
|
||||
void SetInput(const std::vector<float>& data) {
|
||||
PopulateTensor(input_, data);
|
||||
}
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
@ -119,7 +120,7 @@ class QuantizedActivationsOpModel : public BaseActivationsOpModel {
|
||||
using BaseActivationsOpModel::BaseActivationsOpModel;
|
||||
|
||||
template <typename T>
|
||||
void SetInput(std::initializer_list<float> data) {
|
||||
void SetInput(const std::vector<float>& data) {
|
||||
QuantizeAndPopulate<T>(input_, data);
|
||||
}
|
||||
template <typename T>
|
||||
@ -191,45 +192,99 @@ TEST(FloatActivationsOpTest, Relu6) {
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(FloatActivationsOpTest, HardSwish) {
|
||||
void GenerateUniformRandomVector(int size, float min, float max,
|
||||
std::minstd_rand* random_engine,
|
||||
std::vector<float>* result) {
|
||||
// Never use std::uniform_*_distribution in tests, it's
|
||||
// implementation-defined. Likewise, don't use std::default_random_engine,
|
||||
// implementation-defined. Implementation-defined is bad because it means that
|
||||
// any toolchain update or new platform may run into test failures.
|
||||
// std::minstd_rand is a standard instantiation of
|
||||
// std::linear_congruential_engine, the cheapest generator in c++11 stdlib,
|
||||
// it's good enough here.
|
||||
result->resize(size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
// We don't care whether the `max` value may ever be produced exactly.
|
||||
// It may actually be thanks to rounding, as std::minstd_rand::modulus
|
||||
// is 2^31 - 1 is greater than the inverse float epsilon.
|
||||
float random_value_scaled_0_1 =
|
||||
(*random_engine)() *
|
||||
(1.0f / static_cast<float>(std::minstd_rand::modulus));
|
||||
(*result)[i] = min + (max - min) * random_value_scaled_0_1;
|
||||
}
|
||||
}
|
||||
|
||||
void EvalTestReferenceHardSwish(int size, const std::vector<float>& input,
|
||||
std::vector<float>* result) {
|
||||
result->resize(size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
const float in = input[i];
|
||||
(*result)[i] = in * std::min(6.0f, std::max(0.0f, in + 3)) * (1.0f / 6.0f);
|
||||
}
|
||||
}
|
||||
|
||||
void TestFloatHardSwish(int size, std::minstd_rand* random_engine) {
|
||||
std::vector<float> float_input_values;
|
||||
const float kMin = -10.0f;
|
||||
const float kMax = 10.0f;
|
||||
GenerateUniformRandomVector(size, kMin, kMax, random_engine,
|
||||
&float_input_values);
|
||||
std::vector<float> float_ref_output_values;
|
||||
EvalTestReferenceHardSwish(size, float_input_values,
|
||||
&float_ref_output_values);
|
||||
FloatActivationsOpModel m(BuiltinOperator_HARD_SWISH,
|
||||
/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
|
||||
m.SetInput({
|
||||
0, -6, 2, 4, //
|
||||
3, -2, 10, 1, //
|
||||
});
|
||||
/*input=*/{TensorType_FLOAT32, {1, 1, 1, size}},
|
||||
/*output=*/{TensorType_FLOAT32, {1, 1, 1, size}});
|
||||
m.SetInput(float_input_values);
|
||||
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{0, 0, 1.66666667, 4, 3., -0.33333333, 10, 0.66666667})));
|
||||
ElementsAreArray(ArrayFloatNear(float_ref_output_values)));
|
||||
}
|
||||
|
||||
template <typename QuantizedType>
|
||||
void TestQuantizedHardSwish(TensorType tensor_type, int size,
|
||||
std::minstd_rand* random_engine) {
|
||||
std::vector<float> float_input_values;
|
||||
const float kMin = -10.0f;
|
||||
const float kMax = 10.0f;
|
||||
GenerateUniformRandomVector(size, kMin, kMax, random_engine,
|
||||
&float_input_values);
|
||||
const float kOutMin = -3;
|
||||
const float kOutMax = kMax;
|
||||
std::vector<float> float_ref_output_values;
|
||||
EvalTestReferenceHardSwish(size, float_input_values,
|
||||
&float_ref_output_values);
|
||||
QuantizedActivationsOpModel m(
|
||||
BuiltinOperator_HARD_SWISH,
|
||||
/*input=*/{tensor_type, {1, 1, 1, size}, kMin, kMax},
|
||||
/*output=*/{tensor_type, {1, 1, 1, size}, kOutMin, kOutMax});
|
||||
m.SetInput<QuantizedType>(float_input_values);
|
||||
|
||||
m.Invoke();
|
||||
// The numerical error for any 8bit quantized function is at least one half
|
||||
// times the quantization step: 0.5 * (kOutMax - kOutMin) / 256.
|
||||
// To that we add again the quantization step (kOutMax - kOutMin) / 256
|
||||
// to allow for an off-by-one rounding error.
|
||||
const float kTolerance = (kOutMax - kOutMin) * (1.5f / 256.f);
|
||||
EXPECT_THAT(
|
||||
m.GetDequantizedOutput<QuantizedType>(),
|
||||
ElementsAreArray(ArrayFloatNear(float_ref_output_values, kTolerance)));
|
||||
}
|
||||
|
||||
TEST(FloatActivationsOpTest, HardSwish) {
|
||||
std::minstd_rand random_engine;
|
||||
for (int size : {1, 2, 3, 4, 10, 20, 30, 40, 100}) {
|
||||
TestFloatHardSwish(size, &random_engine);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(QuantizedActivationsOpTest, HardSwish) {
|
||||
const float kMin = -10;
|
||||
const float kMax = 15;
|
||||
const float kOutMin = -3;
|
||||
const float kOutMax = kMax;
|
||||
QuantizedActivationsOpModel m(
|
||||
BuiltinOperator_HARD_SWISH,
|
||||
/*input=*/{TensorType_UINT8, {1, 2, 4, 1}, kMin, kMax},
|
||||
/*output=*/{TensorType_UINT8, {1, 2, 4, 1}, kOutMin, kOutMax});
|
||||
m.SetInput<uint8_t>({
|
||||
0, -10, 2, 4, //
|
||||
3, -2, 10, 15, //
|
||||
});
|
||||
auto r =
|
||||
Dequantize(m.ExtractVector<uint8_t>(0), m.GetScale(0), m.GetZeroPoint(0));
|
||||
for (int i = 0; i < r.size(); i++) {
|
||||
LOG(INFO) << r[i];
|
||||
std::minstd_rand random_engine;
|
||||
for (int size : {1, 2, 3, 4, 10, 20, 30, 40, 100}) {
|
||||
TestQuantizedHardSwish<uint8_t>(TensorType_UINT8, size, &random_engine);
|
||||
TestQuantizedHardSwish<int8_t>(TensorType_INT8, size, &random_engine);
|
||||
}
|
||||
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
|
||||
ElementsAreArray(
|
||||
ArrayFloatNear({0, 0, 1.66666667, 4, 3., -0.33333333, 10, 15},
|
||||
(kOutMax - kOutMin) / 1. / 256.)));
|
||||
EXPECT_THAT(m.GetOutput<uint8_t>(),
|
||||
ElementsAreArray({42, 42, 65, 99, 85, 37, 184, 255}));
|
||||
}
|
||||
|
||||
TEST(FloatActivationsOpTest, Tanh) {
|
||||
|
@ -5677,6 +5677,204 @@ inline void Requantize<uint8_t, int8_t>(const uint8_t* input_data, int32_t size,
|
||||
}
|
||||
}
|
||||
|
||||
inline void HardSwish(const RuntimeShape& input_shape, const float* input_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("HardSwish/Float");
|
||||
auto size = MatchingFlatSize(input_shape, output_shape);
|
||||
int i = 0;
|
||||
#ifdef USE_NEON
|
||||
const float32x4_t zero = vdupq_n_f32(0.0f);
|
||||
const float32x4_t three = vdupq_n_f32(3.0f);
|
||||
const float32x4_t six = vdupq_n_f32(6.0f);
|
||||
const float32x4_t one_sixth = vdupq_n_f32(1.0f / 6.0f);
|
||||
|
||||
for (; i <= size - 16; i += 16) {
|
||||
// 4x partially unrolled version of the loop below. Refer to its comments.
|
||||
const float32x4_t in_0 = vld1q_f32(input_data + i + 0);
|
||||
const float32x4_t in_1 = vld1q_f32(input_data + i + 4);
|
||||
const float32x4_t in_2 = vld1q_f32(input_data + i + 8);
|
||||
const float32x4_t in_3 = vld1q_f32(input_data + i + 12);
|
||||
const float32x4_t in_scaled_0 = vmulq_f32(in_0, one_sixth);
|
||||
const float32x4_t in_scaled_1 = vmulq_f32(in_1, one_sixth);
|
||||
const float32x4_t in_scaled_2 = vmulq_f32(in_2, one_sixth);
|
||||
const float32x4_t in_scaled_3 = vmulq_f32(in_3, one_sixth);
|
||||
const float32x4_t in_reluish_0 =
|
||||
vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_0, three)));
|
||||
const float32x4_t in_reluish_1 =
|
||||
vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_1, three)));
|
||||
const float32x4_t in_reluish_2 =
|
||||
vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_2, three)));
|
||||
const float32x4_t in_reluish_3 =
|
||||
vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_3, three)));
|
||||
const float32x4_t product_0 = vmulq_f32(in_scaled_0, in_reluish_0);
|
||||
const float32x4_t product_1 = vmulq_f32(in_scaled_1, in_reluish_1);
|
||||
const float32x4_t product_2 = vmulq_f32(in_scaled_2, in_reluish_2);
|
||||
const float32x4_t product_3 = vmulq_f32(in_scaled_3, in_reluish_3);
|
||||
vst1q_f32(output_data + i + 0, product_0);
|
||||
vst1q_f32(output_data + i + 4, product_1);
|
||||
vst1q_f32(output_data + i + 8, product_2);
|
||||
vst1q_f32(output_data + i + 12, product_3);
|
||||
}
|
||||
for (; i <= size - 4; i += 4) {
|
||||
// The expression to be computed is:
|
||||
// out = one_sixth * in * min(six, max(zero, (in + three)))
|
||||
// We structure the AST to have two roughly balanced, independent branches:
|
||||
// - Multiplication: in_scaled = one_sixth * in.
|
||||
// - Addition and clamping: in_reluish = min(six, max(zero, (in + three))).
|
||||
// Then the remaining multiplication at the root of the tree.
|
||||
const float32x4_t in = vld1q_f32(input_data + i);
|
||||
const float32x4_t in_scaled = vmulq_f32(in, one_sixth);
|
||||
const float32x4_t in_reluish =
|
||||
vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in, three)));
|
||||
const float32x4_t product = vmulq_f32(in_scaled, in_reluish);
|
||||
vst1q_f32(output_data + i, product);
|
||||
}
|
||||
#endif
|
||||
for (; i < size; i++) {
|
||||
const float in = input_data[i];
|
||||
output_data[i] =
|
||||
in * std::min(6.0f, std::max(0.0f, in + 3.0f)) * (1.0f / 6.0f);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_NEON
|
||||
inline void SaturateAndStore(int16x8_t src, std::uint8_t* dst) {
|
||||
// Narrow values down to 8 bit unsigned, saturating.
|
||||
uint8x8_t res8 = vqmovun_s16(src);
|
||||
// Store results to destination.
|
||||
vst1_u8(dst, res8);
|
||||
}
|
||||
|
||||
inline void SaturateAndStore(int16x8_t src, std::int8_t* dst) {
|
||||
// Narrow values down to 8 bit unsigned, saturating.
|
||||
int8x8_t res8 = vqmovn_s16(src);
|
||||
// Store results to destination.
|
||||
vst1_s8(dst, res8);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename QuantizedType>
|
||||
inline void HardSwish(const HardSwishParams& params,
|
||||
const RuntimeShape& input_shape,
|
||||
const QuantizedType* input_data,
|
||||
const RuntimeShape& output_shape,
|
||||
QuantizedType* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("HardSwish/Quantized");
|
||||
// Goal: (x * relu6(x+3))/6
|
||||
const int size = MatchingFlatSize(input_shape, output_shape);
|
||||
const int32_t extra_input_shift = params.clip_input_shift;
|
||||
const auto in_zero_point = params.input_zero_point;
|
||||
const auto three_in = params.three_input;
|
||||
const auto six_in = params.six_input;
|
||||
const auto real_shift = params.shift;
|
||||
const auto scale = params.scale;
|
||||
const auto offset = params.output_offset;
|
||||
int i = 0;
|
||||
#ifdef USE_NEON
|
||||
const int16x8_t extra_input_shift_vec = vdupq_n_s16(extra_input_shift);
|
||||
const int16x8_t three_in_vec = vdupq_n_s16(three_in);
|
||||
const int16x8_t six_in_vec = vdupq_n_s16(six_in);
|
||||
// The quantization params of this op are designed around a reference
|
||||
// implementation that performs plain integer multiplication, not
|
||||
// fixed-point multiplication. The 16-bit fixed-point multiplications
|
||||
// that we use here, vqrdmulhq_s16, differ from that by an (rounding)
|
||||
// right shift by 15 bits. So in terms of scale and leaving aside
|
||||
// accuracy considerations, we could simply compensate for that by
|
||||
// adding 15 to real_shift. Doing so results in approximately correct results,
|
||||
// but there is high inaccuracy in the low bits. That is because unlike
|
||||
// the integer multiplications done in the reference code, our fixed-point
|
||||
// multiplication are destructive of low bits. In order to have accurate
|
||||
// enough results, we move some of that bit-shifting from being applied to
|
||||
// the result to being applied to one of the operands of these fixed-point
|
||||
// multiplications, before the information in the low bits is destroyed.
|
||||
// Fortunately, one of the operands is by construction smaller than 2^8
|
||||
// in absolute value, so it's safe to left-shift it by 7 bits.
|
||||
static constexpr int left_shift_on_scaled_input = 7;
|
||||
// We now adjust the tweak to real_shift accordingly: instead of adding 15,
|
||||
// we only add (15 - left_shift_on_scaled_input).
|
||||
const int16x8_t real_shift_vec =
|
||||
vdupq_n_s16(15 - left_shift_on_scaled_input + real_shift);
|
||||
const int16x8_t scale_vec = vdupq_n_s16((scale + (1 << 15)) >> 16);
|
||||
const int16x8_t offset_vec = vdupq_n_s16(offset);
|
||||
const int16x8_t zero = vdupq_n_s16(0);
|
||||
for (; i <= size - 32; i += 32) {
|
||||
using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
|
||||
int16x8x2_t in_0_1 =
|
||||
Load16AndSubtractZeroPoint(input_data + i + 0, in_zero_point);
|
||||
int16x8x2_t in_2_3 =
|
||||
Load16AndSubtractZeroPoint(input_data + i + 16, in_zero_point);
|
||||
int16x8_t in_reluish_0 = vshlq_s16(in_0_1.val[0], extra_input_shift_vec);
|
||||
int16x8_t in_reluish_1 = vshlq_s16(in_0_1.val[1], extra_input_shift_vec);
|
||||
int16x8_t in_reluish_2 = vshlq_s16(in_2_3.val[0], extra_input_shift_vec);
|
||||
int16x8_t in_reluish_3 = vshlq_s16(in_2_3.val[1], extra_input_shift_vec);
|
||||
in_reluish_0 = vaddq_s16(in_reluish_0, three_in_vec);
|
||||
in_reluish_1 = vaddq_s16(in_reluish_1, three_in_vec);
|
||||
in_reluish_2 = vaddq_s16(in_reluish_2, three_in_vec);
|
||||
in_reluish_3 = vaddq_s16(in_reluish_3, three_in_vec);
|
||||
in_reluish_0 = vminq_s16(in_reluish_0, six_in_vec);
|
||||
in_reluish_1 = vminq_s16(in_reluish_1, six_in_vec);
|
||||
in_reluish_2 = vminq_s16(in_reluish_2, six_in_vec);
|
||||
in_reluish_3 = vminq_s16(in_reluish_3, six_in_vec);
|
||||
in_reluish_0 = vmaxq_s16(in_reluish_0, zero);
|
||||
in_reluish_1 = vmaxq_s16(in_reluish_1, zero);
|
||||
in_reluish_2 = vmaxq_s16(in_reluish_2, zero);
|
||||
in_reluish_3 = vmaxq_s16(in_reluish_3, zero);
|
||||
int16x8_t in_scaled_0 = vqrdmulhq_s16(
|
||||
vshlq_n_s16(in_0_1.val[0], left_shift_on_scaled_input), scale_vec);
|
||||
int16x8_t in_scaled_1 = vqrdmulhq_s16(
|
||||
vshlq_n_s16(in_0_1.val[1], left_shift_on_scaled_input), scale_vec);
|
||||
int16x8_t in_scaled_2 = vqrdmulhq_s16(
|
||||
vshlq_n_s16(in_2_3.val[0], left_shift_on_scaled_input), scale_vec);
|
||||
int16x8_t in_scaled_3 = vqrdmulhq_s16(
|
||||
vshlq_n_s16(in_2_3.val[1], left_shift_on_scaled_input), scale_vec);
|
||||
int16x8_t product_0 = vqrdmulhq_s16(in_scaled_0, in_reluish_0);
|
||||
int16x8_t product_1 = vqrdmulhq_s16(in_scaled_1, in_reluish_1);
|
||||
int16x8_t product_2 = vqrdmulhq_s16(in_scaled_2, in_reluish_2);
|
||||
int16x8_t product_3 = vqrdmulhq_s16(in_scaled_3, in_reluish_3);
|
||||
product_0 = vrshlq_s16(product_0, real_shift_vec);
|
||||
product_1 = vrshlq_s16(product_1, real_shift_vec);
|
||||
product_2 = vrshlq_s16(product_2, real_shift_vec);
|
||||
product_3 = vrshlq_s16(product_3, real_shift_vec);
|
||||
SaturateAndStore(vaddq_s16(product_0, offset_vec), output_data + i + 0);
|
||||
SaturateAndStore(vaddq_s16(product_1, offset_vec), output_data + i + 8);
|
||||
SaturateAndStore(vaddq_s16(product_2, offset_vec), output_data + i + 16);
|
||||
SaturateAndStore(vaddq_s16(product_3, offset_vec), output_data + i + 24);
|
||||
}
|
||||
for (; i <= size - 8; i += 8) {
|
||||
using cpu_backend_gemm::detail::Load8AndSubtractZeroPoint;
|
||||
// See comments in the float NEON HardSwish implementation.
|
||||
int16x8_t in = Load8AndSubtractZeroPoint(input_data + i, in_zero_point);
|
||||
int16x8_t in_reluish = vshlq_s16(in, extra_input_shift_vec);
|
||||
in_reluish = vaddq_s16(in_reluish, three_in_vec);
|
||||
in_reluish = vminq_s16(in_reluish, six_in_vec);
|
||||
in_reluish = vmaxq_s16(zero, in_reluish);
|
||||
int16x8_t in_scaled =
|
||||
vqrdmulhq_s16(vshlq_n_s16(in, left_shift_on_scaled_input), scale_vec);
|
||||
int16x8_t product = vqrdmulhq_s16(in_scaled, in_reluish);
|
||||
product = vrshlq_s16(product, real_shift_vec);
|
||||
SaturateAndStore(vaddq_s16(product, offset_vec), output_data + i);
|
||||
}
|
||||
#endif
|
||||
for (; i < size; i++) {
|
||||
int32_t v = static_cast<int32>(input_data[i]);
|
||||
v -= in_zero_point; // Make zeros - zero again!
|
||||
|
||||
// Computes x + 3 in input * 2^extra_input_shift scale.
|
||||
//
|
||||
// Note: three_in is in that scale already.
|
||||
const int32_t v3 = (v << extra_input_shift) + three_in;
|
||||
|
||||
// Computes hard-swish up to a final scale
|
||||
v *= std::min(six_in, std::max(0, v3));
|
||||
|
||||
// this converts from x * relu6(x+3) in input into x * relu6(x+3) / 6
|
||||
// in output scale.
|
||||
v = MultiplyByQuantizedMultiplierSmallerThanOneExp(v, scale, real_shift);
|
||||
v += offset;
|
||||
output_data[i] = reference_ops::Saturate<QuantizedType>(v);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace optimized_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -130,6 +130,7 @@ TfLiteRegistration* Register_SQUARED_DIFFERENCE();
|
||||
TfLiteRegistration* Register_FILL();
|
||||
TfLiteRegistration* Register_MIRROR_PAD();
|
||||
TfLiteRegistration* Register_QUANTIZE();
|
||||
TfLiteRegistration* Register_HARD_SWISH_REF();
|
||||
|
||||
namespace {
|
||||
|
||||
@ -280,6 +281,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_FILL, Register_FILL());
|
||||
AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD());
|
||||
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE());
|
||||
AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH_REF());
|
||||
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
Loading…
x
Reference in New Issue
Block a user