Merge pull request #43320 from tensorflow:revert-43127-revert-38873-tflu_softnax_int16_ref
PiperOrigin-RevId: 333185684 Change-Id: I4d8a2845286ba905b16cebda204ef006c1a8e535
This commit is contained in:
commit
72c2f694be
@ -226,6 +226,17 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a);
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
#define TF_LITE_ENSURE_NEAR(context, a, b, epsilon) \
|
||||||
|
do { \
|
||||||
|
auto delta = ((a) > (b)) ? ((a) - (b)) : ((b) - (a)); \
|
||||||
|
if (delta > epsilon) { \
|
||||||
|
TF_LITE_KERNEL_LOG((context), "%s:%d %s not near %s (%f != %f)", \
|
||||||
|
__FILE__, __LINE__, #a, #b, static_cast<double>(a), \
|
||||||
|
static_cast<double>(b)); \
|
||||||
|
return kTfLiteError; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
#define TF_LITE_ENSURE_OK(context, status) \
|
#define TF_LITE_ENSURE_OK(context, status) \
|
||||||
do { \
|
do { \
|
||||||
const TfLiteStatus s = (status); \
|
const TfLiteStatus s = (status); \
|
||||||
|
@ -241,8 +241,12 @@ inline Integer FloorLog2(Integer n) {
|
|||||||
|
|
||||||
// generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
|
// generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
|
||||||
// softmax
|
// softmax
|
||||||
inline void gen_lut(const std::function<double(double)>& func, double min,
|
// func - the function to build the LUT for (e.g exp(x))
|
||||||
double max, int16_t* table, const int num) {
|
// min,max - table limits
|
||||||
|
// table - pointer to buffer
|
||||||
|
// num - number of elements in the LUT
|
||||||
|
inline void gen_lut(double (*func)(double), double min, double max,
|
||||||
|
int16_t* table, const int num) {
|
||||||
// size of table should equal to num + 1
|
// size of table should equal to num + 1
|
||||||
// last element only for slope calculation
|
// last element only for slope calculation
|
||||||
double step = (max - min) / (num - 1);
|
double step = (max - min) / (num - 1);
|
||||||
@ -263,6 +267,34 @@ inline void gen_lut(const std::function<double(double)>& func, double min,
|
|||||||
std::min(std::max(TfLiteRound(func(max) * 32768.0), -32768.0), 32767.0);
|
std::min(std::max(TfLiteRound(func(max) * 32768.0), -32768.0), 32767.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
|
||||||
|
// softmax
|
||||||
|
// func - the function to build the LUT for (e.g exp(x))
|
||||||
|
// min,max - table limits
|
||||||
|
// table - pointer to buffer
|
||||||
|
// num - number of elements in the LUT
|
||||||
|
inline void gen_lut(float (*func)(float), float min, float max, int16_t* table,
|
||||||
|
const int num) {
|
||||||
|
// size of table should equal to num + 1
|
||||||
|
// last element only for slope calculation
|
||||||
|
float step = (max - min) / (num - 1);
|
||||||
|
float half_step = step / 2.0f;
|
||||||
|
for (int i = 0; i < num - 1; i++) {
|
||||||
|
float sample_val = TfLiteRound(func(min + i * step) * 32768.0f);
|
||||||
|
float midpoint_interp_val =
|
||||||
|
TfLiteRound((func(min + (i + 1) * step) * 32768.0f +
|
||||||
|
TfLiteRound(func(min + i * step) * 32768.0f)) /
|
||||||
|
2.0f);
|
||||||
|
float midpoint_val =
|
||||||
|
TfLiteRound(func(min + i * step + half_step) * 32768.0f);
|
||||||
|
float midpoint_err = midpoint_interp_val - midpoint_val;
|
||||||
|
float bias = TfLiteRound(midpoint_err / 2.0f);
|
||||||
|
table[i] = std::min(std::max(sample_val - bias, -32768.0f), 32767.0f);
|
||||||
|
}
|
||||||
|
table[num - 1] = std::min(
|
||||||
|
std::max(TfLiteRound(func(max) * 32768.0f), -32768.0f), 32767.0f);
|
||||||
|
}
|
||||||
|
|
||||||
// int16_t func table lookup, e.g., lookup exp() and 1/(1+x) used in softmax
|
// int16_t func table lookup, e.g., lookup exp() and 1/(1+x) used in softmax
|
||||||
inline int16_t generic_int16_table_lookup(int16_t value, const int16_t* lut) {
|
inline int16_t generic_int16_table_lookup(int16_t value, const int16_t* lut) {
|
||||||
// 512 base value, lut[513] only for calculate slope
|
// 512 base value, lut[513] only for calculate slope
|
||||||
|
@ -1044,7 +1044,9 @@ struct SoftmaxParams {
|
|||||||
int32_t zero_point;
|
int32_t zero_point;
|
||||||
float scale;
|
float scale;
|
||||||
float* table;
|
float* table;
|
||||||
|
// int16 LUT for exp(x), where x uniform distributed between [-10.0 , 0.0]
|
||||||
int16_t* exp_lut;
|
int16_t* exp_lut;
|
||||||
|
// int16 LUT for 1 / (1 + x), where x uniform distributed between [0.0 , 1.0]
|
||||||
int16_t* one_over_one_plus_x_lut;
|
int16_t* one_over_one_plus_x_lut;
|
||||||
uint8_t* uint8_table1;
|
uint8_t* uint8_table1;
|
||||||
uint8_t* uint8_table2;
|
uint8_t* uint8_table2;
|
||||||
|
@ -30,23 +30,30 @@ namespace micro {
|
|||||||
namespace activations {
|
namespace activations {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Softmax parameter data that persists in user_data
|
||||||
|
static constexpr int kInt16LUTArraySize = 513;
|
||||||
|
|
||||||
TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
|
TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
|
||||||
const TfLiteTensor* input,
|
const TfLiteTensor* input,
|
||||||
TfLiteTensor* output,
|
TfLiteTensor* output,
|
||||||
const TfLiteSoftmaxParams* params,
|
const TfLiteSoftmaxParams* params,
|
||||||
SoftmaxParams* op_data) {
|
SoftmaxParams* op_data) {
|
||||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
|
||||||
|
input->type == kTfLiteInt16) {
|
||||||
if (input->type == kTfLiteUInt8) {
|
if (input->type == kTfLiteUInt8) {
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8);
|
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8);
|
||||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||||
} else {
|
} else if (input->type == kTfLiteInt16) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||||
|
TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 32768,
|
||||||
|
(0.001f * 1.f / 32768));
|
||||||
|
} else { // input->type == kTfLiteInt8
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
|
||||||
if (output->type == kTfLiteInt16) {
|
if (output->type == kTfLiteInt16) {
|
||||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768);
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768);
|
||||||
// NOTE: Current int16_t softmax output does not require symmetric
|
TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 65536,
|
||||||
// scaling
|
(0.001f * 1.f / 65536));
|
||||||
// - so no need to verify scale here.
|
} else { // output->type == kTfLiteint8
|
||||||
} else {
|
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
|
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
|
||||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
|
||||||
TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
|
TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
|
||||||
@ -55,6 +62,18 @@ TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
|
|||||||
|
|
||||||
static const int kScaledDiffIntegerBits = 5;
|
static const int kScaledDiffIntegerBits = 5;
|
||||||
|
|
||||||
|
// Calculate input_multiplier and input_left_shift
|
||||||
|
if (input->type == kTfLiteInt16) {
|
||||||
|
int input_left_shift;
|
||||||
|
double input_scale_beta_rescale =
|
||||||
|
static_cast<double>(input->params.scale) *
|
||||||
|
static_cast<double>(params->beta) /
|
||||||
|
(10.0 / 65535.0); // scale the input_diff such that [-65535, 0]
|
||||||
|
// correspond to [-10.0, 0.0]
|
||||||
|
QuantizeMultiplier(input_scale_beta_rescale, &op_data->input_multiplier,
|
||||||
|
&input_left_shift);
|
||||||
|
op_data->input_left_shift = input_left_shift;
|
||||||
|
} else {
|
||||||
int input_left_shift;
|
int input_left_shift;
|
||||||
tflite::PreprocessSoftmaxScaling(
|
tflite::PreprocessSoftmaxScaling(
|
||||||
static_cast<double>(params->beta),
|
static_cast<double>(params->beta),
|
||||||
@ -64,6 +83,7 @@ TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
|
|||||||
op_data->diff_min =
|
op_data->diff_min =
|
||||||
-1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
|
-1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
|
||||||
op_data->input_left_shift);
|
op_data->input_left_shift);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
|
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
|
||||||
@ -91,7 +111,7 @@ void SoftmaxQuantized(const TfLiteEvalTensor* input, TfLiteEvalTensor* output,
|
|||||||
tflite::micro::GetTensorData<uint8_t>(input),
|
tflite::micro::GetTensorData<uint8_t>(input),
|
||||||
tflite::micro::GetTensorShape(output),
|
tflite::micro::GetTensorShape(output),
|
||||||
tflite::micro::GetTensorData<uint8_t>(output));
|
tflite::micro::GetTensorData<uint8_t>(output));
|
||||||
} else {
|
} else if (input->type == kTfLiteInt8) {
|
||||||
if (output->type == kTfLiteInt16) {
|
if (output->type == kTfLiteInt16) {
|
||||||
tflite::reference_ops::Softmax(
|
tflite::reference_ops::Softmax(
|
||||||
op_data, tflite::micro::GetTensorShape(input),
|
op_data, tflite::micro::GetTensorShape(input),
|
||||||
@ -105,6 +125,12 @@ void SoftmaxQuantized(const TfLiteEvalTensor* input, TfLiteEvalTensor* output,
|
|||||||
tflite::micro::GetTensorShape(output),
|
tflite::micro::GetTensorShape(output),
|
||||||
tflite::micro::GetTensorData<int8_t>(output));
|
tflite::micro::GetTensorData<int8_t>(output));
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
tflite::reference_ops::SoftmaxInt16(
|
||||||
|
op_data, tflite::micro::GetTensorShape(input),
|
||||||
|
tflite::micro::GetTensorData<int16_t>(input),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<int16_t>(output));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,20 +140,52 @@ void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
|
||||||
|
|
||||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||||
TF_LITE_ENSURE(context, input != nullptr);
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
||||||
|
|
||||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
TFLITE_DCHECK(node->user_data != nullptr);
|
TF_LITE_ENSURE(context, node->user_data != nullptr);
|
||||||
SoftmaxParams* data = static_cast<SoftmaxParams*>(node->user_data);
|
SoftmaxParams* op_data = static_cast<SoftmaxParams*>(node->user_data);
|
||||||
return CalculateSoftmaxParams(context, input, output, params, data);
|
// Only allocate LUTs for KTfLiteInt16 data type
|
||||||
|
if (input->type == kTfLiteInt16) {
|
||||||
|
void* raw_exp_lut = context->AllocatePersistentBuffer(
|
||||||
|
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||||
|
TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
|
||||||
|
op_data->exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
|
||||||
|
void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
|
||||||
|
context, sizeof(int16_t) * kInt16LUTArraySize);
|
||||||
|
TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
|
||||||
|
op_data->one_over_one_plus_x_lut =
|
||||||
|
reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (output->type == kTfLiteInt16) {
|
||||||
|
TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
|
||||||
|
input->type == kTfLiteUInt8 ||
|
||||||
|
input->type == kTfLiteInt16);
|
||||||
|
} else {
|
||||||
|
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate LUT if required
|
||||||
|
if (input->type == kTfLiteInt16) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||||
|
// exp LUT only used on negative values
|
||||||
|
// we consider exp(-10.0) is insignificant to accumulation
|
||||||
|
gen_lut([](float value) { return std::exp(value); }, -10.0f, 0.0f,
|
||||||
|
op_data->exp_lut, kInt16LUTArraySize);
|
||||||
|
gen_lut([](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f,
|
||||||
|
op_data->one_over_one_plus_x_lut, kInt16LUTArraySize);
|
||||||
|
op_data->zero_point = output->params.zero_point;
|
||||||
|
op_data->scale = output->params.scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||||
|
return CalculateSoftmaxParams(context, input, output, params, op_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
@ -135,16 +193,17 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||||
|
|
||||||
TFLITE_DCHECK(node->user_data != nullptr);
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
SoftmaxParams* data = static_cast<SoftmaxParams*>(node->user_data);
|
SoftmaxParams op_data = *static_cast<SoftmaxParams*>(node->user_data);
|
||||||
|
|
||||||
switch (input->type) {
|
switch (input->type) {
|
||||||
case kTfLiteFloat32: {
|
case kTfLiteFloat32: {
|
||||||
SoftmaxFloat(input, output, *data);
|
SoftmaxFloat(input, output, op_data);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
case kTfLiteUInt8: {
|
case kTfLiteUInt8:
|
||||||
SoftmaxQuantized(input, output, *data);
|
case kTfLiteInt16: {
|
||||||
|
SoftmaxQuantized(input, output, op_data);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -28,8 +28,13 @@ namespace {
|
|||||||
// quantization parameters.
|
// quantization parameters.
|
||||||
const float output_scale_int8 = 1.0f / 256.0f;
|
const float output_scale_int8 = 1.0f / 256.0f;
|
||||||
const float output_scale_uint8 = 1.0f / 256.0f;
|
const float output_scale_uint8 = 1.0f / 256.0f;
|
||||||
|
const float output_scale_int16 = 1.0f / 32768.0f;
|
||||||
const int output_zero_point_int8 = -128;
|
const int output_zero_point_int8 = -128;
|
||||||
const int output_zero_point_uint8 = 0;
|
const int output_zero_point_uint8 = 0;
|
||||||
|
const int output_zero_point_int16 = 0;
|
||||||
|
|
||||||
|
// Empirical tolerance in quantization space
|
||||||
|
const float tolerance_int16 = 7.0;
|
||||||
|
|
||||||
// 1-dimensional test data.
|
// 1-dimensional test data.
|
||||||
const int flat_size_1d = 5;
|
const int flat_size_1d = 5;
|
||||||
@ -291,7 +296,7 @@ void TestSoftmaxQuantized(const int* input_dims_data, const float* input_data,
|
|||||||
int input_zero_point, const int* output_dims_data,
|
int input_zero_point, const int* output_dims_data,
|
||||||
const float* golden, T* golden_quantized,
|
const float* golden, T* golden_quantized,
|
||||||
float output_scale, int output_zero_point,
|
float output_scale, int output_zero_point,
|
||||||
T* output_data) {
|
T* output_data, float tolerance = 1.0) {
|
||||||
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||||
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
|
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
|
||||||
const int output_dims_count = ElementCount(*output_dims);
|
const int output_dims_count = ElementCount(*output_dims);
|
||||||
@ -310,7 +315,7 @@ void TestSoftmaxQuantized(const int* input_dims_data, const float* input_data,
|
|||||||
output_zero_point);
|
output_zero_point);
|
||||||
|
|
||||||
ValidateSoftmaxGoldens(tensors, tensors_size, output_data, golden_quantized,
|
ValidateSoftmaxGoldens(tensors, tensors_size, output_data, golden_quantized,
|
||||||
output_dims_count, 1.0);
|
output_dims_count, tolerance);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -356,6 +361,21 @@ TF_LITE_MICRO_TEST(Softmax1DQuantizedInt8ShouldMatchGolden) {
|
|||||||
tflite::testing::output_zero_point_int8, output_data);
|
tflite::testing::output_zero_point_int8, output_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(Softmax1DQuantizedInt16ShouldMatchGolden) {
|
||||||
|
const float input_scale = 0.1f;
|
||||||
|
const int input_zero_point = 0;
|
||||||
|
|
||||||
|
int16_t input_quantized[tflite::testing::flat_size_1d];
|
||||||
|
int16_t golden_quantized[tflite::testing::flat_size_1d];
|
||||||
|
int16_t output_data[tflite::testing::flat_size_1d];
|
||||||
|
tflite::testing::TestSoftmaxQuantized(
|
||||||
|
tflite::testing::shape_1d, tflite::testing::input_data_1d,
|
||||||
|
input_quantized, input_scale, input_zero_point, tflite::testing::shape_1d,
|
||||||
|
tflite::testing::golden_1d, golden_quantized,
|
||||||
|
tflite::testing::output_scale_int16,
|
||||||
|
tflite::testing::output_zero_point_int16, output_data);
|
||||||
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(Softmax2DFloatShouldMatchGolden) {
|
TF_LITE_MICRO_TEST(Softmax2DFloatShouldMatchGolden) {
|
||||||
float output_data[tflite::testing::flat_size_2d];
|
float output_data[tflite::testing::flat_size_2d];
|
||||||
tflite::testing::TestSoftmaxFloat(
|
tflite::testing::TestSoftmaxFloat(
|
||||||
@ -393,6 +413,21 @@ TF_LITE_MICRO_TEST(Softmax2DQuantizedInt8ShouldMatchGolden) {
|
|||||||
tflite::testing::output_zero_point_int8, output_data);
|
tflite::testing::output_zero_point_int8, output_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(Softmax2DQuantizedInt16ShouldMatchGolden) {
|
||||||
|
const float input_scale = 0.1f;
|
||||||
|
const int input_zero_point = 0;
|
||||||
|
|
||||||
|
int16_t input_quantized[tflite::testing::flat_size_2d];
|
||||||
|
int16_t golden_quantized[tflite::testing::flat_size_2d];
|
||||||
|
int16_t output_data[tflite::testing::flat_size_2d];
|
||||||
|
tflite::testing::TestSoftmaxQuantized(
|
||||||
|
tflite::testing::shape_2d, tflite::testing::input_data_2d,
|
||||||
|
input_quantized, input_scale, input_zero_point, tflite::testing::shape_2d,
|
||||||
|
tflite::testing::golden_2d, golden_quantized,
|
||||||
|
tflite::testing::output_scale_int16,
|
||||||
|
tflite::testing::output_zero_point_int16, output_data);
|
||||||
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(Softmax3DFloatShouldMatchGolden) {
|
TF_LITE_MICRO_TEST(Softmax3DFloatShouldMatchGolden) {
|
||||||
float output_data[tflite::testing::flat_size_3d];
|
float output_data[tflite::testing::flat_size_3d];
|
||||||
tflite::testing::TestSoftmaxFloat(
|
tflite::testing::TestSoftmaxFloat(
|
||||||
@ -430,6 +465,22 @@ TF_LITE_MICRO_TEST(Softmax3DQuantizedInt8ShouldMatchGolden) {
|
|||||||
tflite::testing::output_zero_point_int8, output_data);
|
tflite::testing::output_zero_point_int8, output_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(Softmax3DQuantizedInt16ShouldMatchGolden) {
|
||||||
|
const float input_scale = 0.1f;
|
||||||
|
const int input_zero_point = 0;
|
||||||
|
|
||||||
|
int16_t input_quantized[tflite::testing::flat_size_3d];
|
||||||
|
int16_t golden_quantized[tflite::testing::flat_size_3d];
|
||||||
|
int16_t output_data[tflite::testing::flat_size_3d];
|
||||||
|
tflite::testing::TestSoftmaxQuantized(
|
||||||
|
tflite::testing::shape_3d, tflite::testing::input_data_3d,
|
||||||
|
input_quantized, input_scale, input_zero_point, tflite::testing::shape_3d,
|
||||||
|
tflite::testing::golden_3d, golden_quantized,
|
||||||
|
tflite::testing::output_scale_int16,
|
||||||
|
tflite::testing::output_zero_point_int16, output_data,
|
||||||
|
tflite::testing::tolerance_int16);
|
||||||
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(Softmax4DFloatShouldMatchGolden) {
|
TF_LITE_MICRO_TEST(Softmax4DFloatShouldMatchGolden) {
|
||||||
float output_data[tflite::testing::flat_size_4d];
|
float output_data[tflite::testing::flat_size_4d];
|
||||||
tflite::testing::TestSoftmaxFloat(
|
tflite::testing::TestSoftmaxFloat(
|
||||||
@ -467,4 +518,19 @@ TF_LITE_MICRO_TEST(Softmax4DQuantizedInt8ShouldMatchGolden) {
|
|||||||
tflite::testing::output_zero_point_int8, output_data);
|
tflite::testing::output_zero_point_int8, output_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(Softmax4DQuantizedInt16ShouldMatchGolden) {
|
||||||
|
const float input_scale = 0.1f;
|
||||||
|
const int input_zero_point = 0;
|
||||||
|
|
||||||
|
int16_t input_quantized[tflite::testing::flat_size_4d];
|
||||||
|
int16_t golden_quantized[tflite::testing::flat_size_4d];
|
||||||
|
int16_t output_data[tflite::testing::flat_size_4d];
|
||||||
|
tflite::testing::TestSoftmaxQuantized(
|
||||||
|
tflite::testing::shape_4d, tflite::testing::input_data_4d,
|
||||||
|
input_quantized, input_scale, input_zero_point, tflite::testing::shape_4d,
|
||||||
|
tflite::testing::golden_4d, golden_quantized,
|
||||||
|
tflite::testing::output_scale_int16,
|
||||||
|
tflite::testing::output_zero_point_int16, output_data,
|
||||||
|
tflite::testing::tolerance_int16);
|
||||||
|
}
|
||||||
TF_LITE_MICRO_TESTS_END
|
TF_LITE_MICRO_TESTS_END
|
||||||
|
@ -226,6 +226,17 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a);
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
#define TF_LITE_ENSURE_NEAR(context, a, b, epsilon) \
|
||||||
|
do { \
|
||||||
|
auto delta = ((a) > (b)) ? ((a) - (b)) : ((b) - (a)); \
|
||||||
|
if (delta > epsilon) { \
|
||||||
|
TF_LITE_KERNEL_LOG((context), "%s:%d %s not near %s (%f != %f)", \
|
||||||
|
__FILE__, __LINE__, #a, #b, static_cast<double>(a), \
|
||||||
|
static_cast<double>(b)); \
|
||||||
|
return kTfLiteError; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
#define TF_LITE_ENSURE_OK(context, status) \
|
#define TF_LITE_ENSURE_OK(context, status) \
|
||||||
do { \
|
do { \
|
||||||
const TfLiteStatus s = (status); \
|
const TfLiteStatus s = (status); \
|
||||||
|
Loading…
Reference in New Issue
Block a user