Optimize 8bit Softmax op handler
PiperOrigin-RevId: 261410252
This commit is contained in:
parent
9c14f6ba30
commit
96b38db29c
@ -55,6 +55,11 @@ struct OpData {
|
||||
uint8_t* table_zero = nullptr;
|
||||
};
|
||||
|
||||
struct SoftmaxOpData {
|
||||
struct SoftmaxParams params = {};
|
||||
float table[256];
|
||||
};
|
||||
|
||||
struct LogSoftmaxOpData : public OpData {
|
||||
int32_t reverse_scaling_divisor = 0;
|
||||
int32_t reverse_scaling_right_shift = 0;
|
||||
@ -131,6 +136,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
return new OpData;
|
||||
}
|
||||
|
||||
void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
return new SoftmaxOpData;
|
||||
}
|
||||
|
||||
void SoftmaxFree(TfLiteContext* context, void* buffer) {
|
||||
delete reinterpret_cast<SoftmaxOpData*>(buffer);
|
||||
}
|
||||
|
||||
void* LogSoftmaxInit(TfLiteContext* context, const char* buffer,
|
||||
size_t length) {
|
||||
return new LogSoftmaxOpData;
|
||||
@ -363,7 +376,7 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
@ -375,16 +388,11 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4);
|
||||
|
||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
||||
if (CheckOutputQuantParams(context, input, output) == kTfLiteError) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
static const int kScaledDiffIntegerBits = 5;
|
||||
tflite::PreprocessSoftmaxScaling(
|
||||
params->beta, input->params.scale, kScaledDiffIntegerBits,
|
||||
&data->input_multiplier, &data->input_left_shift);
|
||||
data->diff_min = -1.0 * tflite::CalculateInputRadius(
|
||||
kScaledDiffIntegerBits, data->input_left_shift);
|
||||
data->params.table = data->table;
|
||||
optimized_ops::PopulateSoftmaxLookupTable(
|
||||
&data->params, input->params.scale, params->beta);
|
||||
data->params.zero_point = output->params.zero_point;
|
||||
data->params.scale = output->params.scale;
|
||||
}
|
||||
|
||||
return context->ResizeTensor(context, output,
|
||||
@ -749,61 +757,25 @@ TfLiteStatus SoftmaxFloat(TfLiteContext* context, const TfLiteTensor* input,
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus SoftmaxQuantizedUint8(TfLiteContext* context,
|
||||
const TfLiteTensor* input,
|
||||
TfLiteTensor* output,
|
||||
TfLiteSoftmaxParams* params, OpData* data) {
|
||||
switch (NumDimensions(input)) {
|
||||
case 1:
|
||||
case 2:
|
||||
case 3:
|
||||
case 4:
|
||||
SoftmaxParams op_params;
|
||||
op_params.input_multiplier = data->input_multiplier;
|
||||
op_params.input_left_shift = data->input_left_shift;
|
||||
op_params.diff_min = data->diff_min;
|
||||
optimized_ops::Softmax(
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||
return kTfLiteOk;
|
||||
default:
|
||||
context->ReportError(
|
||||
context,
|
||||
"Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.",
|
||||
NumDimensions(input));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus SoftmaxQuantizedInt8(TfLiteContext* context,
|
||||
const TfLiteTensor* input,
|
||||
TfLiteTensor* output,
|
||||
TfLiteSoftmaxParams* params, OpData* data) {
|
||||
switch (NumDimensions(input)) {
|
||||
case 1:
|
||||
case 2:
|
||||
case 3:
|
||||
case 4:
|
||||
SoftmaxParams op_params;
|
||||
op_params.input_multiplier = data->input_multiplier;
|
||||
op_params.input_left_shift = data->input_left_shift;
|
||||
op_params.diff_min = data->diff_min;
|
||||
optimized_integer_ops::Softmax(
|
||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
GetTensorShape(output), GetTensorData<int8_t>(output));
|
||||
return kTfLiteOk;
|
||||
default:
|
||||
context->ReportError(
|
||||
context,
|
||||
"Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.",
|
||||
NumDimensions(input));
|
||||
return kTfLiteError;
|
||||
template <typename T>
|
||||
TfLiteStatus SoftmaxQuantized(TfLiteContext* context, const TfLiteTensor* input,
|
||||
TfLiteTensor* output, SoftmaxOpData* data) {
|
||||
if (NumDimensions(input) >= 1 && NumDimensions(input) <= 4) {
|
||||
optimized_ops::Softmax(data->params, GetTensorShape(input),
|
||||
GetTensorData<T>(input), GetTensorShape(output),
|
||||
GetTensorData<T>(output));
|
||||
return kTfLiteOk;
|
||||
} else {
|
||||
context->ReportError(
|
||||
context, "Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.",
|
||||
NumDimensions(input));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
@ -815,10 +787,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return SoftmaxFloat(context, input, output, params);
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
return SoftmaxQuantizedUint8(context, input, output, params, data);
|
||||
return SoftmaxQuantized<uint8_t>(context, input, output, data);
|
||||
}
|
||||
case kTfLiteInt8: {
|
||||
return SoftmaxQuantizedInt8(context, input, output, params, data);
|
||||
return SoftmaxQuantized<int8_t>(context, input, output, data);
|
||||
}
|
||||
|
||||
default:
|
||||
@ -1055,9 +1027,9 @@ TfLiteRegistration* Register_LOGISTIC() {
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_SOFTMAX() {
|
||||
static TfLiteRegistration r = {activations::Init, activations::Free,
|
||||
activations::SoftmaxPrepare,
|
||||
activations::SoftmaxEval};
|
||||
static TfLiteRegistration r = {
|
||||
activations::SoftmaxInit, activations::SoftmaxFree,
|
||||
activations::SoftmaxPrepare, activations::SoftmaxEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
|
@ -3972,6 +3972,208 @@ void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
|
||||
filter_width, filter_height, output_data, output_dims);
|
||||
}
|
||||
|
||||
inline void Softmax(const SoftmaxParams& params,
|
||||
const RuntimeShape& input_shape, const uint8* input_data,
|
||||
const RuntimeShape& output_shape, uint8* output_data) {
|
||||
const int32 input_beta_multiplier = params.input_multiplier;
|
||||
const int32 input_beta_left_shift = params.input_left_shift;
|
||||
const int diff_min = params.diff_min;
|
||||
// The representation chosen for the input to the exp() function is Q5.26.
|
||||
// We need to leave extra space since values that we skip might be as large as
|
||||
// -32 before multiplying by input_beta_multiplier, and therefore as large as
|
||||
// -16 afterwards. Note that exp(-8) is definitely not insignificant to
|
||||
// accumulation, but exp(-16) definitely is.
|
||||
static const int kScaledDiffIntegerBits = 5;
|
||||
static const int kAccumulationIntegerBits = 12;
|
||||
using FixedPointScaledDiff =
|
||||
gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
|
||||
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
|
||||
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
|
||||
|
||||
gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
|
||||
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||
const int outer_size =
|
||||
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||
const int depth =
|
||||
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||
|
||||
for (int b = 0; b < outer_size; ++b) {
|
||||
const uint8* input_data_ptr = input_data + b * depth;
|
||||
uint8* output_data_ptr = output_data + b * depth;
|
||||
|
||||
// Determine the largest entry in the current row
|
||||
uint8 max_in_row = 0;
|
||||
{
|
||||
int c = 0;
|
||||
#ifdef USE_NEON
|
||||
uint8x16_t max16_0 = vdupq_n_u8(0);
|
||||
uint8x16_t max16_1 = vdupq_n_u8(0);
|
||||
for (; c <= depth - 32; c += 32) {
|
||||
max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
|
||||
max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
|
||||
}
|
||||
uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
|
||||
if (c <= depth - 16) {
|
||||
max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
|
||||
c += 16;
|
||||
}
|
||||
uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
|
||||
if (c <= depth - 8) {
|
||||
max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
|
||||
c += 8;
|
||||
}
|
||||
uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
|
||||
uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
|
||||
uint8x8_t max1 = vpmax_u8(max2, max2);
|
||||
max_in_row = vget_lane_u8(max1, 0);
|
||||
#endif
|
||||
for (; c < depth; ++c) {
|
||||
max_in_row = std::max(max_in_row, input_data_ptr[c]);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_NEON
|
||||
using FixedPointAccumInt32x4 =
|
||||
gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
|
||||
using FixedPointScaledDiffInt32x4 =
|
||||
gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
|
||||
using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
|
||||
FixedPoint0Int32x4 input_beta_multiplier_f0 =
|
||||
FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
|
||||
int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
|
||||
#endif
|
||||
|
||||
// Compute the sum of exponentials of the differences of entries in the
|
||||
// current row from the largest entry in the current row.
|
||||
FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
|
||||
{
|
||||
int c = 0;
|
||||
#ifdef USE_NEON
|
||||
int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
|
||||
FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
|
||||
FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
|
||||
FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
|
||||
for (; c <= depth - 8; c += 8) {
|
||||
uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
|
||||
int16x8_t input_diff_s16 =
|
||||
vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
|
||||
int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
|
||||
int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
|
||||
int32x4_t mask_0 =
|
||||
gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
|
||||
int32x4_t mask_1 =
|
||||
gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
|
||||
FixedPointScaledDiffInt32x4 scaled_diff_0 =
|
||||
input_beta_multiplier_f0 *
|
||||
FixedPointScaledDiffInt32x4::FromRaw(
|
||||
gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
|
||||
FixedPointScaledDiffInt32x4 scaled_diff_1 =
|
||||
input_beta_multiplier_f0 *
|
||||
FixedPointScaledDiffInt32x4::FromRaw(
|
||||
gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
|
||||
FixedPointAccumInt32x4 exps_0 =
|
||||
gemmlowp::Rescale<kAccumulationIntegerBits>(
|
||||
exp_on_negative_values(scaled_diff_0));
|
||||
FixedPointAccumInt32x4 exps_1 =
|
||||
gemmlowp::Rescale<kAccumulationIntegerBits>(
|
||||
exp_on_negative_values(scaled_diff_1));
|
||||
FixedPointAccumInt32x4 masked_exps_0 =
|
||||
SelectUsingMask(mask_0, exps_0, zeros);
|
||||
FixedPointAccumInt32x4 masked_exps_1 =
|
||||
SelectUsingMask(mask_1, exps_1, zeros);
|
||||
sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
|
||||
sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
|
||||
}
|
||||
int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
|
||||
int32x2_t sum_of_exps_reduced_2 =
|
||||
vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
|
||||
vget_high_s32(sum_of_exps_reduced_4));
|
||||
int32x2_t sum_of_exps_reduced_1 =
|
||||
vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
|
||||
sum_of_exps =
|
||||
FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
|
||||
#endif
|
||||
for (; c < depth; ++c) {
|
||||
int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
|
||||
if (input_diff >= diff_min) {
|
||||
const int32 input_diff_rescaled =
|
||||
MultiplyByQuantizedMultiplierGreaterThanOne(
|
||||
input_diff, input_beta_multiplier, input_beta_left_shift);
|
||||
const FixedPointScaledDiff scaled_diff_f8 =
|
||||
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
||||
sum_of_exps =
|
||||
sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
|
||||
exp_on_negative_values(scaled_diff_f8));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the fixed-point multiplier and shift that we need to apply to
|
||||
// perform a division by the above-computed sum-of-exponentials.
|
||||
int num_bits_over_unit = 0;
|
||||
FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
|
||||
sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
|
||||
|
||||
// Compute the quotients of exponentials of differences of entries in the
|
||||
// current row from the largest entry, over the previously-computed sum of
|
||||
// exponentials.
|
||||
{
|
||||
int c = 0;
|
||||
#ifdef USE_NEON
|
||||
int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
|
||||
for (; c <= depth - 8; c += 8) {
|
||||
uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
|
||||
int16x8_t input_diff_s16 =
|
||||
vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
|
||||
int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
|
||||
int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
|
||||
uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
|
||||
FixedPointScaledDiffInt32x4 scaled_diff_0 =
|
||||
input_beta_multiplier_f0 *
|
||||
FixedPointScaledDiffInt32x4::FromRaw(
|
||||
gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
|
||||
FixedPointScaledDiffInt32x4 scaled_diff_1 =
|
||||
input_beta_multiplier_f0 *
|
||||
FixedPointScaledDiffInt32x4::FromRaw(
|
||||
gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
|
||||
FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
|
||||
FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
|
||||
int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
|
||||
vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
|
||||
num_bits_over_unit + 31 - 8);
|
||||
int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
|
||||
vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
|
||||
num_bits_over_unit + 31 - 8);
|
||||
int16x8_t output_s16 =
|
||||
vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
|
||||
uint8x8_t output_u8 = vqmovun_s16(output_s16);
|
||||
uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
|
||||
vst1_u8(output_data_ptr + c, masked_output);
|
||||
}
|
||||
#endif
|
||||
for (; c < depth; ++c) {
|
||||
int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
|
||||
if (input_diff >= diff_min) {
|
||||
const int32 input_diff_rescaled =
|
||||
MultiplyByQuantizedMultiplierGreaterThanOne(
|
||||
input_diff, input_beta_multiplier, input_beta_left_shift);
|
||||
const FixedPointScaledDiff scaled_diff_f8 =
|
||||
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
||||
|
||||
FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
|
||||
int32 unsat_output = gemmlowp::RoundingDivideByPOT(
|
||||
(shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
|
||||
|
||||
output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
|
||||
|
||||
} else {
|
||||
output_data_ptr[c] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
|
||||
float beta, float* output_data,
|
||||
const RuntimeShape& output_shape) {
|
||||
|
@ -3493,205 +3493,64 @@ inline void Softmax(const SoftmaxParams& params,
|
||||
out_mat.array().rowwise() *= scale;
|
||||
}
|
||||
|
||||
inline void Softmax(const SoftmaxParams& params,
|
||||
const RuntimeShape& input_shape, const uint8* input_data,
|
||||
const RuntimeShape& output_shape, uint8* output_data) {
|
||||
const int32 input_beta_multiplier = params.input_multiplier;
|
||||
const int32 input_beta_left_shift = params.input_left_shift;
|
||||
const int diff_min = params.diff_min;
|
||||
// The representation chosen for the input to the exp() function is Q5.26.
|
||||
// We need to leave extra space since values that we skip might be as large as
|
||||
// -32 before multiplying by input_beta_multiplier, and therefore as large as
|
||||
// -16 afterwards. Note that exp(-8) is definitely not insignificant to
|
||||
// accumulation, but exp(-16) definitely is.
|
||||
static const int kScaledDiffIntegerBits = 5;
|
||||
static const int kAccumulationIntegerBits = 12;
|
||||
using FixedPointScaledDiff =
|
||||
gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
|
||||
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
|
||||
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
|
||||
inline int32_t QuantizeSoftmaxOutput(int8_t* output_data, float prob_rescaled,
|
||||
int32_t zero_point) {
|
||||
const int32_t prob_rnd = static_cast<int32_t>(std::round(prob_rescaled));
|
||||
return prob_rnd + zero_point;
|
||||
}
|
||||
|
||||
gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
|
||||
inline int32_t QuantizeSoftmaxOutput(uint8_t* output_data, float prob_rescaled,
|
||||
int32_t zero_point) {
|
||||
return static_cast<int32_t>(prob_rescaled + 0.5);
|
||||
}
|
||||
|
||||
inline void PopulateSoftmaxLookupTable(SoftmaxParams* data, float input_scale,
|
||||
float beta) {
|
||||
const float scale = -input_scale * beta;
|
||||
const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
|
||||
for (int32_t val = 0; val <= max_uint8; ++val) {
|
||||
data->table[max_uint8 - val] = expf(scale * val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void Softmax(const SoftmaxParams& params,
|
||||
const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||
const int outer_size =
|
||||
const int excluding_last_dim =
|
||||
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||
const int depth =
|
||||
const int last_dim =
|
||||
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||
|
||||
for (int b = 0; b < outer_size; ++b) {
|
||||
const uint8* input_data_ptr = input_data + b * depth;
|
||||
uint8* output_data_ptr = output_data + b * depth;
|
||||
|
||||
// Determine the largest entry in the current row
|
||||
uint8 max_in_row = 0;
|
||||
{
|
||||
int c = 0;
|
||||
#ifdef USE_NEON
|
||||
uint8x16_t max16_0 = vdupq_n_u8(0);
|
||||
uint8x16_t max16_1 = vdupq_n_u8(0);
|
||||
for (; c <= depth - 32; c += 32) {
|
||||
max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
|
||||
max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
|
||||
}
|
||||
uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
|
||||
if (c <= depth - 16) {
|
||||
max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
|
||||
c += 16;
|
||||
}
|
||||
uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
|
||||
if (c <= depth - 8) {
|
||||
max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
|
||||
c += 8;
|
||||
}
|
||||
uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
|
||||
uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
|
||||
uint8x8_t max1 = vpmax_u8(max2, max2);
|
||||
max_in_row = vget_lane_u8(max1, 0);
|
||||
#endif
|
||||
for (; c < depth; ++c) {
|
||||
max_in_row = std::max(max_in_row, input_data_ptr[c]);
|
||||
}
|
||||
const int32_t clamp_max = std::numeric_limits<T>::max();
|
||||
const int32_t clamp_min = std::numeric_limits<T>::min();
|
||||
for (int i = 0; i < excluding_last_dim; ++i) {
|
||||
int32_t max_val = std::numeric_limits<T>::min();
|
||||
// Find max quantized value.
|
||||
for (int j = 0; j < last_dim; ++j) {
|
||||
max_val = std::max(max_val, static_cast<int32_t>(input_data[j]));
|
||||
}
|
||||
|
||||
#ifdef USE_NEON
|
||||
using FixedPointAccumInt32x4 =
|
||||
gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
|
||||
using FixedPointScaledDiffInt32x4 =
|
||||
gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
|
||||
using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
|
||||
FixedPoint0Int32x4 input_beta_multiplier_f0 =
|
||||
FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
|
||||
int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
|
||||
#endif
|
||||
|
||||
// Compute the sum of exponentials of the differences of entries in the
|
||||
// current row from the largest entry in the current row.
|
||||
FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
|
||||
{
|
||||
int c = 0;
|
||||
#ifdef USE_NEON
|
||||
int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
|
||||
FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
|
||||
FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
|
||||
FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
|
||||
for (; c <= depth - 8; c += 8) {
|
||||
uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
|
||||
int16x8_t input_diff_s16 =
|
||||
vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
|
||||
int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
|
||||
int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
|
||||
int32x4_t mask_0 =
|
||||
gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
|
||||
int32x4_t mask_1 =
|
||||
gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
|
||||
FixedPointScaledDiffInt32x4 scaled_diff_0 =
|
||||
input_beta_multiplier_f0 *
|
||||
FixedPointScaledDiffInt32x4::FromRaw(
|
||||
gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
|
||||
FixedPointScaledDiffInt32x4 scaled_diff_1 =
|
||||
input_beta_multiplier_f0 *
|
||||
FixedPointScaledDiffInt32x4::FromRaw(
|
||||
gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
|
||||
FixedPointAccumInt32x4 exps_0 =
|
||||
gemmlowp::Rescale<kAccumulationIntegerBits>(
|
||||
exp_on_negative_values(scaled_diff_0));
|
||||
FixedPointAccumInt32x4 exps_1 =
|
||||
gemmlowp::Rescale<kAccumulationIntegerBits>(
|
||||
exp_on_negative_values(scaled_diff_1));
|
||||
FixedPointAccumInt32x4 masked_exps_0 =
|
||||
SelectUsingMask(mask_0, exps_0, zeros);
|
||||
FixedPointAccumInt32x4 masked_exps_1 =
|
||||
SelectUsingMask(mask_1, exps_1, zeros);
|
||||
sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
|
||||
sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
|
||||
}
|
||||
int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
|
||||
int32x2_t sum_of_exps_reduced_2 =
|
||||
vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
|
||||
vget_high_s32(sum_of_exps_reduced_4));
|
||||
int32x2_t sum_of_exps_reduced_1 =
|
||||
vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
|
||||
sum_of_exps =
|
||||
FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
|
||||
#endif
|
||||
for (; c < depth; ++c) {
|
||||
int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
|
||||
if (input_diff >= diff_min) {
|
||||
const int32 input_diff_rescaled =
|
||||
MultiplyByQuantizedMultiplierGreaterThanOne(
|
||||
input_diff, input_beta_multiplier, input_beta_left_shift);
|
||||
const FixedPointScaledDiff scaled_diff_f8 =
|
||||
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
||||
sum_of_exps =
|
||||
sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
|
||||
exp_on_negative_values(scaled_diff_f8));
|
||||
}
|
||||
}
|
||||
float sum_exp = 0.0f;
|
||||
const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
|
||||
const float* table_offset = ¶ms.table[max_uint8 - max_val];
|
||||
// Calculate normalizer sum(exp(x)).
|
||||
for (int j = 0; j < last_dim; ++j) {
|
||||
sum_exp += table_offset[input_data[j]];
|
||||
}
|
||||
|
||||
// Compute the fixed-point multiplier and shift that we need to apply to
|
||||
// perform a division by the above-computed sum-of-exponentials.
|
||||
int num_bits_over_unit = 0;
|
||||
FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
|
||||
sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
|
||||
|
||||
// Compute the quotients of exponentials of differences of entries in the
|
||||
// current row from the largest entry, over the previously-computed sum of
|
||||
// exponentials.
|
||||
{
|
||||
int c = 0;
|
||||
#ifdef USE_NEON
|
||||
int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
|
||||
for (; c <= depth - 8; c += 8) {
|
||||
uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
|
||||
int16x8_t input_diff_s16 =
|
||||
vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
|
||||
int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
|
||||
int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
|
||||
uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
|
||||
FixedPointScaledDiffInt32x4 scaled_diff_0 =
|
||||
input_beta_multiplier_f0 *
|
||||
FixedPointScaledDiffInt32x4::FromRaw(
|
||||
gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
|
||||
FixedPointScaledDiffInt32x4 scaled_diff_1 =
|
||||
input_beta_multiplier_f0 *
|
||||
FixedPointScaledDiffInt32x4::FromRaw(
|
||||
gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
|
||||
FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
|
||||
FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
|
||||
int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
|
||||
vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
|
||||
num_bits_over_unit + 31 - 8);
|
||||
int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
|
||||
vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
|
||||
num_bits_over_unit + 31 - 8);
|
||||
int16x8_t output_s16 =
|
||||
vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
|
||||
uint8x8_t output_u8 = vqmovun_s16(output_s16);
|
||||
uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
|
||||
vst1_u8(output_data_ptr + c, masked_output);
|
||||
}
|
||||
#endif
|
||||
for (; c < depth; ++c) {
|
||||
int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
|
||||
if (input_diff >= diff_min) {
|
||||
const int32 input_diff_rescaled =
|
||||
MultiplyByQuantizedMultiplierGreaterThanOne(
|
||||
input_diff, input_beta_multiplier, input_beta_left_shift);
|
||||
const FixedPointScaledDiff scaled_diff_f8 =
|
||||
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
||||
|
||||
FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
|
||||
int32 unsat_output = gemmlowp::RoundingDivideByPOT(
|
||||
(shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
|
||||
|
||||
output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
|
||||
|
||||
} else {
|
||||
output_data_ptr[c] = 0;
|
||||
}
|
||||
}
|
||||
const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
|
||||
// Normalize and quantize probabilities.
|
||||
for (int j = 0; j < last_dim; ++j) {
|
||||
const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp;
|
||||
const int32_t prob_quantized =
|
||||
QuantizeSoftmaxOutput(output_data, prob_rescaled, params.zero_point);
|
||||
output_data[j] = static_cast<T>(
|
||||
std::max(std::min(clamp_max, prob_quantized), clamp_min));
|
||||
}
|
||||
input_data += last_dim;
|
||||
output_data += last_dim;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,9 +124,14 @@ void RunOneSoftmaxTest(const uint8* input_data,
|
||||
input_beta_left_shift);
|
||||
|
||||
SoftmaxParams params;
|
||||
float table[256];
|
||||
params.input_multiplier = input_beta_multiplier;
|
||||
params.input_left_shift = input_beta_left_shift;
|
||||
params.diff_min = diff_min;
|
||||
params.scale = 1.0f / 256;
|
||||
params.zero_point = 0;
|
||||
params.table = table;
|
||||
optimized_ops::PopulateSoftmaxLookupTable(¶ms, input_scale, beta);
|
||||
optimized_ops::Softmax(params, shape_common, input_data, shape_common,
|
||||
optimized_softmax_output.data());
|
||||
reference_ops::Softmax(params, shape_common, input_data, shape_common,
|
||||
@ -137,7 +142,7 @@ void RunOneSoftmaxTest(const uint8* input_data,
|
||||
"Optimized vs float reference", false);
|
||||
CheckOutputData<uint8_t>(optimized_softmax_output.data(),
|
||||
reference_quant_softmax_output.data(), shape_common,
|
||||
"Optimized vs quant reference", true);
|
||||
"Optimized vs quant reference", false);
|
||||
CheckOutputData<uint8_t>(reference_quant_softmax_output.data(),
|
||||
reference_float_softmax_output.data(), shape_common,
|
||||
"Quant reference vs float reference", false);
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <initializer_list>
|
||||
|
||||
@ -985,6 +986,9 @@ struct SoftmaxParams {
|
||||
int32 reverse_scaling_divisor;
|
||||
int32 reverse_scaling_right_shift;
|
||||
int diff_min;
|
||||
int32_t zero_point;
|
||||
float scale;
|
||||
float* table;
|
||||
};
|
||||
|
||||
struct SpaceToBatchParams {
|
||||
|
Loading…
x
Reference in New Issue
Block a user