Merge pull request #34589 from wwwind:16bit_tanh_sigmoid
PiperOrigin-RevId: 305328243 Change-Id: I43354818c4ce16bbba136aac9a006f5659c06a15
This commit is contained in:
commit
c57bb8bdc6
@ -501,6 +501,9 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
PopulateLookupTable<int8_t>(data, input, output, [](float value) {
|
||||
return 1.0f / (1.0f + std::exp(-value));
|
||||
});
|
||||
} else if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE(context, output->params.scale == 1. / 32768);
|
||||
TF_LITE_ENSURE(context, output->params.zero_point == 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -795,9 +798,12 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TanhParams params;
|
||||
params.input_left_shift = data->input_left_shift;
|
||||
if (kernel_type == kReference) {
|
||||
reference_ops::Tanh(
|
||||
params, GetTensorShape(input), GetTensorData<int16_t>(input),
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||
const int size =
|
||||
MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
|
||||
|
||||
reference_integer_ops::Tanh(data->input_left_shift, size,
|
||||
GetTensorData<int16_t>(input),
|
||||
GetTensorData<int16_t>(output));
|
||||
} else {
|
||||
optimized_ops::Tanh(
|
||||
params, GetTensorShape(input), GetTensorData<int16_t>(input),
|
||||
@ -867,9 +873,11 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt16: {
|
||||
LogisticParams params;
|
||||
if (kernel_type == kReference) {
|
||||
reference_ops::Logistic(
|
||||
params, GetTensorShape(input), GetTensorData<int16_t>(input),
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||
const int size =
|
||||
MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
|
||||
|
||||
reference_integer_ops::Logistic(size, GetTensorData<int16_t>(input),
|
||||
GetTensorData<int16_t>(output));
|
||||
} else {
|
||||
optimized_ops::Logistic(
|
||||
params, GetTensorShape(input), GetTensorData<int16_t>(input),
|
||||
|
@ -764,19 +764,19 @@ TEST_P(TanhOpTest, TanhInt16) {
|
||||
const float kMax = 32767.f / 32768.f;
|
||||
QuantizedActivationsOpModel m(
|
||||
GetRegistration(), BuiltinOperator_TANH,
|
||||
/*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
|
||||
/*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax});
|
||||
m.SetInput<int16_t>({
|
||||
0, -6, 2, 4, //
|
||||
-4, -2, 8, 1, //
|
||||
});
|
||||
/*input=*/{TensorType_INT16, {1, 2, 8, 1}, 8 * kMin, 8 * kMax},
|
||||
/*output=*/{TensorType_INT16, {1, 2, 8, 1}, kMin, kMax});
|
||||
m.SetInput<int16_t>({0, -6, 2, 4, //
|
||||
-4, -2, 8, 1, //
|
||||
7, -8, 3, -5, //
|
||||
6, -1, -3, 5});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
0.0, -0.999987, 0.964027, 0.999329, //
|
||||
-0.999329, -0.96402, 0.99999, 0.76159, //
|
||||
},
|
||||
{0.0, -0.999987, 0.964027, 0.999329, //
|
||||
-0.999329, -0.96402, 0.99999, 0.76159, //
|
||||
0.999998337, -0.99999, 0.995054754, -0.999909204, //
|
||||
0.999999996, -0.76159, -0.995054754, 0.999909204},
|
||||
kQuantizedToleranceInt16)));
|
||||
}
|
||||
|
||||
@ -905,18 +905,18 @@ TEST_P(LogisticOpTest, SigmoidInt16) {
|
||||
const float kMax = 32767.f / 32768.f;
|
||||
QuantizedActivationsOpModel m(
|
||||
GetRegistration(), BuiltinOperator_LOGISTIC,
|
||||
/*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
|
||||
/*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax});
|
||||
m.SetInput<int16_t>({
|
||||
0, -6, 2, 4, //
|
||||
3, -2, 10, 1, //
|
||||
});
|
||||
/*input=*/{TensorType_INT16, {1, 2, 6, 1}, 8 * kMin, 8 * kMax},
|
||||
/*output=*/{TensorType_INT16, {1, 2, 6, 1}, kMin, kMax});
|
||||
m.SetInput<int16_t>({0, -6, 2, 4, //
|
||||
3, -2, 8, 1, //
|
||||
5, -8, 7, -3});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
0.5, 0.002473, 0.880797, 0.982014, //
|
||||
0.952574, 0.119203, 0.999955, 0.731059, //
|
||||
0.952574, 0.119203, 0.9995, 0.731059, //
|
||||
0.993307, 0.0003535, 0.999089, 0.047426 //
|
||||
},
|
||||
kQuantizedToleranceInt16)));
|
||||
}
|
||||
|
@ -195,6 +195,38 @@ inline int CountLeadingSignBits(T integer_input) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// Table of sigmoid(i/24) at 0.16 format - 256 elements.
|
||||
|
||||
// We use combined sigmoid and tanh look-up table, since
|
||||
// tanh(x) = 2*sigmoid(2*x) -1.
|
||||
// Both functions are symmetric, so the LUT table is only needed
|
||||
// for the absolute value of the input.
|
||||
static const uint16_t sigmoid_table_uint16[256] = {
|
||||
32768, 33451, 34133, 34813, 35493, 36169, 36843, 37513, 38180, 38841, 39498,
|
||||
40149, 40794, 41432, 42064, 42688, 43304, 43912, 44511, 45102, 45683, 46255,
|
||||
46817, 47369, 47911, 48443, 48964, 49475, 49975, 50464, 50942, 51409, 51865,
|
||||
52311, 52745, 53169, 53581, 53983, 54374, 54755, 55125, 55485, 55834, 56174,
|
||||
56503, 56823, 57133, 57433, 57724, 58007, 58280, 58544, 58800, 59048, 59288,
|
||||
59519, 59743, 59959, 60168, 60370, 60565, 60753, 60935, 61110, 61279, 61441,
|
||||
61599, 61750, 61896, 62036, 62172, 62302, 62428, 62549, 62666, 62778, 62886,
|
||||
62990, 63090, 63186, 63279, 63368, 63454, 63536, 63615, 63691, 63765, 63835,
|
||||
63903, 63968, 64030, 64090, 64148, 64204, 64257, 64308, 64357, 64405, 64450,
|
||||
64494, 64536, 64576, 64614, 64652, 64687, 64721, 64754, 64786, 64816, 64845,
|
||||
64873, 64900, 64926, 64950, 64974, 64997, 65019, 65039, 65060, 65079, 65097,
|
||||
65115, 65132, 65149, 65164, 65179, 65194, 65208, 65221, 65234, 65246, 65258,
|
||||
65269, 65280, 65291, 65301, 65310, 65319, 65328, 65337, 65345, 65352, 65360,
|
||||
65367, 65374, 65381, 65387, 65393, 65399, 65404, 65410, 65415, 65420, 65425,
|
||||
65429, 65433, 65438, 65442, 65445, 65449, 65453, 65456, 65459, 65462, 65465,
|
||||
65468, 65471, 65474, 65476, 65479, 65481, 65483, 65485, 65488, 65489, 65491,
|
||||
65493, 65495, 65497, 65498, 65500, 65501, 65503, 65504, 65505, 65507, 65508,
|
||||
65509, 65510, 65511, 65512, 65513, 65514, 65515, 65516, 65517, 65517, 65518,
|
||||
65519, 65520, 65520, 65521, 65522, 65522, 65523, 65523, 65524, 65524, 65525,
|
||||
65525, 65526, 65526, 65526, 65527, 65527, 65528, 65528, 65528, 65529, 65529,
|
||||
65529, 65529, 65530, 65530, 65530, 65530, 65531, 65531, 65531, 65531, 65531,
|
||||
65532, 65532, 65532, 65532, 65532, 65532, 65533, 65533, 65533, 65533, 65533,
|
||||
65533, 65533, 65533, 65534, 65534, 65534, 65534, 65534, 65534, 65534, 65534,
|
||||
65534, 65534, 65535};
|
||||
|
||||
// TODO(b/77858996): Add these to gemmlowp.
|
||||
template <typename IntegerType>
|
||||
IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
|
||||
|
@ -58,6 +58,38 @@ inline void Logistic(int32_t input_zero_point, int32_t input_range_radius,
|
||||
}
|
||||
}
|
||||
|
||||
inline void Logistic(int32_t input_size, const int16_t* ptr_input_data,
|
||||
int16_t* ptr_output_data) {
|
||||
// We use the LUT for sigmoid and take into account, that
|
||||
// tanh(x) = 2*sigmoid(2*x) - 1
|
||||
for (int i = 0; i < input_size; ++i, ptr_input_data++, ptr_output_data++) {
|
||||
int32_t input_data = *ptr_input_data;
|
||||
|
||||
// Scale by 3/4 to expand range [-8,8]->[-10.7,10.7] and
|
||||
// we do interpolation on unsigned values.
|
||||
uint32_t abs_input_data = 3 * abs(input_data);
|
||||
|
||||
// We divide by 2 power of 9, because
|
||||
// we need to divide by 2 in power of 7 for
|
||||
// the input conversion + 1/4 from the scale above.
|
||||
uint8_t uh = abs_input_data >> 9;
|
||||
uint32_t ua = sigmoid_table_uint16[uh];
|
||||
uint32_t ub = sigmoid_table_uint16[uh + 1];
|
||||
uint32_t ut = abs_input_data & 0x1ff;
|
||||
|
||||
// Interpolation is done using the fractional bit.
|
||||
uint32_t result = (ua << 9) + ut * (ub - ua);
|
||||
|
||||
result = (input_data >= 0) ? (result + (1 << 9))
|
||||
: ((1 << (16 + 9)) - result + (1 << 9) - 1);
|
||||
|
||||
// Back to 16-bit.
|
||||
result >>= 10;
|
||||
|
||||
*ptr_output_data = result;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_integer_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -57,6 +57,45 @@ inline void Tanh(int32_t input_zero_point, int32_t input_range_radius,
|
||||
}
|
||||
}
|
||||
|
||||
inline void Tanh(int32_t input_left_shift, int32_t input_size,
|
||||
const int16_t* ptr_input_data, int16_t* ptr_output_data) {
|
||||
// We use the LUT for sigmoid and take into account, that
|
||||
// tanh(x) = 2*sigmoid(2*x) - 1
|
||||
for (int i = 0; i < input_size; ++i, ptr_input_data++, ptr_output_data++) {
|
||||
int32_t input_data = *ptr_input_data;
|
||||
|
||||
if (input_left_shift == 1) {
|
||||
input_data <<= 1;
|
||||
}
|
||||
|
||||
// Scale by 3/4 to expand range [-8,8]->[-10.7,10.7].
|
||||
uint32_t abs_input_data = 3 * abs(input_data);
|
||||
uint32_t uh = abs_input_data >> 8;
|
||||
int32_t result;
|
||||
|
||||
if (uh >= 255) {
|
||||
// Saturate to maximum.
|
||||
result = 0xFFFF << 8;
|
||||
} else {
|
||||
uint32_t ua = sigmoid_table_uint16[uh];
|
||||
uint32_t ub = sigmoid_table_uint16[uh + 1];
|
||||
|
||||
uint8_t ut = abs_input_data & 0xFF;
|
||||
|
||||
result = (ua << 8) + ut * (ub - ua);
|
||||
}
|
||||
|
||||
result = (input_data >= 0)
|
||||
? (result - (1 << (14 + 9)) + (1 << (9 - 2)))
|
||||
: (-result + (1 << (14 + 9)) + (1 << (9 - 2)) - 1);
|
||||
|
||||
// Convert back to 16-bit.
|
||||
result >>= (9 - 1);
|
||||
|
||||
*ptr_output_data = result;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_integer_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user