diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc index c95fd0e40a4..a7c5604ef64 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc @@ -29,16 +29,88 @@ namespace micro { namespace activations { namespace { -// TODO(b/141176180): This code is currently a strict subset of the portable -// implementation (softmax.cc one directory up). When TFLM implements -// registrations for selective types (e.g. compile without float support), this -// can be removed. Otherwise, any HiFi specific optimizations should land here. +struct OpData { + uint16_t* exp_lut; +}; + +// Number of unique int8 and int16 values. Used in exponent lookup table +// conputation. +constexpr int kInt8Range = + std::numeric_limits::max() - std::numeric_limits::min() + 1; +constexpr int kInt16Range = + std::numeric_limits::max() - std::numeric_limits::min() + 1; +// Each 16-bit precalculated exponent is expressed as a Q0.16 fixedpoint +// value. We special-case e^0 since 1.0 requires 1 integer bit to +// express. +constexpr int kExpFractionalBits = 16; +// e^0 expressed as Q1.15 exceeds the int16_t range, so it must be handled +// specially. +constexpr int kMaxExponentValue = (1 << kExpFractionalBits); + +// Quantized softmax with int8 input and int16 output. +// TODO(b/155656675): Investigate removing const ref params. +inline TfLiteStatus Softmax(const OpData& op_data, + const RuntimeShape& input_shape, + const int8_t* input_data, + const RuntimeShape& output_shape, + int16_t* output_data) { + // The last dimension is depth. Outer size is the the total input size + // divided by depth. + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + + for (int i = 0; i < outer_size; ++i) { + int8_t max_in_row = std::numeric_limits::min(); + for (int c = 0; c < depth; ++c) { + max_in_row = std::max(max_in_row, input_data[i * depth + c]); + } + + uint32_t sum_of_exps = 0; + for (int c = 0; c < depth; ++c) { + TFLITE_DCHECK(max_in_row >= input_data[i * depth + c]); + uint8_t input_diff = max_in_row - input_data[i * depth + c]; + + sum_of_exps += + input_diff == 0 ? kMaxExponentValue : op_data.exp_lut[input_diff]; + } + + // Ensure we cannnot overflow the full_range_output value. We need to + // guarantee that kInt16Range * max(input_data) / sum_of_exps < kInt16Range. + TFLITE_DCHECK(sum_of_exps >= kMaxExponentValue); + + for (int c = 0; c < depth; ++c) { + uint8_t input_diff = max_in_row - input_data[i * depth + c]; + // Special case for diff == 0 + uint32_t unscaled_output = + input_diff == 0 ? kMaxExponentValue : op_data.exp_lut[input_diff]; + int64_t scaled_output = static_cast(unscaled_output) * + static_cast(kInt16Range); + int32_t full_range_output = + scaled_output / sum_of_exps + std::numeric_limits::min(); + // Round up if remainder exceeds half of the divider value. + uint32_t remainder = scaled_output % sum_of_exps; + if (remainder * 2 >= sum_of_exps) { + full_range_output++; + } + output_data[i * depth + c] = static_cast(std::max( + std::min(full_range_output, + static_cast(std::numeric_limits::max())), + static_cast(std::numeric_limits::min()))); + } + } + return kTfLiteOk; +} + +} // namespace TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, const TfLiteTensor* input, TfLiteTensor* output, const TfLiteSoftmaxParams* params, - SoftmaxParams* op_data) { + OpData* op_data) { if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { if (input->type == kTfLiteUInt8) { TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); @@ -55,28 +127,30 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, } } - static const int kScaledDiffIntegerBits = 5; + // Precompute e^(-x * input_scale * beta) for every possible int8 input. + // This computation is used for every iteration of Softmax. We must compute + // using pre-scaled inputs to avoid introducing additional error, while + // restricting our input range to the int8 range. This is valid since beta + // and input scale are constant for a given op in the graph. Skip index 0 + // since that is a special case which requires 1 integer bit instead of 0. + for (int i = 1; i <= kInt8Range; i++) { + float scaled_input = i * input->params.scale; + float exp_value = + std::exp((-scaled_input) * static_cast(params->beta)); - int input_left_shift; - tflite::PreprocessSoftmaxScaling( - static_cast(params->beta), - static_cast(input->params.scale), kScaledDiffIntegerBits, - &op_data->input_multiplier, &input_left_shift); - op_data->input_left_shift = input_left_shift; - op_data->diff_min = - -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits, - op_data->input_left_shift); + float exponent_scaled = + std::round(exp_value * static_cast(1 << kExpFractionalBits)); + op_data->exp_lut[i] = static_cast(exponent_scaled); + } } return kTfLiteOk; } -} // namespace - void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); void* data = nullptr; - if (context->AllocatePersistentBuffer(context, sizeof(SoftmaxParams), - &data) == kTfLiteError) { + if (context->AllocatePersistentBuffer(context, sizeof(OpData), &data) == + kTfLiteError) { return nullptr; } return data; @@ -92,26 +166,34 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, NumDimensions(input) >= 1); TFLITE_DCHECK(node->user_data != nullptr); - SoftmaxParams* op_params = static_cast(node->user_data); + OpData* op_data = static_cast(node->user_data); + + // Allocate an array to precompute exponents over all int8 inputs, applying + // the scale and beta before calculating exp. It is mandatory to apply beta + // and scale here, since each softmax op may have different beta and scale + // values. Beta and scale will remain constant for a given softmax op. + void* allocated_ptr; + TF_LITE_ENSURE_STATUS(context->AllocatePersistentBuffer( + context, kInt8Range * sizeof(int16_t), &allocated_ptr)); + op_data->exp_lut = static_cast(allocated_ptr); TF_LITE_ENSURE_STATUS( - CalculateSoftmaxOpData(context, input, output, params, op_params)); + CalculateSoftmaxOpData(context, input, output, params, op_data)); return kTfLiteOk; } TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { - auto* op_params = static_cast(node->user_data); + auto* op_data = static_cast(node->user_data); const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); if (input->type == kTfLiteInt8 && output->type == kTfLiteInt16) { // TODO(b/155656675): Const ref params can be slow on xtensa. - tflite::reference_ops::Softmax( - *op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(output), GetTensorData(output)); - return kTfLiteOk; + return Softmax(*op_data, GetTensorShape(input), + GetTensorData(input), GetTensorShape(output), + GetTensorData(output)); } else { TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", TfLiteTypeGetName(input->type), input->type);