From 96b38db29c3ad3b1c1397f57bca85e3df3a1ac5f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 2 Aug 2019 16:06:17 -0700 Subject: [PATCH] Optimize 8bit Softmax op handler PiperOrigin-RevId: 261410252 --- tensorflow/lite/kernels/activations.cc | 104 +++----- .../internal/optimized/legacy_optimized_ops.h | 202 +++++++++++++++ .../internal/optimized/optimized_ops.h | 237 ++++-------------- .../internal/softmax_quantized_test.cc | 7 +- tensorflow/lite/kernels/internal/types.h | 4 + 5 files changed, 298 insertions(+), 256 deletions(-) diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc index 7efd5df07fb..793f90b21d5 100644 --- a/tensorflow/lite/kernels/activations.cc +++ b/tensorflow/lite/kernels/activations.cc @@ -55,6 +55,11 @@ struct OpData { uint8_t* table_zero = nullptr; }; +struct SoftmaxOpData { + struct SoftmaxParams params = {}; + float table[256]; +}; + struct LogSoftmaxOpData : public OpData { int32_t reverse_scaling_divisor = 0; int32_t reverse_scaling_right_shift = 0; @@ -131,6 +136,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { return new OpData; } +void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) { + return new SoftmaxOpData; +} + +void SoftmaxFree(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + void* LogSoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) { return new LogSoftmaxOpData; @@ -363,7 +376,7 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - OpData* data = reinterpret_cast(node->user_data); + SoftmaxOpData* data = reinterpret_cast(node->user_data); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -375,16 +388,11 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4); if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { - if (CheckOutputQuantParams(context, input, output) == kTfLiteError) { - return kTfLiteError; - } - - static const int kScaledDiffIntegerBits = 5; - tflite::PreprocessSoftmaxScaling( - params->beta, input->params.scale, kScaledDiffIntegerBits, - &data->input_multiplier, &data->input_left_shift); - data->diff_min = -1.0 * tflite::CalculateInputRadius( - kScaledDiffIntegerBits, data->input_left_shift); + data->params.table = data->table; + optimized_ops::PopulateSoftmaxLookupTable( + &data->params, input->params.scale, params->beta); + data->params.zero_point = output->params.zero_point; + data->params.scale = output->params.scale; } return context->ResizeTensor(context, output, @@ -749,61 +757,25 @@ TfLiteStatus SoftmaxFloat(TfLiteContext* context, const TfLiteTensor* input, } } -TfLiteStatus SoftmaxQuantizedUint8(TfLiteContext* context, - const TfLiteTensor* input, - TfLiteTensor* output, - TfLiteSoftmaxParams* params, OpData* data) { - switch (NumDimensions(input)) { - case 1: - case 2: - case 3: - case 4: - SoftmaxParams op_params; - op_params.input_multiplier = data->input_multiplier; - op_params.input_left_shift = data->input_left_shift; - op_params.diff_min = data->diff_min; - optimized_ops::Softmax( - op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(output), GetTensorData(output)); - return kTfLiteOk; - default: - context->ReportError( - context, - "Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.", - NumDimensions(input)); - return kTfLiteError; - } -} - -TfLiteStatus SoftmaxQuantizedInt8(TfLiteContext* context, - const TfLiteTensor* input, - TfLiteTensor* output, - TfLiteSoftmaxParams* params, OpData* data) { - switch (NumDimensions(input)) { - case 1: - case 2: - case 3: - case 4: - SoftmaxParams op_params; - op_params.input_multiplier = data->input_multiplier; - op_params.input_left_shift = data->input_left_shift; - op_params.diff_min = data->diff_min; - optimized_integer_ops::Softmax( - op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(output), GetTensorData(output)); - return kTfLiteOk; - default: - context->ReportError( - context, - "Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.", - NumDimensions(input)); - return kTfLiteError; +template +TfLiteStatus SoftmaxQuantized(TfLiteContext* context, const TfLiteTensor* input, + TfLiteTensor* output, SoftmaxOpData* data) { + if (NumDimensions(input) >= 1 && NumDimensions(input) <= 4) { + optimized_ops::Softmax(data->params, GetTensorShape(input), + GetTensorData(input), GetTensorShape(output), + GetTensorData(output)); + return kTfLiteOk; + } else { + context->ReportError( + context, "Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.", + NumDimensions(input)); + return kTfLiteError; } } TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - OpData* data = reinterpret_cast(node->user_data); + SoftmaxOpData* data = reinterpret_cast(node->user_data); const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); @@ -815,10 +787,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { return SoftmaxFloat(context, input, output, params); } case kTfLiteUInt8: { - return SoftmaxQuantizedUint8(context, input, output, params, data); + return SoftmaxQuantized(context, input, output, data); } case kTfLiteInt8: { - return SoftmaxQuantizedInt8(context, input, output, params, data); + return SoftmaxQuantized(context, input, output, data); } default: @@ -1055,9 +1027,9 @@ TfLiteRegistration* Register_LOGISTIC() { } TfLiteRegistration* Register_SOFTMAX() { - static TfLiteRegistration r = {activations::Init, activations::Free, - activations::SoftmaxPrepare, - activations::SoftmaxEval}; + static TfLiteRegistration r = { + activations::SoftmaxInit, activations::SoftmaxFree, + activations::SoftmaxPrepare, activations::SoftmaxEval}; return &r; } diff --git a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h index fbbaad8bf29..b9305169065 100644 --- a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -3972,6 +3972,208 @@ void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, filter_width, filter_height, output_data, output_dims); } +inline void Softmax(const SoftmaxParams& params, + const RuntimeShape& input_shape, const uint8* input_data, + const RuntimeShape& output_shape, uint8* output_data) { + const int32 input_beta_multiplier = params.input_multiplier; + const int32 input_beta_left_shift = params.input_left_shift; + const int diff_min = params.diff_min; + // The representation chosen for the input to the exp() function is Q5.26. + // We need to leave extra space since values that we skip might be as large as + // -32 before multiplying by input_beta_multiplier, and therefore as large as + // -16 afterwards. Note that exp(-8) is definitely not insignificant to + // accumulation, but exp(-16) definitely is. + static const int kScaledDiffIntegerBits = 5; + static const int kAccumulationIntegerBits = 12; + using FixedPointScaledDiff = + gemmlowp::FixedPoint; + using FixedPointAccum = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + + gemmlowp::ScopedProfilingLabel label("Softmax/8bit"); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + + for (int b = 0; b < outer_size; ++b) { + const uint8* input_data_ptr = input_data + b * depth; + uint8* output_data_ptr = output_data + b * depth; + + // Determine the largest entry in the current row + uint8 max_in_row = 0; + { + int c = 0; +#ifdef USE_NEON + uint8x16_t max16_0 = vdupq_n_u8(0); + uint8x16_t max16_1 = vdupq_n_u8(0); + for (; c <= depth - 32; c += 32) { + max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0)); + max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16)); + } + uint8x16_t max16 = vmaxq_u8(max16_0, max16_1); + if (c <= depth - 16) { + max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c)); + c += 16; + } + uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16)); + if (c <= depth - 8) { + max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c)); + c += 8; + } + uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4)); + uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2)); + uint8x8_t max1 = vpmax_u8(max2, max2); + max_in_row = vget_lane_u8(max1, 0); +#endif + for (; c < depth; ++c) { + max_in_row = std::max(max_in_row, input_data_ptr[c]); + } + } + +#ifdef USE_NEON + using FixedPointAccumInt32x4 = + gemmlowp::FixedPoint; + using FixedPointScaledDiffInt32x4 = + gemmlowp::FixedPoint; + using FixedPoint0Int32x4 = gemmlowp::FixedPoint; + FixedPoint0Int32x4 input_beta_multiplier_f0 = + FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier); + int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row); +#endif + + // Compute the sum of exponentials of the differences of entries in the + // current row from the largest entry in the current row. + FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); + { + int c = 0; +#ifdef USE_NEON + int32x4_t diff_min_s32 = vdupq_n_s32(diff_min); + FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero(); + FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero(); + FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero(); + for (; c <= depth - 8; c += 8) { + uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c)); + int16x8_t input_diff_s16 = + vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16); + int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16)); + int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16)); + int32x4_t mask_0 = + gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32); + int32x4_t mask_1 = + gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32); + FixedPointScaledDiffInt32x4 scaled_diff_0 = + input_beta_multiplier_f0 * + FixedPointScaledDiffInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift)); + FixedPointScaledDiffInt32x4 scaled_diff_1 = + input_beta_multiplier_f0 * + FixedPointScaledDiffInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift)); + FixedPointAccumInt32x4 exps_0 = + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_0)); + FixedPointAccumInt32x4 exps_1 = + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_1)); + FixedPointAccumInt32x4 masked_exps_0 = + SelectUsingMask(mask_0, exps_0, zeros); + FixedPointAccumInt32x4 masked_exps_1 = + SelectUsingMask(mask_1, exps_1, zeros); + sum_of_exps_0 = sum_of_exps_0 + masked_exps_0; + sum_of_exps_1 = sum_of_exps_1 + masked_exps_1; + } + int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw(); + int32x2_t sum_of_exps_reduced_2 = + vadd_s32(vget_low_s32(sum_of_exps_reduced_4), + vget_high_s32(sum_of_exps_reduced_4)); + int32x2_t sum_of_exps_reduced_1 = + vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2); + sum_of_exps = + FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0)); +#endif + for (; c < depth; ++c) { + int32 input_diff = static_cast(input_data_ptr[c]) - max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + sum_of_exps = + sum_of_exps + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_f8)); + } + } + } + + // Compute the fixed-point multiplier and shift that we need to apply to + // perform a division by the above-computed sum-of-exponentials. + int num_bits_over_unit = 0; + FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal( + sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit)); + + // Compute the quotients of exponentials of differences of entries in the + // current row from the largest entry, over the previously-computed sum of + // exponentials. + { + int c = 0; +#ifdef USE_NEON + int16x8_t diff_min_s16 = vdupq_n_s16(diff_min); + for (; c <= depth - 8; c += 8) { + uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c)); + int16x8_t input_diff_s16 = + vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16); + int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16)); + int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16)); + uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16)); + FixedPointScaledDiffInt32x4 scaled_diff_0 = + input_beta_multiplier_f0 * + FixedPointScaledDiffInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift)); + FixedPointScaledDiffInt32x4 scaled_diff_1 = + input_beta_multiplier_f0 * + FixedPointScaledDiffInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift)); + FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0); + FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1); + int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT( + vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()), + num_bits_over_unit + 31 - 8); + int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT( + vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()), + num_bits_over_unit + 31 - 8); + int16x8_t output_s16 = + vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1)); + uint8x8_t output_u8 = vqmovun_s16(output_s16); + uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0)); + vst1_u8(output_data_ptr + c, masked_output); + } +#endif + for (; c < depth; ++c) { + int32 input_diff = static_cast(input_data_ptr[c]) - max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + + FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); + int32 unsat_output = gemmlowp::RoundingDivideByPOT( + (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); + + output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0); + + } else { + output_data_ptr[c] = 0; + } + } + } + } +} + inline void Softmax(const float* input_data, const RuntimeShape& input_shape, float beta, float* output_data, const RuntimeShape& output_shape) { diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 1b9de834e60..6f246e7a169 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -3493,205 +3493,64 @@ inline void Softmax(const SoftmaxParams& params, out_mat.array().rowwise() *= scale; } -inline void Softmax(const SoftmaxParams& params, - const RuntimeShape& input_shape, const uint8* input_data, - const RuntimeShape& output_shape, uint8* output_data) { - const int32 input_beta_multiplier = params.input_multiplier; - const int32 input_beta_left_shift = params.input_left_shift; - const int diff_min = params.diff_min; - // The representation chosen for the input to the exp() function is Q5.26. - // We need to leave extra space since values that we skip might be as large as - // -32 before multiplying by input_beta_multiplier, and therefore as large as - // -16 afterwards. Note that exp(-8) is definitely not insignificant to - // accumulation, but exp(-16) definitely is. - static const int kScaledDiffIntegerBits = 5; - static const int kAccumulationIntegerBits = 12; - using FixedPointScaledDiff = - gemmlowp::FixedPoint; - using FixedPointAccum = gemmlowp::FixedPoint; - using FixedPoint0 = gemmlowp::FixedPoint; +inline int32_t QuantizeSoftmaxOutput(int8_t* output_data, float prob_rescaled, + int32_t zero_point) { + const int32_t prob_rnd = static_cast(std::round(prob_rescaled)); + return prob_rnd + zero_point; +} - gemmlowp::ScopedProfilingLabel label("Softmax/8bit"); +inline int32_t QuantizeSoftmaxOutput(uint8_t* output_data, float prob_rescaled, + int32_t zero_point) { + return static_cast(prob_rescaled + 0.5); +} + +inline void PopulateSoftmaxLookupTable(SoftmaxParams* data, float input_scale, + float beta) { + const float scale = -input_scale * beta; + const int32_t max_uint8 = std::numeric_limits::max(); + for (int32_t val = 0; val <= max_uint8; ++val) { + data->table[max_uint8 - val] = expf(scale * val); + } +} + +template +inline void Softmax(const SoftmaxParams& params, + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& output_shape, T* output_data) { const int trailing_dim = input_shape.DimensionsCount() - 1; - const int outer_size = + const int excluding_last_dim = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); - const int depth = + const int last_dim = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); - for (int b = 0; b < outer_size; ++b) { - const uint8* input_data_ptr = input_data + b * depth; - uint8* output_data_ptr = output_data + b * depth; - - // Determine the largest entry in the current row - uint8 max_in_row = 0; - { - int c = 0; -#ifdef USE_NEON - uint8x16_t max16_0 = vdupq_n_u8(0); - uint8x16_t max16_1 = vdupq_n_u8(0); - for (; c <= depth - 32; c += 32) { - max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0)); - max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16)); - } - uint8x16_t max16 = vmaxq_u8(max16_0, max16_1); - if (c <= depth - 16) { - max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c)); - c += 16; - } - uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16)); - if (c <= depth - 8) { - max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c)); - c += 8; - } - uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4)); - uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2)); - uint8x8_t max1 = vpmax_u8(max2, max2); - max_in_row = vget_lane_u8(max1, 0); -#endif - for (; c < depth; ++c) { - max_in_row = std::max(max_in_row, input_data_ptr[c]); - } + const int32_t clamp_max = std::numeric_limits::max(); + const int32_t clamp_min = std::numeric_limits::min(); + for (int i = 0; i < excluding_last_dim; ++i) { + int32_t max_val = std::numeric_limits::min(); + // Find max quantized value. + for (int j = 0; j < last_dim; ++j) { + max_val = std::max(max_val, static_cast(input_data[j])); } -#ifdef USE_NEON - using FixedPointAccumInt32x4 = - gemmlowp::FixedPoint; - using FixedPointScaledDiffInt32x4 = - gemmlowp::FixedPoint; - using FixedPoint0Int32x4 = gemmlowp::FixedPoint; - FixedPoint0Int32x4 input_beta_multiplier_f0 = - FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier); - int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row); -#endif - - // Compute the sum of exponentials of the differences of entries in the - // current row from the largest entry in the current row. - FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); - { - int c = 0; -#ifdef USE_NEON - int32x4_t diff_min_s32 = vdupq_n_s32(diff_min); - FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero(); - FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero(); - FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero(); - for (; c <= depth - 8; c += 8) { - uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c)); - int16x8_t input_diff_s16 = - vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16); - int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16)); - int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16)); - int32x4_t mask_0 = - gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32); - int32x4_t mask_1 = - gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32); - FixedPointScaledDiffInt32x4 scaled_diff_0 = - input_beta_multiplier_f0 * - FixedPointScaledDiffInt32x4::FromRaw( - gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift)); - FixedPointScaledDiffInt32x4 scaled_diff_1 = - input_beta_multiplier_f0 * - FixedPointScaledDiffInt32x4::FromRaw( - gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift)); - FixedPointAccumInt32x4 exps_0 = - gemmlowp::Rescale( - exp_on_negative_values(scaled_diff_0)); - FixedPointAccumInt32x4 exps_1 = - gemmlowp::Rescale( - exp_on_negative_values(scaled_diff_1)); - FixedPointAccumInt32x4 masked_exps_0 = - SelectUsingMask(mask_0, exps_0, zeros); - FixedPointAccumInt32x4 masked_exps_1 = - SelectUsingMask(mask_1, exps_1, zeros); - sum_of_exps_0 = sum_of_exps_0 + masked_exps_0; - sum_of_exps_1 = sum_of_exps_1 + masked_exps_1; - } - int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw(); - int32x2_t sum_of_exps_reduced_2 = - vadd_s32(vget_low_s32(sum_of_exps_reduced_4), - vget_high_s32(sum_of_exps_reduced_4)); - int32x2_t sum_of_exps_reduced_1 = - vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2); - sum_of_exps = - FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0)); -#endif - for (; c < depth; ++c) { - int32 input_diff = static_cast(input_data_ptr[c]) - max_in_row; - if (input_diff >= diff_min) { - const int32 input_diff_rescaled = - MultiplyByQuantizedMultiplierGreaterThanOne( - input_diff, input_beta_multiplier, input_beta_left_shift); - const FixedPointScaledDiff scaled_diff_f8 = - FixedPointScaledDiff::FromRaw(input_diff_rescaled); - sum_of_exps = - sum_of_exps + gemmlowp::Rescale( - exp_on_negative_values(scaled_diff_f8)); - } - } + float sum_exp = 0.0f; + const int32_t max_uint8 = std::numeric_limits::max(); + const float* table_offset = ¶ms.table[max_uint8 - max_val]; + // Calculate normalizer sum(exp(x)). + for (int j = 0; j < last_dim; ++j) { + sum_exp += table_offset[input_data[j]]; } - // Compute the fixed-point multiplier and shift that we need to apply to - // perform a division by the above-computed sum-of-exponentials. - int num_bits_over_unit = 0; - FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal( - sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit)); - - // Compute the quotients of exponentials of differences of entries in the - // current row from the largest entry, over the previously-computed sum of - // exponentials. - { - int c = 0; -#ifdef USE_NEON - int16x8_t diff_min_s16 = vdupq_n_s16(diff_min); - for (; c <= depth - 8; c += 8) { - uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c)); - int16x8_t input_diff_s16 = - vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16); - int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16)); - int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16)); - uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16)); - FixedPointScaledDiffInt32x4 scaled_diff_0 = - input_beta_multiplier_f0 * - FixedPointScaledDiffInt32x4::FromRaw( - gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift)); - FixedPointScaledDiffInt32x4 scaled_diff_1 = - input_beta_multiplier_f0 * - FixedPointScaledDiffInt32x4::FromRaw( - gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift)); - FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0); - FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1); - int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT( - vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()), - num_bits_over_unit + 31 - 8); - int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT( - vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()), - num_bits_over_unit + 31 - 8); - int16x8_t output_s16 = - vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1)); - uint8x8_t output_u8 = vqmovun_s16(output_s16); - uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0)); - vst1_u8(output_data_ptr + c, masked_output); - } -#endif - for (; c < depth; ++c) { - int32 input_diff = static_cast(input_data_ptr[c]) - max_in_row; - if (input_diff >= diff_min) { - const int32 input_diff_rescaled = - MultiplyByQuantizedMultiplierGreaterThanOne( - input_diff, input_beta_multiplier, input_beta_left_shift); - const FixedPointScaledDiff scaled_diff_f8 = - FixedPointScaledDiff::FromRaw(input_diff_rescaled); - - FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); - int32 unsat_output = gemmlowp::RoundingDivideByPOT( - (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); - - output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0); - - } else { - output_data_ptr[c] = 0; - } - } + const float inv_sum_exp = 1.0f / (sum_exp * params.scale); + // Normalize and quantize probabilities. + for (int j = 0; j < last_dim; ++j) { + const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp; + const int32_t prob_quantized = + QuantizeSoftmaxOutput(output_data, prob_rescaled, params.zero_point); + output_data[j] = static_cast( + std::max(std::min(clamp_max, prob_quantized), clamp_min)); } + input_data += last_dim; + output_data += last_dim; } } diff --git a/tensorflow/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/lite/kernels/internal/softmax_quantized_test.cc index ea69f493db1..269dc98e129 100644 --- a/tensorflow/lite/kernels/internal/softmax_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/softmax_quantized_test.cc @@ -124,9 +124,14 @@ void RunOneSoftmaxTest(const uint8* input_data, input_beta_left_shift); SoftmaxParams params; + float table[256]; params.input_multiplier = input_beta_multiplier; params.input_left_shift = input_beta_left_shift; params.diff_min = diff_min; + params.scale = 1.0f / 256; + params.zero_point = 0; + params.table = table; + optimized_ops::PopulateSoftmaxLookupTable(¶ms, input_scale, beta); optimized_ops::Softmax(params, shape_common, input_data, shape_common, optimized_softmax_output.data()); reference_ops::Softmax(params, shape_common, input_data, shape_common, @@ -137,7 +142,7 @@ void RunOneSoftmaxTest(const uint8* input_data, "Optimized vs float reference", false); CheckOutputData(optimized_softmax_output.data(), reference_quant_softmax_output.data(), shape_common, - "Optimized vs quant reference", true); + "Optimized vs quant reference", false); CheckOutputData(reference_quant_softmax_output.data(), reference_float_softmax_output.data(), shape_common, "Quant reference vs float reference", false); diff --git a/tensorflow/lite/kernels/internal/types.h b/tensorflow/lite/kernels/internal/types.h index b786bdeefc2..eb7b630c574 100644 --- a/tensorflow/lite/kernels/internal/types.h +++ b/tensorflow/lite/kernels/internal/types.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_ #include +#include #include #include @@ -985,6 +986,9 @@ struct SoftmaxParams { int32 reverse_scaling_divisor; int32 reverse_scaling_right_shift; int diff_min; + int32_t zero_point; + float scale; + float* table; }; struct SpaceToBatchParams {