Change Softmax to use tflite::reference_ops::Softmax instead of xtensa::hifimini::Softmax.
The two functions are the same when instantiated with the same input/output types. PiperOrigin-RevId: 304428220 Change-Id: Id6d27f7a0a459c164944cc2793b12059b4bcbbb4
This commit is contained in:
parent
549434018c
commit
1e9b684f25
@ -26,97 +26,14 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace micro {
|
namespace micro {
|
||||||
|
|
||||||
namespace xtensa {
|
|
||||||
namespace hifimini {
|
|
||||||
|
|
||||||
// Quantized softmax with int8 input and int8/int16 output.
|
|
||||||
template <typename OutputT = int8_t>
|
|
||||||
inline void Softmax(const SoftmaxParams& params,
|
|
||||||
const RuntimeShape& input_shape, const int8* input_data,
|
|
||||||
const RuntimeShape& output_shape, OutputT* output_data) {
|
|
||||||
const int32_t input_beta_multiplier = params.input_multiplier;
|
|
||||||
const int32_t input_beta_left_shift = params.input_left_shift;
|
|
||||||
const int diff_min = params.diff_min;
|
|
||||||
// The representation chosen for the input to the exp() function is Q5.26.
|
|
||||||
// We need to leave extra space since values that we skip might be as large as
|
|
||||||
// -32 before multiplying by input_beta_multiplier, and therefore as large as
|
|
||||||
// -16 afterwards. Note that exp(-8) is definitely not insignificant to
|
|
||||||
// accumulation, but exp(-16) definitely is.
|
|
||||||
static const int kScaledDiffIntegerBits = 5;
|
|
||||||
static const int kAccumulationIntegerBits = 12;
|
|
||||||
using FixedPointScaledDiff =
|
|
||||||
gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
|
|
||||||
using FixedPointAccum =
|
|
||||||
gemmlowp::FixedPoint<int32_t, kAccumulationIntegerBits>;
|
|
||||||
using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
|
|
||||||
|
|
||||||
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
|
||||||
const int outer_size =
|
|
||||||
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
|
||||||
const int depth =
|
|
||||||
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
|
||||||
|
|
||||||
for (int i = 0; i < outer_size; ++i) {
|
|
||||||
int8 max_in_row = -128;
|
|
||||||
for (int c = 0; c < depth; ++c) {
|
|
||||||
max_in_row = std::max(max_in_row, input_data[i * depth + c]);
|
|
||||||
}
|
|
||||||
|
|
||||||
FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
|
|
||||||
for (int c = 0; c < depth; ++c) {
|
|
||||||
int32_t input_diff =
|
|
||||||
static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
|
|
||||||
if (input_diff >= diff_min) {
|
|
||||||
const int32_t input_diff_rescaled =
|
|
||||||
MultiplyByQuantizedMultiplierGreaterThanOne(
|
|
||||||
input_diff, input_beta_multiplier, input_beta_left_shift);
|
|
||||||
const FixedPointScaledDiff scaled_diff_f8 =
|
|
||||||
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
|
||||||
sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
|
|
||||||
exp_on_negative_values(scaled_diff_f8));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int num_bits_over_unit;
|
|
||||||
FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
|
|
||||||
sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
|
|
||||||
|
|
||||||
for (int c = 0; c < depth; ++c) {
|
|
||||||
int32_t input_diff =
|
|
||||||
static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
|
|
||||||
if (input_diff >= diff_min) {
|
|
||||||
const int32_t input_diff_rescaled =
|
|
||||||
MultiplyByQuantizedMultiplierGreaterThanOne(
|
|
||||||
input_diff, input_beta_multiplier, input_beta_left_shift);
|
|
||||||
const FixedPointScaledDiff scaled_diff_f8 =
|
|
||||||
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
|
||||||
|
|
||||||
FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
|
|
||||||
const int32_t unsat_output = gemmlowp::RoundingDivideByPOT(
|
|
||||||
(shifted_scale * exp_in_0).raw(),
|
|
||||||
num_bits_over_unit + 31 - (sizeof(OutputT) * 8));
|
|
||||||
// TODO(b/148494470): Handle int32 shifts properly:
|
|
||||||
const int32_t shifted_output =
|
|
||||||
unsat_output -
|
|
||||||
(static_cast<int32_t>(std::numeric_limits<OutputT>::max()) + 1);
|
|
||||||
output_data[i * depth + c] = static_cast<OutputT>(std::max(
|
|
||||||
std::min(shifted_output,
|
|
||||||
static_cast<int32_t>(std::numeric_limits<OutputT>::max())),
|
|
||||||
static_cast<int32_t>(std::numeric_limits<OutputT>::min())));
|
|
||||||
} else {
|
|
||||||
output_data[i * depth + c] = std::numeric_limits<OutputT>::min();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace hifimini
|
|
||||||
} // namespace xtensa
|
|
||||||
|
|
||||||
namespace activations {
|
namespace activations {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// TODO(b/141176180): This code is currently a strict subset of the portable
|
||||||
|
// implementation (softmax.cc one directory up). When TFLM implements
|
||||||
|
// registrations for selective types (e.g. compile without float support), this
|
||||||
|
// can be removed. Otherwise, any HiFi specific optimizations should land here.
|
||||||
|
|
||||||
// This size will work for both the hotword (1) and ambient music (0):
|
// This size will work for both the hotword (1) and ambient music (0):
|
||||||
static SoftmaxParams kStaticOpData;
|
static SoftmaxParams kStaticOpData;
|
||||||
|
|
||||||
@ -159,12 +76,11 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
|
|||||||
void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
||||||
const SoftmaxParams& op_params) {
|
const SoftmaxParams& op_params) {
|
||||||
if (output->type == kTfLiteInt16) {
|
if (output->type == kTfLiteInt16) {
|
||||||
xtensa::hifimini::Softmax(
|
tflite::reference_ops::Softmax(
|
||||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||||
GetTensorShape(output), GetTensorData<int16_t>(output));
|
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
xtensa::hifimini::Softmax(
|
tflite::reference_ops::Softmax(
|
||||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||||
GetTensorShape(output), GetTensorData<int8_t>(output));
|
GetTensorShape(output), GetTensorData<int8_t>(output));
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user