Optimize 8bit Softmax op handler
PiperOrigin-RevId: 261410252
This commit is contained in:
parent
9c14f6ba30
commit
96b38db29c
@ -55,6 +55,11 @@ struct OpData {
|
|||||||
uint8_t* table_zero = nullptr;
|
uint8_t* table_zero = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct SoftmaxOpData {
|
||||||
|
struct SoftmaxParams params = {};
|
||||||
|
float table[256];
|
||||||
|
};
|
||||||
|
|
||||||
struct LogSoftmaxOpData : public OpData {
|
struct LogSoftmaxOpData : public OpData {
|
||||||
int32_t reverse_scaling_divisor = 0;
|
int32_t reverse_scaling_divisor = 0;
|
||||||
int32_t reverse_scaling_right_shift = 0;
|
int32_t reverse_scaling_right_shift = 0;
|
||||||
@ -131,6 +136,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
return new OpData;
|
return new OpData;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
|
return new SoftmaxOpData;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SoftmaxFree(TfLiteContext* context, void* buffer) {
|
||||||
|
delete reinterpret_cast<SoftmaxOpData*>(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
void* LogSoftmaxInit(TfLiteContext* context, const char* buffer,
|
void* LogSoftmaxInit(TfLiteContext* context, const char* buffer,
|
||||||
size_t length) {
|
size_t length) {
|
||||||
return new LogSoftmaxOpData;
|
return new LogSoftmaxOpData;
|
||||||
@ -363,7 +376,7 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
|
||||||
|
|
||||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
@ -375,16 +388,11 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4);
|
TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4);
|
||||||
|
|
||||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
||||||
if (CheckOutputQuantParams(context, input, output) == kTfLiteError) {
|
data->params.table = data->table;
|
||||||
return kTfLiteError;
|
optimized_ops::PopulateSoftmaxLookupTable(
|
||||||
}
|
&data->params, input->params.scale, params->beta);
|
||||||
|
data->params.zero_point = output->params.zero_point;
|
||||||
static const int kScaledDiffIntegerBits = 5;
|
data->params.scale = output->params.scale;
|
||||||
tflite::PreprocessSoftmaxScaling(
|
|
||||||
params->beta, input->params.scale, kScaledDiffIntegerBits,
|
|
||||||
&data->input_multiplier, &data->input_left_shift);
|
|
||||||
data->diff_min = -1.0 * tflite::CalculateInputRadius(
|
|
||||||
kScaledDiffIntegerBits, data->input_left_shift);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return context->ResizeTensor(context, output,
|
return context->ResizeTensor(context, output,
|
||||||
@ -749,61 +757,25 @@ TfLiteStatus SoftmaxFloat(TfLiteContext* context, const TfLiteTensor* input,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus SoftmaxQuantizedUint8(TfLiteContext* context,
|
template <typename T>
|
||||||
const TfLiteTensor* input,
|
TfLiteStatus SoftmaxQuantized(TfLiteContext* context, const TfLiteTensor* input,
|
||||||
TfLiteTensor* output,
|
TfLiteTensor* output, SoftmaxOpData* data) {
|
||||||
TfLiteSoftmaxParams* params, OpData* data) {
|
if (NumDimensions(input) >= 1 && NumDimensions(input) <= 4) {
|
||||||
switch (NumDimensions(input)) {
|
optimized_ops::Softmax(data->params, GetTensorShape(input),
|
||||||
case 1:
|
GetTensorData<T>(input), GetTensorShape(output),
|
||||||
case 2:
|
GetTensorData<T>(output));
|
||||||
case 3:
|
return kTfLiteOk;
|
||||||
case 4:
|
} else {
|
||||||
SoftmaxParams op_params;
|
context->ReportError(
|
||||||
op_params.input_multiplier = data->input_multiplier;
|
context, "Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.",
|
||||||
op_params.input_left_shift = data->input_left_shift;
|
NumDimensions(input));
|
||||||
op_params.diff_min = data->diff_min;
|
return kTfLiteError;
|
||||||
optimized_ops::Softmax(
|
|
||||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
|
||||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
|
||||||
return kTfLiteOk;
|
|
||||||
default:
|
|
||||||
context->ReportError(
|
|
||||||
context,
|
|
||||||
"Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.",
|
|
||||||
NumDimensions(input));
|
|
||||||
return kTfLiteError;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteStatus SoftmaxQuantizedInt8(TfLiteContext* context,
|
|
||||||
const TfLiteTensor* input,
|
|
||||||
TfLiteTensor* output,
|
|
||||||
TfLiteSoftmaxParams* params, OpData* data) {
|
|
||||||
switch (NumDimensions(input)) {
|
|
||||||
case 1:
|
|
||||||
case 2:
|
|
||||||
case 3:
|
|
||||||
case 4:
|
|
||||||
SoftmaxParams op_params;
|
|
||||||
op_params.input_multiplier = data->input_multiplier;
|
|
||||||
op_params.input_left_shift = data->input_left_shift;
|
|
||||||
op_params.diff_min = data->diff_min;
|
|
||||||
optimized_integer_ops::Softmax(
|
|
||||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
|
||||||
GetTensorShape(output), GetTensorData<int8_t>(output));
|
|
||||||
return kTfLiteOk;
|
|
||||||
default:
|
|
||||||
context->ReportError(
|
|
||||||
context,
|
|
||||||
"Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.",
|
|
||||||
NumDimensions(input));
|
|
||||||
return kTfLiteError;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
|
||||||
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||||
@ -815,10 +787,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return SoftmaxFloat(context, input, output, params);
|
return SoftmaxFloat(context, input, output, params);
|
||||||
}
|
}
|
||||||
case kTfLiteUInt8: {
|
case kTfLiteUInt8: {
|
||||||
return SoftmaxQuantizedUint8(context, input, output, params, data);
|
return SoftmaxQuantized<uint8_t>(context, input, output, data);
|
||||||
}
|
}
|
||||||
case kTfLiteInt8: {
|
case kTfLiteInt8: {
|
||||||
return SoftmaxQuantizedInt8(context, input, output, params, data);
|
return SoftmaxQuantized<int8_t>(context, input, output, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@ -1055,9 +1027,9 @@ TfLiteRegistration* Register_LOGISTIC() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_SOFTMAX() {
|
TfLiteRegistration* Register_SOFTMAX() {
|
||||||
static TfLiteRegistration r = {activations::Init, activations::Free,
|
static TfLiteRegistration r = {
|
||||||
activations::SoftmaxPrepare,
|
activations::SoftmaxInit, activations::SoftmaxFree,
|
||||||
activations::SoftmaxEval};
|
activations::SoftmaxPrepare, activations::SoftmaxEval};
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3972,6 +3972,208 @@ void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
|
|||||||
filter_width, filter_height, output_data, output_dims);
|
filter_width, filter_height, output_data, output_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void Softmax(const SoftmaxParams& params,
|
||||||
|
const RuntimeShape& input_shape, const uint8* input_data,
|
||||||
|
const RuntimeShape& output_shape, uint8* output_data) {
|
||||||
|
const int32 input_beta_multiplier = params.input_multiplier;
|
||||||
|
const int32 input_beta_left_shift = params.input_left_shift;
|
||||||
|
const int diff_min = params.diff_min;
|
||||||
|
// The representation chosen for the input to the exp() function is Q5.26.
|
||||||
|
// We need to leave extra space since values that we skip might be as large as
|
||||||
|
// -32 before multiplying by input_beta_multiplier, and therefore as large as
|
||||||
|
// -16 afterwards. Note that exp(-8) is definitely not insignificant to
|
||||||
|
// accumulation, but exp(-16) definitely is.
|
||||||
|
static const int kScaledDiffIntegerBits = 5;
|
||||||
|
static const int kAccumulationIntegerBits = 12;
|
||||||
|
using FixedPointScaledDiff =
|
||||||
|
gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
|
||||||
|
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
|
||||||
|
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
|
||||||
|
|
||||||
|
gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
|
||||||
|
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||||
|
const int outer_size =
|
||||||
|
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||||
|
const int depth =
|
||||||
|
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||||
|
|
||||||
|
for (int b = 0; b < outer_size; ++b) {
|
||||||
|
const uint8* input_data_ptr = input_data + b * depth;
|
||||||
|
uint8* output_data_ptr = output_data + b * depth;
|
||||||
|
|
||||||
|
// Determine the largest entry in the current row
|
||||||
|
uint8 max_in_row = 0;
|
||||||
|
{
|
||||||
|
int c = 0;
|
||||||
|
#ifdef USE_NEON
|
||||||
|
uint8x16_t max16_0 = vdupq_n_u8(0);
|
||||||
|
uint8x16_t max16_1 = vdupq_n_u8(0);
|
||||||
|
for (; c <= depth - 32; c += 32) {
|
||||||
|
max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
|
||||||
|
max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
|
||||||
|
}
|
||||||
|
uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
|
||||||
|
if (c <= depth - 16) {
|
||||||
|
max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
|
||||||
|
c += 16;
|
||||||
|
}
|
||||||
|
uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
|
||||||
|
if (c <= depth - 8) {
|
||||||
|
max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
|
||||||
|
c += 8;
|
||||||
|
}
|
||||||
|
uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
|
||||||
|
uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
|
||||||
|
uint8x8_t max1 = vpmax_u8(max2, max2);
|
||||||
|
max_in_row = vget_lane_u8(max1, 0);
|
||||||
|
#endif
|
||||||
|
for (; c < depth; ++c) {
|
||||||
|
max_in_row = std::max(max_in_row, input_data_ptr[c]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef USE_NEON
|
||||||
|
using FixedPointAccumInt32x4 =
|
||||||
|
gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
|
||||||
|
using FixedPointScaledDiffInt32x4 =
|
||||||
|
gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
|
||||||
|
using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
|
||||||
|
FixedPoint0Int32x4 input_beta_multiplier_f0 =
|
||||||
|
FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
|
||||||
|
int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Compute the sum of exponentials of the differences of entries in the
|
||||||
|
// current row from the largest entry in the current row.
|
||||||
|
FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
|
||||||
|
{
|
||||||
|
int c = 0;
|
||||||
|
#ifdef USE_NEON
|
||||||
|
int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
|
||||||
|
FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
|
||||||
|
FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
|
||||||
|
FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
|
||||||
|
for (; c <= depth - 8; c += 8) {
|
||||||
|
uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
|
||||||
|
int16x8_t input_diff_s16 =
|
||||||
|
vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
|
||||||
|
int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
|
||||||
|
int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
|
||||||
|
int32x4_t mask_0 =
|
||||||
|
gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
|
||||||
|
int32x4_t mask_1 =
|
||||||
|
gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
|
||||||
|
FixedPointScaledDiffInt32x4 scaled_diff_0 =
|
||||||
|
input_beta_multiplier_f0 *
|
||||||
|
FixedPointScaledDiffInt32x4::FromRaw(
|
||||||
|
gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
|
||||||
|
FixedPointScaledDiffInt32x4 scaled_diff_1 =
|
||||||
|
input_beta_multiplier_f0 *
|
||||||
|
FixedPointScaledDiffInt32x4::FromRaw(
|
||||||
|
gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
|
||||||
|
FixedPointAccumInt32x4 exps_0 =
|
||||||
|
gemmlowp::Rescale<kAccumulationIntegerBits>(
|
||||||
|
exp_on_negative_values(scaled_diff_0));
|
||||||
|
FixedPointAccumInt32x4 exps_1 =
|
||||||
|
gemmlowp::Rescale<kAccumulationIntegerBits>(
|
||||||
|
exp_on_negative_values(scaled_diff_1));
|
||||||
|
FixedPointAccumInt32x4 masked_exps_0 =
|
||||||
|
SelectUsingMask(mask_0, exps_0, zeros);
|
||||||
|
FixedPointAccumInt32x4 masked_exps_1 =
|
||||||
|
SelectUsingMask(mask_1, exps_1, zeros);
|
||||||
|
sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
|
||||||
|
sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
|
||||||
|
}
|
||||||
|
int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
|
||||||
|
int32x2_t sum_of_exps_reduced_2 =
|
||||||
|
vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
|
||||||
|
vget_high_s32(sum_of_exps_reduced_4));
|
||||||
|
int32x2_t sum_of_exps_reduced_1 =
|
||||||
|
vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
|
||||||
|
sum_of_exps =
|
||||||
|
FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
|
||||||
|
#endif
|
||||||
|
for (; c < depth; ++c) {
|
||||||
|
int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
|
||||||
|
if (input_diff >= diff_min) {
|
||||||
|
const int32 input_diff_rescaled =
|
||||||
|
MultiplyByQuantizedMultiplierGreaterThanOne(
|
||||||
|
input_diff, input_beta_multiplier, input_beta_left_shift);
|
||||||
|
const FixedPointScaledDiff scaled_diff_f8 =
|
||||||
|
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
||||||
|
sum_of_exps =
|
||||||
|
sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
|
||||||
|
exp_on_negative_values(scaled_diff_f8));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the fixed-point multiplier and shift that we need to apply to
|
||||||
|
// perform a division by the above-computed sum-of-exponentials.
|
||||||
|
int num_bits_over_unit = 0;
|
||||||
|
FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
|
||||||
|
sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
|
||||||
|
|
||||||
|
// Compute the quotients of exponentials of differences of entries in the
|
||||||
|
// current row from the largest entry, over the previously-computed sum of
|
||||||
|
// exponentials.
|
||||||
|
{
|
||||||
|
int c = 0;
|
||||||
|
#ifdef USE_NEON
|
||||||
|
int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
|
||||||
|
for (; c <= depth - 8; c += 8) {
|
||||||
|
uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
|
||||||
|
int16x8_t input_diff_s16 =
|
||||||
|
vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
|
||||||
|
int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
|
||||||
|
int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
|
||||||
|
uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
|
||||||
|
FixedPointScaledDiffInt32x4 scaled_diff_0 =
|
||||||
|
input_beta_multiplier_f0 *
|
||||||
|
FixedPointScaledDiffInt32x4::FromRaw(
|
||||||
|
gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
|
||||||
|
FixedPointScaledDiffInt32x4 scaled_diff_1 =
|
||||||
|
input_beta_multiplier_f0 *
|
||||||
|
FixedPointScaledDiffInt32x4::FromRaw(
|
||||||
|
gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
|
||||||
|
FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
|
||||||
|
FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
|
||||||
|
int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
|
||||||
|
vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
|
||||||
|
num_bits_over_unit + 31 - 8);
|
||||||
|
int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
|
||||||
|
vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
|
||||||
|
num_bits_over_unit + 31 - 8);
|
||||||
|
int16x8_t output_s16 =
|
||||||
|
vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
|
||||||
|
uint8x8_t output_u8 = vqmovun_s16(output_s16);
|
||||||
|
uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
|
||||||
|
vst1_u8(output_data_ptr + c, masked_output);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; c < depth; ++c) {
|
||||||
|
int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
|
||||||
|
if (input_diff >= diff_min) {
|
||||||
|
const int32 input_diff_rescaled =
|
||||||
|
MultiplyByQuantizedMultiplierGreaterThanOne(
|
||||||
|
input_diff, input_beta_multiplier, input_beta_left_shift);
|
||||||
|
const FixedPointScaledDiff scaled_diff_f8 =
|
||||||
|
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
||||||
|
|
||||||
|
FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
|
||||||
|
int32 unsat_output = gemmlowp::RoundingDivideByPOT(
|
||||||
|
(shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
|
||||||
|
|
||||||
|
output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
output_data_ptr[c] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
|
inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
|
||||||
float beta, float* output_data,
|
float beta, float* output_data,
|
||||||
const RuntimeShape& output_shape) {
|
const RuntimeShape& output_shape) {
|
||||||
|
@ -3493,205 +3493,64 @@ inline void Softmax(const SoftmaxParams& params,
|
|||||||
out_mat.array().rowwise() *= scale;
|
out_mat.array().rowwise() *= scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void Softmax(const SoftmaxParams& params,
|
inline int32_t QuantizeSoftmaxOutput(int8_t* output_data, float prob_rescaled,
|
||||||
const RuntimeShape& input_shape, const uint8* input_data,
|
int32_t zero_point) {
|
||||||
const RuntimeShape& output_shape, uint8* output_data) {
|
const int32_t prob_rnd = static_cast<int32_t>(std::round(prob_rescaled));
|
||||||
const int32 input_beta_multiplier = params.input_multiplier;
|
return prob_rnd + zero_point;
|
||||||
const int32 input_beta_left_shift = params.input_left_shift;
|
}
|
||||||
const int diff_min = params.diff_min;
|
|
||||||
// The representation chosen for the input to the exp() function is Q5.26.
|
|
||||||
// We need to leave extra space since values that we skip might be as large as
|
|
||||||
// -32 before multiplying by input_beta_multiplier, and therefore as large as
|
|
||||||
// -16 afterwards. Note that exp(-8) is definitely not insignificant to
|
|
||||||
// accumulation, but exp(-16) definitely is.
|
|
||||||
static const int kScaledDiffIntegerBits = 5;
|
|
||||||
static const int kAccumulationIntegerBits = 12;
|
|
||||||
using FixedPointScaledDiff =
|
|
||||||
gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
|
|
||||||
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
|
|
||||||
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
|
|
||||||
|
|
||||||
gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
|
inline int32_t QuantizeSoftmaxOutput(uint8_t* output_data, float prob_rescaled,
|
||||||
|
int32_t zero_point) {
|
||||||
|
return static_cast<int32_t>(prob_rescaled + 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void PopulateSoftmaxLookupTable(SoftmaxParams* data, float input_scale,
|
||||||
|
float beta) {
|
||||||
|
const float scale = -input_scale * beta;
|
||||||
|
const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
|
||||||
|
for (int32_t val = 0; val <= max_uint8; ++val) {
|
||||||
|
data->table[max_uint8 - val] = expf(scale * val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void Softmax(const SoftmaxParams& params,
|
||||||
|
const RuntimeShape& input_shape, const T* input_data,
|
||||||
|
const RuntimeShape& output_shape, T* output_data) {
|
||||||
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||||
const int outer_size =
|
const int excluding_last_dim =
|
||||||
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||||
const int depth =
|
const int last_dim =
|
||||||
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||||
|
|
||||||
for (int b = 0; b < outer_size; ++b) {
|
const int32_t clamp_max = std::numeric_limits<T>::max();
|
||||||
const uint8* input_data_ptr = input_data + b * depth;
|
const int32_t clamp_min = std::numeric_limits<T>::min();
|
||||||
uint8* output_data_ptr = output_data + b * depth;
|
for (int i = 0; i < excluding_last_dim; ++i) {
|
||||||
|
int32_t max_val = std::numeric_limits<T>::min();
|
||||||
// Determine the largest entry in the current row
|
// Find max quantized value.
|
||||||
uint8 max_in_row = 0;
|
for (int j = 0; j < last_dim; ++j) {
|
||||||
{
|
max_val = std::max(max_val, static_cast<int32_t>(input_data[j]));
|
||||||
int c = 0;
|
|
||||||
#ifdef USE_NEON
|
|
||||||
uint8x16_t max16_0 = vdupq_n_u8(0);
|
|
||||||
uint8x16_t max16_1 = vdupq_n_u8(0);
|
|
||||||
for (; c <= depth - 32; c += 32) {
|
|
||||||
max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
|
|
||||||
max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
|
|
||||||
}
|
|
||||||
uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
|
|
||||||
if (c <= depth - 16) {
|
|
||||||
max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
|
|
||||||
c += 16;
|
|
||||||
}
|
|
||||||
uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
|
|
||||||
if (c <= depth - 8) {
|
|
||||||
max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
|
|
||||||
c += 8;
|
|
||||||
}
|
|
||||||
uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
|
|
||||||
uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
|
|
||||||
uint8x8_t max1 = vpmax_u8(max2, max2);
|
|
||||||
max_in_row = vget_lane_u8(max1, 0);
|
|
||||||
#endif
|
|
||||||
for (; c < depth; ++c) {
|
|
||||||
max_in_row = std::max(max_in_row, input_data_ptr[c]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_NEON
|
float sum_exp = 0.0f;
|
||||||
using FixedPointAccumInt32x4 =
|
const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
|
||||||
gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
|
const float* table_offset = ¶ms.table[max_uint8 - max_val];
|
||||||
using FixedPointScaledDiffInt32x4 =
|
// Calculate normalizer sum(exp(x)).
|
||||||
gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
|
for (int j = 0; j < last_dim; ++j) {
|
||||||
using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
|
sum_exp += table_offset[input_data[j]];
|
||||||
FixedPoint0Int32x4 input_beta_multiplier_f0 =
|
|
||||||
FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
|
|
||||||
int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Compute the sum of exponentials of the differences of entries in the
|
|
||||||
// current row from the largest entry in the current row.
|
|
||||||
FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
|
|
||||||
{
|
|
||||||
int c = 0;
|
|
||||||
#ifdef USE_NEON
|
|
||||||
int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
|
|
||||||
FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
|
|
||||||
FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
|
|
||||||
FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
|
|
||||||
for (; c <= depth - 8; c += 8) {
|
|
||||||
uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
|
|
||||||
int16x8_t input_diff_s16 =
|
|
||||||
vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
|
|
||||||
int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
|
|
||||||
int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
|
|
||||||
int32x4_t mask_0 =
|
|
||||||
gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
|
|
||||||
int32x4_t mask_1 =
|
|
||||||
gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
|
|
||||||
FixedPointScaledDiffInt32x4 scaled_diff_0 =
|
|
||||||
input_beta_multiplier_f0 *
|
|
||||||
FixedPointScaledDiffInt32x4::FromRaw(
|
|
||||||
gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
|
|
||||||
FixedPointScaledDiffInt32x4 scaled_diff_1 =
|
|
||||||
input_beta_multiplier_f0 *
|
|
||||||
FixedPointScaledDiffInt32x4::FromRaw(
|
|
||||||
gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
|
|
||||||
FixedPointAccumInt32x4 exps_0 =
|
|
||||||
gemmlowp::Rescale<kAccumulationIntegerBits>(
|
|
||||||
exp_on_negative_values(scaled_diff_0));
|
|
||||||
FixedPointAccumInt32x4 exps_1 =
|
|
||||||
gemmlowp::Rescale<kAccumulationIntegerBits>(
|
|
||||||
exp_on_negative_values(scaled_diff_1));
|
|
||||||
FixedPointAccumInt32x4 masked_exps_0 =
|
|
||||||
SelectUsingMask(mask_0, exps_0, zeros);
|
|
||||||
FixedPointAccumInt32x4 masked_exps_1 =
|
|
||||||
SelectUsingMask(mask_1, exps_1, zeros);
|
|
||||||
sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
|
|
||||||
sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
|
|
||||||
}
|
|
||||||
int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
|
|
||||||
int32x2_t sum_of_exps_reduced_2 =
|
|
||||||
vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
|
|
||||||
vget_high_s32(sum_of_exps_reduced_4));
|
|
||||||
int32x2_t sum_of_exps_reduced_1 =
|
|
||||||
vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
|
|
||||||
sum_of_exps =
|
|
||||||
FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
|
|
||||||
#endif
|
|
||||||
for (; c < depth; ++c) {
|
|
||||||
int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
|
|
||||||
if (input_diff >= diff_min) {
|
|
||||||
const int32 input_diff_rescaled =
|
|
||||||
MultiplyByQuantizedMultiplierGreaterThanOne(
|
|
||||||
input_diff, input_beta_multiplier, input_beta_left_shift);
|
|
||||||
const FixedPointScaledDiff scaled_diff_f8 =
|
|
||||||
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
|
||||||
sum_of_exps =
|
|
||||||
sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
|
|
||||||
exp_on_negative_values(scaled_diff_f8));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute the fixed-point multiplier and shift that we need to apply to
|
const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
|
||||||
// perform a division by the above-computed sum-of-exponentials.
|
// Normalize and quantize probabilities.
|
||||||
int num_bits_over_unit = 0;
|
for (int j = 0; j < last_dim; ++j) {
|
||||||
FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
|
const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp;
|
||||||
sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
|
const int32_t prob_quantized =
|
||||||
|
QuantizeSoftmaxOutput(output_data, prob_rescaled, params.zero_point);
|
||||||
// Compute the quotients of exponentials of differences of entries in the
|
output_data[j] = static_cast<T>(
|
||||||
// current row from the largest entry, over the previously-computed sum of
|
std::max(std::min(clamp_max, prob_quantized), clamp_min));
|
||||||
// exponentials.
|
|
||||||
{
|
|
||||||
int c = 0;
|
|
||||||
#ifdef USE_NEON
|
|
||||||
int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
|
|
||||||
for (; c <= depth - 8; c += 8) {
|
|
||||||
uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
|
|
||||||
int16x8_t input_diff_s16 =
|
|
||||||
vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
|
|
||||||
int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
|
|
||||||
int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
|
|
||||||
uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
|
|
||||||
FixedPointScaledDiffInt32x4 scaled_diff_0 =
|
|
||||||
input_beta_multiplier_f0 *
|
|
||||||
FixedPointScaledDiffInt32x4::FromRaw(
|
|
||||||
gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
|
|
||||||
FixedPointScaledDiffInt32x4 scaled_diff_1 =
|
|
||||||
input_beta_multiplier_f0 *
|
|
||||||
FixedPointScaledDiffInt32x4::FromRaw(
|
|
||||||
gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
|
|
||||||
FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
|
|
||||||
FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
|
|
||||||
int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
|
|
||||||
vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
|
|
||||||
num_bits_over_unit + 31 - 8);
|
|
||||||
int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
|
|
||||||
vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
|
|
||||||
num_bits_over_unit + 31 - 8);
|
|
||||||
int16x8_t output_s16 =
|
|
||||||
vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
|
|
||||||
uint8x8_t output_u8 = vqmovun_s16(output_s16);
|
|
||||||
uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
|
|
||||||
vst1_u8(output_data_ptr + c, masked_output);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
for (; c < depth; ++c) {
|
|
||||||
int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
|
|
||||||
if (input_diff >= diff_min) {
|
|
||||||
const int32 input_diff_rescaled =
|
|
||||||
MultiplyByQuantizedMultiplierGreaterThanOne(
|
|
||||||
input_diff, input_beta_multiplier, input_beta_left_shift);
|
|
||||||
const FixedPointScaledDiff scaled_diff_f8 =
|
|
||||||
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
|
||||||
|
|
||||||
FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
|
|
||||||
int32 unsat_output = gemmlowp::RoundingDivideByPOT(
|
|
||||||
(shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
|
|
||||||
|
|
||||||
output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
output_data_ptr[c] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
input_data += last_dim;
|
||||||
|
output_data += last_dim;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,9 +124,14 @@ void RunOneSoftmaxTest(const uint8* input_data,
|
|||||||
input_beta_left_shift);
|
input_beta_left_shift);
|
||||||
|
|
||||||
SoftmaxParams params;
|
SoftmaxParams params;
|
||||||
|
float table[256];
|
||||||
params.input_multiplier = input_beta_multiplier;
|
params.input_multiplier = input_beta_multiplier;
|
||||||
params.input_left_shift = input_beta_left_shift;
|
params.input_left_shift = input_beta_left_shift;
|
||||||
params.diff_min = diff_min;
|
params.diff_min = diff_min;
|
||||||
|
params.scale = 1.0f / 256;
|
||||||
|
params.zero_point = 0;
|
||||||
|
params.table = table;
|
||||||
|
optimized_ops::PopulateSoftmaxLookupTable(¶ms, input_scale, beta);
|
||||||
optimized_ops::Softmax(params, shape_common, input_data, shape_common,
|
optimized_ops::Softmax(params, shape_common, input_data, shape_common,
|
||||||
optimized_softmax_output.data());
|
optimized_softmax_output.data());
|
||||||
reference_ops::Softmax(params, shape_common, input_data, shape_common,
|
reference_ops::Softmax(params, shape_common, input_data, shape_common,
|
||||||
@ -137,7 +142,7 @@ void RunOneSoftmaxTest(const uint8* input_data,
|
|||||||
"Optimized vs float reference", false);
|
"Optimized vs float reference", false);
|
||||||
CheckOutputData<uint8_t>(optimized_softmax_output.data(),
|
CheckOutputData<uint8_t>(optimized_softmax_output.data(),
|
||||||
reference_quant_softmax_output.data(), shape_common,
|
reference_quant_softmax_output.data(), shape_common,
|
||||||
"Optimized vs quant reference", true);
|
"Optimized vs quant reference", false);
|
||||||
CheckOutputData<uint8_t>(reference_quant_softmax_output.data(),
|
CheckOutputData<uint8_t>(reference_quant_softmax_output.data(),
|
||||||
reference_float_softmax_output.data(), shape_common,
|
reference_float_softmax_output.data(), shape_common,
|
||||||
"Quant reference vs float reference", false);
|
"Quant reference vs float reference", false);
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
|
|
||||||
@ -985,6 +986,9 @@ struct SoftmaxParams {
|
|||||||
int32 reverse_scaling_divisor;
|
int32 reverse_scaling_divisor;
|
||||||
int32 reverse_scaling_right_shift;
|
int32 reverse_scaling_right_shift;
|
||||||
int diff_min;
|
int diff_min;
|
||||||
|
int32_t zero_point;
|
||||||
|
float scale;
|
||||||
|
float* table;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SpaceToBatchParams {
|
struct SpaceToBatchParams {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user