diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc index 983ceecfcf5..39c2d0d168a 100644 --- a/tensorflow/lite/kernels/activations.cc +++ b/tensorflow/lite/kernels/activations.cc @@ -144,13 +144,21 @@ TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_STATUS(GenericPrepare(context, node)); TfLiteTensor* output = GetOutput(context, node, 0); - if (output->type == kTfLiteUInt8) { + if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { HardSwishData* data = static_cast(node->user_data); HardSwishParams* params = &data->params; const TfLiteTensor* input = GetInput(context, node, 0); // TODO(131260336): Maybe pick a better way to select the denominator shift. // Include input shift into the shift. - const int32_t extra_input_shift = 3; + static constexpr int32_t extra_input_shift = 3; + // Note: optimized implementations will rely on the ability to perform this + // left shift within int16 without overflow. The values being left-shifted + // range in [-255, 255] i.e. just under 2^8 in absolute value, and after the + // left shift they will still be added the 'three_input' value, which is + // safe if they're not greater than 2^14 in absolute value (since 2^15 is + // the magnitude of the boundaries of int16 range). 14-8 == 6, so we + // require extra_input_shift to be no greater than 6. + static_assert(extra_input_shift <= 6, ""); const auto in_scale = input->params.scale; params->input_zero_point = input->params.zero_point; const auto out_scale = output->params.scale; @@ -492,6 +500,7 @@ TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) { } } +template TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) { HardSwishData* data = static_cast(node->user_data); @@ -499,22 +508,47 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { case kTfLiteFloat32: { - reference_ops::HardSwish( - GetTensorShape(input), GetTensorData(input), - GetTensorShape(output), GetTensorData(output)); + if (kernel_type == kReference) { + reference_ops::HardSwish( + GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } else { + optimized_ops::HardSwish( + GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } return kTfLiteOk; } break; case kTfLiteUInt8: { HardSwishParams& params = data->params; - - reference_ops::HardSwish( - params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(output), GetTensorData(output)); + if (kernel_type == kReference) { + reference_ops::HardSwish( + params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } else { + optimized_ops::HardSwish( + params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } + return kTfLiteOk; + } break; + case kTfLiteInt8: { + HardSwishParams& params = data->params; + if (kernel_type == kReference) { + reference_ops::HardSwish( + params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } else { + optimized_ops::HardSwish( + params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } return kTfLiteOk; } break; default: context->ReportError( - context, "Only float32, uint8 are supported currently, got %s.", + context, + "Only float32, uint8 and int8 are supported currently, got %s.", TfLiteTypeGetName(input->type)); return kTfLiteError; } @@ -1070,9 +1104,19 @@ TfLiteRegistration* Register_LEAKY_RELU() { TfLiteRegistration* Register_HARD_SWISH() { static TfLiteRegistration r = { activations::HardSwishInit, activations::HardSwishFree, - activations::HardSwishPrepare, activations::HardSwishEval}; + activations::HardSwishPrepare, + activations::HardSwishEval}; return &r; } + +TfLiteRegistration* Register_HARD_SWISH_REF() { + static TfLiteRegistration r = { + activations::HardSwishInit, activations::HardSwishFree, + activations::HardSwishPrepare, + activations::HardSwishEval}; + return &r; +} + } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/lite/kernels/activations_test.cc b/tensorflow/lite/kernels/activations_test.cc index 10e13d612a8..20199e4008b 100644 --- a/tensorflow/lite/kernels/activations_test.cc +++ b/tensorflow/lite/kernels/activations_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include "tensorflow/lite/interpreter.h" @@ -89,7 +90,7 @@ class FloatActivationsOpModel : public BaseActivationsOpModel { public: using BaseActivationsOpModel::BaseActivationsOpModel; - void SetInput(std::initializer_list data) { + void SetInput(const std::vector& data) { PopulateTensor(input_, data); } std::vector GetOutput() { return ExtractVector(output_); } @@ -119,7 +120,7 @@ class QuantizedActivationsOpModel : public BaseActivationsOpModel { using BaseActivationsOpModel::BaseActivationsOpModel; template - void SetInput(std::initializer_list data) { + void SetInput(const std::vector& data) { QuantizeAndPopulate(input_, data); } template @@ -191,45 +192,99 @@ TEST(FloatActivationsOpTest, Relu6) { })); } -TEST(FloatActivationsOpTest, HardSwish) { +void GenerateUniformRandomVector(int size, float min, float max, + std::minstd_rand* random_engine, + std::vector* result) { + // Never use std::uniform_*_distribution in tests, it's + // implementation-defined. Likewise, don't use std::default_random_engine, + // implementation-defined. Implementation-defined is bad because it means that + // any toolchain update or new platform may run into test failures. + // std::minstd_rand is a standard instantiation of + // std::linear_congruential_engine, the cheapest generator in c++11 stdlib, + // it's good enough here. + result->resize(size); + for (int i = 0; i < size; i++) { + // We don't care whether the `max` value may ever be produced exactly. + // It may actually be thanks to rounding, as std::minstd_rand::modulus + // is 2^31 - 1 is greater than the inverse float epsilon. + float random_value_scaled_0_1 = + (*random_engine)() * + (1.0f / static_cast(std::minstd_rand::modulus)); + (*result)[i] = min + (max - min) * random_value_scaled_0_1; + } +} + +void EvalTestReferenceHardSwish(int size, const std::vector& input, + std::vector* result) { + result->resize(size); + for (int i = 0; i < size; i++) { + const float in = input[i]; + (*result)[i] = in * std::min(6.0f, std::max(0.0f, in + 3)) * (1.0f / 6.0f); + } +} + +void TestFloatHardSwish(int size, std::minstd_rand* random_engine) { + std::vector float_input_values; + const float kMin = -10.0f; + const float kMax = 10.0f; + GenerateUniformRandomVector(size, kMin, kMax, random_engine, + &float_input_values); + std::vector float_ref_output_values; + EvalTestReferenceHardSwish(size, float_input_values, + &float_ref_output_values); FloatActivationsOpModel m(BuiltinOperator_HARD_SWISH, - /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); - m.SetInput({ - 0, -6, 2, 4, // - 3, -2, 10, 1, // - }); + /*input=*/{TensorType_FLOAT32, {1, 1, 1, size}}, + /*output=*/{TensorType_FLOAT32, {1, 1, 1, size}}); + m.SetInput(float_input_values); + m.Invoke(); EXPECT_THAT(m.GetOutput(), - ElementsAreArray(ArrayFloatNear( - {0, 0, 1.66666667, 4, 3., -0.33333333, 10, 0.66666667}))); + ElementsAreArray(ArrayFloatNear(float_ref_output_values))); +} + +template +void TestQuantizedHardSwish(TensorType tensor_type, int size, + std::minstd_rand* random_engine) { + std::vector float_input_values; + const float kMin = -10.0f; + const float kMax = 10.0f; + GenerateUniformRandomVector(size, kMin, kMax, random_engine, + &float_input_values); + const float kOutMin = -3; + const float kOutMax = kMax; + std::vector float_ref_output_values; + EvalTestReferenceHardSwish(size, float_input_values, + &float_ref_output_values); + QuantizedActivationsOpModel m( + BuiltinOperator_HARD_SWISH, + /*input=*/{tensor_type, {1, 1, 1, size}, kMin, kMax}, + /*output=*/{tensor_type, {1, 1, 1, size}, kOutMin, kOutMax}); + m.SetInput(float_input_values); + + m.Invoke(); + // The numerical error for any 8bit quantized function is at least one half + // times the quantization step: 0.5 * (kOutMax - kOutMin) / 256. + // To that we add again the quantization step (kOutMax - kOutMin) / 256 + // to allow for an off-by-one rounding error. + const float kTolerance = (kOutMax - kOutMin) * (1.5f / 256.f); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(float_ref_output_values, kTolerance))); +} + +TEST(FloatActivationsOpTest, HardSwish) { + std::minstd_rand random_engine; + for (int size : {1, 2, 3, 4, 10, 20, 30, 40, 100}) { + TestFloatHardSwish(size, &random_engine); + } } TEST(QuantizedActivationsOpTest, HardSwish) { - const float kMin = -10; - const float kMax = 15; - const float kOutMin = -3; - const float kOutMax = kMax; - QuantizedActivationsOpModel m( - BuiltinOperator_HARD_SWISH, - /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, kMin, kMax}, - /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, kOutMin, kOutMax}); - m.SetInput({ - 0, -10, 2, 4, // - 3, -2, 10, 15, // - }); - auto r = - Dequantize(m.ExtractVector(0), m.GetScale(0), m.GetZeroPoint(0)); - for (int i = 0; i < r.size(); i++) { - LOG(INFO) << r[i]; + std::minstd_rand random_engine; + for (int size : {1, 2, 3, 4, 10, 20, 30, 40, 100}) { + TestQuantizedHardSwish(TensorType_UINT8, size, &random_engine); + TestQuantizedHardSwish(TensorType_INT8, size, &random_engine); } - - m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), - ElementsAreArray( - ArrayFloatNear({0, 0, 1.66666667, 4, 3., -0.33333333, 10, 15}, - (kOutMax - kOutMin) / 1. / 256.))); - EXPECT_THAT(m.GetOutput(), - ElementsAreArray({42, 42, 65, 99, 85, 37, 184, 255})); } TEST(FloatActivationsOpTest, Tanh) { diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index ec1af798dfd..e2d318b1372 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -5677,6 +5677,204 @@ inline void Requantize(const uint8_t* input_data, int32_t size, } } +inline void HardSwish(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + gemmlowp::ScopedProfilingLabel label("HardSwish/Float"); + auto size = MatchingFlatSize(input_shape, output_shape); + int i = 0; +#ifdef USE_NEON + const float32x4_t zero = vdupq_n_f32(0.0f); + const float32x4_t three = vdupq_n_f32(3.0f); + const float32x4_t six = vdupq_n_f32(6.0f); + const float32x4_t one_sixth = vdupq_n_f32(1.0f / 6.0f); + + for (; i <= size - 16; i += 16) { + // 4x partially unrolled version of the loop below. Refer to its comments. + const float32x4_t in_0 = vld1q_f32(input_data + i + 0); + const float32x4_t in_1 = vld1q_f32(input_data + i + 4); + const float32x4_t in_2 = vld1q_f32(input_data + i + 8); + const float32x4_t in_3 = vld1q_f32(input_data + i + 12); + const float32x4_t in_scaled_0 = vmulq_f32(in_0, one_sixth); + const float32x4_t in_scaled_1 = vmulq_f32(in_1, one_sixth); + const float32x4_t in_scaled_2 = vmulq_f32(in_2, one_sixth); + const float32x4_t in_scaled_3 = vmulq_f32(in_3, one_sixth); + const float32x4_t in_reluish_0 = + vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_0, three))); + const float32x4_t in_reluish_1 = + vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_1, three))); + const float32x4_t in_reluish_2 = + vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_2, three))); + const float32x4_t in_reluish_3 = + vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_3, three))); + const float32x4_t product_0 = vmulq_f32(in_scaled_0, in_reluish_0); + const float32x4_t product_1 = vmulq_f32(in_scaled_1, in_reluish_1); + const float32x4_t product_2 = vmulq_f32(in_scaled_2, in_reluish_2); + const float32x4_t product_3 = vmulq_f32(in_scaled_3, in_reluish_3); + vst1q_f32(output_data + i + 0, product_0); + vst1q_f32(output_data + i + 4, product_1); + vst1q_f32(output_data + i + 8, product_2); + vst1q_f32(output_data + i + 12, product_3); + } + for (; i <= size - 4; i += 4) { + // The expression to be computed is: + // out = one_sixth * in * min(six, max(zero, (in + three))) + // We structure the AST to have two roughly balanced, independent branches: + // - Multiplication: in_scaled = one_sixth * in. + // - Addition and clamping: in_reluish = min(six, max(zero, (in + three))). + // Then the remaining multiplication at the root of the tree. + const float32x4_t in = vld1q_f32(input_data + i); + const float32x4_t in_scaled = vmulq_f32(in, one_sixth); + const float32x4_t in_reluish = + vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in, three))); + const float32x4_t product = vmulq_f32(in_scaled, in_reluish); + vst1q_f32(output_data + i, product); + } +#endif + for (; i < size; i++) { + const float in = input_data[i]; + output_data[i] = + in * std::min(6.0f, std::max(0.0f, in + 3.0f)) * (1.0f / 6.0f); + } +} + +#ifdef USE_NEON +inline void SaturateAndStore(int16x8_t src, std::uint8_t* dst) { + // Narrow values down to 8 bit unsigned, saturating. + uint8x8_t res8 = vqmovun_s16(src); + // Store results to destination. + vst1_u8(dst, res8); +} + +inline void SaturateAndStore(int16x8_t src, std::int8_t* dst) { + // Narrow values down to 8 bit unsigned, saturating. + int8x8_t res8 = vqmovn_s16(src); + // Store results to destination. + vst1_s8(dst, res8); +} +#endif + +template +inline void HardSwish(const HardSwishParams& params, + const RuntimeShape& input_shape, + const QuantizedType* input_data, + const RuntimeShape& output_shape, + QuantizedType* output_data) { + gemmlowp::ScopedProfilingLabel label("HardSwish/Quantized"); + // Goal: (x * relu6(x+3))/6 + const int size = MatchingFlatSize(input_shape, output_shape); + const int32_t extra_input_shift = params.clip_input_shift; + const auto in_zero_point = params.input_zero_point; + const auto three_in = params.three_input; + const auto six_in = params.six_input; + const auto real_shift = params.shift; + const auto scale = params.scale; + const auto offset = params.output_offset; + int i = 0; +#ifdef USE_NEON + const int16x8_t extra_input_shift_vec = vdupq_n_s16(extra_input_shift); + const int16x8_t three_in_vec = vdupq_n_s16(three_in); + const int16x8_t six_in_vec = vdupq_n_s16(six_in); + // The quantization params of this op are designed around a reference + // implementation that performs plain integer multiplication, not + // fixed-point multiplication. The 16-bit fixed-point multiplications + // that we use here, vqrdmulhq_s16, differ from that by an (rounding) + // right shift by 15 bits. So in terms of scale and leaving aside + // accuracy considerations, we could simply compensate for that by + // adding 15 to real_shift. Doing so results in approximately correct results, + // but there is high inaccuracy in the low bits. That is because unlike + // the integer multiplications done in the reference code, our fixed-point + // multiplication are destructive of low bits. In order to have accurate + // enough results, we move some of that bit-shifting from being applied to + // the result to being applied to one of the operands of these fixed-point + // multiplications, before the information in the low bits is destroyed. + // Fortunately, one of the operands is by construction smaller than 2^8 + // in absolute value, so it's safe to left-shift it by 7 bits. + static constexpr int left_shift_on_scaled_input = 7; + // We now adjust the tweak to real_shift accordingly: instead of adding 15, + // we only add (15 - left_shift_on_scaled_input). + const int16x8_t real_shift_vec = + vdupq_n_s16(15 - left_shift_on_scaled_input + real_shift); + const int16x8_t scale_vec = vdupq_n_s16((scale + (1 << 15)) >> 16); + const int16x8_t offset_vec = vdupq_n_s16(offset); + const int16x8_t zero = vdupq_n_s16(0); + for (; i <= size - 32; i += 32) { + using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint; + int16x8x2_t in_0_1 = + Load16AndSubtractZeroPoint(input_data + i + 0, in_zero_point); + int16x8x2_t in_2_3 = + Load16AndSubtractZeroPoint(input_data + i + 16, in_zero_point); + int16x8_t in_reluish_0 = vshlq_s16(in_0_1.val[0], extra_input_shift_vec); + int16x8_t in_reluish_1 = vshlq_s16(in_0_1.val[1], extra_input_shift_vec); + int16x8_t in_reluish_2 = vshlq_s16(in_2_3.val[0], extra_input_shift_vec); + int16x8_t in_reluish_3 = vshlq_s16(in_2_3.val[1], extra_input_shift_vec); + in_reluish_0 = vaddq_s16(in_reluish_0, three_in_vec); + in_reluish_1 = vaddq_s16(in_reluish_1, three_in_vec); + in_reluish_2 = vaddq_s16(in_reluish_2, three_in_vec); + in_reluish_3 = vaddq_s16(in_reluish_3, three_in_vec); + in_reluish_0 = vminq_s16(in_reluish_0, six_in_vec); + in_reluish_1 = vminq_s16(in_reluish_1, six_in_vec); + in_reluish_2 = vminq_s16(in_reluish_2, six_in_vec); + in_reluish_3 = vminq_s16(in_reluish_3, six_in_vec); + in_reluish_0 = vmaxq_s16(in_reluish_0, zero); + in_reluish_1 = vmaxq_s16(in_reluish_1, zero); + in_reluish_2 = vmaxq_s16(in_reluish_2, zero); + in_reluish_3 = vmaxq_s16(in_reluish_3, zero); + int16x8_t in_scaled_0 = vqrdmulhq_s16( + vshlq_n_s16(in_0_1.val[0], left_shift_on_scaled_input), scale_vec); + int16x8_t in_scaled_1 = vqrdmulhq_s16( + vshlq_n_s16(in_0_1.val[1], left_shift_on_scaled_input), scale_vec); + int16x8_t in_scaled_2 = vqrdmulhq_s16( + vshlq_n_s16(in_2_3.val[0], left_shift_on_scaled_input), scale_vec); + int16x8_t in_scaled_3 = vqrdmulhq_s16( + vshlq_n_s16(in_2_3.val[1], left_shift_on_scaled_input), scale_vec); + int16x8_t product_0 = vqrdmulhq_s16(in_scaled_0, in_reluish_0); + int16x8_t product_1 = vqrdmulhq_s16(in_scaled_1, in_reluish_1); + int16x8_t product_2 = vqrdmulhq_s16(in_scaled_2, in_reluish_2); + int16x8_t product_3 = vqrdmulhq_s16(in_scaled_3, in_reluish_3); + product_0 = vrshlq_s16(product_0, real_shift_vec); + product_1 = vrshlq_s16(product_1, real_shift_vec); + product_2 = vrshlq_s16(product_2, real_shift_vec); + product_3 = vrshlq_s16(product_3, real_shift_vec); + SaturateAndStore(vaddq_s16(product_0, offset_vec), output_data + i + 0); + SaturateAndStore(vaddq_s16(product_1, offset_vec), output_data + i + 8); + SaturateAndStore(vaddq_s16(product_2, offset_vec), output_data + i + 16); + SaturateAndStore(vaddq_s16(product_3, offset_vec), output_data + i + 24); + } + for (; i <= size - 8; i += 8) { + using cpu_backend_gemm::detail::Load8AndSubtractZeroPoint; + // See comments in the float NEON HardSwish implementation. + int16x8_t in = Load8AndSubtractZeroPoint(input_data + i, in_zero_point); + int16x8_t in_reluish = vshlq_s16(in, extra_input_shift_vec); + in_reluish = vaddq_s16(in_reluish, three_in_vec); + in_reluish = vminq_s16(in_reluish, six_in_vec); + in_reluish = vmaxq_s16(zero, in_reluish); + int16x8_t in_scaled = + vqrdmulhq_s16(vshlq_n_s16(in, left_shift_on_scaled_input), scale_vec); + int16x8_t product = vqrdmulhq_s16(in_scaled, in_reluish); + product = vrshlq_s16(product, real_shift_vec); + SaturateAndStore(vaddq_s16(product, offset_vec), output_data + i); + } +#endif + for (; i < size; i++) { + int32_t v = static_cast(input_data[i]); + v -= in_zero_point; // Make zeros - zero again! + + // Computes x + 3 in input * 2^extra_input_shift scale. + // + // Note: three_in is in that scale already. + const int32_t v3 = (v << extra_input_shift) + three_in; + + // Computes hard-swish up to a final scale + v *= std::min(six_in, std::max(0, v3)); + + // this converts from x * relu6(x+3) in input into x * relu6(x+3) / 6 + // in output scale. + v = MultiplyByQuantizedMultiplierSmallerThanOneExp(v, scale, real_shift); + v += offset; + output_data[i] = reference_ops::Saturate(v); + } +} + } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index b9467b99cc8..fc0cbb3f251 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -130,6 +130,7 @@ TfLiteRegistration* Register_SQUARED_DIFFERENCE(); TfLiteRegistration* Register_FILL(); TfLiteRegistration* Register_MIRROR_PAD(); TfLiteRegistration* Register_QUANTIZE(); +TfLiteRegistration* Register_HARD_SWISH_REF(); namespace { @@ -280,6 +281,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_FILL, Register_FILL()); AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD()); AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE()); + AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH_REF()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default.