Optimize 8bit Softmax op handler

PiperOrigin-RevId: 261410252
This commit is contained in:
A. Unique TensorFlower 2019-08-02 16:06:17 -07:00 committed by TensorFlower Gardener
parent 9c14f6ba30
commit 96b38db29c
5 changed files with 298 additions and 256 deletions

View File

@ -55,6 +55,11 @@ struct OpData {
uint8_t* table_zero = nullptr;
};
struct SoftmaxOpData {
struct SoftmaxParams params = {};
float table[256];
};
struct LogSoftmaxOpData : public OpData {
int32_t reverse_scaling_divisor = 0;
int32_t reverse_scaling_right_shift = 0;
@ -131,6 +136,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return new OpData;
}
void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
return new SoftmaxOpData;
}
void SoftmaxFree(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<SoftmaxOpData*>(buffer);
}
void* LogSoftmaxInit(TfLiteContext* context, const char* buffer,
size_t length) {
return new LogSoftmaxOpData;
@ -363,7 +376,7 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@ -375,16 +388,11 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4);
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
if (CheckOutputQuantParams(context, input, output) == kTfLiteError) {
return kTfLiteError;
}
static const int kScaledDiffIntegerBits = 5;
tflite::PreprocessSoftmaxScaling(
params->beta, input->params.scale, kScaledDiffIntegerBits,
&data->input_multiplier, &data->input_left_shift);
data->diff_min = -1.0 * tflite::CalculateInputRadius(
kScaledDiffIntegerBits, data->input_left_shift);
data->params.table = data->table;
optimized_ops::PopulateSoftmaxLookupTable(
&data->params, input->params.scale, params->beta);
data->params.zero_point = output->params.zero_point;
data->params.scale = output->params.scale;
}
return context->ResizeTensor(context, output,
@ -749,61 +757,25 @@ TfLiteStatus SoftmaxFloat(TfLiteContext* context, const TfLiteTensor* input,
}
}
TfLiteStatus SoftmaxQuantizedUint8(TfLiteContext* context,
const TfLiteTensor* input,
TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
switch (NumDimensions(input)) {
case 1:
case 2:
case 3:
case 4:
SoftmaxParams op_params;
op_params.input_multiplier = data->input_multiplier;
op_params.input_left_shift = data->input_left_shift;
op_params.diff_min = data->diff_min;
optimized_ops::Softmax(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(output), GetTensorData<uint8_t>(output));
return kTfLiteOk;
default:
context->ReportError(
context,
"Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.",
NumDimensions(input));
return kTfLiteError;
}
}
TfLiteStatus SoftmaxQuantizedInt8(TfLiteContext* context,
const TfLiteTensor* input,
TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
switch (NumDimensions(input)) {
case 1:
case 2:
case 3:
case 4:
SoftmaxParams op_params;
op_params.input_multiplier = data->input_multiplier;
op_params.input_left_shift = data->input_left_shift;
op_params.diff_min = data->diff_min;
optimized_integer_ops::Softmax(
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
GetTensorShape(output), GetTensorData<int8_t>(output));
return kTfLiteOk;
default:
context->ReportError(
context,
"Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.",
NumDimensions(input));
return kTfLiteError;
template <typename T>
TfLiteStatus SoftmaxQuantized(TfLiteContext* context, const TfLiteTensor* input,
TfLiteTensor* output, SoftmaxOpData* data) {
if (NumDimensions(input) >= 1 && NumDimensions(input) <= 4) {
optimized_ops::Softmax(data->params, GetTensorShape(input),
GetTensorData<T>(input), GetTensorShape(output),
GetTensorData<T>(output));
return kTfLiteOk;
} else {
context->ReportError(
context, "Only 1D, 2D, 3D and 4D tensors supported currently, got %dD.",
NumDimensions(input));
return kTfLiteError;
}
}
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
@ -815,10 +787,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
return SoftmaxFloat(context, input, output, params);
}
case kTfLiteUInt8: {
return SoftmaxQuantizedUint8(context, input, output, params, data);
return SoftmaxQuantized<uint8_t>(context, input, output, data);
}
case kTfLiteInt8: {
return SoftmaxQuantizedInt8(context, input, output, params, data);
return SoftmaxQuantized<int8_t>(context, input, output, data);
}
default:
@ -1055,9 +1027,9 @@ TfLiteRegistration* Register_LOGISTIC() {
}
TfLiteRegistration* Register_SOFTMAX() {
static TfLiteRegistration r = {activations::Init, activations::Free,
activations::SoftmaxPrepare,
activations::SoftmaxEval};
static TfLiteRegistration r = {
activations::SoftmaxInit, activations::SoftmaxFree,
activations::SoftmaxPrepare, activations::SoftmaxEval};
return &r;
}

View File

@ -3972,6 +3972,208 @@ void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
filter_width, filter_height, output_data, output_dims);
}
inline void Softmax(const SoftmaxParams& params,
const RuntimeShape& input_shape, const uint8* input_data,
const RuntimeShape& output_shape, uint8* output_data) {
const int32 input_beta_multiplier = params.input_multiplier;
const int32 input_beta_left_shift = params.input_left_shift;
const int diff_min = params.diff_min;
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
// -16 afterwards. Note that exp(-8) is definitely not insignificant to
// accumulation, but exp(-16) definitely is.
static const int kScaledDiffIntegerBits = 5;
static const int kAccumulationIntegerBits = 12;
using FixedPointScaledDiff =
gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
const int depth =
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int b = 0; b < outer_size; ++b) {
const uint8* input_data_ptr = input_data + b * depth;
uint8* output_data_ptr = output_data + b * depth;
// Determine the largest entry in the current row
uint8 max_in_row = 0;
{
int c = 0;
#ifdef USE_NEON
uint8x16_t max16_0 = vdupq_n_u8(0);
uint8x16_t max16_1 = vdupq_n_u8(0);
for (; c <= depth - 32; c += 32) {
max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
}
uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
if (c <= depth - 16) {
max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
c += 16;
}
uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
if (c <= depth - 8) {
max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
c += 8;
}
uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
uint8x8_t max1 = vpmax_u8(max2, max2);
max_in_row = vget_lane_u8(max1, 0);
#endif
for (; c < depth; ++c) {
max_in_row = std::max(max_in_row, input_data_ptr[c]);
}
}
#ifdef USE_NEON
using FixedPointAccumInt32x4 =
gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
using FixedPointScaledDiffInt32x4 =
gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
FixedPoint0Int32x4 input_beta_multiplier_f0 =
FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
#endif
// Compute the sum of exponentials of the differences of entries in the
// current row from the largest entry in the current row.
FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
{
int c = 0;
#ifdef USE_NEON
int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
for (; c <= depth - 8; c += 8) {
uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
int16x8_t input_diff_s16 =
vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
int32x4_t mask_0 =
gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
int32x4_t mask_1 =
gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
FixedPointScaledDiffInt32x4 scaled_diff_0 =
input_beta_multiplier_f0 *
FixedPointScaledDiffInt32x4::FromRaw(
gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
FixedPointScaledDiffInt32x4 scaled_diff_1 =
input_beta_multiplier_f0 *
FixedPointScaledDiffInt32x4::FromRaw(
gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
FixedPointAccumInt32x4 exps_0 =
gemmlowp::Rescale<kAccumulationIntegerBits>(
exp_on_negative_values(scaled_diff_0));
FixedPointAccumInt32x4 exps_1 =
gemmlowp::Rescale<kAccumulationIntegerBits>(
exp_on_negative_values(scaled_diff_1));
FixedPointAccumInt32x4 masked_exps_0 =
SelectUsingMask(mask_0, exps_0, zeros);
FixedPointAccumInt32x4 masked_exps_1 =
SelectUsingMask(mask_1, exps_1, zeros);
sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
}
int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
int32x2_t sum_of_exps_reduced_2 =
vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
vget_high_s32(sum_of_exps_reduced_4));
int32x2_t sum_of_exps_reduced_1 =
vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
sum_of_exps =
FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
#endif
for (; c < depth; ++c) {
int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
if (input_diff >= diff_min) {
const int32 input_diff_rescaled =
MultiplyByQuantizedMultiplierGreaterThanOne(
input_diff, input_beta_multiplier, input_beta_left_shift);
const FixedPointScaledDiff scaled_diff_f8 =
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
sum_of_exps =
sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
exp_on_negative_values(scaled_diff_f8));
}
}
}
// Compute the fixed-point multiplier and shift that we need to apply to
// perform a division by the above-computed sum-of-exponentials.
int num_bits_over_unit = 0;
FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
// Compute the quotients of exponentials of differences of entries in the
// current row from the largest entry, over the previously-computed sum of
// exponentials.
{
int c = 0;
#ifdef USE_NEON
int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
for (; c <= depth - 8; c += 8) {
uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
int16x8_t input_diff_s16 =
vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
FixedPointScaledDiffInt32x4 scaled_diff_0 =
input_beta_multiplier_f0 *
FixedPointScaledDiffInt32x4::FromRaw(
gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
FixedPointScaledDiffInt32x4 scaled_diff_1 =
input_beta_multiplier_f0 *
FixedPointScaledDiffInt32x4::FromRaw(
gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
num_bits_over_unit + 31 - 8);
int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
num_bits_over_unit + 31 - 8);
int16x8_t output_s16 =
vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
uint8x8_t output_u8 = vqmovun_s16(output_s16);
uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
vst1_u8(output_data_ptr + c, masked_output);
}
#endif
for (; c < depth; ++c) {
int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
if (input_diff >= diff_min) {
const int32 input_diff_rescaled =
MultiplyByQuantizedMultiplierGreaterThanOne(
input_diff, input_beta_multiplier, input_beta_left_shift);
const FixedPointScaledDiff scaled_diff_f8 =
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
int32 unsat_output = gemmlowp::RoundingDivideByPOT(
(shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
} else {
output_data_ptr[c] = 0;
}
}
}
}
}
inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
float beta, float* output_data,
const RuntimeShape& output_shape) {

View File

@ -3493,205 +3493,64 @@ inline void Softmax(const SoftmaxParams& params,
out_mat.array().rowwise() *= scale;
}
inline void Softmax(const SoftmaxParams& params,
const RuntimeShape& input_shape, const uint8* input_data,
const RuntimeShape& output_shape, uint8* output_data) {
const int32 input_beta_multiplier = params.input_multiplier;
const int32 input_beta_left_shift = params.input_left_shift;
const int diff_min = params.diff_min;
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
// -16 afterwards. Note that exp(-8) is definitely not insignificant to
// accumulation, but exp(-16) definitely is.
static const int kScaledDiffIntegerBits = 5;
static const int kAccumulationIntegerBits = 12;
using FixedPointScaledDiff =
gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
inline int32_t QuantizeSoftmaxOutput(int8_t* output_data, float prob_rescaled,
int32_t zero_point) {
const int32_t prob_rnd = static_cast<int32_t>(std::round(prob_rescaled));
return prob_rnd + zero_point;
}
gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
inline int32_t QuantizeSoftmaxOutput(uint8_t* output_data, float prob_rescaled,
int32_t zero_point) {
return static_cast<int32_t>(prob_rescaled + 0.5);
}
inline void PopulateSoftmaxLookupTable(SoftmaxParams* data, float input_scale,
float beta) {
const float scale = -input_scale * beta;
const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
for (int32_t val = 0; val <= max_uint8; ++val) {
data->table[max_uint8 - val] = expf(scale * val);
}
}
template <typename T>
inline void Softmax(const SoftmaxParams& params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
const int excluding_last_dim =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
const int depth =
const int last_dim =
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int b = 0; b < outer_size; ++b) {
const uint8* input_data_ptr = input_data + b * depth;
uint8* output_data_ptr = output_data + b * depth;
// Determine the largest entry in the current row
uint8 max_in_row = 0;
{
int c = 0;
#ifdef USE_NEON
uint8x16_t max16_0 = vdupq_n_u8(0);
uint8x16_t max16_1 = vdupq_n_u8(0);
for (; c <= depth - 32; c += 32) {
max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
}
uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
if (c <= depth - 16) {
max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
c += 16;
}
uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
if (c <= depth - 8) {
max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
c += 8;
}
uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
uint8x8_t max1 = vpmax_u8(max2, max2);
max_in_row = vget_lane_u8(max1, 0);
#endif
for (; c < depth; ++c) {
max_in_row = std::max(max_in_row, input_data_ptr[c]);
}
const int32_t clamp_max = std::numeric_limits<T>::max();
const int32_t clamp_min = std::numeric_limits<T>::min();
for (int i = 0; i < excluding_last_dim; ++i) {
int32_t max_val = std::numeric_limits<T>::min();
// Find max quantized value.
for (int j = 0; j < last_dim; ++j) {
max_val = std::max(max_val, static_cast<int32_t>(input_data[j]));
}
#ifdef USE_NEON
using FixedPointAccumInt32x4 =
gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
using FixedPointScaledDiffInt32x4 =
gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
FixedPoint0Int32x4 input_beta_multiplier_f0 =
FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
#endif
// Compute the sum of exponentials of the differences of entries in the
// current row from the largest entry in the current row.
FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
{
int c = 0;
#ifdef USE_NEON
int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
for (; c <= depth - 8; c += 8) {
uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
int16x8_t input_diff_s16 =
vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
int32x4_t mask_0 =
gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
int32x4_t mask_1 =
gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
FixedPointScaledDiffInt32x4 scaled_diff_0 =
input_beta_multiplier_f0 *
FixedPointScaledDiffInt32x4::FromRaw(
gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
FixedPointScaledDiffInt32x4 scaled_diff_1 =
input_beta_multiplier_f0 *
FixedPointScaledDiffInt32x4::FromRaw(
gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
FixedPointAccumInt32x4 exps_0 =
gemmlowp::Rescale<kAccumulationIntegerBits>(
exp_on_negative_values(scaled_diff_0));
FixedPointAccumInt32x4 exps_1 =
gemmlowp::Rescale<kAccumulationIntegerBits>(
exp_on_negative_values(scaled_diff_1));
FixedPointAccumInt32x4 masked_exps_0 =
SelectUsingMask(mask_0, exps_0, zeros);
FixedPointAccumInt32x4 masked_exps_1 =
SelectUsingMask(mask_1, exps_1, zeros);
sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
}
int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
int32x2_t sum_of_exps_reduced_2 =
vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
vget_high_s32(sum_of_exps_reduced_4));
int32x2_t sum_of_exps_reduced_1 =
vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
sum_of_exps =
FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
#endif
for (; c < depth; ++c) {
int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
if (input_diff >= diff_min) {
const int32 input_diff_rescaled =
MultiplyByQuantizedMultiplierGreaterThanOne(
input_diff, input_beta_multiplier, input_beta_left_shift);
const FixedPointScaledDiff scaled_diff_f8 =
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
sum_of_exps =
sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
exp_on_negative_values(scaled_diff_f8));
}
}
float sum_exp = 0.0f;
const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
const float* table_offset = &params.table[max_uint8 - max_val];
// Calculate normalizer sum(exp(x)).
for (int j = 0; j < last_dim; ++j) {
sum_exp += table_offset[input_data[j]];
}
// Compute the fixed-point multiplier and shift that we need to apply to
// perform a division by the above-computed sum-of-exponentials.
int num_bits_over_unit = 0;
FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
// Compute the quotients of exponentials of differences of entries in the
// current row from the largest entry, over the previously-computed sum of
// exponentials.
{
int c = 0;
#ifdef USE_NEON
int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
for (; c <= depth - 8; c += 8) {
uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
int16x8_t input_diff_s16 =
vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
FixedPointScaledDiffInt32x4 scaled_diff_0 =
input_beta_multiplier_f0 *
FixedPointScaledDiffInt32x4::FromRaw(
gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
FixedPointScaledDiffInt32x4 scaled_diff_1 =
input_beta_multiplier_f0 *
FixedPointScaledDiffInt32x4::FromRaw(
gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
num_bits_over_unit + 31 - 8);
int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
num_bits_over_unit + 31 - 8);
int16x8_t output_s16 =
vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
uint8x8_t output_u8 = vqmovun_s16(output_s16);
uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
vst1_u8(output_data_ptr + c, masked_output);
}
#endif
for (; c < depth; ++c) {
int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
if (input_diff >= diff_min) {
const int32 input_diff_rescaled =
MultiplyByQuantizedMultiplierGreaterThanOne(
input_diff, input_beta_multiplier, input_beta_left_shift);
const FixedPointScaledDiff scaled_diff_f8 =
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
int32 unsat_output = gemmlowp::RoundingDivideByPOT(
(shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
} else {
output_data_ptr[c] = 0;
}
}
const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
// Normalize and quantize probabilities.
for (int j = 0; j < last_dim; ++j) {
const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp;
const int32_t prob_quantized =
QuantizeSoftmaxOutput(output_data, prob_rescaled, params.zero_point);
output_data[j] = static_cast<T>(
std::max(std::min(clamp_max, prob_quantized), clamp_min));
}
input_data += last_dim;
output_data += last_dim;
}
}

View File

@ -124,9 +124,14 @@ void RunOneSoftmaxTest(const uint8* input_data,
input_beta_left_shift);
SoftmaxParams params;
float table[256];
params.input_multiplier = input_beta_multiplier;
params.input_left_shift = input_beta_left_shift;
params.diff_min = diff_min;
params.scale = 1.0f / 256;
params.zero_point = 0;
params.table = table;
optimized_ops::PopulateSoftmaxLookupTable(&params, input_scale, beta);
optimized_ops::Softmax(params, shape_common, input_data, shape_common,
optimized_softmax_output.data());
reference_ops::Softmax(params, shape_common, input_data, shape_common,
@ -137,7 +142,7 @@ void RunOneSoftmaxTest(const uint8* input_data,
"Optimized vs float reference", false);
CheckOutputData<uint8_t>(optimized_softmax_output.data(),
reference_quant_softmax_output.data(), shape_common,
"Optimized vs quant reference", true);
"Optimized vs quant reference", false);
CheckOutputData<uint8_t>(reference_quant_softmax_output.data(),
reference_float_softmax_output.data(), shape_common,
"Quant reference vs float reference", false);

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <initializer_list>
@ -985,6 +986,9 @@ struct SoftmaxParams {
int32 reverse_scaling_divisor;
int32 reverse_scaling_right_shift;
int diff_min;
int32_t zero_point;
float scale;
float* table;
};
struct SpaceToBatchParams {