From 8e0b3713a460eef87a67584ded585aa6670b068b Mon Sep 17 00:00:00 2001 From: Jian Li Date: Tue, 19 Feb 2019 12:54:51 -0800 Subject: [PATCH] Create int8 log softmax. PiperOrigin-RevId: 234656496 --- tensorflow/lite/kernels/activations.cc | 25 +- tensorflow/lite/kernels/activations_test.cc | 26 +- tensorflow/lite/kernels/internal/BUILD | 3 +- tensorflow/lite/kernels/internal/common.h | 215 ++++++++++++++++ .../kernels/internal/log_quantized_test.cc | 37 +-- .../internal/logsoftmax_quantized_test.cc | 121 +++++++-- .../internal/optimized/optimized_ops.h | 152 ------------ .../reference/integer_ops/log_softmax.h | 111 +++++++++ .../internal/reference/reference_ops.h | 230 ------------------ tensorflow/lite/kernels/register.cc | 4 +- tensorflow/lite/toco/tflite/operator.cc | 18 +- tensorflow/lite/toco/tflite/operator_test.cc | 4 + 12 files changed, 510 insertions(+), 436 deletions(-) create mode 100644 tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc index 0066d189948..7ef99cd0652 100644 --- a/tensorflow/lite/kernels/activations.cc +++ b/tensorflow/lite/kernels/activations.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h" @@ -270,8 +271,13 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, output->type); - if (input->type == kTfLiteUInt8) { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255); + if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255); + } + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 127); + } TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256); static const double kBeta = 1.0; @@ -854,6 +860,21 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } return kTfLiteOk; } + case kTfLiteInt8: { + const auto input_shape = GetTensorShape(input); + const auto output_shape = GetTensorShape(output); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + reference_integer_ops::LogSoftmax( + data->input_multiplier, data->input_left_shift, + data->reverse_scaling_divisor, data->reverse_scaling_right_shift, + data->diff_min, outer_size, depth, GetTensorData(input), + GetTensorData(output)); + return kTfLiteOk; + } default: context->ReportError(context, "Only float32 supported currently., got %s", TfLiteTypeGetName(input->type)); diff --git a/tensorflow/lite/kernels/activations_test.cc b/tensorflow/lite/kernels/activations_test.cc index f005dac438f..ccce4b0becf 100644 --- a/tensorflow/lite/kernels/activations_test.cc +++ b/tensorflow/lite/kernels/activations_test.cc @@ -772,7 +772,7 @@ TEST(FloatActivationsOpTest, LogSoftmax) { }))); } -TEST(QuantizedActivationsOpTest, LogSoftmax) { +TEST(QuantizedActivationsOpTest, LogSoftmaxUint8) { const float kLogSoftmaxQuantizedTolerance = 16 / 256.0; QuantizedActivationsOpModel m( BuiltinOperator_LOG_SOFTMAX, @@ -794,6 +794,30 @@ TEST(QuantizedActivationsOpTest, LogSoftmax) { ElementsAreArray({189, 93, 221, 253, 142, 63, 255, 111})); } +TEST(QuantizedActivationsOpTest, LogSoftmaxInt8) { + const float kLogSoftmaxQuantizedTolerance = 0.06355; + QuantizedActivationsOpModel m( + BuiltinOperator_LOG_SOFTMAX, + /*input=*/{TensorType_INT8, {2, 4}, -10, 10}, + /*output=*/{TensorType_INT8, {}, 0, 0, 16. / 256, 127}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + -4.14297, -10.14297, -2.14297, -.142971, // + -7.00104, -12.00104, -.00104087, -9.00104, // + }, + kLogSoftmaxQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 61, -36, 93, 125, // + 15, -65, 127, -16, // + })); +} + // A base class of PRelu op model. It provides the constructor for // FloatPReluOpModel and QuantizedPReluOpModel. class BasePReluOpModel : public SingleOpModel { diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 8527ec648ad..816b88d675c 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -311,6 +311,7 @@ cc_library( "reference/integer_ops/depthwise_conv.h", "reference/integer_ops/dequantize.h", "reference/integer_ops/fully_connected.h", + "reference/integer_ops/log_softmax.h", "reference/integer_ops/logistic.h", "reference/integer_ops/mul.h", "reference/integer_ops/pooling.h", @@ -652,7 +653,7 @@ cc_test( srcs = [ "logsoftmax_quantized_test.cc", ], - shard_count = 3, + shard_count = 4, tags = [ # TODO(b/122242739): Reenable after fixing the flakiness? "nomac", diff --git a/tensorflow/lite/kernels/internal/common.h b/tensorflow/lite/kernels/internal/common.h index bc30ac91220..e00a3f405e0 100644 --- a/tensorflow/lite/kernels/internal/common.h +++ b/tensorflow/lite/kernels/internal/common.h @@ -131,6 +131,221 @@ int CountLeadingZeros(T integer_input) { #endif } +// TODO(b/77858996): Add these to gemmlowp. +template +IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) { + static_assert(std::is_same::value, "unimplemented"); + return a; +} + +template <> +inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) { + std::int64_t a64 = a; + std::int64_t b64 = b; + std::int64_t sum = a64 + b64; + return static_cast(std::min( + static_cast(std::numeric_limits::max()), + std::max( + static_cast(std::numeric_limits::min()), + sum))); +} + +template +gemmlowp::FixedPoint SaturatingAddNonGemmlowp( + gemmlowp::FixedPoint a, + gemmlowp::FixedPoint b) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingAddNonGemmlowp(a.raw(), b.raw())); +} + +template +IntegerType SaturatingSub(IntegerType a, IntegerType b) { + static_assert(std::is_same::value, "unimplemented"); + return a; +} + +template <> +inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) { + std::int32_t a32 = a; + std::int32_t b32 = b; + std::int32_t diff = a32 - b32; + return static_cast(std::min(32767, std::max(-32768, diff))); +} + +template <> +inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) { + std::int64_t a64 = a; + std::int64_t b64 = b; + std::int64_t diff = a64 - b64; + return static_cast(std::min( + static_cast(std::numeric_limits::max()), + std::max( + static_cast(std::numeric_limits::min()), + diff))); +} + +template +gemmlowp::FixedPoint SaturatingSub( + gemmlowp::FixedPoint a, + gemmlowp::FixedPoint b) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingSub(a.raw(), b.raw())); +} +// End section to be moved to gemmlowp. + +template +IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) { + if (exponent == 0) { + return x; + } + using ScalarIntegerType = + typename gemmlowp::FixedPointRawTypeTraits::ScalarRawType; + const IntegerType min = + gemmlowp::Dup(std::numeric_limits::min()); + const IntegerType max = + gemmlowp::Dup(std::numeric_limits::max()); + const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); + + const std::int32_t threshold = + ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1); + const IntegerType positive_mask = + gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup(threshold)); + const IntegerType negative_mask = + gemmlowp::MaskIfLessThan(x, gemmlowp::Dup(-threshold)); + + IntegerType result = gemmlowp::ShiftLeft(x, exponent); + result = gemmlowp::SelectUsingMask(positive_mask, max, result); + result = gemmlowp::SelectUsingMask(negative_mask, min, result); + return result; +} + +// If we want to leave IntegerBits fixed, then multiplication +// by a power of two has to be saturating/rounding, not exact anymore. +template +gemmlowp::FixedPoint +SaturatingRoundingMultiplyByPOTParam( + gemmlowp::FixedPoint a, int exponent) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); +} + +// Minimum output bits to accommodate log of maximum input range. It actually +// does not matter if one considers, say, [-64,64] or [-64,64). +// +// For example, run this through Octave: +// [0:127; ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))] +constexpr int min_log_x_output_bits(int input_bits) { + return input_bits > 90 + ? 7 + : input_bits > 44 + ? 6 + : input_bits > 21 + ? 5 + : input_bits > 10 + ? 4 + : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1; +} + +// Although currently the name of this function says that it cannot handle +// values less than 1, in practice it can handle as low as 1/x_max, where +// x_max is the largest representable input. In other words, the output range +// is symmetric. +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1_impl( + gemmlowp::FixedPoint input_val) { + // assert(__builtin_clz(0u) >= std::numeric_limits::digits - 1); + // assert(__builtin_clz(0u) <= std::numeric_limits::digits); + using FixedPoint0 = gemmlowp::FixedPoint; + // The reason for accumulating the result with an extra bit of headroom is + // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled * + // recip_denom will otherwise introduce an error. + static constexpr int kAccumIntegerBits = OutputIntegerBits + 1; + using FixedPointAccum = gemmlowp::FixedPoint; + + const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1488522236, std::log(2.0)); + const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5))); + const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1518500250, std::sqrt(0.5)); + const FixedPoint0 one_quarter = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0); + + const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1057819769, + 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0))); + + const FixedPointAccum shifted_quarter = + gemmlowp::Rescale(one_quarter); + + // Reinterpret the input value as Q0.31, because we will figure out the + // required shift "ourselves" instead of using, say, Rescale. + FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); + // z_a_pow_2 = input_integer_bits - z_a_headroom; + int z_a_headroom_plus_1 = CountLeadingZeros(static_cast(z_a.raw())); + FixedPoint0 r_a_tmp = + SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); + const int32 r_a_raw = + SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1); + // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25); + // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25, + // InputIntegerBits - z_b_headroom - 0.25); + const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)), + shifted_quarter); + + // z_b is treated like z_a, but premultiplying by sqrt(0.5). + FixedPoint0 z_b = z_a * sqrt_half; + int z_b_headroom = CountLeadingZeros(static_cast(z_b.raw())) - 1; + const int32 r_b_raw = + SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); + const FixedPointAccum z_b_pow_2_adj = SaturatingSub( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)), + shifted_quarter); + + const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw)); + const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw( + std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw())); + + const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half); + FixedPoint0 q = r - sqrt_sqrt_half; + q = q + q; + + const FixedPoint0 common_sq = q * q; + const FixedPoint0 num = q * r + q * common_sq * alpha_n; + const FixedPoint0 denom_minus_one_0 = + p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q; + const FixedPoint0 recip_denom = + one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0); + + const FixedPointAccum num_scaled = gemmlowp::Rescale(num); + return gemmlowp::Rescale(z_pow_2_adj * log_2 + + num_scaled * recip_denom); +} + +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1( + gemmlowp::FixedPoint input_val) { + static_assert( + OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits), + "Output integer bits must be sufficent to accommodate logs of inputs."); + return log_x_for_x_greater_than_or_equal_to_1_impl( + input_val); +} + inline int32 GetReciprocal(int32 x, int x_integer_digits, int* num_bits_over_unit) { int headroom_plus_one = CountLeadingZeros(static_cast(x)); diff --git a/tensorflow/lite/kernels/internal/log_quantized_test.cc b/tensorflow/lite/kernels/internal/log_quantized_test.cc index 8c39350ab1d..c31c8e30775 100644 --- a/tensorflow/lite/kernels/internal/log_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/log_quantized_test.cc @@ -121,8 +121,7 @@ void RunSingleTest(const std::vector& test_input, const string& check_label, int tolerance) { const int n = test_input.size(); std::vector float_gen_output(n, 0); - std::vector reference_output(n, 0); - std::vector optimized_output(n, 0); + std::vector quantized_output(n, 0); // Workaround the stupid things that intelligent humans do. // Consequence of __builtin_clz(0u) may equal 31 instead of 32. @@ -132,45 +131,21 @@ void RunSingleTest(const std::vector& test_input, } for (int i = 0; i < n; ++i) { - reference_output[i] = - tflite::reference_ops::log_x_for_x_greater_than_or_equal_to_1_impl< - OutputIntegerBits, InputIntegerBits>( - gemmlowp::FixedPoint::FromRaw( - fudged_input[i])) - .raw(); - optimized_output[i] = - tflite::optimized_ops::log_x_for_x_greater_than_or_equal_to_1_impl< - OutputIntegerBits, InputIntegerBits>( + quantized_output[i] = + tflite::log_x_for_x_greater_than_or_equal_to_1_impl( gemmlowp::FixedPoint::FromRaw( fudged_input[i])) .raw(); float_gen_output[i] = LogPositiveValuesViaFloat( fudged_input[i], InputIntegerBits, OutputIntegerBits); } - // Note that first check is intolerant. - { - std::ostringstream label; - label << check_label << " / optimized vs reference / InputIntegerBits=" - << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; - CheckOutputData( - optimized_output, reference_output, test_input, label.str(), - InputIntegerBits, OutputIntegerBits, 0); - } { std::ostringstream label; label << check_label << " / reference vs float-gen / InputIntegerBits=" << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; - CheckOutputData( - reference_output, float_gen_output, test_input, label.str(), - InputIntegerBits, OutputIntegerBits, tolerance); - } - { - std::ostringstream label; - label << check_label << " optimized vs float-gen / InputIntegerBits=" - << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; - CheckOutputData( - optimized_output, float_gen_output, test_input, label.str(), - InputIntegerBits, OutputIntegerBits, tolerance); + CheckOutputData(quantized_output, float_gen_output, test_input, label.str(), + InputIntegerBits, OutputIntegerBits, tolerance); } } diff --git a/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc index 945300dad16..d0d2654d412 100644 --- a/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ limitations under the License. #include #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/test_util.h" #include "tensorflow/lite/string.h" @@ -61,7 +63,42 @@ void RunLogSoftmaxFloatReference(const uint8* input_data, } } -void CheckOutputData(const uint8* test_output, const uint8* reference_output, +// Same as above except for the following change: +// - input and output data type +// - Dequnatize function +// - clamping values +void RunLogSoftmaxFloatReference(const int8* input_data, + const RuntimeShape& shape_common, + int32 input_offset, const double input_scale, + int stride, float beta, + int8* reference_output_data) { + const int ref_buffer_size = shape_common.FlatSize(); + std::vector reference_dequant_data(ref_buffer_size); + std::vector reference_output_float_data(ref_buffer_size); + + // Reference data generated via Dequant of input into float, and then applying + // float LogSoftmax. + DequantizationParams dq_params; + dq_params.zero_point = input_offset; + dq_params.scale = input_scale; + reference_integer_ops::Dequantize(dq_params, shape_common, input_data, + shape_common, + reference_dequant_data.data()); + SoftmaxParams sm_params; + optimized_ops::LogSoftmax(sm_params, shape_common, + reference_dequant_data.data(), shape_common, + reference_output_float_data.data()); + // Work with quantized scaling for LogSoftmax, under which 255 represents 0, + // and -16 gets nudged up to 0. + for (int i = 0; i < ref_buffer_size; i++) { + reference_output_data[i] = std::max( + -128, static_cast( + 127 + std::round(16.0f * reference_output_float_data[i]))); + } +} + +template +void CheckOutputData(const T* test_output, const T* reference_output, const RuntimeShape& shape_common, const string& check_label, bool be_exacting) { const int buffer_size = shape_common.FlatSize(); @@ -144,15 +181,58 @@ void RunOneLogSoftmaxTest(const uint8* input_data, reference_ops::LogSoftmax(params, shape_common, input_data, shape_common, reference_quant_logsoftmax_output.data()); - CheckOutputData(optimized_logsoftmax_output.data(), - reference_float_logsoftmax_output.data(), shape_common, - "Optimized vs float reference", false); - CheckOutputData(optimized_logsoftmax_output.data(), - reference_quant_logsoftmax_output.data(), shape_common, - "Optimized vs quant reference", true); - CheckOutputData(reference_quant_logsoftmax_output.data(), - reference_float_logsoftmax_output.data(), shape_common, - "Quant reference vs float reference", false); + CheckOutputData(optimized_logsoftmax_output.data(), + reference_float_logsoftmax_output.data(), + shape_common, "Optimized vs float reference", false); + CheckOutputData(optimized_logsoftmax_output.data(), + reference_quant_logsoftmax_output.data(), + shape_common, "Optimized vs quant reference", true); + CheckOutputData(reference_quant_logsoftmax_output.data(), + reference_float_logsoftmax_output.data(), + shape_common, "Quant reference vs float reference", + false); +} + +// Runs the LogSoftmax and compares against the float reference implementation +// and the int8 quantized reference implementation. +void RunOneLogSoftmaxTest(const int8* input_data, + const RuntimeShape& shape_common, int32 input_offset, + const double input_scale, int stride, float beta) { + const int buffer_size = shape_common.FlatSize(); + std::vector quantized_logsoftmax_reference_implementation(buffer_size); + std::vector float_logsoftmax_optimized_implementation(buffer_size); + + RunLogSoftmaxFloatReference(input_data, shape_common, input_offset, + input_scale, stride, beta, + float_logsoftmax_optimized_implementation.data()); + + int32 input_beta_multiplier; + int input_beta_left_shift; + int32 reverse_scaling_divisor; + int reverse_scaling_right_shift; + static const int kScaledDiffIntegerBits = 5; + tflite::PreprocessLogSoftmaxScalingExp( + beta, input_scale, kScaledDiffIntegerBits, &input_beta_multiplier, + &input_beta_left_shift, &reverse_scaling_divisor, + &reverse_scaling_right_shift); + reverse_scaling_right_shift *= -1; + // diff_min has a negative value, and is used to limit the maximum magnitude + // of the diffs, which are <= 0. + const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, + input_beta_left_shift); + + const int outer_size = + shape_common.Dims(0) * shape_common.Dims(1) * shape_common.Dims(2); + const int inner_size = shape_common.Dims(3); + reference_integer_ops::LogSoftmax( + input_beta_multiplier, input_beta_left_shift, reverse_scaling_divisor, + reverse_scaling_right_shift, diff_min, outer_size, inner_size, input_data, + quantized_logsoftmax_reference_implementation.data()); + + CheckOutputData(quantized_logsoftmax_reference_implementation.data(), + float_logsoftmax_optimized_implementation.data(), + shape_common, "Quant reference vs float reference", + false); } // This function picks some random LogSoftmax params, which are checked for @@ -161,6 +241,7 @@ void RunOneLogSoftmaxTest(const uint8* input_data, // to loop until a test has been run. // // Currently we do not reject for any reason. +template bool TryOneUniformLogSoftmax() { // We pick mostly positive values, on the whole emphasizing smaller values and // therefore faster tests. We test a wider range of depths. In the case of @@ -178,7 +259,7 @@ bool TryOneUniformLogSoftmax() { RuntimeShape({batch, input_height, input_width, input_depth}); const int buffer_size = shape_common.FlatSize(); - std::vector input_data(buffer_size); + std::vector input_data(buffer_size); FillRandom(&input_data); RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); @@ -224,15 +305,23 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) { return true; } -TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxTests) { +TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxUint8Tests) { const int kTestsToRun = 100; for (int i = 0; i < kTestsToRun; i++) { - while (!TryOneUniformLogSoftmax()) { + while (!TryOneUniformLogSoftmax()) { } } } -TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxTests) { +TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxUint8Int8Tests) { + const int kTestsToRun = 100; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneUniformLogSoftmax()) { + } + } +} + +TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxUint8Tests) { const int kTestsToRun = 100; for (int i = 0; i < kTestsToRun; i++) { while (!TryOneSkyscraperLogSoftmax(false)) { @@ -240,7 +329,7 @@ TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxTests) { } } -TEST(TestQuantizedLogSoftmax, SmallSkyscraperLogSoftmaxTests) { +TEST(TestQuantizedLogSoftmax, SmallSkyscraperLogSoftmaxUint8Tests) { const int kTestsToRun = 100; for (int i = 0; i < kTestsToRun; i++) { while (!TryOneSkyscraperLogSoftmax(true)) { diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 72d2698da0c..3927540cc32 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -182,45 +182,6 @@ MatrixMap MapAsMatrixWithGivenNumberOfRows(Scalar* data, return MatrixMap(data, rows, cols); } -// This is like the template-parameter version, except that the power-of-two is -// passed as a function parameter. The template version is to be preferred, -// since some target hardware optimizations depend on the range of the exponent. -template -IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) { - if (exponent == 0) { - return x; - } - using ScalarIntegerType = - typename gemmlowp::FixedPointRawTypeTraits::ScalarRawType; - const IntegerType min = - gemmlowp::Dup(std::numeric_limits::min()); - const IntegerType max = - gemmlowp::Dup(std::numeric_limits::max()); - const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); - - const std::int32_t threshold = - ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1); - const IntegerType positive_mask = - gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup(threshold)); - const IntegerType negative_mask = - gemmlowp::MaskIfLessThan(x, gemmlowp::Dup(-threshold)); - - IntegerType result = gemmlowp::ShiftLeft(x, exponent); - result = gemmlowp::SelectUsingMask(positive_mask, max, result); - result = gemmlowp::SelectUsingMask(negative_mask, min, result); - return result; -} - -// This is like the template-parameter version, except that the power-of-two is -// passed as a function parameter. See raw-integer version for further comments. -template -gemmlowp::FixedPoint -SaturatingRoundingMultiplyByPOTParam( - gemmlowp::FixedPoint a, int exponent) { - return gemmlowp::FixedPoint::FromRaw( - SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); -} - inline void AddBiasAndEvalActivationFunction(float output_activation_min, float output_activation_max, const RuntimeShape& bias_shape, @@ -4557,119 +4518,6 @@ inline void LogSoftmax(const SoftmaxParams& params, } } -template -inline gemmlowp::FixedPoint -log_x_for_x_greater_than_or_equal_to_1_impl( - gemmlowp::FixedPoint input_val) { - // assert(__builtin_clz(0u) >= std::numeric_limits::digits - 1); - // assert(__builtin_clz(0u) <= std::numeric_limits::digits); - using FixedPoint0 = gemmlowp::FixedPoint; - // The reason for accumulating the result with an extra bit of headroom is - // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled * - // recip_denom will otherwise introduce an error. - static constexpr int kAccumIntegerBits = OutputIntegerBits + 1; - using FixedPointAccum = gemmlowp::FixedPoint; - - const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1488522236, std::log(2.0)); - const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5))); - const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1518500250, std::sqrt(0.5)); - const FixedPoint0 one_quarter = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0); - - const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0))); - const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0))); - const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1057819769, - 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0))); - const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0))); - - const FixedPointAccum shifted_quarter = - gemmlowp::Rescale(one_quarter); - - // Reinterpret the input value as Q0.31, because we will figure out the - // required shift "ourselves" instead of using, say, Rescale. - FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); - // z_a_pow_2 = input_integer_bits - z_a_headroom; - int z_a_headroom_plus_1 = CountLeadingZeros(static_cast(z_a.raw())); - FixedPoint0 r_a_tmp = - SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); - const int32 r_a_raw = - SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1); - // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25); - // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25, - // InputIntegerBits - z_b_headroom - 0.25); - const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp( - FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( - InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)), - shifted_quarter); - - // z_b is treated like z_a, but premultiplying by sqrt(0.5). - FixedPoint0 z_b = z_a * sqrt_half; - int z_b_headroom = CountLeadingZeros(static_cast(z_b.raw())) - 1; - const int32 r_b_raw = - SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); - const FixedPointAccum z_b_pow_2_adj = SaturatingSub( - FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( - InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)), - shifted_quarter); - - const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw)); - const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw( - std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw())); - - const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half); - FixedPoint0 q = r - sqrt_sqrt_half; - q = q + q; - - const FixedPoint0 common_sq = q * q; - const FixedPoint0 num = q * r + q * common_sq * alpha_n; - const FixedPoint0 denom_minus_one_0 = - p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q; - const FixedPoint0 recip_denom = - one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0); - - const FixedPointAccum num_scaled = gemmlowp::Rescale(num); - return gemmlowp::Rescale(z_pow_2_adj * log_2 + - num_scaled * recip_denom); -} - -// Minimum output bits to accommodate log of maximum input range. It actually -// does not matter if one considers, say, [-64,64] or [-64,64). -// -// For example, run this through Octave: -// [0:127; ... -// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ... -// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))] -constexpr int min_log_x_output_bits(int input_bits) { - return input_bits > 90 - ? 7 - : input_bits > 44 - ? 6 - : input_bits > 21 - ? 5 - : input_bits > 10 - ? 4 - : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1; -} - -template -inline gemmlowp::FixedPoint -log_x_for_x_greater_than_or_equal_to_1( - gemmlowp::FixedPoint input_val) { - static_assert( - OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits), - "Output integer bits must be sufficent to accommodate logs of inputs."); - return log_x_for_x_greater_than_or_equal_to_1_impl( - input_val); -} - // Currently just a copy of the reference code. inline void LogSoftmax(const SoftmaxParams& params, const RuntimeShape& input_shape, const uint8* input_data, diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h b/tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h new file mode 100644 index 00000000000..f22bb4f1380 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h @@ -0,0 +1,111 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_ + +#include "tensorflow/lite/kernels/internal/common.h" + +namespace tflite { +namespace reference_integer_ops { + +inline void LogSoftmax(int32_t input_multiplier, int32_t input_shift, + int32_t reverse_multiplier, int32_t reverse_shift, + int32_t diff_min, int32_t outer_size, int32_t depth, + const int8* input_data, int8* output_data) { + static constexpr int8_t kMinInt8 = std::numeric_limits::min(); + static constexpr int8_t kMaxInt8 = std::numeric_limits::max(); + static constexpr int32_t kMinInt32 = std::numeric_limits::min(); + + // [-16, 0] is mapped to [-128, 127] with 1/16 as scale and 127 as zero + // point. This nudges the output to [-255/16, 0]. + static constexpr int32_t kOutputZeroPoint = 127; + + // All IntegerBits must agree with Prepare function. + // Input is chosen as Q5.26 so exp(-1 * 2^5 * 2^-1) = exp(-16) is negligible. + static constexpr int kInputIntegerBits = 5; + static constexpr int kAccumulationIntegerBits = 12; + static constexpr int kOutputIntegerBits = 4; + using F5 = gemmlowp::FixedPoint; + using F12 = gemmlowp::FixedPoint; + + for (int outer_index = 0; outer_index < outer_size; ++outer_index) { + int8 max_in_row = kMinInt8; + for (int inner_index = 0; inner_index < depth; ++inner_index) { + max_in_row = + std::max(max_in_row, input_data[outer_index * depth + inner_index]); + } + + // Accumulator "sum_of_exps_in_q12" is safe from overflowing in 2^12 steps. + F12 sum_of_exps_in_q12 = F12::FromRaw(0); + for (int inner_index = 0; inner_index < depth; ++inner_index) { + int32_t input_diff = + static_cast(input_data[outer_index * depth + inner_index]) - + max_in_row; + if (input_diff >= diff_min) { + const int32_t input_diff_in_q5 = MultiplyByQuantizedMultiplier( + input_diff, input_multiplier, input_shift); + sum_of_exps_in_q12 = + sum_of_exps_in_q12 + + gemmlowp::Rescale( + exp_on_negative_values(F5::FromRaw(input_diff_in_q5))); + } + } + + const int32_t log_sum_of_exps_in_q5 = + log_x_for_x_greater_than_or_equal_to_1( + sum_of_exps_in_q12) + .raw(); + + // Potentially reduced the valid range. shifted_log_sum_of_exps_in_q5 is + // smallest representable in Q5.26 plus the log_sum_of_exps. + const int32_t shifted_log_sum_of_exps_in_q5 = + log_sum_of_exps_in_q5 + kMinInt32; + const int32_t adjusted_diff_min = std::max( + diff_min - 1, + MultiplyByQuantizedMultiplier(shifted_log_sum_of_exps_in_q5, + reverse_multiplier, -reverse_shift)); + + for (int inner_index = 0; inner_index < depth; ++inner_index) { + int32_t input_diff = + static_cast(input_data[outer_index * depth + inner_index]) - + max_in_row; + // Note use of > below instead of >= above. + if (input_diff > adjusted_diff_min) { + const int32_t input_diff_in_q5 = MultiplyByQuantizedMultiplier( + input_diff, input_multiplier, input_shift); + + // Rescale and downcast. + int32_t output_in_q27 = + gemmlowp::RoundingDivideByPOT( + (input_diff_in_q5 - log_sum_of_exps_in_q5), + 31 - kInputIntegerBits - kOutputIntegerBits) + + kOutputZeroPoint; + + output_in_q27 = + std::max(std::min(output_in_q27, static_cast(kMaxInt8)), + static_cast(kMinInt8)); + output_data[outer_index * depth + inner_index] = + static_cast(output_in_q27); + } else { + output_data[outer_index * depth + inner_index] = kMinInt8; + } + } + } +} + +} // namespace reference_integer_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_ diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 7652bcf798d..66f18d29993 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -36,68 +36,6 @@ limitations under the License. namespace tflite { -// TODO(b/77858996): Add these to gemmlowp. -template -IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) { - static_assert(std::is_same::value, "unimplemented"); - return a; -} - -template <> -inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) { - std::int64_t a64 = a; - std::int64_t b64 = b; - std::int64_t sum = a64 + b64; - return static_cast(std::min( - static_cast(std::numeric_limits::max()), - std::max( - static_cast(std::numeric_limits::min()), - sum))); -} - -template -gemmlowp::FixedPoint SaturatingAddNonGemmlowp( - gemmlowp::FixedPoint a, - gemmlowp::FixedPoint b) { - return gemmlowp::FixedPoint::FromRaw( - SaturatingAddNonGemmlowp(a.raw(), b.raw())); -} - -template -IntegerType SaturatingSub(IntegerType a, IntegerType b) { - static_assert(std::is_same::value, "unimplemented"); - return a; -} - -template <> -inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) { - std::int32_t a32 = a; - std::int32_t b32 = b; - std::int32_t diff = a32 - b32; - return static_cast(std::min(32767, std::max(-32768, diff))); -} - -template <> -inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) { - std::int64_t a64 = a; - std::int64_t b64 = b; - std::int64_t diff = a64 - b64; - return static_cast(std::min( - static_cast(std::numeric_limits::max()), - std::max( - static_cast(std::numeric_limits::min()), - diff))); -} - -template -gemmlowp::FixedPoint SaturatingSub( - gemmlowp::FixedPoint a, - gemmlowp::FixedPoint b) { - return gemmlowp::FixedPoint::FromRaw( - SaturatingSub(a.raw(), b.raw())); -} -// End section to be moved to gemmlowp. - namespace reference_ops { // Return true for broadcast case, false otherwise. @@ -192,59 +130,6 @@ inline bool ProcessBroadcastShapes(const RuntimeShape& shape0, return true; } -template -int CountLeadingZeros(T integer_input) { - static_assert(std::is_unsigned::value, - "Only unsigned integer types handled."); - if (integer_input == 0) { - return std::numeric_limits::digits; - } - const T one_in_leading_positive = static_cast(1) - << (std::numeric_limits::digits - 1); - int leading_zeros = 0; - while (integer_input < one_in_leading_positive) { - integer_input <<= 1; - ++leading_zeros; - } - return leading_zeros; -} - -template -IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) { - if (exponent == 0) { - return x; - } - using ScalarIntegerType = - typename gemmlowp::FixedPointRawTypeTraits::ScalarRawType; - const IntegerType min = - gemmlowp::Dup(std::numeric_limits::min()); - const IntegerType max = - gemmlowp::Dup(std::numeric_limits::max()); - const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); - - const std::int32_t threshold = - ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1); - const IntegerType positive_mask = - gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup(threshold)); - const IntegerType negative_mask = - gemmlowp::MaskIfLessThan(x, gemmlowp::Dup(-threshold)); - - IntegerType result = gemmlowp::ShiftLeft(x, exponent); - result = gemmlowp::SelectUsingMask(positive_mask, max, result); - result = gemmlowp::SelectUsingMask(negative_mask, min, result); - return result; -} - -// If we want to leave IntegerBits fixed, then multiplication -// by a power of two has to be saturating/rounding, not exact anymore. -template -gemmlowp::FixedPoint -SaturatingRoundingMultiplyByPOTParam( - gemmlowp::FixedPoint a, int exponent) { - return gemmlowp::FixedPoint::FromRaw( - SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); -} - inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& filter_shape, const float* filter_data, const RuntimeShape& bias_shape, @@ -2752,121 +2637,6 @@ inline void LogSoftmax(const SoftmaxParams& params, } } -// Although currently the name of this function says that it cannot handle -// values less than 1, in practice it can handle as low as 1/x_max, where -// x_max is the largest representable input. In other words, the output range -// is symmetric. -template -inline gemmlowp::FixedPoint -log_x_for_x_greater_than_or_equal_to_1_impl( - gemmlowp::FixedPoint input_val) { - using FixedPoint0 = gemmlowp::FixedPoint; - // The reason for accumulating the result with an extra bit of headroom is - // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled * - // recip_denom will otherwise introduce an error. - static constexpr int kAccumIntegerBits = OutputIntegerBits + 1; - using FixedPointAccum = gemmlowp::FixedPoint; - - const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1488522236, std::log(2.0)); - const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5))); - const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1518500250, std::sqrt(0.5)); - const FixedPoint0 one_quarter = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0); - - const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0))); - const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0))); - const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1057819769, - 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0))); - const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0))); - - const FixedPointAccum shifted_quarter = - gemmlowp::Rescale(one_quarter); - - // Reinterpret the input value as Q0.31, because we will figure out the - // required shift "ourselves" instead of using, say, Rescale. - FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); - // z_a_pow_2 = input_integer_bits - z_a_headroom; - int z_a_headroom_plus_1 = CountLeadingZeros(static_cast(z_a.raw())); - FixedPoint0 r_a_tmp = - SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); - const int32 r_a_raw = - SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1); - // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25); - // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25, - // InputIntegerBits - z_b_headroom - 0.25); - const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp( - FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( - InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)), - shifted_quarter); - - // z_b is treated like z_a, but premultiplying by sqrt(0.5). - FixedPoint0 z_b = z_a * sqrt_half; - int z_b_headroom = CountLeadingZeros(static_cast(z_b.raw())) - 1; - const int32 r_b_raw = - SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); - const FixedPointAccum z_b_pow_2_adj = SaturatingSub( - FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( - InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)), - shifted_quarter); - - const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw)); - const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw( - std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw())); - - const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half); - FixedPoint0 q = r - sqrt_sqrt_half; - q = q + q; - - const FixedPoint0 common_sq = q * q; - const FixedPoint0 num = q * r + q * common_sq * alpha_n; - const FixedPoint0 denom_minus_one_0 = - p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q; - const FixedPoint0 recip_denom = - one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0); - - const FixedPointAccum num_scaled = gemmlowp::Rescale(num); - return gemmlowp::Rescale(z_pow_2_adj * log_2 + - num_scaled * recip_denom); -} - -// Minimum output bits to accommodate log of maximum input range. It actually -// does not matter if one considers, say, [-64,64] or [-64,64). -// -// For example, run this through Octave: -// [0:127; ... -// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ... -// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))] -constexpr int min_log_x_output_bits(int input_bits) { - return input_bits > 90 - ? 7 - : input_bits > 44 - ? 6 - : input_bits > 21 - ? 5 - : input_bits > 10 - ? 4 - : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1; -} - -template -inline gemmlowp::FixedPoint -log_x_for_x_greater_than_or_equal_to_1( - gemmlowp::FixedPoint input_val) { - static_assert( - OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits), - "Output integer bits must be sufficent to accommodate logs of inputs."); - return log_x_for_x_greater_than_or_equal_to_1_impl( - input_val); -} - inline void LogSoftmax(const SoftmaxParams& params, const RuntimeShape& input_shape, const uint8* input_data, const RuntimeShape& output_shape, uint8* output_data) { diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index a050f7e3ad7..aa0358085ce 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -277,7 +277,9 @@ BuiltinOpResolver::BuiltinOpResolver() { /* min_version */ 1, /* max_version */ 2); AddBuiltin(BuiltinOperator_LOG, Register_LOG()); - AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX()); + AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_CAST, Register_CAST()); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(), /* min_version */ 1, diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 8ab5f1b37e8..93f7b99ca02 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -1813,6 +1813,21 @@ class Logistic : public SimpleOperator { } }; +class LogSoftmax : public SimpleOperator { + public: + explicit LogSoftmax() + : SimpleOperator("LOG_SOFTMAX", OperatorType::kLogSoftmax) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + class SquaredDifference : public BuiltinOperator< SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions, @@ -2451,8 +2466,7 @@ std::vector> BuildOperatorList( MakeUnique>("EXP", OperatorType::kExp)); ops.push_back( MakeUnique>("COS", OperatorType::kCos)); - ops.push_back(MakeUnique>( - "LOG_SOFTMAX", OperatorType::kLogSoftmax)); + ops.push_back(MakeUnique()); ops.push_back(MakeUnique()); // Element-wise Maximum ops.push_back(MakeUnique()); // Element-wise Minimum ops.push_back(MakeUnique()); diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc index 864f83e501f..7a541a85caf 100644 --- a/tensorflow/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -820,6 +820,10 @@ TEST_F(OperatorTest, VersioningSpaceToBatchNDTest) { SimpleVersioningTest(); } +TEST_F(OperatorTest, VersioningLogSoftmaxTest) { + SimpleVersioningTest(); +} + TEST_F(OperatorTest, VersioningPackTest) { SimpleVersioningTest(); }