Merge pull request #46709 from advaitjain:softmax-refactor
PiperOrigin-RevId: 353952307 Change-Id: Iefcc2d4f8d28693a25aafd4c5ce3a35592a671ce
This commit is contained in:
commit
a4c269747a
@ -132,6 +132,7 @@ cc_library(
|
||||
"resize_nearest_neighbor.cc",
|
||||
"round.cc",
|
||||
"shape.cc",
|
||||
"softmax_common.cc",
|
||||
"split.cc",
|
||||
"split_v.cc",
|
||||
"strided_slice.cc",
|
||||
@ -159,6 +160,7 @@ cc_library(
|
||||
hdrs = [
|
||||
"micro_ops.h",
|
||||
"quantize.h",
|
||||
"softmax.h",
|
||||
"svdf.h",
|
||||
],
|
||||
copts = micro_copts(),
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
|
||||
#include "tensorflow/lite/micro/kernels/softmax.h"
|
||||
|
||||
#include "CMSIS/NN/Include/arm_nnfunctions.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
@ -27,131 +28,6 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
// Softmax parameter data that persists in user_data
|
||||
static constexpr int kInt16LUTArraySize = 513;
|
||||
|
||||
TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
|
||||
const TfLiteTensor* input,
|
||||
TfLiteTensor* output,
|
||||
const TfLiteSoftmaxParams* params,
|
||||
SoftmaxParams* op_data) {
|
||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
|
||||
input->type == kTfLiteInt16) {
|
||||
if (input->type == kTfLiteUInt8) {
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8);
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
} else if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 32768,
|
||||
(0.001f * 1.f / 32768));
|
||||
} else { // input->type == kTfLiteInt8
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
|
||||
if (output->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768);
|
||||
TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 65536,
|
||||
(0.001f * 1.f / 65536));
|
||||
} else { // output->type == kTfLiteint8
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
|
||||
TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
|
||||
}
|
||||
}
|
||||
|
||||
static const int kScaledDiffIntegerBits = 5;
|
||||
|
||||
// Calculate input_multiplier and input_left_shift
|
||||
if (input->type == kTfLiteInt16) {
|
||||
int input_left_shift;
|
||||
double input_scale_beta_rescale =
|
||||
static_cast<double>(input->params.scale) *
|
||||
static_cast<double>(params->beta) /
|
||||
(10.0 / 65535.0); // scale the input_diff such that [-65535, 0]
|
||||
// correspond to [-10.0, 0.0]
|
||||
QuantizeMultiplier(input_scale_beta_rescale, &op_data->input_multiplier,
|
||||
&input_left_shift);
|
||||
op_data->input_left_shift = input_left_shift;
|
||||
} else {
|
||||
int input_left_shift;
|
||||
tflite::PreprocessSoftmaxScaling(
|
||||
static_cast<double>(params->beta),
|
||||
static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
|
||||
&op_data->input_multiplier, &input_left_shift);
|
||||
op_data->input_left_shift = input_left_shift;
|
||||
op_data->diff_min =
|
||||
-1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
|
||||
op_data->input_left_shift);
|
||||
}
|
||||
} else {
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
|
||||
op_data->beta = static_cast<double>(params->beta);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(SoftmaxParams));
|
||||
}
|
||||
|
||||
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(context, node->user_data != nullptr);
|
||||
SoftmaxParams* op_data = static_cast<SoftmaxParams*>(node->user_data);
|
||||
// Only allocate LUTs for KTfLiteInt16 data type
|
||||
if (input->type == kTfLiteInt16) {
|
||||
void* raw_exp_lut = context->AllocatePersistentBuffer(
|
||||
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||
TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
|
||||
op_data->exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
|
||||
void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
|
||||
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||
TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
|
||||
op_data->one_over_one_plus_x_lut =
|
||||
reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
|
||||
}
|
||||
|
||||
if (output->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
|
||||
input->type == kTfLiteUInt8 ||
|
||||
input->type == kTfLiteInt16);
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
}
|
||||
|
||||
// Populate LUT if required
|
||||
if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
// exp LUT only used on negative values
|
||||
// we consider exp(-10.0) is insignificant to accumulation
|
||||
gen_lut([](float value) { return std::exp(value); }, -10.0f, 0.0f,
|
||||
op_data->exp_lut, kInt16LUTArraySize);
|
||||
gen_lut([](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f,
|
||||
op_data->one_over_one_plus_x_lut, kInt16LUTArraySize);
|
||||
op_data->zero_point = output->params.zero_point;
|
||||
op_data->scale = output->params.scale;
|
||||
}
|
||||
|
||||
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||
return CalculateSoftmaxParams(context, input, output, params, op_data);
|
||||
}
|
||||
|
||||
// Takes a tensor and performs softmax along the last dimension.
|
||||
void SoftmaxFloat(const TfLiteEvalTensor* input, TfLiteEvalTensor* output,
|
||||
const SoftmaxParams& op_data) {
|
||||
tflite::reference_ops::Softmax(op_data, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
}
|
||||
|
||||
void SoftmaxQuantized(const TfLiteEvalTensor* input, TfLiteEvalTensor* output,
|
||||
const SoftmaxParams& op_data) {
|
||||
if (input->type == kTfLiteUInt8) {
|
||||
@ -200,7 +76,11 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
SoftmaxFloat(input, output, data);
|
||||
tflite::reference_ops::Softmax(
|
||||
data, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case kTfLiteInt8:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
|
||||
#include "tensorflow/lite/micro/kernels/softmax.h"
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
@ -27,77 +28,6 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
// Softmax parameter data that persists in user_data
|
||||
static constexpr int kInt16LUTArraySize = 513;
|
||||
|
||||
TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
|
||||
const TfLiteTensor* input,
|
||||
TfLiteTensor* output,
|
||||
const TfLiteSoftmaxParams* params,
|
||||
SoftmaxParams* op_data) {
|
||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
|
||||
input->type == kTfLiteInt16) {
|
||||
if (input->type == kTfLiteUInt8) {
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8);
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
} else if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 32768,
|
||||
(0.001f * 1.f / 32768));
|
||||
} else { // input->type == kTfLiteInt8
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
|
||||
if (output->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768);
|
||||
TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 65536,
|
||||
(0.001f * 1.f / 65536));
|
||||
} else { // output->type == kTfLiteint8
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
|
||||
TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
|
||||
}
|
||||
}
|
||||
|
||||
static const int kScaledDiffIntegerBits = 5;
|
||||
|
||||
// Calculate input_multiplier and input_left_shift
|
||||
if (input->type == kTfLiteInt16) {
|
||||
int input_left_shift;
|
||||
double input_scale_beta_rescale =
|
||||
static_cast<double>(input->params.scale) *
|
||||
static_cast<double>(params->beta) /
|
||||
(10.0 / 65535.0); // scale the input_diff such that [-65535, 0]
|
||||
// correspond to [-10.0, 0.0]
|
||||
QuantizeMultiplier(input_scale_beta_rescale, &op_data->input_multiplier,
|
||||
&input_left_shift);
|
||||
op_data->input_left_shift = input_left_shift;
|
||||
} else {
|
||||
int input_left_shift;
|
||||
tflite::PreprocessSoftmaxScaling(
|
||||
static_cast<double>(params->beta),
|
||||
static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
|
||||
&op_data->input_multiplier, &input_left_shift);
|
||||
op_data->input_left_shift = input_left_shift;
|
||||
op_data->diff_min =
|
||||
-1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
|
||||
op_data->input_left_shift);
|
||||
}
|
||||
} else {
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
|
||||
op_data->beta = static_cast<double>(params->beta);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Takes a tensor and performs softmax along the last dimension.
|
||||
void SoftmaxFloat(const TfLiteEvalTensor* input, TfLiteEvalTensor* output,
|
||||
const SoftmaxParams& op_data) {
|
||||
tflite::reference_ops::Softmax(op_data, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
}
|
||||
|
||||
void SoftmaxQuantized(const TfLiteEvalTensor* input, TfLiteEvalTensor* output,
|
||||
const SoftmaxParams& op_data) {
|
||||
if (input->type == kTfLiteUInt8) {
|
||||
@ -129,60 +59,6 @@ void SoftmaxQuantized(const TfLiteEvalTensor* input, TfLiteEvalTensor* output,
|
||||
}
|
||||
}
|
||||
|
||||
void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(SoftmaxParams));
|
||||
}
|
||||
|
||||
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(context, node->user_data != nullptr);
|
||||
SoftmaxParams* op_data = static_cast<SoftmaxParams*>(node->user_data);
|
||||
// Only allocate LUTs for KTfLiteInt16 data type
|
||||
if (input->type == kTfLiteInt16) {
|
||||
void* raw_exp_lut = context->AllocatePersistentBuffer(
|
||||
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||
TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
|
||||
op_data->exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
|
||||
void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
|
||||
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||
TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
|
||||
op_data->one_over_one_plus_x_lut =
|
||||
reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
|
||||
}
|
||||
|
||||
if (output->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
|
||||
input->type == kTfLiteUInt8 ||
|
||||
input->type == kTfLiteInt16);
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
}
|
||||
|
||||
// Populate LUT if required
|
||||
if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
// exp LUT only used on negative values
|
||||
// we consider exp(-10.0) is insignificant to accumulation
|
||||
gen_lut([](float value) { return std::exp(value); }, -10.0f, 0.0f,
|
||||
op_data->exp_lut, kInt16LUTArraySize);
|
||||
gen_lut([](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f,
|
||||
op_data->one_over_one_plus_x_lut, kInt16LUTArraySize);
|
||||
op_data->zero_point = output->params.zero_point;
|
||||
op_data->scale = output->params.scale;
|
||||
}
|
||||
|
||||
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||
return CalculateSoftmaxParams(context, input, output, params, op_data);
|
||||
}
|
||||
|
||||
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||
@ -192,7 +68,11 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
SoftmaxFloat(input, output, op_data);
|
||||
tflite::reference_ops::Softmax(
|
||||
op_data, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case kTfLiteInt8:
|
||||
|
||||
30
tensorflow/lite/micro/kernels/softmax.h
Normal file
30
tensorflow/lite/micro/kernels/softmax.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_SOFTMAX_H_
|
||||
#define TENSORFLOW_LITE_MICRO_KERNELS_SOFTMAX_H_
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length);
|
||||
|
||||
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node);
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_SOFTMAX_H_
|
||||
145
tensorflow/lite/micro/kernels/softmax_common.cc
Normal file
145
tensorflow/lite/micro/kernels/softmax_common.cc
Normal file
@ -0,0 +1,145 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/softmax.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace {
|
||||
// Softmax parameter data that persists in user_data
|
||||
const int kInt16LUTArraySize = 513;
|
||||
|
||||
TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
|
||||
const TfLiteTensor* input,
|
||||
TfLiteTensor* output,
|
||||
const TfLiteSoftmaxParams* params,
|
||||
SoftmaxParams* op_data) {
|
||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
|
||||
input->type == kTfLiteInt16) {
|
||||
if (input->type == kTfLiteUInt8) {
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8);
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
} else if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 32768,
|
||||
(0.001f * 1.f / 32768));
|
||||
} else { // input->type == kTfLiteInt8
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
|
||||
if (output->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768);
|
||||
TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 65536,
|
||||
(0.001f * 1.f / 65536));
|
||||
} else { // output->type == kTfLiteint8
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
|
||||
TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
|
||||
}
|
||||
}
|
||||
|
||||
static const int kScaledDiffIntegerBits = 5;
|
||||
|
||||
// Calculate input_multiplier and input_left_shift
|
||||
if (input->type == kTfLiteInt16) {
|
||||
int input_left_shift;
|
||||
double input_scale_beta_rescale =
|
||||
static_cast<double>(input->params.scale) *
|
||||
static_cast<double>(params->beta) /
|
||||
(10.0 / 65535.0); // scale the input_diff such that [-65535, 0]
|
||||
// correspond to [-10.0, 0.0]
|
||||
QuantizeMultiplier(input_scale_beta_rescale, &op_data->input_multiplier,
|
||||
&input_left_shift);
|
||||
op_data->input_left_shift = input_left_shift;
|
||||
} else {
|
||||
int input_left_shift;
|
||||
tflite::PreprocessSoftmaxScaling(
|
||||
static_cast<double>(params->beta),
|
||||
static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
|
||||
&op_data->input_multiplier, &input_left_shift);
|
||||
op_data->input_left_shift = input_left_shift;
|
||||
op_data->diff_min =
|
||||
-1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
|
||||
op_data->input_left_shift);
|
||||
}
|
||||
} else {
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
|
||||
op_data->beta = static_cast<double>(params->beta);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(SoftmaxParams));
|
||||
}
|
||||
|
||||
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(context, node->user_data != nullptr);
|
||||
SoftmaxParams* op_data = static_cast<SoftmaxParams*>(node->user_data);
|
||||
// Only allocate LUTs for KTfLiteInt16 data type
|
||||
if (input->type == kTfLiteInt16) {
|
||||
void* raw_exp_lut = context->AllocatePersistentBuffer(
|
||||
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||
TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
|
||||
op_data->exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
|
||||
void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
|
||||
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||
TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
|
||||
op_data->one_over_one_plus_x_lut =
|
||||
reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
|
||||
}
|
||||
|
||||
if (output->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
|
||||
input->type == kTfLiteUInt8 ||
|
||||
input->type == kTfLiteInt16);
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
}
|
||||
|
||||
// Populate LUT if required
|
||||
if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
// exp LUT only used on negative values
|
||||
// we consider exp(-10.0) is insignificant to accumulation
|
||||
gen_lut([](float value) { return std::exp(value); }, -10.0f, 0.0f,
|
||||
op_data->exp_lut, kInt16LUTArraySize);
|
||||
gen_lut([](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f,
|
||||
op_data->one_over_one_plus_x_lut, kInt16LUTArraySize);
|
||||
op_data->zero_point = output->params.zero_point;
|
||||
op_data->scale = output->params.scale;
|
||||
}
|
||||
|
||||
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||
return CalculateSoftmaxParams(context, input, output, params, op_data);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@ -143,12 +143,13 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
void* SoftmaxInitXtensa(TfLiteContext* context, const char* buffer,
|
||||
size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||
}
|
||||
|
||||
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus SoftmaxPrepareXtensa(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
@ -195,9 +196,9 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_SOFTMAX() {
|
||||
return {/*init=*/SoftmaxInit,
|
||||
return {/*init=*/SoftmaxInitXtensa,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/SoftmaxPrepare,
|
||||
/*prepare=*/SoftmaxPrepareXtensa,
|
||||
/*invoke=*/SoftmaxEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
|
||||
@ -340,6 +340,7 @@ tensorflow/lite/micro/kernels/resize_nearest_neighbor.cc \
|
||||
tensorflow/lite/micro/kernels/round.cc \
|
||||
tensorflow/lite/micro/kernels/shape.cc \
|
||||
tensorflow/lite/micro/kernels/softmax.cc \
|
||||
tensorflow/lite/micro/kernels/softmax_common.cc \
|
||||
tensorflow/lite/micro/kernels/split.cc \
|
||||
tensorflow/lite/micro/kernels/split_v.cc \
|
||||
tensorflow/lite/micro/kernels/strided_slice.cc \
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user