optimize softmax quantized with two uint8 lookup table methods.
PiperOrigin-RevId: 315225925 Change-Id: I787ce554a45ccdf1d29834c6b0f74c8f409e30d4
This commit is contained in:
parent
4e7ce793d9
commit
b997946576
|
@ -60,13 +60,18 @@ struct OpData {
|
||||||
|
|
||||||
struct SoftmaxOpData {
|
struct SoftmaxOpData {
|
||||||
struct SoftmaxParams params = {};
|
struct SoftmaxParams params = {};
|
||||||
float table[256]{};
|
float table[256];
|
||||||
const int size_of_lut = 513;
|
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
|
||||||
int16_t exp_lut[513]{}; // int16 LUT for exp(x), where x uniform distributed
|
uint8_t uint8_table1[256];
|
||||||
// between [-10.0 , 0.0]
|
uint8_t uint8_table2[256];
|
||||||
int16_t one_over_one_plus_x_lut[513]{}; // int16 LUT for 1 / (1 + x), where
|
#endif
|
||||||
// x uniform distributed between
|
static constexpr int kInt16LUTArraySize = 513;
|
||||||
// [0.0 , 1.0]
|
int16_t exp_lut[kInt16LUTArraySize]; // int16 LUT for exp(x), where x uniform
|
||||||
|
// distributed between [-10.0 , 0.0]
|
||||||
|
int16_t one_over_one_plus_x_lut[kInt16LUTArraySize]; // int16 LUT for 1 /
|
||||||
|
// (1 + x), where x
|
||||||
|
// uniform distributed
|
||||||
|
// between [0.0 , 1.0]
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LogSoftmaxOpData : public OpData {
|
struct LogSoftmaxOpData : public OpData {
|
||||||
|
@ -134,29 +139,6 @@ void PopulateLookupTable(struct OpData* data, const TfLiteTensor* input,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if __aarch64__ && __clang__
|
|
||||||
namespace {
|
|
||||||
// Looks up each element of <indices> in <table>, returns them in a vector.
|
|
||||||
inline uint8x16_t aarch64_lookup_vector(const uint8x16x4_t table[4],
|
|
||||||
uint8x16_t indices) {
|
|
||||||
// Look up in 1st quarter of the table: top 2 bits of indices == 00
|
|
||||||
uint8x16_t output1 = vqtbl4q_u8(table[0], indices);
|
|
||||||
// Look up in 2nd quarter of the table: top 2 bits of indices == 01
|
|
||||||
uint8x16_t output2 =
|
|
||||||
vqtbl4q_u8(table[1], veorq_u8(indices, vdupq_n_u8(0x40)));
|
|
||||||
// Look up in 3rd quarter of the table: top 2 bits of indices == 10
|
|
||||||
uint8x16_t output3 =
|
|
||||||
vqtbl4q_u8(table[2], veorq_u8(indices, vdupq_n_u8(0x80)));
|
|
||||||
// Look up in 4th quarter of the table: top 2 bits of indices == 11
|
|
||||||
uint8x16_t output4 =
|
|
||||||
vqtbl4q_u8(table[3], veorq_u8(indices, vdupq_n_u8(0xc0)));
|
|
||||||
|
|
||||||
// Combine result of the 4 lookups.
|
|
||||||
return vorrq_u8(vorrq_u8(output1, output2), vorrq_u8(output3, output4));
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// TODO(b/143696793): move this to optimized_ops.
|
// TODO(b/143696793): move this to optimized_ops.
|
||||||
void EvalUsingLookupTable(struct OpData* data, const TfLiteTensor* input,
|
void EvalUsingLookupTable(struct OpData* data, const TfLiteTensor* input,
|
||||||
TfLiteTensor* output) {
|
TfLiteTensor* output) {
|
||||||
|
@ -182,7 +164,7 @@ void EvalUsingLookupTable(struct OpData* data, const TfLiteTensor* input,
|
||||||
size / vectorized_16_loop_step * vectorized_16_loop_step;
|
size / vectorized_16_loop_step * vectorized_16_loop_step;
|
||||||
for (; i < vectorized_16_loop_end; i += vectorized_16_loop_step) {
|
for (; i < vectorized_16_loop_end; i += vectorized_16_loop_step) {
|
||||||
uint8x16_t input = vld1q_u8(input_data + i);
|
uint8x16_t input = vld1q_u8(input_data + i);
|
||||||
uint8x16_t output = aarch64_lookup_vector(table, input);
|
uint8x16_t output = optimized_ops::aarch64_lookup_vector(table, input);
|
||||||
vst1q_u8(output_data + i, output);
|
vst1q_u8(output_data + i, output);
|
||||||
}
|
}
|
||||||
// Postamble and non-ARM64 code: simple for loop.
|
// Postamble and non-ARM64 code: simple for loop.
|
||||||
|
@ -583,9 +565,26 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
||||||
|
|
||||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
||||||
data->params.table = data->table;
|
switch (output->type) {
|
||||||
optimized_ops::PopulateSoftmaxLookupTable(
|
case kTfLiteUInt8:
|
||||||
&data->params, input->params.scale, params->beta);
|
case kTfLiteInt8:
|
||||||
|
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
|
||||||
|
// Only apply when both input & output are uint8/int8 & build with clang
|
||||||
|
// on aarch64.
|
||||||
|
// TODO(b/143709993): Port to ARMv7 and other platforms.
|
||||||
|
data->params.uint8_table1 = data->uint8_table1;
|
||||||
|
data->params.uint8_table2 = data->uint8_table2;
|
||||||
|
optimized_ops::PopulateSoftmaxUInt8LookupTable(
|
||||||
|
&data->params, input->params.scale, params->beta);
|
||||||
|
break;
|
||||||
|
#endif
|
||||||
|
case kTfLiteInt16:
|
||||||
|
default:
|
||||||
|
data->params.table = data->table;
|
||||||
|
optimized_ops::PopulateSoftmaxLookupTable(
|
||||||
|
&data->params, input->params.scale, params->beta);
|
||||||
|
}
|
||||||
|
|
||||||
data->params.zero_point = output->params.zero_point;
|
data->params.zero_point = output->params.zero_point;
|
||||||
data->params.scale = output->params.scale;
|
data->params.scale = output->params.scale;
|
||||||
}
|
}
|
||||||
|
@ -597,10 +596,10 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
// exp LUT only used on nagative values
|
// exp LUT only used on nagative values
|
||||||
// we consider exp(-10.0) is insignificant to accumulation
|
// we consider exp(-10.0) is insignificant to accumulation
|
||||||
gen_lut([](double value) { return std::exp(value); }, -10.0, 0.0,
|
gen_lut([](double value) { return std::exp(value); }, -10.0, 0.0,
|
||||||
data->params.exp_lut, data->size_of_lut);
|
data->params.exp_lut, data->kInt16LUTArraySize);
|
||||||
data->params.one_over_one_plus_x_lut = data->one_over_one_plus_x_lut;
|
data->params.one_over_one_plus_x_lut = data->one_over_one_plus_x_lut;
|
||||||
gen_lut([](double value) { return 1.0 / (1.0 + value); }, 0.0, 1.0,
|
gen_lut([](double value) { return 1.0 / (1.0 + value); }, 0.0, 1.0,
|
||||||
data->params.one_over_one_plus_x_lut, data->size_of_lut);
|
data->params.one_over_one_plus_x_lut, data->kInt16LUTArraySize);
|
||||||
data->params.zero_point = output->params.zero_point;
|
data->params.zero_point = output->params.zero_point;
|
||||||
data->params.scale = output->params.scale;
|
data->params.scale = output->params.scale;
|
||||||
|
|
||||||
|
@ -1019,6 +1018,40 @@ TfLiteStatus SoftmaxQuantized(TfLiteContext* context, const TfLiteTensor* input,
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TfLiteStatus SoftmaxQuantized<int8_t, int8_t>(TfLiteContext* context,
|
||||||
|
const TfLiteTensor* input,
|
||||||
|
TfLiteTensor* output,
|
||||||
|
SoftmaxOpData* data) {
|
||||||
|
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
|
||||||
|
optimized_ops::SoftmaxInt8LUT(
|
||||||
|
data->params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||||
|
GetTensorShape(output), GetTensorData<int8_t>(output));
|
||||||
|
#else
|
||||||
|
optimized_ops::Softmax(data->params, GetTensorShape(input),
|
||||||
|
GetTensorData<int8_t>(input), GetTensorShape(output),
|
||||||
|
GetTensorData<int8_t>(output));
|
||||||
|
#endif
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TfLiteStatus SoftmaxQuantized<uint8_t, uint8_t>(TfLiteContext* context,
|
||||||
|
const TfLiteTensor* input,
|
||||||
|
TfLiteTensor* output,
|
||||||
|
SoftmaxOpData* data) {
|
||||||
|
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
|
||||||
|
optimized_ops::SoftmaxInt8LUT(
|
||||||
|
data->params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||||
|
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||||
|
#else
|
||||||
|
optimized_ops::Softmax(data->params, GetTensorShape(input),
|
||||||
|
GetTensorData<uint8_t>(input), GetTensorShape(output),
|
||||||
|
GetTensorData<uint8_t>(output));
|
||||||
|
#endif
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
TfLiteStatus SoftmaxQuantized<int16, int16>(TfLiteContext* context,
|
TfLiteStatus SoftmaxQuantized<int16, int16>(TfLiteContext* context,
|
||||||
const TfLiteTensor* input,
|
const TfLiteTensor* input,
|
||||||
|
|
|
@ -56,6 +56,10 @@ limitations under the License.
|
||||||
#include "tensorflow/lite/kernels/internal/transpose_utils.h"
|
#include "tensorflow/lite/kernels/internal/transpose_utils.h"
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
#if __aarch64__ && __clang__
|
||||||
|
#define TFLITE_SOFTMAX_USE_UINT16_LUT
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace optimized_ops {
|
namespace optimized_ops {
|
||||||
|
|
||||||
|
@ -328,6 +332,29 @@ inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
|
||||||
|
|
||||||
|
// Looks up each element of <indices> in <table>, returns them in a vector.
|
||||||
|
inline uint8x16_t aarch64_lookup_vector(const uint8x16x4_t table[4],
|
||||||
|
uint8x16_t indices) {
|
||||||
|
// Look up in 1st quarter of the table: top 2 bits of indices == 00
|
||||||
|
uint8x16_t output1 = vqtbl4q_u8(table[0], indices);
|
||||||
|
// Look up in 2nd quarter of the table: top 2 bits of indices == 01
|
||||||
|
uint8x16_t output2 =
|
||||||
|
vqtbl4q_u8(table[1], veorq_u8(indices, vdupq_n_u8(0x40)));
|
||||||
|
// Look up in 3rd quarter of the table: top 2 bits of indices == 10
|
||||||
|
uint8x16_t output3 =
|
||||||
|
vqtbl4q_u8(table[2], veorq_u8(indices, vdupq_n_u8(0x80)));
|
||||||
|
// Look up in 4th quarter of the table: top 2 bits of indices == 11
|
||||||
|
uint8x16_t output4 =
|
||||||
|
vqtbl4q_u8(table[3], veorq_u8(indices, vdupq_n_u8(0xc0)));
|
||||||
|
|
||||||
|
// Combine result of the 4 lookups.
|
||||||
|
return vorrq_u8(vorrq_u8(output1, output2), vorrq_u8(output3, output4));
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
inline void AddBiasAndEvalActivationFunction(float output_activation_min,
|
inline void AddBiasAndEvalActivationFunction(float output_activation_min,
|
||||||
float output_activation_max,
|
float output_activation_max,
|
||||||
const RuntimeShape& bias_shape,
|
const RuntimeShape& bias_shape,
|
||||||
|
@ -3969,6 +3996,271 @@ inline void Softmax(const SoftmaxParams& params,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Here's the softmax LUT optimization strategy:
|
||||||
|
// For softmax, we can do some mathmetically equivalent transformation:
|
||||||
|
//
|
||||||
|
// softmax(x) = e^x / sum(e^x, 0...n) ===> equals to
|
||||||
|
// softmax(x) = e^(x - CONST) / sum(e^(x - CONST), 0...n)
|
||||||
|
//
|
||||||
|
// For quantization, `x` in our case is (input_q - input_zp) * input_s
|
||||||
|
// For uint8 case (int8 can be handled similarly), the range is [0, 255]
|
||||||
|
//
|
||||||
|
// so if we let
|
||||||
|
// CONST = (255 - input_zp) * input_s
|
||||||
|
// then we will have:
|
||||||
|
// softmax(x) = e^((input_q - 255) * input_s) --------- (1)
|
||||||
|
// /
|
||||||
|
// sum(e^(input_q - 255) * input_s, 0...n) -------- (2)
|
||||||
|
//
|
||||||
|
// the good thing about (1) is it's within the range of (0, 1), so we can
|
||||||
|
// approximate its result with uint16.
|
||||||
|
// (1) = uint8_out * 1 / 2^16.
|
||||||
|
//
|
||||||
|
// so (1) is lookup_uint8_table(input_zp) * 1 / 2^16.
|
||||||
|
// then (2) is essentially the following:
|
||||||
|
// sum(lookup_uint8_table(input_zp), 0...n) / 2^16.
|
||||||
|
//
|
||||||
|
// since (output_q - output_zp) * output_s = softmax(x)
|
||||||
|
// output_q = lookup_uint8_table(input_zp)
|
||||||
|
// /
|
||||||
|
// (sum(lookup_uint8_table(input_zp), 0...n) * output_s)
|
||||||
|
// +
|
||||||
|
// output_zp
|
||||||
|
//
|
||||||
|
// We can actually further improve the performance by using uint8 instead of
|
||||||
|
// uint16. But that we may lose some accuracy, so we need to pay attention
|
||||||
|
// to that.
|
||||||
|
inline void PopulateSoftmaxUInt8LookupTable(SoftmaxParams* data,
|
||||||
|
float input_scale, float beta) {
|
||||||
|
const float scale = input_scale * beta;
|
||||||
|
const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
|
||||||
|
const int32_t max_uint16 = std::numeric_limits<uint16_t>::max();
|
||||||
|
|
||||||
|
for (int32_t val = 0; val <= max_uint8; ++val) {
|
||||||
|
float input_to_exp = scale * (val - max_uint8);
|
||||||
|
int32_t temp = static_cast<int>(expf(input_to_exp) * max_uint16 + 0.5);
|
||||||
|
temp = std::min(max_uint16, temp);
|
||||||
|
uint8_t part1 = temp >> 8;
|
||||||
|
uint8_t part2 = temp & 0xff;
|
||||||
|
data->uint8_table1[val] = static_cast<uint8_t>(part1);
|
||||||
|
data->uint8_table2[val] = static_cast<uint8_t>(part2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int FindMaxValue(int size, const uint8_t* input_data, uint8_t offset) {
|
||||||
|
int32_t max_val = std::numeric_limits<uint8_t>::min();
|
||||||
|
int j = 0;
|
||||||
|
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
|
||||||
|
uint8x16_t max_val_dup = vdupq_n_u8(max_val);
|
||||||
|
uint8x16_t offset_dup = vdupq_n_u8(offset);
|
||||||
|
for (; j <= size - 16; j += 16) {
|
||||||
|
uint8x16_t input_value = vld1q_u8(input_data + j);
|
||||||
|
input_value = veorq_u8(input_value, offset_dup);
|
||||||
|
max_val_dup = vmaxq_u8(input_value, max_val_dup);
|
||||||
|
}
|
||||||
|
max_val = std::max(max_val, static_cast<int32>(vmaxvq_u8(max_val_dup)));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
for (; j < size; ++j) {
|
||||||
|
max_val = std::max(max_val, static_cast<int32_t>(input_data[j] ^ offset));
|
||||||
|
}
|
||||||
|
return max_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef USE_NEON
|
||||||
|
// Value_to_store layout:
|
||||||
|
// [high_high, high_low, low_high, low_low].
|
||||||
|
inline void StoreValue(int32x4x4_t value_to_store, int8_t* output) {
|
||||||
|
const int16x8_t result_1 = vcombine_s16(vqmovn_s32(value_to_store.val[1]),
|
||||||
|
vqmovn_s32(value_to_store.val[0]));
|
||||||
|
const int16x8_t result_2 = vcombine_s16(vqmovn_s32(value_to_store.val[3]),
|
||||||
|
vqmovn_s32(value_to_store.val[2]));
|
||||||
|
const int8x16_t result =
|
||||||
|
vcombine_s8(vqmovn_s16(result_2), vqmovn_s16(result_1));
|
||||||
|
vst1q_s8(output, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value_to_store layout:
|
||||||
|
// [high_high, high_low, low_high, low_low].
|
||||||
|
inline void StoreValue(int32x4x4_t value_to_store, uint8_t* output) {
|
||||||
|
const uint16x8_t result_1 =
|
||||||
|
vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[1])),
|
||||||
|
vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[0])));
|
||||||
|
const uint16x8_t result_2 =
|
||||||
|
vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[3])),
|
||||||
|
vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[2])));
|
||||||
|
const uint8x16_t result =
|
||||||
|
vcombine_u8(vqmovn_u16(result_2), vqmovn_u16(result_1));
|
||||||
|
vst1q_u8(output, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename In, typename Out>
|
||||||
|
inline void SoftmaxInt8LUT(const SoftmaxParams& params,
|
||||||
|
const RuntimeShape& input_shape,
|
||||||
|
const In* input_data,
|
||||||
|
const RuntimeShape& output_shape, Out* output_data) {
|
||||||
|
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||||
|
const int excluding_last_dim =
|
||||||
|
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||||
|
const int last_dim =
|
||||||
|
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||||
|
|
||||||
|
const int32_t clamp_max = std::numeric_limits<Out>::max();
|
||||||
|
const int32_t clamp_min = std::numeric_limits<Out>::min();
|
||||||
|
|
||||||
|
// Offset is used to interpret the input data "correctly".
|
||||||
|
// If the input is uint8, the data will be unchanged.
|
||||||
|
// If the input is int8, since it will be reinterpret as uint8.
|
||||||
|
// e.g.,
|
||||||
|
// int8 127 will be applied "offset" to become 255 in uint8.
|
||||||
|
uint8_t offset = 0;
|
||||||
|
if (std::is_same<In, int8>::value) {
|
||||||
|
offset = 0x80;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint8_t* input_data_uint = reinterpret_cast<const uint8_t*>(input_data);
|
||||||
|
|
||||||
|
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
|
||||||
|
// This code uses ARM64-only instructions.
|
||||||
|
// TODO(b/143709993): Port to ARMv7
|
||||||
|
|
||||||
|
// Load the tables into registers. (4*4 128-bit registers)
|
||||||
|
uint8x16x4_t table1[4];
|
||||||
|
table1[0] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 0);
|
||||||
|
table1[1] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 1);
|
||||||
|
table1[2] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 2);
|
||||||
|
table1[3] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 3);
|
||||||
|
|
||||||
|
uint8x16x4_t table2[4];
|
||||||
|
table2[0] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 0);
|
||||||
|
table2[1] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 1);
|
||||||
|
table2[2] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 2);
|
||||||
|
table2[3] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 3);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
for (int i = 0; i < excluding_last_dim; ++i) {
|
||||||
|
// Find max quantized value.
|
||||||
|
int32_t max_val = FindMaxValue(last_dim, input_data_uint, offset);
|
||||||
|
|
||||||
|
int32 sum_exp = 0;
|
||||||
|
const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
|
||||||
|
const uint8_t table_offset = max_uint8 - max_val;
|
||||||
|
|
||||||
|
// Calculate normalizer sum(exp(x)).
|
||||||
|
int sum_j = 0;
|
||||||
|
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
|
||||||
|
uint8x16_t table_offset_dup = vdupq_n_u8(table_offset);
|
||||||
|
uint8x16_t offset_dup = vdupq_n_u8(offset);
|
||||||
|
uint32x4_t sum_4 = vdupq_n_u32(0);
|
||||||
|
const int multiplier_shift = 8;
|
||||||
|
for (; sum_j <= last_dim - 16; sum_j += 16) {
|
||||||
|
uint8x16_t input_value = vld1q_u8(input_data_uint + sum_j);
|
||||||
|
input_value = veorq_u8(input_value, offset_dup);
|
||||||
|
input_value = vaddq_u8(input_value, table_offset_dup);
|
||||||
|
|
||||||
|
const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
|
||||||
|
const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
|
||||||
|
|
||||||
|
uint16x8_t exp_value1 =
|
||||||
|
vshll_n_u8(vget_high_u8(output1), multiplier_shift);
|
||||||
|
uint16x8_t exp_value2 =
|
||||||
|
vshll_n_u8(vget_low_u8(output1), multiplier_shift);
|
||||||
|
|
||||||
|
exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
|
||||||
|
exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
|
||||||
|
|
||||||
|
sum_4 = vpadalq_u16(sum_4, exp_value1);
|
||||||
|
sum_4 = vpadalq_u16(sum_4, exp_value2);
|
||||||
|
}
|
||||||
|
int temp = vgetq_lane_u32(sum_4, 0) + vgetq_lane_u32(sum_4, 1) +
|
||||||
|
vgetq_lane_u32(sum_4, 2) + vgetq_lane_u32(sum_4, 3);
|
||||||
|
sum_exp += temp;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
for (; sum_j < last_dim; ++sum_j) {
|
||||||
|
const uint8_t index = (input_data_uint[sum_j] ^ offset) + table_offset;
|
||||||
|
|
||||||
|
uint8_t part1 = params.uint8_table1[index];
|
||||||
|
uint8_t part2 = params.uint8_table2[index];
|
||||||
|
sum_exp += ((part1 << 8) + part2);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
|
||||||
|
|
||||||
|
int32 multiplier, shift;
|
||||||
|
QuantizeMultiplier(inv_sum_exp, &multiplier, &shift);
|
||||||
|
|
||||||
|
// Normalize and quantize probabilities.
|
||||||
|
int j = 0;
|
||||||
|
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
|
||||||
|
const int32x4_t output_zp_dup = vdupq_n_s32(params.zero_point);
|
||||||
|
const int32x4_t max_val_dup = vdupq_n_s32(clamp_max);
|
||||||
|
const int32x4_t min_val_dup = vdupq_n_s32(clamp_min);
|
||||||
|
|
||||||
|
for (; j <= last_dim - 16; j += 16) {
|
||||||
|
uint8x16_t input_value = vld1q_u8(input_data_uint + j);
|
||||||
|
input_value = veorq_u8(input_value, offset_dup);
|
||||||
|
input_value = vaddq_u8(input_value, table_offset_dup);
|
||||||
|
|
||||||
|
const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
|
||||||
|
const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
|
||||||
|
|
||||||
|
uint16x8_t exp_value1 =
|
||||||
|
vshll_n_u8(vget_high_u8(output1), multiplier_shift);
|
||||||
|
uint16x8_t exp_value2 =
|
||||||
|
vshll_n_u8(vget_low_u8(output1), multiplier_shift);
|
||||||
|
|
||||||
|
exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
|
||||||
|
exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
|
||||||
|
|
||||||
|
int32x4x4_t output_value;
|
||||||
|
output_value.val[0] =
|
||||||
|
vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value1)));
|
||||||
|
output_value.val[1] =
|
||||||
|
vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value1)));
|
||||||
|
output_value.val[2] =
|
||||||
|
vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value2)));
|
||||||
|
output_value.val[3] =
|
||||||
|
vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value2)));
|
||||||
|
|
||||||
|
int32x4x4_t temp_val =
|
||||||
|
MultiplyByQuantizedMultiplier4Rows(output_value, multiplier, shift);
|
||||||
|
|
||||||
|
temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
|
||||||
|
temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
|
||||||
|
temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
|
||||||
|
temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
|
||||||
|
|
||||||
|
temp_val.val[0] =
|
||||||
|
vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
|
||||||
|
temp_val.val[1] =
|
||||||
|
vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
|
||||||
|
temp_val.val[2] =
|
||||||
|
vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
|
||||||
|
temp_val.val[3] =
|
||||||
|
vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
|
||||||
|
|
||||||
|
StoreValue(temp_val, output_data + j);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; j < last_dim; ++j) {
|
||||||
|
const uint8_t index = (input_data_uint[j] ^ offset) + table_offset;
|
||||||
|
const uint8_t part1 = params.uint8_table1[index];
|
||||||
|
const uint8_t part2 = params.uint8_table2[index];
|
||||||
|
const int32_t exp_value = (part1 << 8) + part2;
|
||||||
|
const int32_t output_value =
|
||||||
|
MultiplyByQuantizedMultiplier(exp_value, multiplier, shift);
|
||||||
|
|
||||||
|
output_data[j] = static_cast<Out>(std::max(
|
||||||
|
std::min(clamp_max, output_value + params.zero_point), clamp_min));
|
||||||
|
}
|
||||||
|
input_data_uint += last_dim;
|
||||||
|
output_data += last_dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(myenik): This is the same as the reference implementation, not actually
|
// TODO(myenik): This is the same as the reference implementation, not actually
|
||||||
// optimized yet.
|
// optimized yet.
|
||||||
inline void LogSoftmax(const SoftmaxParams& params,
|
inline void LogSoftmax(const SoftmaxParams& params,
|
||||||
|
|
|
@ -133,6 +133,7 @@ void RunOneSoftmaxTest(const uint8* input_data,
|
||||||
params.scale = 1.0f / 256;
|
params.scale = 1.0f / 256;
|
||||||
params.zero_point = 0;
|
params.zero_point = 0;
|
||||||
params.table = table;
|
params.table = table;
|
||||||
|
|
||||||
optimized_ops::PopulateSoftmaxLookupTable(¶ms, input_scale, beta);
|
optimized_ops::PopulateSoftmaxLookupTable(¶ms, input_scale, beta);
|
||||||
optimized_ops::Softmax(params, shape_common, input_data, shape_common,
|
optimized_ops::Softmax(params, shape_common, input_data, shape_common,
|
||||||
optimized_softmax_output.data());
|
optimized_softmax_output.data());
|
||||||
|
@ -148,6 +149,19 @@ void RunOneSoftmaxTest(const uint8* input_data,
|
||||||
CheckOutputData<uint8_t>(reference_quant_softmax_output.data(),
|
CheckOutputData<uint8_t>(reference_quant_softmax_output.data(),
|
||||||
reference_float_softmax_output.data(), shape_common,
|
reference_float_softmax_output.data(), shape_common,
|
||||||
"Quant reference vs float reference", false);
|
"Quant reference vs float reference", false);
|
||||||
|
|
||||||
|
#if __aarch64__ && __clang__
|
||||||
|
uint8_t uint8_table1[256];
|
||||||
|
uint8_t uint8_table2[256];
|
||||||
|
params.uint8_table1 = uint8_table1;
|
||||||
|
params.uint8_table2 = uint8_table2;
|
||||||
|
optimized_ops::PopulateSoftmaxUInt8LookupTable(¶ms, input_scale, beta);
|
||||||
|
optimized_ops::SoftmaxInt8LUT(params, shape_common, input_data, shape_common,
|
||||||
|
optimized_softmax_output.data());
|
||||||
|
CheckOutputData<uint8_t>(
|
||||||
|
optimized_softmax_output.data(), reference_quant_softmax_output.data(),
|
||||||
|
shape_common, "Optimized (Uint8 Lookup table) vs quant reference", false);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function picks some random Softmax params, which are checked for
|
// This function picks some random Softmax params, which are checked for
|
||||||
|
|
|
@ -1035,6 +1035,8 @@ struct SoftmaxParams {
|
||||||
float* table;
|
float* table;
|
||||||
int16_t* exp_lut;
|
int16_t* exp_lut;
|
||||||
int16_t* one_over_one_plus_x_lut;
|
int16_t* one_over_one_plus_x_lut;
|
||||||
|
uint8_t* uint8_table1;
|
||||||
|
uint8_t* uint8_table2;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SpaceToBatchParams {
|
struct SpaceToBatchParams {
|
||||||
|
|
Loading…
Reference in New Issue