Change Softmax to use tflite::reference_ops::Softmax instead of xtensa::hifimini::Softmax.
The two functions are the same when instantiated with the same input/output types. PiperOrigin-RevId: 304428220 Change-Id: Id6d27f7a0a459c164944cc2793b12059b4bcbbb4
This commit is contained in:
parent
549434018c
commit
1e9b684f25
@ -26,97 +26,14 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
|
||||
namespace xtensa {
|
||||
namespace hifimini {
|
||||
|
||||
// Quantized softmax with int8 input and int8/int16 output.
|
||||
template <typename OutputT = int8_t>
|
||||
inline void Softmax(const SoftmaxParams& params,
|
||||
const RuntimeShape& input_shape, const int8* input_data,
|
||||
const RuntimeShape& output_shape, OutputT* output_data) {
|
||||
const int32_t input_beta_multiplier = params.input_multiplier;
|
||||
const int32_t input_beta_left_shift = params.input_left_shift;
|
||||
const int diff_min = params.diff_min;
|
||||
// The representation chosen for the input to the exp() function is Q5.26.
|
||||
// We need to leave extra space since values that we skip might be as large as
|
||||
// -32 before multiplying by input_beta_multiplier, and therefore as large as
|
||||
// -16 afterwards. Note that exp(-8) is definitely not insignificant to
|
||||
// accumulation, but exp(-16) definitely is.
|
||||
static const int kScaledDiffIntegerBits = 5;
|
||||
static const int kAccumulationIntegerBits = 12;
|
||||
using FixedPointScaledDiff =
|
||||
gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
|
||||
using FixedPointAccum =
|
||||
gemmlowp::FixedPoint<int32_t, kAccumulationIntegerBits>;
|
||||
using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
|
||||
|
||||
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||
const int outer_size =
|
||||
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||
const int depth =
|
||||
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||
|
||||
for (int i = 0; i < outer_size; ++i) {
|
||||
int8 max_in_row = -128;
|
||||
for (int c = 0; c < depth; ++c) {
|
||||
max_in_row = std::max(max_in_row, input_data[i * depth + c]);
|
||||
}
|
||||
|
||||
FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
|
||||
for (int c = 0; c < depth; ++c) {
|
||||
int32_t input_diff =
|
||||
static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
|
||||
if (input_diff >= diff_min) {
|
||||
const int32_t input_diff_rescaled =
|
||||
MultiplyByQuantizedMultiplierGreaterThanOne(
|
||||
input_diff, input_beta_multiplier, input_beta_left_shift);
|
||||
const FixedPointScaledDiff scaled_diff_f8 =
|
||||
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
||||
sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
|
||||
exp_on_negative_values(scaled_diff_f8));
|
||||
}
|
||||
}
|
||||
|
||||
int num_bits_over_unit;
|
||||
FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
|
||||
sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
|
||||
|
||||
for (int c = 0; c < depth; ++c) {
|
||||
int32_t input_diff =
|
||||
static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
|
||||
if (input_diff >= diff_min) {
|
||||
const int32_t input_diff_rescaled =
|
||||
MultiplyByQuantizedMultiplierGreaterThanOne(
|
||||
input_diff, input_beta_multiplier, input_beta_left_shift);
|
||||
const FixedPointScaledDiff scaled_diff_f8 =
|
||||
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
||||
|
||||
FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
|
||||
const int32_t unsat_output = gemmlowp::RoundingDivideByPOT(
|
||||
(shifted_scale * exp_in_0).raw(),
|
||||
num_bits_over_unit + 31 - (sizeof(OutputT) * 8));
|
||||
// TODO(b/148494470): Handle int32 shifts properly:
|
||||
const int32_t shifted_output =
|
||||
unsat_output -
|
||||
(static_cast<int32_t>(std::numeric_limits<OutputT>::max()) + 1);
|
||||
output_data[i * depth + c] = static_cast<OutputT>(std::max(
|
||||
std::min(shifted_output,
|
||||
static_cast<int32_t>(std::numeric_limits<OutputT>::max())),
|
||||
static_cast<int32_t>(std::numeric_limits<OutputT>::min())));
|
||||
} else {
|
||||
output_data[i * depth + c] = std::numeric_limits<OutputT>::min();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace hifimini
|
||||
} // namespace xtensa
|
||||
|
||||
namespace activations {
|
||||
namespace {
|
||||
|
||||
// TODO(b/141176180): This code is currently a strict subset of the portable
|
||||
// implementation (softmax.cc one directory up). When TFLM implements
|
||||
// registrations for selective types (e.g. compile without float support), this
|
||||
// can be removed. Otherwise, any HiFi specific optimizations should land here.
|
||||
|
||||
// This size will work for both the hotword (1) and ambient music (0):
|
||||
static SoftmaxParams kStaticOpData;
|
||||
|
||||
@ -159,12 +76,11 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
|
||||
void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
||||
const SoftmaxParams& op_params) {
|
||||
if (output->type == kTfLiteInt16) {
|
||||
xtensa::hifimini::Softmax(
|
||||
tflite::reference_ops::Softmax(
|
||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||
|
||||
} else {
|
||||
xtensa::hifimini::Softmax(
|
||||
tflite::reference_ops::Softmax(
|
||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
GetTensorShape(output), GetTensorData<int8_t>(output));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user