Create int8 log softmax.
PiperOrigin-RevId: 234656496
This commit is contained in:
parent
0b4cfb42a6
commit
8e0b3713a4
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
|
||||
@ -270,8 +271,13 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
|
||||
if (input->type == kTfLiteUInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255);
|
||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
||||
if (input->type == kTfLiteUInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255);
|
||||
}
|
||||
if (input->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 127);
|
||||
}
|
||||
TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256);
|
||||
|
||||
static const double kBeta = 1.0;
|
||||
@ -854,6 +860,21 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case kTfLiteInt8: {
|
||||
const auto input_shape = GetTensorShape(input);
|
||||
const auto output_shape = GetTensorShape(output);
|
||||
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||
const int outer_size =
|
||||
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||
const int depth =
|
||||
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||
reference_integer_ops::LogSoftmax(
|
||||
data->input_multiplier, data->input_left_shift,
|
||||
data->reverse_scaling_divisor, data->reverse_scaling_right_shift,
|
||||
data->diff_min, outer_size, depth, GetTensorData<int8_t>(input),
|
||||
GetTensorData<int8_t>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
default:
|
||||
context->ReportError(context, "Only float32 supported currently., got %s",
|
||||
TfLiteTypeGetName(input->type));
|
||||
|
@ -772,7 +772,7 @@ TEST(FloatActivationsOpTest, LogSoftmax) {
|
||||
})));
|
||||
}
|
||||
|
||||
TEST(QuantizedActivationsOpTest, LogSoftmax) {
|
||||
TEST(QuantizedActivationsOpTest, LogSoftmaxUint8) {
|
||||
const float kLogSoftmaxQuantizedTolerance = 16 / 256.0;
|
||||
QuantizedActivationsOpModel m(
|
||||
BuiltinOperator_LOG_SOFTMAX,
|
||||
@ -794,6 +794,30 @@ TEST(QuantizedActivationsOpTest, LogSoftmax) {
|
||||
ElementsAreArray({189, 93, 221, 253, 142, 63, 255, 111}));
|
||||
}
|
||||
|
||||
TEST(QuantizedActivationsOpTest, LogSoftmaxInt8) {
|
||||
const float kLogSoftmaxQuantizedTolerance = 0.06355;
|
||||
QuantizedActivationsOpModel m(
|
||||
BuiltinOperator_LOG_SOFTMAX,
|
||||
/*input=*/{TensorType_INT8, {2, 4}, -10, 10},
|
||||
/*output=*/{TensorType_INT8, {}, 0, 0, 16. / 256, 127});
|
||||
m.SetInput<int8_t>({
|
||||
0, -6, 2, 4, //
|
||||
3, -2, 10, 1, //
|
||||
});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
-4.14297, -10.14297, -2.14297, -.142971, //
|
||||
-7.00104, -12.00104, -.00104087, -9.00104, //
|
||||
},
|
||||
kLogSoftmaxQuantizedTolerance)));
|
||||
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({
|
||||
61, -36, 93, 125, //
|
||||
15, -65, 127, -16, //
|
||||
}));
|
||||
}
|
||||
|
||||
// A base class of PRelu op model. It provides the constructor for
|
||||
// FloatPReluOpModel and QuantizedPReluOpModel.
|
||||
class BasePReluOpModel : public SingleOpModel {
|
||||
|
@ -311,6 +311,7 @@ cc_library(
|
||||
"reference/integer_ops/depthwise_conv.h",
|
||||
"reference/integer_ops/dequantize.h",
|
||||
"reference/integer_ops/fully_connected.h",
|
||||
"reference/integer_ops/log_softmax.h",
|
||||
"reference/integer_ops/logistic.h",
|
||||
"reference/integer_ops/mul.h",
|
||||
"reference/integer_ops/pooling.h",
|
||||
@ -652,7 +653,7 @@ cc_test(
|
||||
srcs = [
|
||||
"logsoftmax_quantized_test.cc",
|
||||
],
|
||||
shard_count = 3,
|
||||
shard_count = 4,
|
||||
tags = [
|
||||
# TODO(b/122242739): Reenable after fixing the flakiness?
|
||||
"nomac",
|
||||
|
@ -131,6 +131,221 @@ int CountLeadingZeros(T integer_input) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// TODO(b/77858996): Add these to gemmlowp.
|
||||
template <typename IntegerType>
|
||||
IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
|
||||
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
|
||||
return a;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
|
||||
std::int64_t a64 = a;
|
||||
std::int64_t b64 = b;
|
||||
std::int64_t sum = a64 + b64;
|
||||
return static_cast<std::int32_t>(std::min(
|
||||
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
|
||||
std::max(
|
||||
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
|
||||
sum)));
|
||||
}
|
||||
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
|
||||
return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
|
||||
SaturatingAddNonGemmlowp(a.raw(), b.raw()));
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
IntegerType SaturatingSub(IntegerType a, IntegerType b) {
|
||||
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
|
||||
return a;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
|
||||
std::int32_t a32 = a;
|
||||
std::int32_t b32 = b;
|
||||
std::int32_t diff = a32 - b32;
|
||||
return static_cast<std::int16_t>(std::min(32767, std::max(-32768, diff)));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
|
||||
std::int64_t a64 = a;
|
||||
std::int64_t b64 = b;
|
||||
std::int64_t diff = a64 - b64;
|
||||
return static_cast<std::int32_t>(std::min(
|
||||
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
|
||||
std::max(
|
||||
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
|
||||
diff)));
|
||||
}
|
||||
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
|
||||
return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
|
||||
SaturatingSub(a.raw(), b.raw()));
|
||||
}
|
||||
// End section to be moved to gemmlowp.
|
||||
|
||||
template <typename IntegerType>
|
||||
IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
|
||||
if (exponent == 0) {
|
||||
return x;
|
||||
}
|
||||
using ScalarIntegerType =
|
||||
typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
|
||||
const IntegerType min =
|
||||
gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
|
||||
const IntegerType max =
|
||||
gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
|
||||
const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
|
||||
|
||||
const std::int32_t threshold =
|
||||
((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
|
||||
const IntegerType positive_mask =
|
||||
gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
|
||||
const IntegerType negative_mask =
|
||||
gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
|
||||
|
||||
IntegerType result = gemmlowp::ShiftLeft(x, exponent);
|
||||
result = gemmlowp::SelectUsingMask(positive_mask, max, result);
|
||||
result = gemmlowp::SelectUsingMask(negative_mask, min, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
// If we want to leave IntegerBits fixed, then multiplication
|
||||
// by a power of two has to be saturating/rounding, not exact anymore.
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits>
|
||||
SaturatingRoundingMultiplyByPOTParam(
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
|
||||
return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
|
||||
SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
|
||||
}
|
||||
|
||||
// Minimum output bits to accommodate log of maximum input range. It actually
|
||||
// does not matter if one considers, say, [-64,64] or [-64,64).
|
||||
//
|
||||
// For example, run this through Octave:
|
||||
// [0:127; ...
|
||||
// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
|
||||
// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
|
||||
constexpr int min_log_x_output_bits(int input_bits) {
|
||||
return input_bits > 90
|
||||
? 7
|
||||
: input_bits > 44
|
||||
? 6
|
||||
: input_bits > 21
|
||||
? 5
|
||||
: input_bits > 10
|
||||
? 4
|
||||
: input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1;
|
||||
}
|
||||
|
||||
// Although currently the name of this function says that it cannot handle
|
||||
// values less than 1, in practice it can handle as low as 1/x_max, where
|
||||
// x_max is the largest representable input. In other words, the output range
|
||||
// is symmetric.
|
||||
template <int OutputIntegerBits, int InputIntegerBits>
|
||||
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
|
||||
log_x_for_x_greater_than_or_equal_to_1_impl(
|
||||
gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
|
||||
// assert(__builtin_clz(0u) >= std::numeric_limits<uint32>::digits - 1);
|
||||
// assert(__builtin_clz(0u) <= std::numeric_limits<uint32>::digits);
|
||||
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
|
||||
// The reason for accumulating the result with an extra bit of headroom is
|
||||
// that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
|
||||
// recip_denom will otherwise introduce an error.
|
||||
static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
|
||||
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumIntegerBits>;
|
||||
|
||||
const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 1488522236, std::log(2.0));
|
||||
const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
|
||||
const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 1518500250, std::sqrt(0.5));
|
||||
const FixedPoint0 one_quarter =
|
||||
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
|
||||
|
||||
const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
|
||||
const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
|
||||
const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 1057819769,
|
||||
2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
|
||||
const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
|
||||
|
||||
const FixedPointAccum shifted_quarter =
|
||||
gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
|
||||
|
||||
// Reinterpret the input value as Q0.31, because we will figure out the
|
||||
// required shift "ourselves" instead of using, say, Rescale.
|
||||
FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
|
||||
// z_a_pow_2 = input_integer_bits - z_a_headroom;
|
||||
int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32>(z_a.raw()));
|
||||
FixedPoint0 r_a_tmp =
|
||||
SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
|
||||
const int32 r_a_raw =
|
||||
SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
|
||||
// z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
|
||||
// z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
|
||||
// InputIntegerBits - z_b_headroom - 0.25);
|
||||
const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
|
||||
FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
|
||||
InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
|
||||
shifted_quarter);
|
||||
|
||||
// z_b is treated like z_a, but premultiplying by sqrt(0.5).
|
||||
FixedPoint0 z_b = z_a * sqrt_half;
|
||||
int z_b_headroom = CountLeadingZeros(static_cast<uint32>(z_b.raw())) - 1;
|
||||
const int32 r_b_raw =
|
||||
SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
|
||||
const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
|
||||
FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
|
||||
InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
|
||||
shifted_quarter);
|
||||
|
||||
const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
|
||||
const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
|
||||
std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
|
||||
|
||||
const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
|
||||
FixedPoint0 q = r - sqrt_sqrt_half;
|
||||
q = q + q;
|
||||
|
||||
const FixedPoint0 common_sq = q * q;
|
||||
const FixedPoint0 num = q * r + q * common_sq * alpha_n;
|
||||
const FixedPoint0 denom_minus_one_0 =
|
||||
p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
|
||||
const FixedPoint0 recip_denom =
|
||||
one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
|
||||
|
||||
const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
|
||||
return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
|
||||
num_scaled * recip_denom);
|
||||
}
|
||||
|
||||
template <int OutputIntegerBits, int InputIntegerBits>
|
||||
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
|
||||
log_x_for_x_greater_than_or_equal_to_1(
|
||||
gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
|
||||
static_assert(
|
||||
OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
|
||||
"Output integer bits must be sufficent to accommodate logs of inputs.");
|
||||
return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
|
||||
InputIntegerBits>(
|
||||
input_val);
|
||||
}
|
||||
|
||||
inline int32 GetReciprocal(int32 x, int x_integer_digits,
|
||||
int* num_bits_over_unit) {
|
||||
int headroom_plus_one = CountLeadingZeros(static_cast<uint32>(x));
|
||||
|
@ -121,8 +121,7 @@ void RunSingleTest(const std::vector<int32>& test_input,
|
||||
const string& check_label, int tolerance) {
|
||||
const int n = test_input.size();
|
||||
std::vector<int32> float_gen_output(n, 0);
|
||||
std::vector<int32> reference_output(n, 0);
|
||||
std::vector<int32> optimized_output(n, 0);
|
||||
std::vector<int32> quantized_output(n, 0);
|
||||
|
||||
// Workaround the stupid things that intelligent humans do.
|
||||
// Consequence of __builtin_clz(0u) may equal 31 instead of 32.
|
||||
@ -132,45 +131,21 @@ void RunSingleTest(const std::vector<int32>& test_input,
|
||||
}
|
||||
|
||||
for (int i = 0; i < n; ++i) {
|
||||
reference_output[i] =
|
||||
tflite::reference_ops::log_x_for_x_greater_than_or_equal_to_1_impl<
|
||||
OutputIntegerBits, InputIntegerBits>(
|
||||
gemmlowp::FixedPoint<int32, InputIntegerBits>::FromRaw(
|
||||
fudged_input[i]))
|
||||
.raw();
|
||||
optimized_output[i] =
|
||||
tflite::optimized_ops::log_x_for_x_greater_than_or_equal_to_1_impl<
|
||||
OutputIntegerBits, InputIntegerBits>(
|
||||
quantized_output[i] =
|
||||
tflite::log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
|
||||
InputIntegerBits>(
|
||||
gemmlowp::FixedPoint<int32, InputIntegerBits>::FromRaw(
|
||||
fudged_input[i]))
|
||||
.raw();
|
||||
float_gen_output[i] = LogPositiveValuesViaFloat(
|
||||
fudged_input[i], InputIntegerBits, OutputIntegerBits);
|
||||
}
|
||||
// Note that first check is intolerant.
|
||||
{
|
||||
std::ostringstream label;
|
||||
label << check_label << " / optimized vs reference / InputIntegerBits="
|
||||
<< InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits;
|
||||
CheckOutputData(
|
||||
optimized_output, reference_output, test_input, label.str(),
|
||||
InputIntegerBits, OutputIntegerBits, 0);
|
||||
}
|
||||
{
|
||||
std::ostringstream label;
|
||||
label << check_label << " / reference vs float-gen / InputIntegerBits="
|
||||
<< InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits;
|
||||
CheckOutputData(
|
||||
reference_output, float_gen_output, test_input, label.str(),
|
||||
InputIntegerBits, OutputIntegerBits, tolerance);
|
||||
}
|
||||
{
|
||||
std::ostringstream label;
|
||||
label << check_label << " optimized vs float-gen / InputIntegerBits="
|
||||
<< InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits;
|
||||
CheckOutputData(
|
||||
optimized_output, float_gen_output, test_input, label.str(),
|
||||
InputIntegerBits, OutputIntegerBits, tolerance);
|
||||
CheckOutputData(quantized_output, float_gen_output, test_input, label.str(),
|
||||
InputIntegerBits, OutputIntegerBits, tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -25,6 +25,8 @@ limitations under the License.
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/test_util.h"
|
||||
#include "tensorflow/lite/string.h"
|
||||
@ -61,7 +63,42 @@ void RunLogSoftmaxFloatReference(const uint8* input_data,
|
||||
}
|
||||
}
|
||||
|
||||
void CheckOutputData(const uint8* test_output, const uint8* reference_output,
|
||||
// Same as above except for the following change:
|
||||
// - input and output data type
|
||||
// - Dequnatize function
|
||||
// - clamping values
|
||||
void RunLogSoftmaxFloatReference(const int8* input_data,
|
||||
const RuntimeShape& shape_common,
|
||||
int32 input_offset, const double input_scale,
|
||||
int stride, float beta,
|
||||
int8* reference_output_data) {
|
||||
const int ref_buffer_size = shape_common.FlatSize();
|
||||
std::vector<float> reference_dequant_data(ref_buffer_size);
|
||||
std::vector<float> reference_output_float_data(ref_buffer_size);
|
||||
|
||||
// Reference data generated via Dequant of input into float, and then applying
|
||||
// float LogSoftmax.
|
||||
DequantizationParams dq_params;
|
||||
dq_params.zero_point = input_offset;
|
||||
dq_params.scale = input_scale;
|
||||
reference_integer_ops::Dequantize(dq_params, shape_common, input_data,
|
||||
shape_common,
|
||||
reference_dequant_data.data());
|
||||
SoftmaxParams sm_params;
|
||||
optimized_ops::LogSoftmax(sm_params, shape_common,
|
||||
reference_dequant_data.data(), shape_common,
|
||||
reference_output_float_data.data());
|
||||
// Work with quantized scaling for LogSoftmax, under which 255 represents 0,
|
||||
// and -16 gets nudged up to 0.
|
||||
for (int i = 0; i < ref_buffer_size; i++) {
|
||||
reference_output_data[i] = std::max(
|
||||
-128, static_cast<int>(
|
||||
127 + std::round(16.0f * reference_output_float_data[i])));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CheckOutputData(const T* test_output, const T* reference_output,
|
||||
const RuntimeShape& shape_common,
|
||||
const string& check_label, bool be_exacting) {
|
||||
const int buffer_size = shape_common.FlatSize();
|
||||
@ -144,15 +181,58 @@ void RunOneLogSoftmaxTest(const uint8* input_data,
|
||||
reference_ops::LogSoftmax(params, shape_common, input_data, shape_common,
|
||||
reference_quant_logsoftmax_output.data());
|
||||
|
||||
CheckOutputData(optimized_logsoftmax_output.data(),
|
||||
reference_float_logsoftmax_output.data(), shape_common,
|
||||
"Optimized vs float reference", false);
|
||||
CheckOutputData(optimized_logsoftmax_output.data(),
|
||||
reference_quant_logsoftmax_output.data(), shape_common,
|
||||
"Optimized vs quant reference", true);
|
||||
CheckOutputData(reference_quant_logsoftmax_output.data(),
|
||||
reference_float_logsoftmax_output.data(), shape_common,
|
||||
"Quant reference vs float reference", false);
|
||||
CheckOutputData<uint8_t>(optimized_logsoftmax_output.data(),
|
||||
reference_float_logsoftmax_output.data(),
|
||||
shape_common, "Optimized vs float reference", false);
|
||||
CheckOutputData<uint8_t>(optimized_logsoftmax_output.data(),
|
||||
reference_quant_logsoftmax_output.data(),
|
||||
shape_common, "Optimized vs quant reference", true);
|
||||
CheckOutputData<uint8_t>(reference_quant_logsoftmax_output.data(),
|
||||
reference_float_logsoftmax_output.data(),
|
||||
shape_common, "Quant reference vs float reference",
|
||||
false);
|
||||
}
|
||||
|
||||
// Runs the LogSoftmax and compares against the float reference implementation
|
||||
// and the int8 quantized reference implementation.
|
||||
void RunOneLogSoftmaxTest(const int8* input_data,
|
||||
const RuntimeShape& shape_common, int32 input_offset,
|
||||
const double input_scale, int stride, float beta) {
|
||||
const int buffer_size = shape_common.FlatSize();
|
||||
std::vector<int8> quantized_logsoftmax_reference_implementation(buffer_size);
|
||||
std::vector<int8> float_logsoftmax_optimized_implementation(buffer_size);
|
||||
|
||||
RunLogSoftmaxFloatReference(input_data, shape_common, input_offset,
|
||||
input_scale, stride, beta,
|
||||
float_logsoftmax_optimized_implementation.data());
|
||||
|
||||
int32 input_beta_multiplier;
|
||||
int input_beta_left_shift;
|
||||
int32 reverse_scaling_divisor;
|
||||
int reverse_scaling_right_shift;
|
||||
static const int kScaledDiffIntegerBits = 5;
|
||||
tflite::PreprocessLogSoftmaxScalingExp(
|
||||
beta, input_scale, kScaledDiffIntegerBits, &input_beta_multiplier,
|
||||
&input_beta_left_shift, &reverse_scaling_divisor,
|
||||
&reverse_scaling_right_shift);
|
||||
reverse_scaling_right_shift *= -1;
|
||||
// diff_min has a negative value, and is used to limit the maximum magnitude
|
||||
// of the diffs, which are <= 0.
|
||||
const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
|
||||
input_beta_left_shift);
|
||||
|
||||
const int outer_size =
|
||||
shape_common.Dims(0) * shape_common.Dims(1) * shape_common.Dims(2);
|
||||
const int inner_size = shape_common.Dims(3);
|
||||
reference_integer_ops::LogSoftmax(
|
||||
input_beta_multiplier, input_beta_left_shift, reverse_scaling_divisor,
|
||||
reverse_scaling_right_shift, diff_min, outer_size, inner_size, input_data,
|
||||
quantized_logsoftmax_reference_implementation.data());
|
||||
|
||||
CheckOutputData<int8_t>(quantized_logsoftmax_reference_implementation.data(),
|
||||
float_logsoftmax_optimized_implementation.data(),
|
||||
shape_common, "Quant reference vs float reference",
|
||||
false);
|
||||
}
|
||||
|
||||
// This function picks some random LogSoftmax params, which are checked for
|
||||
@ -161,6 +241,7 @@ void RunOneLogSoftmaxTest(const uint8* input_data,
|
||||
// to loop until a test has been run.
|
||||
//
|
||||
// Currently we do not reject for any reason.
|
||||
template <typename T>
|
||||
bool TryOneUniformLogSoftmax() {
|
||||
// We pick mostly positive values, on the whole emphasizing smaller values and
|
||||
// therefore faster tests. We test a wider range of depths. In the case of
|
||||
@ -178,7 +259,7 @@ bool TryOneUniformLogSoftmax() {
|
||||
RuntimeShape({batch, input_height, input_width, input_depth});
|
||||
const int buffer_size = shape_common.FlatSize();
|
||||
|
||||
std::vector<uint8> input_data(buffer_size);
|
||||
std::vector<T> input_data(buffer_size);
|
||||
FillRandom(&input_data);
|
||||
RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset,
|
||||
input_scale, stride, beta);
|
||||
@ -224,15 +305,23 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) {
|
||||
return true;
|
||||
}
|
||||
|
||||
TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxTests) {
|
||||
TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxUint8Tests) {
|
||||
const int kTestsToRun = 100;
|
||||
for (int i = 0; i < kTestsToRun; i++) {
|
||||
while (!TryOneUniformLogSoftmax()) {
|
||||
while (!TryOneUniformLogSoftmax<uint8_t>()) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxTests) {
|
||||
TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxUint8Int8Tests) {
|
||||
const int kTestsToRun = 100;
|
||||
for (int i = 0; i < kTestsToRun; i++) {
|
||||
while (!TryOneUniformLogSoftmax<int8_t>()) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxUint8Tests) {
|
||||
const int kTestsToRun = 100;
|
||||
for (int i = 0; i < kTestsToRun; i++) {
|
||||
while (!TryOneSkyscraperLogSoftmax(false)) {
|
||||
@ -240,7 +329,7 @@ TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxTests) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestQuantizedLogSoftmax, SmallSkyscraperLogSoftmaxTests) {
|
||||
TEST(TestQuantizedLogSoftmax, SmallSkyscraperLogSoftmaxUint8Tests) {
|
||||
const int kTestsToRun = 100;
|
||||
for (int i = 0; i < kTestsToRun; i++) {
|
||||
while (!TryOneSkyscraperLogSoftmax(true)) {
|
||||
|
@ -182,45 +182,6 @@ MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
|
||||
return MatrixMap<Scalar>(data, rows, cols);
|
||||
}
|
||||
|
||||
// This is like the template-parameter version, except that the power-of-two is
|
||||
// passed as a function parameter. The template version is to be preferred,
|
||||
// since some target hardware optimizations depend on the range of the exponent.
|
||||
template <typename IntegerType>
|
||||
IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
|
||||
if (exponent == 0) {
|
||||
return x;
|
||||
}
|
||||
using ScalarIntegerType =
|
||||
typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
|
||||
const IntegerType min =
|
||||
gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
|
||||
const IntegerType max =
|
||||
gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
|
||||
const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
|
||||
|
||||
const std::int32_t threshold =
|
||||
((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
|
||||
const IntegerType positive_mask =
|
||||
gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
|
||||
const IntegerType negative_mask =
|
||||
gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
|
||||
|
||||
IntegerType result = gemmlowp::ShiftLeft(x, exponent);
|
||||
result = gemmlowp::SelectUsingMask(positive_mask, max, result);
|
||||
result = gemmlowp::SelectUsingMask(negative_mask, min, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
// This is like the template-parameter version, except that the power-of-two is
|
||||
// passed as a function parameter. See raw-integer version for further comments.
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits>
|
||||
SaturatingRoundingMultiplyByPOTParam(
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
|
||||
return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
|
||||
SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
|
||||
}
|
||||
|
||||
inline void AddBiasAndEvalActivationFunction(float output_activation_min,
|
||||
float output_activation_max,
|
||||
const RuntimeShape& bias_shape,
|
||||
@ -4557,119 +4518,6 @@ inline void LogSoftmax(const SoftmaxParams& params,
|
||||
}
|
||||
}
|
||||
|
||||
template <int OutputIntegerBits, int InputIntegerBits>
|
||||
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
|
||||
log_x_for_x_greater_than_or_equal_to_1_impl(
|
||||
gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
|
||||
// assert(__builtin_clz(0u) >= std::numeric_limits<uint32>::digits - 1);
|
||||
// assert(__builtin_clz(0u) <= std::numeric_limits<uint32>::digits);
|
||||
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
|
||||
// The reason for accumulating the result with an extra bit of headroom is
|
||||
// that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
|
||||
// recip_denom will otherwise introduce an error.
|
||||
static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
|
||||
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumIntegerBits>;
|
||||
|
||||
const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 1488522236, std::log(2.0));
|
||||
const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
|
||||
const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 1518500250, std::sqrt(0.5));
|
||||
const FixedPoint0 one_quarter =
|
||||
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
|
||||
|
||||
const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
|
||||
const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
|
||||
const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 1057819769,
|
||||
2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
|
||||
const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
|
||||
|
||||
const FixedPointAccum shifted_quarter =
|
||||
gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
|
||||
|
||||
// Reinterpret the input value as Q0.31, because we will figure out the
|
||||
// required shift "ourselves" instead of using, say, Rescale.
|
||||
FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
|
||||
// z_a_pow_2 = input_integer_bits - z_a_headroom;
|
||||
int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32>(z_a.raw()));
|
||||
FixedPoint0 r_a_tmp =
|
||||
SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
|
||||
const int32 r_a_raw =
|
||||
SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
|
||||
// z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
|
||||
// z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
|
||||
// InputIntegerBits - z_b_headroom - 0.25);
|
||||
const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
|
||||
FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
|
||||
InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
|
||||
shifted_quarter);
|
||||
|
||||
// z_b is treated like z_a, but premultiplying by sqrt(0.5).
|
||||
FixedPoint0 z_b = z_a * sqrt_half;
|
||||
int z_b_headroom = CountLeadingZeros(static_cast<uint32>(z_b.raw())) - 1;
|
||||
const int32 r_b_raw =
|
||||
SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
|
||||
const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
|
||||
FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
|
||||
InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
|
||||
shifted_quarter);
|
||||
|
||||
const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
|
||||
const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
|
||||
std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
|
||||
|
||||
const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
|
||||
FixedPoint0 q = r - sqrt_sqrt_half;
|
||||
q = q + q;
|
||||
|
||||
const FixedPoint0 common_sq = q * q;
|
||||
const FixedPoint0 num = q * r + q * common_sq * alpha_n;
|
||||
const FixedPoint0 denom_minus_one_0 =
|
||||
p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
|
||||
const FixedPoint0 recip_denom =
|
||||
one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
|
||||
|
||||
const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
|
||||
return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
|
||||
num_scaled * recip_denom);
|
||||
}
|
||||
|
||||
// Minimum output bits to accommodate log of maximum input range. It actually
|
||||
// does not matter if one considers, say, [-64,64] or [-64,64).
|
||||
//
|
||||
// For example, run this through Octave:
|
||||
// [0:127; ...
|
||||
// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
|
||||
// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
|
||||
constexpr int min_log_x_output_bits(int input_bits) {
|
||||
return input_bits > 90
|
||||
? 7
|
||||
: input_bits > 44
|
||||
? 6
|
||||
: input_bits > 21
|
||||
? 5
|
||||
: input_bits > 10
|
||||
? 4
|
||||
: input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1;
|
||||
}
|
||||
|
||||
template <int OutputIntegerBits, int InputIntegerBits>
|
||||
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
|
||||
log_x_for_x_greater_than_or_equal_to_1(
|
||||
gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
|
||||
static_assert(
|
||||
OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
|
||||
"Output integer bits must be sufficent to accommodate logs of inputs.");
|
||||
return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
|
||||
InputIntegerBits>(
|
||||
input_val);
|
||||
}
|
||||
|
||||
// Currently just a copy of the reference code.
|
||||
inline void LogSoftmax(const SoftmaxParams& params,
|
||||
const RuntimeShape& input_shape, const uint8* input_data,
|
||||
|
@ -0,0 +1,111 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace reference_integer_ops {
|
||||
|
||||
inline void LogSoftmax(int32_t input_multiplier, int32_t input_shift,
|
||||
int32_t reverse_multiplier, int32_t reverse_shift,
|
||||
int32_t diff_min, int32_t outer_size, int32_t depth,
|
||||
const int8* input_data, int8* output_data) {
|
||||
static constexpr int8_t kMinInt8 = std::numeric_limits<int8_t>::min();
|
||||
static constexpr int8_t kMaxInt8 = std::numeric_limits<int8_t>::max();
|
||||
static constexpr int32_t kMinInt32 = std::numeric_limits<int32_t>::min();
|
||||
|
||||
// [-16, 0] is mapped to [-128, 127] with 1/16 as scale and 127 as zero
|
||||
// point. This nudges the output to [-255/16, 0].
|
||||
static constexpr int32_t kOutputZeroPoint = 127;
|
||||
|
||||
// All IntegerBits must agree with Prepare function.
|
||||
// Input is chosen as Q5.26 so exp(-1 * 2^5 * 2^-1) = exp(-16) is negligible.
|
||||
static constexpr int kInputIntegerBits = 5;
|
||||
static constexpr int kAccumulationIntegerBits = 12;
|
||||
static constexpr int kOutputIntegerBits = 4;
|
||||
using F5 = gemmlowp::FixedPoint<int32, kInputIntegerBits>;
|
||||
using F12 = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
|
||||
|
||||
for (int outer_index = 0; outer_index < outer_size; ++outer_index) {
|
||||
int8 max_in_row = kMinInt8;
|
||||
for (int inner_index = 0; inner_index < depth; ++inner_index) {
|
||||
max_in_row =
|
||||
std::max(max_in_row, input_data[outer_index * depth + inner_index]);
|
||||
}
|
||||
|
||||
// Accumulator "sum_of_exps_in_q12" is safe from overflowing in 2^12 steps.
|
||||
F12 sum_of_exps_in_q12 = F12::FromRaw(0);
|
||||
for (int inner_index = 0; inner_index < depth; ++inner_index) {
|
||||
int32_t input_diff =
|
||||
static_cast<int32_t>(input_data[outer_index * depth + inner_index]) -
|
||||
max_in_row;
|
||||
if (input_diff >= diff_min) {
|
||||
const int32_t input_diff_in_q5 = MultiplyByQuantizedMultiplier(
|
||||
input_diff, input_multiplier, input_shift);
|
||||
sum_of_exps_in_q12 =
|
||||
sum_of_exps_in_q12 +
|
||||
gemmlowp::Rescale<kAccumulationIntegerBits>(
|
||||
exp_on_negative_values(F5::FromRaw(input_diff_in_q5)));
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t log_sum_of_exps_in_q5 =
|
||||
log_x_for_x_greater_than_or_equal_to_1<kInputIntegerBits>(
|
||||
sum_of_exps_in_q12)
|
||||
.raw();
|
||||
|
||||
// Potentially reduced the valid range. shifted_log_sum_of_exps_in_q5 is
|
||||
// smallest representable in Q5.26 plus the log_sum_of_exps.
|
||||
const int32_t shifted_log_sum_of_exps_in_q5 =
|
||||
log_sum_of_exps_in_q5 + kMinInt32;
|
||||
const int32_t adjusted_diff_min = std::max(
|
||||
diff_min - 1,
|
||||
MultiplyByQuantizedMultiplier(shifted_log_sum_of_exps_in_q5,
|
||||
reverse_multiplier, -reverse_shift));
|
||||
|
||||
for (int inner_index = 0; inner_index < depth; ++inner_index) {
|
||||
int32_t input_diff =
|
||||
static_cast<int32_t>(input_data[outer_index * depth + inner_index]) -
|
||||
max_in_row;
|
||||
// Note use of > below instead of >= above.
|
||||
if (input_diff > adjusted_diff_min) {
|
||||
const int32_t input_diff_in_q5 = MultiplyByQuantizedMultiplier(
|
||||
input_diff, input_multiplier, input_shift);
|
||||
|
||||
// Rescale and downcast.
|
||||
int32_t output_in_q27 =
|
||||
gemmlowp::RoundingDivideByPOT(
|
||||
(input_diff_in_q5 - log_sum_of_exps_in_q5),
|
||||
31 - kInputIntegerBits - kOutputIntegerBits) +
|
||||
kOutputZeroPoint;
|
||||
|
||||
output_in_q27 =
|
||||
std::max(std::min(output_in_q27, static_cast<int32_t>(kMaxInt8)),
|
||||
static_cast<int32_t>(kMinInt8));
|
||||
output_data[outer_index * depth + inner_index] =
|
||||
static_cast<int8_t>(output_in_q27);
|
||||
} else {
|
||||
output_data[outer_index * depth + inner_index] = kMinInt8;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_integer_ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_
|
@ -36,68 +36,6 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// TODO(b/77858996): Add these to gemmlowp.
|
||||
template <typename IntegerType>
|
||||
IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
|
||||
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
|
||||
return a;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
|
||||
std::int64_t a64 = a;
|
||||
std::int64_t b64 = b;
|
||||
std::int64_t sum = a64 + b64;
|
||||
return static_cast<std::int32_t>(std::min(
|
||||
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
|
||||
std::max(
|
||||
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
|
||||
sum)));
|
||||
}
|
||||
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
|
||||
return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
|
||||
SaturatingAddNonGemmlowp(a.raw(), b.raw()));
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
IntegerType SaturatingSub(IntegerType a, IntegerType b) {
|
||||
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
|
||||
return a;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
|
||||
std::int32_t a32 = a;
|
||||
std::int32_t b32 = b;
|
||||
std::int32_t diff = a32 - b32;
|
||||
return static_cast<std::int16_t>(std::min(32767, std::max(-32768, diff)));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
|
||||
std::int64_t a64 = a;
|
||||
std::int64_t b64 = b;
|
||||
std::int64_t diff = a64 - b64;
|
||||
return static_cast<std::int32_t>(std::min(
|
||||
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
|
||||
std::max(
|
||||
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
|
||||
diff)));
|
||||
}
|
||||
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
|
||||
return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
|
||||
SaturatingSub(a.raw(), b.raw()));
|
||||
}
|
||||
// End section to be moved to gemmlowp.
|
||||
|
||||
namespace reference_ops {
|
||||
|
||||
// Return true for broadcast case, false otherwise.
|
||||
@ -192,59 +130,6 @@ inline bool ProcessBroadcastShapes(const RuntimeShape& shape0,
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int CountLeadingZeros(T integer_input) {
|
||||
static_assert(std::is_unsigned<T>::value,
|
||||
"Only unsigned integer types handled.");
|
||||
if (integer_input == 0) {
|
||||
return std::numeric_limits<T>::digits;
|
||||
}
|
||||
const T one_in_leading_positive = static_cast<T>(1)
|
||||
<< (std::numeric_limits<T>::digits - 1);
|
||||
int leading_zeros = 0;
|
||||
while (integer_input < one_in_leading_positive) {
|
||||
integer_input <<= 1;
|
||||
++leading_zeros;
|
||||
}
|
||||
return leading_zeros;
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
|
||||
if (exponent == 0) {
|
||||
return x;
|
||||
}
|
||||
using ScalarIntegerType =
|
||||
typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
|
||||
const IntegerType min =
|
||||
gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
|
||||
const IntegerType max =
|
||||
gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
|
||||
const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
|
||||
|
||||
const std::int32_t threshold =
|
||||
((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
|
||||
const IntegerType positive_mask =
|
||||
gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
|
||||
const IntegerType negative_mask =
|
||||
gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
|
||||
|
||||
IntegerType result = gemmlowp::ShiftLeft(x, exponent);
|
||||
result = gemmlowp::SelectUsingMask(positive_mask, max, result);
|
||||
result = gemmlowp::SelectUsingMask(negative_mask, min, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
// If we want to leave IntegerBits fixed, then multiplication
|
||||
// by a power of two has to be saturating/rounding, not exact anymore.
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits>
|
||||
SaturatingRoundingMultiplyByPOTParam(
|
||||
gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
|
||||
return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
|
||||
SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
|
||||
}
|
||||
|
||||
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
const float* input_data, const RuntimeShape& filter_shape,
|
||||
const float* filter_data, const RuntimeShape& bias_shape,
|
||||
@ -2752,121 +2637,6 @@ inline void LogSoftmax(const SoftmaxParams& params,
|
||||
}
|
||||
}
|
||||
|
||||
// Although currently the name of this function says that it cannot handle
|
||||
// values less than 1, in practice it can handle as low as 1/x_max, where
|
||||
// x_max is the largest representable input. In other words, the output range
|
||||
// is symmetric.
|
||||
template <int OutputIntegerBits, int InputIntegerBits>
|
||||
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
|
||||
log_x_for_x_greater_than_or_equal_to_1_impl(
|
||||
gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
|
||||
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
|
||||
// The reason for accumulating the result with an extra bit of headroom is
|
||||
// that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
|
||||
// recip_denom will otherwise introduce an error.
|
||||
static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
|
||||
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumIntegerBits>;
|
||||
|
||||
const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 1488522236, std::log(2.0));
|
||||
const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
|
||||
const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 1518500250, std::sqrt(0.5));
|
||||
const FixedPoint0 one_quarter =
|
||||
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
|
||||
|
||||
const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
|
||||
const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
|
||||
const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 1057819769,
|
||||
2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
|
||||
const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
|
||||
FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
|
||||
|
||||
const FixedPointAccum shifted_quarter =
|
||||
gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
|
||||
|
||||
// Reinterpret the input value as Q0.31, because we will figure out the
|
||||
// required shift "ourselves" instead of using, say, Rescale.
|
||||
FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
|
||||
// z_a_pow_2 = input_integer_bits - z_a_headroom;
|
||||
int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32>(z_a.raw()));
|
||||
FixedPoint0 r_a_tmp =
|
||||
SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
|
||||
const int32 r_a_raw =
|
||||
SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
|
||||
// z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
|
||||
// z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
|
||||
// InputIntegerBits - z_b_headroom - 0.25);
|
||||
const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
|
||||
FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
|
||||
InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
|
||||
shifted_quarter);
|
||||
|
||||
// z_b is treated like z_a, but premultiplying by sqrt(0.5).
|
||||
FixedPoint0 z_b = z_a * sqrt_half;
|
||||
int z_b_headroom = CountLeadingZeros(static_cast<uint32>(z_b.raw())) - 1;
|
||||
const int32 r_b_raw =
|
||||
SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
|
||||
const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
|
||||
FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
|
||||
InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
|
||||
shifted_quarter);
|
||||
|
||||
const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
|
||||
const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
|
||||
std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
|
||||
|
||||
const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
|
||||
FixedPoint0 q = r - sqrt_sqrt_half;
|
||||
q = q + q;
|
||||
|
||||
const FixedPoint0 common_sq = q * q;
|
||||
const FixedPoint0 num = q * r + q * common_sq * alpha_n;
|
||||
const FixedPoint0 denom_minus_one_0 =
|
||||
p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
|
||||
const FixedPoint0 recip_denom =
|
||||
one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
|
||||
|
||||
const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
|
||||
return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
|
||||
num_scaled * recip_denom);
|
||||
}
|
||||
|
||||
// Minimum output bits to accommodate log of maximum input range. It actually
|
||||
// does not matter if one considers, say, [-64,64] or [-64,64).
|
||||
//
|
||||
// For example, run this through Octave:
|
||||
// [0:127; ...
|
||||
// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
|
||||
// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
|
||||
constexpr int min_log_x_output_bits(int input_bits) {
|
||||
return input_bits > 90
|
||||
? 7
|
||||
: input_bits > 44
|
||||
? 6
|
||||
: input_bits > 21
|
||||
? 5
|
||||
: input_bits > 10
|
||||
? 4
|
||||
: input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1;
|
||||
}
|
||||
|
||||
template <int OutputIntegerBits, int InputIntegerBits>
|
||||
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
|
||||
log_x_for_x_greater_than_or_equal_to_1(
|
||||
gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
|
||||
static_assert(
|
||||
OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
|
||||
"Output integer bits must be sufficent to accommodate logs of inputs.");
|
||||
return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
|
||||
InputIntegerBits>(
|
||||
input_val);
|
||||
}
|
||||
|
||||
inline void LogSoftmax(const SoftmaxParams& params,
|
||||
const RuntimeShape& input_shape, const uint8* input_data,
|
||||
const RuntimeShape& output_shape, uint8* output_data) {
|
||||
|
@ -277,7 +277,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_LOG, Register_LOG());
|
||||
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX());
|
||||
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
|
||||
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
|
||||
/* min_version */ 1,
|
||||
|
@ -1813,6 +1813,21 @@ class Logistic : public SimpleOperator<LogisticOperator> {
|
||||
}
|
||||
};
|
||||
|
||||
class LogSoftmax : public SimpleOperator<LogSoftmaxOperator> {
|
||||
public:
|
||||
explicit LogSoftmax()
|
||||
: SimpleOperator("LOG_SOFTMAX", OperatorType::kLogSoftmax) {}
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const string& input_name = op_signature.op->inputs[0];
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
// Version 2 supports signed int8 input types.
|
||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class SquaredDifference
|
||||
: public BuiltinOperator<
|
||||
SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions,
|
||||
@ -2451,8 +2466,7 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
MakeUnique<SimpleOperator<ExpOperator>>("EXP", OperatorType::kExp));
|
||||
ops.push_back(
|
||||
MakeUnique<SimpleOperator<CosOperator>>("COS", OperatorType::kCos));
|
||||
ops.push_back(MakeUnique<SimpleOperator<LogSoftmaxOperator>>(
|
||||
"LOG_SOFTMAX", OperatorType::kLogSoftmax));
|
||||
ops.push_back(MakeUnique<LogSoftmax>());
|
||||
ops.push_back(MakeUnique<Maximum>()); // Element-wise Maximum
|
||||
ops.push_back(MakeUnique<Minimum>()); // Element-wise Minimum
|
||||
ops.push_back(MakeUnique<Greater>());
|
||||
|
@ -820,6 +820,10 @@ TEST_F(OperatorTest, VersioningSpaceToBatchNDTest) {
|
||||
SimpleVersioningTest<SpaceToBatchNDOperator>();
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningLogSoftmaxTest) {
|
||||
SimpleVersioningTest<LogSoftmaxOperator>();
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningPackTest) {
|
||||
SimpleVersioningTest<PackOperator>();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user