Create int8 log softmax.

PiperOrigin-RevId: 234656496
This commit is contained in:
Jian Li 2019-02-19 12:54:51 -08:00 committed by TensorFlower Gardener
parent 0b4cfb42a6
commit 8e0b3713a4
12 changed files with 510 additions and 436 deletions

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
@ -270,8 +271,13 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255);
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255);
}
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 127);
}
TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256);
static const double kBeta = 1.0;
@ -854,6 +860,21 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
}
return kTfLiteOk;
}
case kTfLiteInt8: {
const auto input_shape = GetTensorShape(input);
const auto output_shape = GetTensorShape(output);
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
const int depth =
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
reference_integer_ops::LogSoftmax(
data->input_multiplier, data->input_left_shift,
data->reverse_scaling_divisor, data->reverse_scaling_right_shift,
data->diff_min, outer_size, depth, GetTensorData<int8_t>(input),
GetTensorData<int8_t>(output));
return kTfLiteOk;
}
default:
context->ReportError(context, "Only float32 supported currently., got %s",
TfLiteTypeGetName(input->type));

View File

@ -772,7 +772,7 @@ TEST(FloatActivationsOpTest, LogSoftmax) {
})));
}
TEST(QuantizedActivationsOpTest, LogSoftmax) {
TEST(QuantizedActivationsOpTest, LogSoftmaxUint8) {
const float kLogSoftmaxQuantizedTolerance = 16 / 256.0;
QuantizedActivationsOpModel m(
BuiltinOperator_LOG_SOFTMAX,
@ -794,6 +794,30 @@ TEST(QuantizedActivationsOpTest, LogSoftmax) {
ElementsAreArray({189, 93, 221, 253, 142, 63, 255, 111}));
}
TEST(QuantizedActivationsOpTest, LogSoftmaxInt8) {
const float kLogSoftmaxQuantizedTolerance = 0.06355;
QuantizedActivationsOpModel m(
BuiltinOperator_LOG_SOFTMAX,
/*input=*/{TensorType_INT8, {2, 4}, -10, 10},
/*output=*/{TensorType_INT8, {}, 0, 0, 16. / 256, 127});
m.SetInput<int8_t>({
0, -6, 2, 4, //
3, -2, 10, 1, //
});
m.Invoke();
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear(
{
-4.14297, -10.14297, -2.14297, -.142971, //
-7.00104, -12.00104, -.00104087, -9.00104, //
},
kLogSoftmaxQuantizedTolerance)));
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({
61, -36, 93, 125, //
15, -65, 127, -16, //
}));
}
// A base class of PRelu op model. It provides the constructor for
// FloatPReluOpModel and QuantizedPReluOpModel.
class BasePReluOpModel : public SingleOpModel {

View File

@ -311,6 +311,7 @@ cc_library(
"reference/integer_ops/depthwise_conv.h",
"reference/integer_ops/dequantize.h",
"reference/integer_ops/fully_connected.h",
"reference/integer_ops/log_softmax.h",
"reference/integer_ops/logistic.h",
"reference/integer_ops/mul.h",
"reference/integer_ops/pooling.h",
@ -652,7 +653,7 @@ cc_test(
srcs = [
"logsoftmax_quantized_test.cc",
],
shard_count = 3,
shard_count = 4,
tags = [
# TODO(b/122242739): Reenable after fixing the flakiness?
"nomac",

View File

@ -131,6 +131,221 @@ int CountLeadingZeros(T integer_input) {
#endif
}
// TODO(b/77858996): Add these to gemmlowp.
template <typename IntegerType>
IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
return a;
}
template <>
inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
std::int64_t a64 = a;
std::int64_t b64 = b;
std::int64_t sum = a64 + b64;
return static_cast<std::int32_t>(std::min(
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
std::max(
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
sum)));
}
template <typename tRawType, int tIntegerBits>
gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
SaturatingAddNonGemmlowp(a.raw(), b.raw()));
}
template <typename IntegerType>
IntegerType SaturatingSub(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
return a;
}
template <>
inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
std::int32_t a32 = a;
std::int32_t b32 = b;
std::int32_t diff = a32 - b32;
return static_cast<std::int16_t>(std::min(32767, std::max(-32768, diff)));
}
template <>
inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
std::int64_t a64 = a;
std::int64_t b64 = b;
std::int64_t diff = a64 - b64;
return static_cast<std::int32_t>(std::min(
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
std::max(
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
diff)));
}
template <typename tRawType, int tIntegerBits>
gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
SaturatingSub(a.raw(), b.raw()));
}
// End section to be moved to gemmlowp.
template <typename IntegerType>
IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
if (exponent == 0) {
return x;
}
using ScalarIntegerType =
typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
const IntegerType min =
gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
const IntegerType max =
gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
const std::int32_t threshold =
((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
const IntegerType positive_mask =
gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
const IntegerType negative_mask =
gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
IntegerType result = gemmlowp::ShiftLeft(x, exponent);
result = gemmlowp::SelectUsingMask(positive_mask, max, result);
result = gemmlowp::SelectUsingMask(negative_mask, min, result);
return result;
}
// If we want to leave IntegerBits fixed, then multiplication
// by a power of two has to be saturating/rounding, not exact anymore.
template <typename tRawType, int tIntegerBits>
gemmlowp::FixedPoint<tRawType, tIntegerBits>
SaturatingRoundingMultiplyByPOTParam(
gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
}
// Minimum output bits to accommodate log of maximum input range. It actually
// does not matter if one considers, say, [-64,64] or [-64,64).
//
// For example, run this through Octave:
// [0:127; ...
// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
constexpr int min_log_x_output_bits(int input_bits) {
return input_bits > 90
? 7
: input_bits > 44
? 6
: input_bits > 21
? 5
: input_bits > 10
? 4
: input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1;
}
// Although currently the name of this function says that it cannot handle
// values less than 1, in practice it can handle as low as 1/x_max, where
// x_max is the largest representable input. In other words, the output range
// is symmetric.
template <int OutputIntegerBits, int InputIntegerBits>
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(
gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
// assert(__builtin_clz(0u) >= std::numeric_limits<uint32>::digits - 1);
// assert(__builtin_clz(0u) <= std::numeric_limits<uint32>::digits);
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
// The reason for accumulating the result with an extra bit of headroom is
// that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
// recip_denom will otherwise introduce an error.
static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumIntegerBits>;
const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 1488522236, std::log(2.0));
const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 1518500250, std::sqrt(0.5));
const FixedPoint0 one_quarter =
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 1057819769,
2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
const FixedPointAccum shifted_quarter =
gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
// Reinterpret the input value as Q0.31, because we will figure out the
// required shift "ourselves" instead of using, say, Rescale.
FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
// z_a_pow_2 = input_integer_bits - z_a_headroom;
int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32>(z_a.raw()));
FixedPoint0 r_a_tmp =
SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
const int32 r_a_raw =
SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
// z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
// z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
// InputIntegerBits - z_b_headroom - 0.25);
const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
shifted_quarter);
// z_b is treated like z_a, but premultiplying by sqrt(0.5).
FixedPoint0 z_b = z_a * sqrt_half;
int z_b_headroom = CountLeadingZeros(static_cast<uint32>(z_b.raw())) - 1;
const int32 r_b_raw =
SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
shifted_quarter);
const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
FixedPoint0 q = r - sqrt_sqrt_half;
q = q + q;
const FixedPoint0 common_sq = q * q;
const FixedPoint0 num = q * r + q * common_sq * alpha_n;
const FixedPoint0 denom_minus_one_0 =
p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
const FixedPoint0 recip_denom =
one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
num_scaled * recip_denom);
}
template <int OutputIntegerBits, int InputIntegerBits>
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1(
gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
static_assert(
OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
"Output integer bits must be sufficent to accommodate logs of inputs.");
return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
InputIntegerBits>(
input_val);
}
inline int32 GetReciprocal(int32 x, int x_integer_digits,
int* num_bits_over_unit) {
int headroom_plus_one = CountLeadingZeros(static_cast<uint32>(x));

View File

@ -121,8 +121,7 @@ void RunSingleTest(const std::vector<int32>& test_input,
const string& check_label, int tolerance) {
const int n = test_input.size();
std::vector<int32> float_gen_output(n, 0);
std::vector<int32> reference_output(n, 0);
std::vector<int32> optimized_output(n, 0);
std::vector<int32> quantized_output(n, 0);
// Workaround the stupid things that intelligent humans do.
// Consequence of __builtin_clz(0u) may equal 31 instead of 32.
@ -132,45 +131,21 @@ void RunSingleTest(const std::vector<int32>& test_input,
}
for (int i = 0; i < n; ++i) {
reference_output[i] =
tflite::reference_ops::log_x_for_x_greater_than_or_equal_to_1_impl<
OutputIntegerBits, InputIntegerBits>(
gemmlowp::FixedPoint<int32, InputIntegerBits>::FromRaw(
fudged_input[i]))
.raw();
optimized_output[i] =
tflite::optimized_ops::log_x_for_x_greater_than_or_equal_to_1_impl<
OutputIntegerBits, InputIntegerBits>(
quantized_output[i] =
tflite::log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
InputIntegerBits>(
gemmlowp::FixedPoint<int32, InputIntegerBits>::FromRaw(
fudged_input[i]))
.raw();
float_gen_output[i] = LogPositiveValuesViaFloat(
fudged_input[i], InputIntegerBits, OutputIntegerBits);
}
// Note that first check is intolerant.
{
std::ostringstream label;
label << check_label << " / optimized vs reference / InputIntegerBits="
<< InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits;
CheckOutputData(
optimized_output, reference_output, test_input, label.str(),
InputIntegerBits, OutputIntegerBits, 0);
}
{
std::ostringstream label;
label << check_label << " / reference vs float-gen / InputIntegerBits="
<< InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits;
CheckOutputData(
reference_output, float_gen_output, test_input, label.str(),
InputIntegerBits, OutputIntegerBits, tolerance);
}
{
std::ostringstream label;
label << check_label << " optimized vs float-gen / InputIntegerBits="
<< InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits;
CheckOutputData(
optimized_output, float_gen_output, test_input, label.str(),
InputIntegerBits, OutputIntegerBits, tolerance);
CheckOutputData(quantized_output, float_gen_output, test_input, label.str(),
InputIntegerBits, OutputIntegerBits, tolerance);
}
}

View File

@ -1,4 +1,4 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -25,6 +25,8 @@ limitations under the License.
#include <gtest/gtest.h>
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/test_util.h"
#include "tensorflow/lite/string.h"
@ -61,7 +63,42 @@ void RunLogSoftmaxFloatReference(const uint8* input_data,
}
}
void CheckOutputData(const uint8* test_output, const uint8* reference_output,
// Same as above except for the following change:
// - input and output data type
// - Dequnatize function
// - clamping values
void RunLogSoftmaxFloatReference(const int8* input_data,
const RuntimeShape& shape_common,
int32 input_offset, const double input_scale,
int stride, float beta,
int8* reference_output_data) {
const int ref_buffer_size = shape_common.FlatSize();
std::vector<float> reference_dequant_data(ref_buffer_size);
std::vector<float> reference_output_float_data(ref_buffer_size);
// Reference data generated via Dequant of input into float, and then applying
// float LogSoftmax.
DequantizationParams dq_params;
dq_params.zero_point = input_offset;
dq_params.scale = input_scale;
reference_integer_ops::Dequantize(dq_params, shape_common, input_data,
shape_common,
reference_dequant_data.data());
SoftmaxParams sm_params;
optimized_ops::LogSoftmax(sm_params, shape_common,
reference_dequant_data.data(), shape_common,
reference_output_float_data.data());
// Work with quantized scaling for LogSoftmax, under which 255 represents 0,
// and -16 gets nudged up to 0.
for (int i = 0; i < ref_buffer_size; i++) {
reference_output_data[i] = std::max(
-128, static_cast<int>(
127 + std::round(16.0f * reference_output_float_data[i])));
}
}
template <typename T>
void CheckOutputData(const T* test_output, const T* reference_output,
const RuntimeShape& shape_common,
const string& check_label, bool be_exacting) {
const int buffer_size = shape_common.FlatSize();
@ -144,15 +181,58 @@ void RunOneLogSoftmaxTest(const uint8* input_data,
reference_ops::LogSoftmax(params, shape_common, input_data, shape_common,
reference_quant_logsoftmax_output.data());
CheckOutputData(optimized_logsoftmax_output.data(),
reference_float_logsoftmax_output.data(), shape_common,
"Optimized vs float reference", false);
CheckOutputData(optimized_logsoftmax_output.data(),
reference_quant_logsoftmax_output.data(), shape_common,
"Optimized vs quant reference", true);
CheckOutputData(reference_quant_logsoftmax_output.data(),
reference_float_logsoftmax_output.data(), shape_common,
"Quant reference vs float reference", false);
CheckOutputData<uint8_t>(optimized_logsoftmax_output.data(),
reference_float_logsoftmax_output.data(),
shape_common, "Optimized vs float reference", false);
CheckOutputData<uint8_t>(optimized_logsoftmax_output.data(),
reference_quant_logsoftmax_output.data(),
shape_common, "Optimized vs quant reference", true);
CheckOutputData<uint8_t>(reference_quant_logsoftmax_output.data(),
reference_float_logsoftmax_output.data(),
shape_common, "Quant reference vs float reference",
false);
}
// Runs the LogSoftmax and compares against the float reference implementation
// and the int8 quantized reference implementation.
void RunOneLogSoftmaxTest(const int8* input_data,
const RuntimeShape& shape_common, int32 input_offset,
const double input_scale, int stride, float beta) {
const int buffer_size = shape_common.FlatSize();
std::vector<int8> quantized_logsoftmax_reference_implementation(buffer_size);
std::vector<int8> float_logsoftmax_optimized_implementation(buffer_size);
RunLogSoftmaxFloatReference(input_data, shape_common, input_offset,
input_scale, stride, beta,
float_logsoftmax_optimized_implementation.data());
int32 input_beta_multiplier;
int input_beta_left_shift;
int32 reverse_scaling_divisor;
int reverse_scaling_right_shift;
static const int kScaledDiffIntegerBits = 5;
tflite::PreprocessLogSoftmaxScalingExp(
beta, input_scale, kScaledDiffIntegerBits, &input_beta_multiplier,
&input_beta_left_shift, &reverse_scaling_divisor,
&reverse_scaling_right_shift);
reverse_scaling_right_shift *= -1;
// diff_min has a negative value, and is used to limit the maximum magnitude
// of the diffs, which are <= 0.
const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
input_beta_left_shift);
const int outer_size =
shape_common.Dims(0) * shape_common.Dims(1) * shape_common.Dims(2);
const int inner_size = shape_common.Dims(3);
reference_integer_ops::LogSoftmax(
input_beta_multiplier, input_beta_left_shift, reverse_scaling_divisor,
reverse_scaling_right_shift, diff_min, outer_size, inner_size, input_data,
quantized_logsoftmax_reference_implementation.data());
CheckOutputData<int8_t>(quantized_logsoftmax_reference_implementation.data(),
float_logsoftmax_optimized_implementation.data(),
shape_common, "Quant reference vs float reference",
false);
}
// This function picks some random LogSoftmax params, which are checked for
@ -161,6 +241,7 @@ void RunOneLogSoftmaxTest(const uint8* input_data,
// to loop until a test has been run.
//
// Currently we do not reject for any reason.
template <typename T>
bool TryOneUniformLogSoftmax() {
// We pick mostly positive values, on the whole emphasizing smaller values and
// therefore faster tests. We test a wider range of depths. In the case of
@ -178,7 +259,7 @@ bool TryOneUniformLogSoftmax() {
RuntimeShape({batch, input_height, input_width, input_depth});
const int buffer_size = shape_common.FlatSize();
std::vector<uint8> input_data(buffer_size);
std::vector<T> input_data(buffer_size);
FillRandom(&input_data);
RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset,
input_scale, stride, beta);
@ -224,15 +305,23 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) {
return true;
}
TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxTests) {
TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxUint8Tests) {
const int kTestsToRun = 100;
for (int i = 0; i < kTestsToRun; i++) {
while (!TryOneUniformLogSoftmax()) {
while (!TryOneUniformLogSoftmax<uint8_t>()) {
}
}
}
TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxTests) {
TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxUint8Int8Tests) {
const int kTestsToRun = 100;
for (int i = 0; i < kTestsToRun; i++) {
while (!TryOneUniformLogSoftmax<int8_t>()) {
}
}
}
TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxUint8Tests) {
const int kTestsToRun = 100;
for (int i = 0; i < kTestsToRun; i++) {
while (!TryOneSkyscraperLogSoftmax(false)) {
@ -240,7 +329,7 @@ TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxTests) {
}
}
TEST(TestQuantizedLogSoftmax, SmallSkyscraperLogSoftmaxTests) {
TEST(TestQuantizedLogSoftmax, SmallSkyscraperLogSoftmaxUint8Tests) {
const int kTestsToRun = 100;
for (int i = 0; i < kTestsToRun; i++) {
while (!TryOneSkyscraperLogSoftmax(true)) {

View File

@ -182,45 +182,6 @@ MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
return MatrixMap<Scalar>(data, rows, cols);
}
// This is like the template-parameter version, except that the power-of-two is
// passed as a function parameter. The template version is to be preferred,
// since some target hardware optimizations depend on the range of the exponent.
template <typename IntegerType>
IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
if (exponent == 0) {
return x;
}
using ScalarIntegerType =
typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
const IntegerType min =
gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
const IntegerType max =
gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
const std::int32_t threshold =
((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
const IntegerType positive_mask =
gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
const IntegerType negative_mask =
gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
IntegerType result = gemmlowp::ShiftLeft(x, exponent);
result = gemmlowp::SelectUsingMask(positive_mask, max, result);
result = gemmlowp::SelectUsingMask(negative_mask, min, result);
return result;
}
// This is like the template-parameter version, except that the power-of-two is
// passed as a function parameter. See raw-integer version for further comments.
template <typename tRawType, int tIntegerBits>
gemmlowp::FixedPoint<tRawType, tIntegerBits>
SaturatingRoundingMultiplyByPOTParam(
gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
}
inline void AddBiasAndEvalActivationFunction(float output_activation_min,
float output_activation_max,
const RuntimeShape& bias_shape,
@ -4557,119 +4518,6 @@ inline void LogSoftmax(const SoftmaxParams& params,
}
}
template <int OutputIntegerBits, int InputIntegerBits>
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(
gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
// assert(__builtin_clz(0u) >= std::numeric_limits<uint32>::digits - 1);
// assert(__builtin_clz(0u) <= std::numeric_limits<uint32>::digits);
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
// The reason for accumulating the result with an extra bit of headroom is
// that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
// recip_denom will otherwise introduce an error.
static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumIntegerBits>;
const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 1488522236, std::log(2.0));
const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 1518500250, std::sqrt(0.5));
const FixedPoint0 one_quarter =
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 1057819769,
2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
const FixedPointAccum shifted_quarter =
gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
// Reinterpret the input value as Q0.31, because we will figure out the
// required shift "ourselves" instead of using, say, Rescale.
FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
// z_a_pow_2 = input_integer_bits - z_a_headroom;
int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32>(z_a.raw()));
FixedPoint0 r_a_tmp =
SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
const int32 r_a_raw =
SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
// z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
// z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
// InputIntegerBits - z_b_headroom - 0.25);
const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
shifted_quarter);
// z_b is treated like z_a, but premultiplying by sqrt(0.5).
FixedPoint0 z_b = z_a * sqrt_half;
int z_b_headroom = CountLeadingZeros(static_cast<uint32>(z_b.raw())) - 1;
const int32 r_b_raw =
SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
shifted_quarter);
const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
FixedPoint0 q = r - sqrt_sqrt_half;
q = q + q;
const FixedPoint0 common_sq = q * q;
const FixedPoint0 num = q * r + q * common_sq * alpha_n;
const FixedPoint0 denom_minus_one_0 =
p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
const FixedPoint0 recip_denom =
one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
num_scaled * recip_denom);
}
// Minimum output bits to accommodate log of maximum input range. It actually
// does not matter if one considers, say, [-64,64] or [-64,64).
//
// For example, run this through Octave:
// [0:127; ...
// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
constexpr int min_log_x_output_bits(int input_bits) {
return input_bits > 90
? 7
: input_bits > 44
? 6
: input_bits > 21
? 5
: input_bits > 10
? 4
: input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1;
}
template <int OutputIntegerBits, int InputIntegerBits>
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1(
gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
static_assert(
OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
"Output integer bits must be sufficent to accommodate logs of inputs.");
return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
InputIntegerBits>(
input_val);
}
// Currently just a copy of the reference code.
inline void LogSoftmax(const SoftmaxParams& params,
const RuntimeShape& input_shape, const uint8* input_data,

View File

@ -0,0 +1,111 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
namespace reference_integer_ops {
inline void LogSoftmax(int32_t input_multiplier, int32_t input_shift,
int32_t reverse_multiplier, int32_t reverse_shift,
int32_t diff_min, int32_t outer_size, int32_t depth,
const int8* input_data, int8* output_data) {
static constexpr int8_t kMinInt8 = std::numeric_limits<int8_t>::min();
static constexpr int8_t kMaxInt8 = std::numeric_limits<int8_t>::max();
static constexpr int32_t kMinInt32 = std::numeric_limits<int32_t>::min();
// [-16, 0] is mapped to [-128, 127] with 1/16 as scale and 127 as zero
// point. This nudges the output to [-255/16, 0].
static constexpr int32_t kOutputZeroPoint = 127;
// All IntegerBits must agree with Prepare function.
// Input is chosen as Q5.26 so exp(-1 * 2^5 * 2^-1) = exp(-16) is negligible.
static constexpr int kInputIntegerBits = 5;
static constexpr int kAccumulationIntegerBits = 12;
static constexpr int kOutputIntegerBits = 4;
using F5 = gemmlowp::FixedPoint<int32, kInputIntegerBits>;
using F12 = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
for (int outer_index = 0; outer_index < outer_size; ++outer_index) {
int8 max_in_row = kMinInt8;
for (int inner_index = 0; inner_index < depth; ++inner_index) {
max_in_row =
std::max(max_in_row, input_data[outer_index * depth + inner_index]);
}
// Accumulator "sum_of_exps_in_q12" is safe from overflowing in 2^12 steps.
F12 sum_of_exps_in_q12 = F12::FromRaw(0);
for (int inner_index = 0; inner_index < depth; ++inner_index) {
int32_t input_diff =
static_cast<int32_t>(input_data[outer_index * depth + inner_index]) -
max_in_row;
if (input_diff >= diff_min) {
const int32_t input_diff_in_q5 = MultiplyByQuantizedMultiplier(
input_diff, input_multiplier, input_shift);
sum_of_exps_in_q12 =
sum_of_exps_in_q12 +
gemmlowp::Rescale<kAccumulationIntegerBits>(
exp_on_negative_values(F5::FromRaw(input_diff_in_q5)));
}
}
const int32_t log_sum_of_exps_in_q5 =
log_x_for_x_greater_than_or_equal_to_1<kInputIntegerBits>(
sum_of_exps_in_q12)
.raw();
// Potentially reduced the valid range. shifted_log_sum_of_exps_in_q5 is
// smallest representable in Q5.26 plus the log_sum_of_exps.
const int32_t shifted_log_sum_of_exps_in_q5 =
log_sum_of_exps_in_q5 + kMinInt32;
const int32_t adjusted_diff_min = std::max(
diff_min - 1,
MultiplyByQuantizedMultiplier(shifted_log_sum_of_exps_in_q5,
reverse_multiplier, -reverse_shift));
for (int inner_index = 0; inner_index < depth; ++inner_index) {
int32_t input_diff =
static_cast<int32_t>(input_data[outer_index * depth + inner_index]) -
max_in_row;
// Note use of > below instead of >= above.
if (input_diff > adjusted_diff_min) {
const int32_t input_diff_in_q5 = MultiplyByQuantizedMultiplier(
input_diff, input_multiplier, input_shift);
// Rescale and downcast.
int32_t output_in_q27 =
gemmlowp::RoundingDivideByPOT(
(input_diff_in_q5 - log_sum_of_exps_in_q5),
31 - kInputIntegerBits - kOutputIntegerBits) +
kOutputZeroPoint;
output_in_q27 =
std::max(std::min(output_in_q27, static_cast<int32_t>(kMaxInt8)),
static_cast<int32_t>(kMinInt8));
output_data[outer_index * depth + inner_index] =
static_cast<int8_t>(output_in_q27);
} else {
output_data[outer_index * depth + inner_index] = kMinInt8;
}
}
}
}
} // namespace reference_integer_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_

View File

@ -36,68 +36,6 @@ limitations under the License.
namespace tflite {
// TODO(b/77858996): Add these to gemmlowp.
template <typename IntegerType>
IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
return a;
}
template <>
inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
std::int64_t a64 = a;
std::int64_t b64 = b;
std::int64_t sum = a64 + b64;
return static_cast<std::int32_t>(std::min(
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
std::max(
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
sum)));
}
template <typename tRawType, int tIntegerBits>
gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
SaturatingAddNonGemmlowp(a.raw(), b.raw()));
}
template <typename IntegerType>
IntegerType SaturatingSub(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
return a;
}
template <>
inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
std::int32_t a32 = a;
std::int32_t b32 = b;
std::int32_t diff = a32 - b32;
return static_cast<std::int16_t>(std::min(32767, std::max(-32768, diff)));
}
template <>
inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
std::int64_t a64 = a;
std::int64_t b64 = b;
std::int64_t diff = a64 - b64;
return static_cast<std::int32_t>(std::min(
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
std::max(
static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
diff)));
}
template <typename tRawType, int tIntegerBits>
gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
SaturatingSub(a.raw(), b.raw()));
}
// End section to be moved to gemmlowp.
namespace reference_ops {
// Return true for broadcast case, false otherwise.
@ -192,59 +130,6 @@ inline bool ProcessBroadcastShapes(const RuntimeShape& shape0,
return true;
}
template <typename T>
int CountLeadingZeros(T integer_input) {
static_assert(std::is_unsigned<T>::value,
"Only unsigned integer types handled.");
if (integer_input == 0) {
return std::numeric_limits<T>::digits;
}
const T one_in_leading_positive = static_cast<T>(1)
<< (std::numeric_limits<T>::digits - 1);
int leading_zeros = 0;
while (integer_input < one_in_leading_positive) {
integer_input <<= 1;
++leading_zeros;
}
return leading_zeros;
}
template <typename IntegerType>
IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
if (exponent == 0) {
return x;
}
using ScalarIntegerType =
typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
const IntegerType min =
gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
const IntegerType max =
gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
const std::int32_t threshold =
((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
const IntegerType positive_mask =
gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
const IntegerType negative_mask =
gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
IntegerType result = gemmlowp::ShiftLeft(x, exponent);
result = gemmlowp::SelectUsingMask(positive_mask, max, result);
result = gemmlowp::SelectUsingMask(negative_mask, min, result);
return result;
}
// If we want to leave IntegerBits fixed, then multiplication
// by a power of two has to be saturating/rounding, not exact anymore.
template <typename tRawType, int tIntegerBits>
gemmlowp::FixedPoint<tRawType, tIntegerBits>
SaturatingRoundingMultiplyByPOTParam(
gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
}
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& bias_shape,
@ -2752,121 +2637,6 @@ inline void LogSoftmax(const SoftmaxParams& params,
}
}
// Although currently the name of this function says that it cannot handle
// values less than 1, in practice it can handle as low as 1/x_max, where
// x_max is the largest representable input. In other words, the output range
// is symmetric.
template <int OutputIntegerBits, int InputIntegerBits>
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(
gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
// The reason for accumulating the result with an extra bit of headroom is
// that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
// recip_denom will otherwise introduce an error.
static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumIntegerBits>;
const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 1488522236, std::log(2.0));
const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 1518500250, std::sqrt(0.5));
const FixedPoint0 one_quarter =
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 1057819769,
2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
const FixedPointAccum shifted_quarter =
gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
// Reinterpret the input value as Q0.31, because we will figure out the
// required shift "ourselves" instead of using, say, Rescale.
FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
// z_a_pow_2 = input_integer_bits - z_a_headroom;
int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32>(z_a.raw()));
FixedPoint0 r_a_tmp =
SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
const int32 r_a_raw =
SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
// z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
// z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
// InputIntegerBits - z_b_headroom - 0.25);
const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
shifted_quarter);
// z_b is treated like z_a, but premultiplying by sqrt(0.5).
FixedPoint0 z_b = z_a * sqrt_half;
int z_b_headroom = CountLeadingZeros(static_cast<uint32>(z_b.raw())) - 1;
const int32 r_b_raw =
SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
shifted_quarter);
const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
FixedPoint0 q = r - sqrt_sqrt_half;
q = q + q;
const FixedPoint0 common_sq = q * q;
const FixedPoint0 num = q * r + q * common_sq * alpha_n;
const FixedPoint0 denom_minus_one_0 =
p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
const FixedPoint0 recip_denom =
one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
num_scaled * recip_denom);
}
// Minimum output bits to accommodate log of maximum input range. It actually
// does not matter if one considers, say, [-64,64] or [-64,64).
//
// For example, run this through Octave:
// [0:127; ...
// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
constexpr int min_log_x_output_bits(int input_bits) {
return input_bits > 90
? 7
: input_bits > 44
? 6
: input_bits > 21
? 5
: input_bits > 10
? 4
: input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1;
}
template <int OutputIntegerBits, int InputIntegerBits>
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1(
gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
static_assert(
OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
"Output integer bits must be sufficent to accommodate logs of inputs.");
return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
InputIntegerBits>(
input_val);
}
inline void LogSoftmax(const SoftmaxParams& params,
const RuntimeShape& input_shape, const uint8* input_data,
const RuntimeShape& output_shape, uint8* output_data) {

View File

@ -277,7 +277,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_LOG, Register_LOG());
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX());
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
/* min_version */ 1,

View File

@ -1813,6 +1813,21 @@ class Logistic : public SimpleOperator<LogisticOperator> {
}
};
class LogSoftmax : public SimpleOperator<LogSoftmaxOperator> {
public:
explicit LogSoftmax()
: SimpleOperator("LOG_SOFTMAX", OperatorType::kLogSoftmax) {}
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// Version 2 supports signed int8 input types.
if (input_array.data_type == ArrayDataType::kInt8) {
return 2;
}
return 1;
}
};
class SquaredDifference
: public BuiltinOperator<
SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions,
@ -2451,8 +2466,7 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
MakeUnique<SimpleOperator<ExpOperator>>("EXP", OperatorType::kExp));
ops.push_back(
MakeUnique<SimpleOperator<CosOperator>>("COS", OperatorType::kCos));
ops.push_back(MakeUnique<SimpleOperator<LogSoftmaxOperator>>(
"LOG_SOFTMAX", OperatorType::kLogSoftmax));
ops.push_back(MakeUnique<LogSoftmax>());
ops.push_back(MakeUnique<Maximum>()); // Element-wise Maximum
ops.push_back(MakeUnique<Minimum>()); // Element-wise Minimum
ops.push_back(MakeUnique<Greater>());

View File

@ -820,6 +820,10 @@ TEST_F(OperatorTest, VersioningSpaceToBatchNDTest) {
SimpleVersioningTest<SpaceToBatchNDOperator>();
}
TEST_F(OperatorTest, VersioningLogSoftmaxTest) {
SimpleVersioningTest<LogSoftmaxOperator>();
}
TEST_F(OperatorTest, VersioningPackTest) {
SimpleVersioningTest<PackOperator>();
}