Merge pull request #34589 from wwwind:16bit_tanh_sigmoid

PiperOrigin-RevId: 305328243
Change-Id: I43354818c4ce16bbba136aac9a006f5659c06a15
This commit is contained in:
TensorFlower Gardener 2020-04-07 13:35:26 -07:00
commit c57bb8bdc6
5 changed files with 134 additions and 23 deletions

View File

@ -501,6 +501,9 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
PopulateLookupTable<int8_t>(data, input, output, [](float value) {
return 1.0f / (1.0f + std::exp(-value));
});
} else if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE(context, output->params.scale == 1. / 32768);
TF_LITE_ENSURE(context, output->params.zero_point == 0);
}
}
@ -795,9 +798,12 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
TanhParams params;
params.input_left_shift = data->input_left_shift;
if (kernel_type == kReference) {
reference_ops::Tanh(
params, GetTensorShape(input), GetTensorData<int16_t>(input),
GetTensorShape(output), GetTensorData<int16_t>(output));
const int size =
MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
reference_integer_ops::Tanh(data->input_left_shift, size,
GetTensorData<int16_t>(input),
GetTensorData<int16_t>(output));
} else {
optimized_ops::Tanh(
params, GetTensorShape(input), GetTensorData<int16_t>(input),
@ -867,9 +873,11 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt16: {
LogisticParams params;
if (kernel_type == kReference) {
reference_ops::Logistic(
params, GetTensorShape(input), GetTensorData<int16_t>(input),
GetTensorShape(output), GetTensorData<int16_t>(output));
const int size =
MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
reference_integer_ops::Logistic(size, GetTensorData<int16_t>(input),
GetTensorData<int16_t>(output));
} else {
optimized_ops::Logistic(
params, GetTensorShape(input), GetTensorData<int16_t>(input),

View File

@ -764,19 +764,19 @@ TEST_P(TanhOpTest, TanhInt16) {
const float kMax = 32767.f / 32768.f;
QuantizedActivationsOpModel m(
GetRegistration(), BuiltinOperator_TANH,
/*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
/*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax});
m.SetInput<int16_t>({
0, -6, 2, 4, //
/*input=*/{TensorType_INT16, {1, 2, 8, 1}, 8 * kMin, 8 * kMax},
/*output=*/{TensorType_INT16, {1, 2, 8, 1}, kMin, kMax});
m.SetInput<int16_t>({0, -6, 2, 4, //
-4, -2, 8, 1, //
});
7, -8, 3, -5, //
6, -1, -3, 5});
m.Invoke();
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
ElementsAreArray(ArrayFloatNear(
{
0.0, -0.999987, 0.964027, 0.999329, //
{0.0, -0.999987, 0.964027, 0.999329, //
-0.999329, -0.96402, 0.99999, 0.76159, //
},
0.999998337, -0.99999, 0.995054754, -0.999909204, //
0.999999996, -0.76159, -0.995054754, 0.999909204},
kQuantizedToleranceInt16)));
}
@ -905,18 +905,18 @@ TEST_P(LogisticOpTest, SigmoidInt16) {
const float kMax = 32767.f / 32768.f;
QuantizedActivationsOpModel m(
GetRegistration(), BuiltinOperator_LOGISTIC,
/*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
/*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax});
m.SetInput<int16_t>({
0, -6, 2, 4, //
3, -2, 10, 1, //
});
/*input=*/{TensorType_INT16, {1, 2, 6, 1}, 8 * kMin, 8 * kMax},
/*output=*/{TensorType_INT16, {1, 2, 6, 1}, kMin, kMax});
m.SetInput<int16_t>({0, -6, 2, 4, //
3, -2, 8, 1, //
5, -8, 7, -3});
m.Invoke();
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
ElementsAreArray(ArrayFloatNear(
{
0.5, 0.002473, 0.880797, 0.982014, //
0.952574, 0.119203, 0.999955, 0.731059, //
0.952574, 0.119203, 0.9995, 0.731059, //
0.993307, 0.0003535, 0.999089, 0.047426 //
},
kQuantizedToleranceInt16)));
}

View File

@ -195,6 +195,38 @@ inline int CountLeadingSignBits(T integer_input) {
#endif
}
// Table of sigmoid(i/24) at 0.16 format - 256 elements.
// We use combined sigmoid and tanh look-up table, since
// tanh(x) = 2*sigmoid(2*x) -1.
// Both functions are symmetric, so the LUT table is only needed
// for the absolute value of the input.
static const uint16_t sigmoid_table_uint16[256] = {
32768, 33451, 34133, 34813, 35493, 36169, 36843, 37513, 38180, 38841, 39498,
40149, 40794, 41432, 42064, 42688, 43304, 43912, 44511, 45102, 45683, 46255,
46817, 47369, 47911, 48443, 48964, 49475, 49975, 50464, 50942, 51409, 51865,
52311, 52745, 53169, 53581, 53983, 54374, 54755, 55125, 55485, 55834, 56174,
56503, 56823, 57133, 57433, 57724, 58007, 58280, 58544, 58800, 59048, 59288,
59519, 59743, 59959, 60168, 60370, 60565, 60753, 60935, 61110, 61279, 61441,
61599, 61750, 61896, 62036, 62172, 62302, 62428, 62549, 62666, 62778, 62886,
62990, 63090, 63186, 63279, 63368, 63454, 63536, 63615, 63691, 63765, 63835,
63903, 63968, 64030, 64090, 64148, 64204, 64257, 64308, 64357, 64405, 64450,
64494, 64536, 64576, 64614, 64652, 64687, 64721, 64754, 64786, 64816, 64845,
64873, 64900, 64926, 64950, 64974, 64997, 65019, 65039, 65060, 65079, 65097,
65115, 65132, 65149, 65164, 65179, 65194, 65208, 65221, 65234, 65246, 65258,
65269, 65280, 65291, 65301, 65310, 65319, 65328, 65337, 65345, 65352, 65360,
65367, 65374, 65381, 65387, 65393, 65399, 65404, 65410, 65415, 65420, 65425,
65429, 65433, 65438, 65442, 65445, 65449, 65453, 65456, 65459, 65462, 65465,
65468, 65471, 65474, 65476, 65479, 65481, 65483, 65485, 65488, 65489, 65491,
65493, 65495, 65497, 65498, 65500, 65501, 65503, 65504, 65505, 65507, 65508,
65509, 65510, 65511, 65512, 65513, 65514, 65515, 65516, 65517, 65517, 65518,
65519, 65520, 65520, 65521, 65522, 65522, 65523, 65523, 65524, 65524, 65525,
65525, 65526, 65526, 65526, 65527, 65527, 65528, 65528, 65528, 65529, 65529,
65529, 65529, 65530, 65530, 65530, 65530, 65531, 65531, 65531, 65531, 65531,
65532, 65532, 65532, 65532, 65532, 65532, 65533, 65533, 65533, 65533, 65533,
65533, 65533, 65533, 65534, 65534, 65534, 65534, 65534, 65534, 65534, 65534,
65534, 65534, 65535};
// TODO(b/77858996): Add these to gemmlowp.
template <typename IntegerType>
IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {

View File

@ -58,6 +58,38 @@ inline void Logistic(int32_t input_zero_point, int32_t input_range_radius,
}
}
inline void Logistic(int32_t input_size, const int16_t* ptr_input_data,
int16_t* ptr_output_data) {
// We use the LUT for sigmoid and take into account, that
// tanh(x) = 2*sigmoid(2*x) - 1
for (int i = 0; i < input_size; ++i, ptr_input_data++, ptr_output_data++) {
int32_t input_data = *ptr_input_data;
// Scale by 3/4 to expand range [-8,8]->[-10.7,10.7] and
// we do interpolation on unsigned values.
uint32_t abs_input_data = 3 * abs(input_data);
// We divide by 2 power of 9, because
// we need to divide by 2 in power of 7 for
// the input conversion + 1/4 from the scale above.
uint8_t uh = abs_input_data >> 9;
uint32_t ua = sigmoid_table_uint16[uh];
uint32_t ub = sigmoid_table_uint16[uh + 1];
uint32_t ut = abs_input_data & 0x1ff;
// Interpolation is done using the fractional bit.
uint32_t result = (ua << 9) + ut * (ub - ua);
result = (input_data >= 0) ? (result + (1 << 9))
: ((1 << (16 + 9)) - result + (1 << 9) - 1);
// Back to 16-bit.
result >>= 10;
*ptr_output_data = result;
}
}
} // namespace reference_integer_ops
} // namespace tflite

View File

@ -57,6 +57,45 @@ inline void Tanh(int32_t input_zero_point, int32_t input_range_radius,
}
}
inline void Tanh(int32_t input_left_shift, int32_t input_size,
const int16_t* ptr_input_data, int16_t* ptr_output_data) {
// We use the LUT for sigmoid and take into account, that
// tanh(x) = 2*sigmoid(2*x) - 1
for (int i = 0; i < input_size; ++i, ptr_input_data++, ptr_output_data++) {
int32_t input_data = *ptr_input_data;
if (input_left_shift == 1) {
input_data <<= 1;
}
// Scale by 3/4 to expand range [-8,8]->[-10.7,10.7].
uint32_t abs_input_data = 3 * abs(input_data);
uint32_t uh = abs_input_data >> 8;
int32_t result;
if (uh >= 255) {
// Saturate to maximum.
result = 0xFFFF << 8;
} else {
uint32_t ua = sigmoid_table_uint16[uh];
uint32_t ub = sigmoid_table_uint16[uh + 1];
uint8_t ut = abs_input_data & 0xFF;
result = (ua << 8) + ut * (ub - ua);
}
result = (input_data >= 0)
? (result - (1 << (14 + 9)) + (1 << (9 - 2)))
: (-result + (1 << (14 + 9)) + (1 << (9 - 2)) - 1);
// Convert back to 16-bit.
result >>= (9 - 1);
*ptr_output_data = result;
}
}
} // namespace reference_integer_ops
} // namespace tflite