Move the templatized integer implementation of Softmax to reference/softmax.h, make it work with uint8_t to replace the uint8 version there.
PiperOrigin-RevId: 302070491 Change-Id: I3a1148604fbeae271891d2cc202e765aed5b5b9f
This commit is contained in:
parent
d0e21cd468
commit
a794c690b4
@ -438,7 +438,6 @@ cc_library(
|
|||||||
"reference/integer_ops/mean.h",
|
"reference/integer_ops/mean.h",
|
||||||
"reference/integer_ops/mul.h",
|
"reference/integer_ops/mul.h",
|
||||||
"reference/integer_ops/pooling.h",
|
"reference/integer_ops/pooling.h",
|
||||||
"reference/integer_ops/softmax.h",
|
|
||||||
"reference/integer_ops/tanh.h",
|
"reference/integer_ops/tanh.h",
|
||||||
"reference/integer_ops/transpose_conv.h",
|
"reference/integer_ops/transpose_conv.h",
|
||||||
"reference/logistic.h",
|
"reference/logistic.h",
|
||||||
|
|||||||
@ -1,107 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_SOFTMAX_H_
|
|
||||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_SOFTMAX_H_
|
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace reference_integer_ops {
|
|
||||||
|
|
||||||
// Quantized softmax with int8 input and int8/int16 output.
|
|
||||||
template <typename OutputT = int8_t>
|
|
||||||
inline void Softmax(const SoftmaxParams& params,
|
|
||||||
const RuntimeShape& input_shape, const int8* input_data,
|
|
||||||
const RuntimeShape& output_shape, OutputT* output_data) {
|
|
||||||
const int32_t input_beta_multiplier = params.input_multiplier;
|
|
||||||
const int32_t input_beta_left_shift = params.input_left_shift;
|
|
||||||
const int diff_min = params.diff_min;
|
|
||||||
// The representation chosen for the input to the exp() function is Q5.26.
|
|
||||||
// We need to leave extra space since values that we skip might be as large as
|
|
||||||
// -32 before multiplying by input_beta_multiplier, and therefore as large as
|
|
||||||
// -16 afterwards. Note that exp(-8) is definitely not insignificant to
|
|
||||||
// accumulation, but exp(-16) definitely is.
|
|
||||||
static const int kScaledDiffIntegerBits = 5;
|
|
||||||
static const int kAccumulationIntegerBits = 12;
|
|
||||||
using FixedPointScaledDiff =
|
|
||||||
gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
|
|
||||||
using FixedPointAccum =
|
|
||||||
gemmlowp::FixedPoint<int32_t, kAccumulationIntegerBits>;
|
|
||||||
using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
|
|
||||||
|
|
||||||
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
|
||||||
const int outer_size =
|
|
||||||
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
|
||||||
const int depth =
|
|
||||||
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
|
||||||
|
|
||||||
for (int i = 0; i < outer_size; ++i) {
|
|
||||||
int8 max_in_row = -128;
|
|
||||||
for (int c = 0; c < depth; ++c) {
|
|
||||||
max_in_row = std::max(max_in_row, input_data[i * depth + c]);
|
|
||||||
}
|
|
||||||
|
|
||||||
FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
|
|
||||||
for (int c = 0; c < depth; ++c) {
|
|
||||||
int32_t input_diff =
|
|
||||||
static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
|
|
||||||
if (input_diff >= diff_min) {
|
|
||||||
const int32_t input_diff_rescaled =
|
|
||||||
MultiplyByQuantizedMultiplierGreaterThanOne(
|
|
||||||
input_diff, input_beta_multiplier, input_beta_left_shift);
|
|
||||||
const FixedPointScaledDiff scaled_diff_f8 =
|
|
||||||
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
|
||||||
sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
|
|
||||||
exp_on_negative_values(scaled_diff_f8));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int num_bits_over_unit;
|
|
||||||
FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
|
|
||||||
sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
|
|
||||||
|
|
||||||
for (int c = 0; c < depth; ++c) {
|
|
||||||
int32_t input_diff =
|
|
||||||
static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
|
|
||||||
if (input_diff >= diff_min) {
|
|
||||||
const int32_t input_diff_rescaled =
|
|
||||||
MultiplyByQuantizedMultiplierGreaterThanOne(
|
|
||||||
input_diff, input_beta_multiplier, input_beta_left_shift);
|
|
||||||
const FixedPointScaledDiff scaled_diff_f8 =
|
|
||||||
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
|
||||||
|
|
||||||
FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
|
|
||||||
const int32_t unsat_output = gemmlowp::RoundingDivideByPOT(
|
|
||||||
(shifted_scale * exp_in_0).raw(),
|
|
||||||
num_bits_over_unit + 31 - (sizeof(OutputT) * 8));
|
|
||||||
// TODO(b/148494470): Handle int32 shifts properly:
|
|
||||||
const int32_t shifted_output =
|
|
||||||
unsat_output -
|
|
||||||
(static_cast<int32_t>(std::numeric_limits<OutputT>::max()) + 1);
|
|
||||||
output_data[i * depth + c] = static_cast<OutputT>(std::max(
|
|
||||||
std::min(shifted_output,
|
|
||||||
static_cast<int32_t>(std::numeric_limits<OutputT>::max())),
|
|
||||||
static_cast<int32_t>(std::numeric_limits<OutputT>::min())));
|
|
||||||
} else {
|
|
||||||
output_data[i * depth + c] = std::numeric_limits<OutputT>::min();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace reference_integer_ops
|
|
||||||
} // namespace tflite
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_SOFTMAX_H_
|
|
||||||
@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
|
||||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
#include "fixedpoint/fixedpoint.h"
|
#include "fixedpoint/fixedpoint.h"
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
@ -59,9 +61,11 @@ inline void Softmax(const SoftmaxParams& params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Quantized softmax with int8/uint8 input and int8/uint8/int16 output.
|
||||||
|
template <typename InputT, typename OutputT>
|
||||||
inline void Softmax(const SoftmaxParams& params,
|
inline void Softmax(const SoftmaxParams& params,
|
||||||
const RuntimeShape& input_shape, const uint8* input_data,
|
const RuntimeShape& input_shape, const InputT* input_data,
|
||||||
const RuntimeShape& output_shape, uint8* output_data) {
|
const RuntimeShape& output_shape, OutputT* output_data) {
|
||||||
const int32 input_beta_multiplier = params.input_multiplier;
|
const int32 input_beta_multiplier = params.input_multiplier;
|
||||||
const int32 input_beta_left_shift = params.input_left_shift;
|
const int32 input_beta_left_shift = params.input_left_shift;
|
||||||
const int diff_min = params.diff_min;
|
const int diff_min = params.diff_min;
|
||||||
@ -84,7 +88,7 @@ inline void Softmax(const SoftmaxParams& params,
|
|||||||
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||||
|
|
||||||
for (int i = 0; i < outer_size; ++i) {
|
for (int i = 0; i < outer_size; ++i) {
|
||||||
uint8 max_in_row = 0;
|
InputT max_in_row = std::numeric_limits<InputT>::min();
|
||||||
for (int c = 0; c < depth; ++c) {
|
for (int c = 0; c < depth; ++c) {
|
||||||
max_in_row = std::max(max_in_row, input_data[i * depth + c]);
|
max_in_row = std::max(max_in_row, input_data[i * depth + c]);
|
||||||
}
|
}
|
||||||
@ -120,14 +124,19 @@ inline void Softmax(const SoftmaxParams& params,
|
|||||||
|
|
||||||
FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
|
FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
|
||||||
int32 unsat_output = gemmlowp::RoundingDivideByPOT(
|
int32 unsat_output = gemmlowp::RoundingDivideByPOT(
|
||||||
(shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
|
(shifted_scale * exp_in_0).raw(),
|
||||||
|
num_bits_over_unit + 31 - (sizeof(OutputT) * 8));
|
||||||
|
|
||||||
output_data[i * depth + c] = static_cast<uint8>(
|
const int32 shifted_output =
|
||||||
std::max(std::min(unsat_output, static_cast<int32>(255)),
|
unsat_output +
|
||||||
static_cast<int32>(0)));
|
static_cast<int32>(std::numeric_limits<OutputT>::min());
|
||||||
|
|
||||||
|
output_data[i * depth + c] = static_cast<OutputT>(std::max(
|
||||||
|
std::min(shifted_output,
|
||||||
|
static_cast<int32>(std::numeric_limits<OutputT>::max())),
|
||||||
|
static_cast<int32>(std::numeric_limits<OutputT>::min())));
|
||||||
} else {
|
} else {
|
||||||
output_data[i * depth + c] = 0;
|
output_data[i * depth + c] = std::numeric_limits<OutputT>::min();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -25,7 +25,6 @@ limitations under the License.
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h"
|
|
||||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/test_util.h"
|
#include "tensorflow/lite/kernels/internal/test_util.h"
|
||||||
#include "tensorflow/lite/string_type.h"
|
#include "tensorflow/lite/string_type.h"
|
||||||
|
|||||||
@ -19,7 +19,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h"
|
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
#include "tensorflow/lite/kernels/op_macros.h"
|
#include "tensorflow/lite/kernels/op_macros.h"
|
||||||
@ -116,13 +115,13 @@ void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
|||||||
GetTensorData<uint8_t>(output));
|
GetTensorData<uint8_t>(output));
|
||||||
} else {
|
} else {
|
||||||
if (output->type == kTfLiteInt16) {
|
if (output->type == kTfLiteInt16) {
|
||||||
tflite::reference_integer_ops::Softmax(
|
tflite::reference_ops::Softmax(op_params, shape,
|
||||||
op_params, shape, GetTensorData<int8_t>(input), shape,
|
GetTensorData<int8_t>(input), shape,
|
||||||
GetTensorData<int16_t>(output));
|
GetTensorData<int16_t>(output));
|
||||||
} else {
|
} else {
|
||||||
tflite::reference_integer_ops::Softmax(
|
tflite::reference_ops::Softmax(op_params, shape,
|
||||||
op_params, shape, GetTensorData<int8_t>(input), shape,
|
GetTensorData<int8_t>(input), shape,
|
||||||
GetTensorData<int8_t>(output));
|
GetTensorData<int8_t>(output));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -147,13 +146,13 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
|||||||
GetTensorData<uint8_t>(output));
|
GetTensorData<uint8_t>(output));
|
||||||
} else {
|
} else {
|
||||||
if (output->type == kTfLiteInt16) {
|
if (output->type == kTfLiteInt16) {
|
||||||
tflite::reference_integer_ops::Softmax(
|
tflite::reference_ops::Softmax(op_params, shape,
|
||||||
op_params, shape, GetTensorData<int8_t>(input), shape,
|
GetTensorData<int8_t>(input), shape,
|
||||||
GetTensorData<int16_t>(output));
|
GetTensorData<int16_t>(output));
|
||||||
} else {
|
} else {
|
||||||
tflite::reference_integer_ops::Softmax(
|
tflite::reference_ops::Softmax(op_params, shape,
|
||||||
op_params, shape, GetTensorData<int8_t>(input), shape,
|
GetTensorData<int8_t>(input), shape,
|
||||||
GetTensorData<int8_t>(output));
|
GetTensorData<int8_t>(output));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -180,11 +179,11 @@ void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
|||||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||||
} else {
|
} else {
|
||||||
if (output->type == kTfLiteInt16) {
|
if (output->type == kTfLiteInt16) {
|
||||||
tflite::reference_integer_ops::Softmax(
|
tflite::reference_ops::Softmax(
|
||||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||||
GetTensorShape(output), GetTensorData<int16_t>(output));
|
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||||
} else {
|
} else {
|
||||||
tflite::reference_integer_ops::Softmax(
|
tflite::reference_ops::Softmax(
|
||||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||||
GetTensorShape(output), GetTensorData<int8_t>(output));
|
GetTensorShape(output), GetTensorData<int8_t>(output));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,7 +19,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h"
|
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
#include "tensorflow/lite/kernels/op_macros.h"
|
#include "tensorflow/lite/kernels/op_macros.h"
|
||||||
|
|||||||
@ -149,7 +149,6 @@ tensorflow/lite/kernels/internal/reference/integer_ops/conv.h \
|
|||||||
tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h \
|
tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h \
|
||||||
tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h \
|
tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h \
|
||||||
tensorflow/lite/kernels/internal/reference/integer_ops/mul.h \
|
tensorflow/lite/kernels/internal/reference/integer_ops/mul.h \
|
||||||
tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h \
|
|
||||||
tensorflow/lite/kernels/internal/reference/maximum_minimum.h \
|
tensorflow/lite/kernels/internal/reference/maximum_minimum.h \
|
||||||
tensorflow/lite/kernels/internal/reference/mul.h \
|
tensorflow/lite/kernels/internal/reference/mul.h \
|
||||||
tensorflow/lite/kernels/internal/reference/neg.h \
|
tensorflow/lite/kernels/internal/reference/neg.h \
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user