optimize softmax quantized with two uint8 lookup table methods.

PiperOrigin-RevId: 315225925
Change-Id: I787ce554a45ccdf1d29834c6b0f74c8f409e30d4
This commit is contained in:
Renjie Liu 2020-06-08 00:47:08 -07:00 committed by TensorFlower Gardener
parent 4e7ce793d9
commit b997946576
4 changed files with 377 additions and 36 deletions

View File

@ -60,13 +60,18 @@ struct OpData {
struct SoftmaxOpData {
struct SoftmaxParams params = {};
float table[256]{};
const int size_of_lut = 513;
int16_t exp_lut[513]{}; // int16 LUT for exp(x), where x uniform distributed
// between [-10.0 , 0.0]
int16_t one_over_one_plus_x_lut[513]{}; // int16 LUT for 1 / (1 + x), where
// x uniform distributed between
// [0.0 , 1.0]
float table[256];
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
uint8_t uint8_table1[256];
uint8_t uint8_table2[256];
#endif
static constexpr int kInt16LUTArraySize = 513;
int16_t exp_lut[kInt16LUTArraySize]; // int16 LUT for exp(x), where x uniform
// distributed between [-10.0 , 0.0]
int16_t one_over_one_plus_x_lut[kInt16LUTArraySize]; // int16 LUT for 1 /
// (1 + x), where x
// uniform distributed
// between [0.0 , 1.0]
};
struct LogSoftmaxOpData : public OpData {
@ -134,29 +139,6 @@ void PopulateLookupTable(struct OpData* data, const TfLiteTensor* input,
}
}
#if __aarch64__ && __clang__
namespace {
// Looks up each element of <indices> in <table>, returns them in a vector.
inline uint8x16_t aarch64_lookup_vector(const uint8x16x4_t table[4],
uint8x16_t indices) {
// Look up in 1st quarter of the table: top 2 bits of indices == 00
uint8x16_t output1 = vqtbl4q_u8(table[0], indices);
// Look up in 2nd quarter of the table: top 2 bits of indices == 01
uint8x16_t output2 =
vqtbl4q_u8(table[1], veorq_u8(indices, vdupq_n_u8(0x40)));
// Look up in 3rd quarter of the table: top 2 bits of indices == 10
uint8x16_t output3 =
vqtbl4q_u8(table[2], veorq_u8(indices, vdupq_n_u8(0x80)));
// Look up in 4th quarter of the table: top 2 bits of indices == 11
uint8x16_t output4 =
vqtbl4q_u8(table[3], veorq_u8(indices, vdupq_n_u8(0xc0)));
// Combine result of the 4 lookups.
return vorrq_u8(vorrq_u8(output1, output2), vorrq_u8(output3, output4));
}
} // namespace
#endif
// TODO(b/143696793): move this to optimized_ops.
void EvalUsingLookupTable(struct OpData* data, const TfLiteTensor* input,
TfLiteTensor* output) {
@ -182,7 +164,7 @@ void EvalUsingLookupTable(struct OpData* data, const TfLiteTensor* input,
size / vectorized_16_loop_step * vectorized_16_loop_step;
for (; i < vectorized_16_loop_end; i += vectorized_16_loop_step) {
uint8x16_t input = vld1q_u8(input_data + i);
uint8x16_t output = aarch64_lookup_vector(table, input);
uint8x16_t output = optimized_ops::aarch64_lookup_vector(table, input);
vst1q_u8(output_data + i, output);
}
// Postamble and non-ARM64 code: simple for loop.
@ -583,9 +565,26 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
switch (output->type) {
case kTfLiteUInt8:
case kTfLiteInt8:
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
// Only apply when both input & output are uint8/int8 & build with clang
// on aarch64.
// TODO(b/143709993): Port to ARMv7 and other platforms.
data->params.uint8_table1 = data->uint8_table1;
data->params.uint8_table2 = data->uint8_table2;
optimized_ops::PopulateSoftmaxUInt8LookupTable(
&data->params, input->params.scale, params->beta);
break;
#endif
case kTfLiteInt16:
default:
data->params.table = data->table;
optimized_ops::PopulateSoftmaxLookupTable(
&data->params, input->params.scale, params->beta);
}
data->params.zero_point = output->params.zero_point;
data->params.scale = output->params.scale;
}
@ -597,10 +596,10 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
// exp LUT only used on nagative values
// we consider exp(-10.0) is insignificant to accumulation
gen_lut([](double value) { return std::exp(value); }, -10.0, 0.0,
data->params.exp_lut, data->size_of_lut);
data->params.exp_lut, data->kInt16LUTArraySize);
data->params.one_over_one_plus_x_lut = data->one_over_one_plus_x_lut;
gen_lut([](double value) { return 1.0 / (1.0 + value); }, 0.0, 1.0,
data->params.one_over_one_plus_x_lut, data->size_of_lut);
data->params.one_over_one_plus_x_lut, data->kInt16LUTArraySize);
data->params.zero_point = output->params.zero_point;
data->params.scale = output->params.scale;
@ -1019,6 +1018,40 @@ TfLiteStatus SoftmaxQuantized(TfLiteContext* context, const TfLiteTensor* input,
return kTfLiteOk;
}
template <>
TfLiteStatus SoftmaxQuantized<int8_t, int8_t>(TfLiteContext* context,
const TfLiteTensor* input,
TfLiteTensor* output,
SoftmaxOpData* data) {
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
optimized_ops::SoftmaxInt8LUT(
data->params, GetTensorShape(input), GetTensorData<int8_t>(input),
GetTensorShape(output), GetTensorData<int8_t>(output));
#else
optimized_ops::Softmax(data->params, GetTensorShape(input),
GetTensorData<int8_t>(input), GetTensorShape(output),
GetTensorData<int8_t>(output));
#endif
return kTfLiteOk;
}
template <>
TfLiteStatus SoftmaxQuantized<uint8_t, uint8_t>(TfLiteContext* context,
const TfLiteTensor* input,
TfLiteTensor* output,
SoftmaxOpData* data) {
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
optimized_ops::SoftmaxInt8LUT(
data->params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(output), GetTensorData<uint8_t>(output));
#else
optimized_ops::Softmax(data->params, GetTensorShape(input),
GetTensorData<uint8_t>(input), GetTensorShape(output),
GetTensorData<uint8_t>(output));
#endif
return kTfLiteOk;
}
template <>
TfLiteStatus SoftmaxQuantized<int16, int16>(TfLiteContext* context,
const TfLiteTensor* input,

View File

@ -56,6 +56,10 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/transpose_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"
#if __aarch64__ && __clang__
#define TFLITE_SOFTMAX_USE_UINT16_LUT
#endif
namespace tflite {
namespace optimized_ops {
@ -328,6 +332,29 @@ inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params,
}
}
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
// Looks up each element of <indices> in <table>, returns them in a vector.
inline uint8x16_t aarch64_lookup_vector(const uint8x16x4_t table[4],
uint8x16_t indices) {
// Look up in 1st quarter of the table: top 2 bits of indices == 00
uint8x16_t output1 = vqtbl4q_u8(table[0], indices);
// Look up in 2nd quarter of the table: top 2 bits of indices == 01
uint8x16_t output2 =
vqtbl4q_u8(table[1], veorq_u8(indices, vdupq_n_u8(0x40)));
// Look up in 3rd quarter of the table: top 2 bits of indices == 10
uint8x16_t output3 =
vqtbl4q_u8(table[2], veorq_u8(indices, vdupq_n_u8(0x80)));
// Look up in 4th quarter of the table: top 2 bits of indices == 11
uint8x16_t output4 =
vqtbl4q_u8(table[3], veorq_u8(indices, vdupq_n_u8(0xc0)));
// Combine result of the 4 lookups.
return vorrq_u8(vorrq_u8(output1, output2), vorrq_u8(output3, output4));
}
#endif
inline void AddBiasAndEvalActivationFunction(float output_activation_min,
float output_activation_max,
const RuntimeShape& bias_shape,
@ -3969,6 +3996,271 @@ inline void Softmax(const SoftmaxParams& params,
}
}
// Here's the softmax LUT optimization strategy:
// For softmax, we can do some mathmetically equivalent transformation:
//
// softmax(x) = e^x / sum(e^x, 0...n) ===> equals to
// softmax(x) = e^(x - CONST) / sum(e^(x - CONST), 0...n)
//
// For quantization, `x` in our case is (input_q - input_zp) * input_s
// For uint8 case (int8 can be handled similarly), the range is [0, 255]
//
// so if we let
// CONST = (255 - input_zp) * input_s
// then we will have:
// softmax(x) = e^((input_q - 255) * input_s) --------- (1)
// /
// sum(e^(input_q - 255) * input_s, 0...n) -------- (2)
//
// the good thing about (1) is it's within the range of (0, 1), so we can
// approximate its result with uint16.
// (1) = uint8_out * 1 / 2^16.
//
// so (1) is lookup_uint8_table(input_zp) * 1 / 2^16.
// then (2) is essentially the following:
// sum(lookup_uint8_table(input_zp), 0...n) / 2^16.
//
// since (output_q - output_zp) * output_s = softmax(x)
// output_q = lookup_uint8_table(input_zp)
// /
// (sum(lookup_uint8_table(input_zp), 0...n) * output_s)
// +
// output_zp
//
// We can actually further improve the performance by using uint8 instead of
// uint16. But that we may lose some accuracy, so we need to pay attention
// to that.
inline void PopulateSoftmaxUInt8LookupTable(SoftmaxParams* data,
float input_scale, float beta) {
const float scale = input_scale * beta;
const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
const int32_t max_uint16 = std::numeric_limits<uint16_t>::max();
for (int32_t val = 0; val <= max_uint8; ++val) {
float input_to_exp = scale * (val - max_uint8);
int32_t temp = static_cast<int>(expf(input_to_exp) * max_uint16 + 0.5);
temp = std::min(max_uint16, temp);
uint8_t part1 = temp >> 8;
uint8_t part2 = temp & 0xff;
data->uint8_table1[val] = static_cast<uint8_t>(part1);
data->uint8_table2[val] = static_cast<uint8_t>(part2);
}
}
inline int FindMaxValue(int size, const uint8_t* input_data, uint8_t offset) {
int32_t max_val = std::numeric_limits<uint8_t>::min();
int j = 0;
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
uint8x16_t max_val_dup = vdupq_n_u8(max_val);
uint8x16_t offset_dup = vdupq_n_u8(offset);
for (; j <= size - 16; j += 16) {
uint8x16_t input_value = vld1q_u8(input_data + j);
input_value = veorq_u8(input_value, offset_dup);
max_val_dup = vmaxq_u8(input_value, max_val_dup);
}
max_val = std::max(max_val, static_cast<int32>(vmaxvq_u8(max_val_dup)));
#endif
for (; j < size; ++j) {
max_val = std::max(max_val, static_cast<int32_t>(input_data[j] ^ offset));
}
return max_val;
}
#ifdef USE_NEON
// Value_to_store layout:
// [high_high, high_low, low_high, low_low].
inline void StoreValue(int32x4x4_t value_to_store, int8_t* output) {
const int16x8_t result_1 = vcombine_s16(vqmovn_s32(value_to_store.val[1]),
vqmovn_s32(value_to_store.val[0]));
const int16x8_t result_2 = vcombine_s16(vqmovn_s32(value_to_store.val[3]),
vqmovn_s32(value_to_store.val[2]));
const int8x16_t result =
vcombine_s8(vqmovn_s16(result_2), vqmovn_s16(result_1));
vst1q_s8(output, result);
}
// Value_to_store layout:
// [high_high, high_low, low_high, low_low].
inline void StoreValue(int32x4x4_t value_to_store, uint8_t* output) {
const uint16x8_t result_1 =
vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[1])),
vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[0])));
const uint16x8_t result_2 =
vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[3])),
vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[2])));
const uint8x16_t result =
vcombine_u8(vqmovn_u16(result_2), vqmovn_u16(result_1));
vst1q_u8(output, result);
}
#endif
template <typename In, typename Out>
inline void SoftmaxInt8LUT(const SoftmaxParams& params,
const RuntimeShape& input_shape,
const In* input_data,
const RuntimeShape& output_shape, Out* output_data) {
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int excluding_last_dim =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
const int last_dim =
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
const int32_t clamp_max = std::numeric_limits<Out>::max();
const int32_t clamp_min = std::numeric_limits<Out>::min();
// Offset is used to interpret the input data "correctly".
// If the input is uint8, the data will be unchanged.
// If the input is int8, since it will be reinterpret as uint8.
// e.g.,
// int8 127 will be applied "offset" to become 255 in uint8.
uint8_t offset = 0;
if (std::is_same<In, int8>::value) {
offset = 0x80;
}
const uint8_t* input_data_uint = reinterpret_cast<const uint8_t*>(input_data);
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
// This code uses ARM64-only instructions.
// TODO(b/143709993): Port to ARMv7
// Load the tables into registers. (4*4 128-bit registers)
uint8x16x4_t table1[4];
table1[0] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 0);
table1[1] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 1);
table1[2] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 2);
table1[3] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 3);
uint8x16x4_t table2[4];
table2[0] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 0);
table2[1] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 1);
table2[2] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 2);
table2[3] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 3);
#endif
for (int i = 0; i < excluding_last_dim; ++i) {
// Find max quantized value.
int32_t max_val = FindMaxValue(last_dim, input_data_uint, offset);
int32 sum_exp = 0;
const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
const uint8_t table_offset = max_uint8 - max_val;
// Calculate normalizer sum(exp(x)).
int sum_j = 0;
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
uint8x16_t table_offset_dup = vdupq_n_u8(table_offset);
uint8x16_t offset_dup = vdupq_n_u8(offset);
uint32x4_t sum_4 = vdupq_n_u32(0);
const int multiplier_shift = 8;
for (; sum_j <= last_dim - 16; sum_j += 16) {
uint8x16_t input_value = vld1q_u8(input_data_uint + sum_j);
input_value = veorq_u8(input_value, offset_dup);
input_value = vaddq_u8(input_value, table_offset_dup);
const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
uint16x8_t exp_value1 =
vshll_n_u8(vget_high_u8(output1), multiplier_shift);
uint16x8_t exp_value2 =
vshll_n_u8(vget_low_u8(output1), multiplier_shift);
exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
sum_4 = vpadalq_u16(sum_4, exp_value1);
sum_4 = vpadalq_u16(sum_4, exp_value2);
}
int temp = vgetq_lane_u32(sum_4, 0) + vgetq_lane_u32(sum_4, 1) +
vgetq_lane_u32(sum_4, 2) + vgetq_lane_u32(sum_4, 3);
sum_exp += temp;
#endif
for (; sum_j < last_dim; ++sum_j) {
const uint8_t index = (input_data_uint[sum_j] ^ offset) + table_offset;
uint8_t part1 = params.uint8_table1[index];
uint8_t part2 = params.uint8_table2[index];
sum_exp += ((part1 << 8) + part2);
}
const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
int32 multiplier, shift;
QuantizeMultiplier(inv_sum_exp, &multiplier, &shift);
// Normalize and quantize probabilities.
int j = 0;
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
const int32x4_t output_zp_dup = vdupq_n_s32(params.zero_point);
const int32x4_t max_val_dup = vdupq_n_s32(clamp_max);
const int32x4_t min_val_dup = vdupq_n_s32(clamp_min);
for (; j <= last_dim - 16; j += 16) {
uint8x16_t input_value = vld1q_u8(input_data_uint + j);
input_value = veorq_u8(input_value, offset_dup);
input_value = vaddq_u8(input_value, table_offset_dup);
const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
uint16x8_t exp_value1 =
vshll_n_u8(vget_high_u8(output1), multiplier_shift);
uint16x8_t exp_value2 =
vshll_n_u8(vget_low_u8(output1), multiplier_shift);
exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
int32x4x4_t output_value;
output_value.val[0] =
vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value1)));
output_value.val[1] =
vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value1)));
output_value.val[2] =
vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value2)));
output_value.val[3] =
vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value2)));
int32x4x4_t temp_val =
MultiplyByQuantizedMultiplier4Rows(output_value, multiplier, shift);
temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
temp_val.val[0] =
vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
temp_val.val[1] =
vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
temp_val.val[2] =
vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
temp_val.val[3] =
vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
StoreValue(temp_val, output_data + j);
}
#endif
for (; j < last_dim; ++j) {
const uint8_t index = (input_data_uint[j] ^ offset) + table_offset;
const uint8_t part1 = params.uint8_table1[index];
const uint8_t part2 = params.uint8_table2[index];
const int32_t exp_value = (part1 << 8) + part2;
const int32_t output_value =
MultiplyByQuantizedMultiplier(exp_value, multiplier, shift);
output_data[j] = static_cast<Out>(std::max(
std::min(clamp_max, output_value + params.zero_point), clamp_min));
}
input_data_uint += last_dim;
output_data += last_dim;
}
}
// TODO(myenik): This is the same as the reference implementation, not actually
// optimized yet.
inline void LogSoftmax(const SoftmaxParams& params,

View File

@ -133,6 +133,7 @@ void RunOneSoftmaxTest(const uint8* input_data,
params.scale = 1.0f / 256;
params.zero_point = 0;
params.table = table;
optimized_ops::PopulateSoftmaxLookupTable(&params, input_scale, beta);
optimized_ops::Softmax(params, shape_common, input_data, shape_common,
optimized_softmax_output.data());
@ -148,6 +149,19 @@ void RunOneSoftmaxTest(const uint8* input_data,
CheckOutputData<uint8_t>(reference_quant_softmax_output.data(),
reference_float_softmax_output.data(), shape_common,
"Quant reference vs float reference", false);
#if __aarch64__ && __clang__
uint8_t uint8_table1[256];
uint8_t uint8_table2[256];
params.uint8_table1 = uint8_table1;
params.uint8_table2 = uint8_table2;
optimized_ops::PopulateSoftmaxUInt8LookupTable(&params, input_scale, beta);
optimized_ops::SoftmaxInt8LUT(params, shape_common, input_data, shape_common,
optimized_softmax_output.data());
CheckOutputData<uint8_t>(
optimized_softmax_output.data(), reference_quant_softmax_output.data(),
shape_common, "Optimized (Uint8 Lookup table) vs quant reference", false);
#endif
}
// This function picks some random Softmax params, which are checked for

View File

@ -1035,6 +1035,8 @@ struct SoftmaxParams {
float* table;
int16_t* exp_lut;
int16_t* one_over_one_plus_x_lut;
uint8_t* uint8_table1;
uint8_t* uint8_table2;
};
struct SpaceToBatchParams {