Add quantized TANH operation, mostly copied from LOGISTIC.
PiperOrigin-RevId: 313569148 Change-Id: Id7801e9afaa7cb10dc51234cba9bf4d9320a0dc5
This commit is contained in:
parent
3c9dfef469
commit
f0f84935e3
@ -481,6 +481,7 @@ cc_library(
|
||||
"reference/strided_slice.h",
|
||||
"reference/sub.h",
|
||||
"reference/svdf.h",
|
||||
"reference/tanh.h",
|
||||
],
|
||||
build_for_embedded = True,
|
||||
copts = tflite_copts(),
|
||||
@ -551,6 +552,7 @@ cc_library(
|
||||
"reference/softmax.h",
|
||||
"reference/strided_slice.h",
|
||||
"reference/sub.h",
|
||||
"reference/tanh.h",
|
||||
],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
|
@ -59,6 +59,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/sub.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/tanh.h"
|
||||
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
@ -1343,59 +1344,6 @@ inline void LogSoftmax(const SoftmaxParams& params,
|
||||
}
|
||||
}
|
||||
|
||||
inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
float val = input_data[i];
|
||||
float result = std::tanh(val);
|
||||
output_data[i] = result;
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience version that allows, for example, generated-code calls to be
|
||||
// uniform between data types.
|
||||
inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
|
||||
const float* input_data, const RuntimeShape& output_shape,
|
||||
float* output_data) {
|
||||
// Drop params: not needed.
|
||||
Tanh(input_shape, input_data, output_shape, output_data);
|
||||
}
|
||||
|
||||
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
|
||||
const int16* input_data, const RuntimeShape& output_shape,
|
||||
int16* output_data) {
|
||||
const int input_left_shift = params.input_left_shift;
|
||||
// Support for shifts is limited until we have a parameterized version of
|
||||
// SaturatingRoundingMultiplyByPOT().
|
||||
TFLITE_DCHECK_GE(input_left_shift, 0);
|
||||
TFLITE_DCHECK_LE(input_left_shift, 1);
|
||||
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
|
||||
// F0 uses 0 integer bits, range [-1, 1].
|
||||
// This is the return type of math functions such as tanh, logistic,
|
||||
// whose range is in [-1, 1].
|
||||
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
|
||||
// F3 uses 3 integer bits, range [-8, 8], the input range expected here.
|
||||
using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
|
||||
|
||||
if (input_left_shift == 0) {
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
F3 input = F3::FromRaw(input_data[i]);
|
||||
F0 output = gemmlowp::tanh(input);
|
||||
output_data[i] = output.raw();
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
F3 input = F3::FromRaw(
|
||||
gemmlowp::SaturatingRoundingMultiplyByPOT<1>(input_data[i]));
|
||||
F0 output = gemmlowp::tanh(input);
|
||||
output_data[i] = output.raw();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void Dequantize(const RuntimeShape& input_shape,
|
||||
const Eigen::half* input_data,
|
||||
|
86
tensorflow/lite/kernels/internal/reference/tanh.h
Normal file
86
tensorflow/lite/kernels/internal/reference/tanh.h
Normal file
@ -0,0 +1,86 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_TANH_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_TANH_H_
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "fixedpoint/fixedpoint.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/cppmath.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace reference_ops {
|
||||
|
||||
inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
float val = input_data[i];
|
||||
float result = std::tanh(val);
|
||||
output_data[i] = result;
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience version that allows, for example, generated-code calls to be
|
||||
// uniform between data types.
|
||||
inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
|
||||
const float* input_data, const RuntimeShape& output_shape,
|
||||
float* output_data) {
|
||||
// Drop params: not needed.
|
||||
Tanh(input_shape, input_data, output_shape, output_data);
|
||||
}
|
||||
|
||||
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
|
||||
const int16* input_data, const RuntimeShape& output_shape,
|
||||
int16* output_data) {
|
||||
const int input_left_shift = params.input_left_shift;
|
||||
// Support for shifts is limited until we have a parameterized version of
|
||||
// SaturatingRoundingMultiplyByPOT().
|
||||
TFLITE_DCHECK_GE(input_left_shift, 0);
|
||||
TFLITE_DCHECK_LE(input_left_shift, 1);
|
||||
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
|
||||
// F0 uses 0 integer bits, range [-1, 1].
|
||||
// This is the return type of math functions such as tanh, logistic,
|
||||
// whose range is in [-1, 1].
|
||||
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
|
||||
// F3 uses 3 integer bits, range [-8, 8], the input range expected here.
|
||||
using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
|
||||
|
||||
if (input_left_shift == 0) {
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
F3 input = F3::FromRaw(input_data[i]);
|
||||
F0 output = gemmlowp::tanh(input);
|
||||
output_data[i] = output.raw();
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
F3 input = F3::FromRaw(
|
||||
gemmlowp::SaturatingRoundingMultiplyByPOT<1>(input_data[i]));
|
||||
F0 output = gemmlowp::tanh(input);
|
||||
output_data[i] = output.raw();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_TANH_H_
|
@ -51,6 +51,7 @@ cc_library(
|
||||
"split.cc",
|
||||
"strided_slice.cc",
|
||||
"sub.cc",
|
||||
"tanh.cc",
|
||||
"unpack.cc",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
@ -153,6 +154,7 @@ cc_library(
|
||||
"strided_slice.cc",
|
||||
"sub.cc",
|
||||
"svdf.cc",
|
||||
"tanh.cc",
|
||||
"unpack.cc",
|
||||
],
|
||||
hdrs = ["micro_ops.h"],
|
||||
@ -656,3 +658,14 @@ tflite_micro_cc_test(
|
||||
"//tensorflow/lite/micro/testing:micro_test",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_micro_cc_test(
|
||||
name = "tanh_test",
|
||||
srcs = ["tanh_test.cc"],
|
||||
deps = [
|
||||
":all_ops_resolver",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/micro:micro_framework",
|
||||
"//tensorflow/lite/micro/testing:micro_test",
|
||||
],
|
||||
)
|
||||
|
@ -106,9 +106,6 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalLogical(context, node, [](bool v) { return !v; });
|
||||
}
|
||||
|
||||
TfLiteStatus TANHEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalNumeric(context, node, std::tanh);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace elementwise
|
||||
@ -225,20 +222,6 @@ TfLiteRegistration* Register_LOGICAL_NOT() {
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_TANH() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
/*invoke=*/elementwise::TANHEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
128
tensorflow/lite/micro/kernels/tanh.cc
Normal file
128
tensorflow/lite/micro/kernels/tanh.cc
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/tanh.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace activations {
|
||||
namespace {
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
struct OpData {
|
||||
int32_t input_zero_point;
|
||||
int32_t input_range_radius;
|
||||
int32_t input_multiplier;
|
||||
int input_left_shift;
|
||||
};
|
||||
|
||||
TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
|
||||
OpData* data) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
if (input->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
|
||||
// The number if input integer bits is set to be consistent with the
|
||||
// required value in reference_integer_ops::Tanh
|
||||
static constexpr int kInputIntegerBits = 4;
|
||||
const double input_real_multiplier =
|
||||
static_cast<double>(input->params.scale) *
|
||||
static_cast<double>(1 << (31 - kInputIntegerBits));
|
||||
|
||||
const double q = std::frexp(input_real_multiplier, &data->input_left_shift);
|
||||
data->input_multiplier = static_cast<int32_t>(TfLiteRound(q * (1ll << 31)));
|
||||
|
||||
data->input_range_radius =
|
||||
CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 31);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
OpData data;
|
||||
CalculateArithmeticOpData(context, node, &data);
|
||||
|
||||
if (input->type == kTfLiteFloat32) {
|
||||
switch (output->type) {
|
||||
case kTfLiteFloat32: {
|
||||
reference_ops::Tanh(GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(output),
|
||||
GetTensorData<float>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else if (input->type == kTfLiteInt8) {
|
||||
switch (output->type) {
|
||||
case kTfLiteInt8: {
|
||||
reference_integer_ops::Tanh(
|
||||
input->params.zero_point, data.input_range_radius,
|
||||
data.input_multiplier, data.input_left_shift,
|
||||
NumElements(input->dims), GetTensorData<int8_t>(input),
|
||||
GetTensorData<int8_t>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||
TfLiteTypeGetName(input->type),
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace activations
|
||||
|
||||
TfLiteRegistration* Register_TANH() {
|
||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/nullptr,
|
||||
/*invoke=*/activations::TanhEval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
return &r;
|
||||
}
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
220
tensorflow/lite/micro/kernels/tanh_test.cc
Normal file
220
tensorflow/lite/micro/kernels/tanh_test.cc
Normal file
@ -0,0 +1,220 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/kernels/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||
#include "tensorflow/lite/micro/testing/test_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
void TestTanhFloat(std::initializer_list<int> input_dims_data,
|
||||
std::initializer_list<float> input_data,
|
||||
std::initializer_list<float> expected_output_data,
|
||||
std::initializer_list<int> output_dims_data,
|
||||
float* output_data) {
|
||||
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
|
||||
TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
|
||||
const int output_elements_count = ElementCount(*output_dims);
|
||||
|
||||
constexpr int inputs_size = 1;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
TfLiteTensor tensors[tensors_size] = {
|
||||
CreateFloatTensor(input_data, input_dims, "input_tensor"),
|
||||
CreateFloatTensor(output_data, output_dims, "output_tensor"),
|
||||
};
|
||||
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
|
||||
::tflite::ops::micro::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration =
|
||||
resolver.FindOp(tflite::BuiltinOperator_TANH, 1);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
const char* init_data = nullptr;
|
||||
size_t init_data_size = 0;
|
||||
void* user_data = nullptr;
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, init_data, init_data_size);
|
||||
}
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
if (registration->free) {
|
||||
registration->free(&context, user_data);
|
||||
}
|
||||
for (int i = 0; i < output_elements_count; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
|
||||
1e-5f);
|
||||
}
|
||||
}
|
||||
|
||||
void TestTanhInt8(std::initializer_list<int> input_dims_data,
|
||||
std::initializer_list<int8_t> input_data, float input_min,
|
||||
float input_max,
|
||||
std::initializer_list<int8_t> expected_output_data,
|
||||
std::initializer_list<int> output_dims_data, float output_min,
|
||||
float output_max, int8_t* output_data) {
|
||||
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
|
||||
TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
|
||||
const int output_elements_count = ElementCount(*output_dims);
|
||||
|
||||
constexpr int inputs_size = 1;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
TfLiteTensor tensors[tensors_size] = {
|
||||
CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min,
|
||||
input_max),
|
||||
CreateQuantizedTensor(output_data, output_dims, "output_tensor",
|
||||
output_min, output_max),
|
||||
};
|
||||
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
|
||||
::tflite::ops::micro::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration =
|
||||
resolver.FindOp(tflite::BuiltinOperator_TANH, 1);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
const char* init_data = nullptr;
|
||||
size_t init_data_size = 1;
|
||||
void* user_data = nullptr;
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, init_data, init_data_size);
|
||||
}
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
if (registration->free) {
|
||||
registration->free(&context, user_data);
|
||||
}
|
||||
for (int i = 0; i < output_elements_count; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
|
||||
1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
TF_LITE_MICRO_TESTS_BEGIN
|
||||
|
||||
TF_LITE_MICRO_TEST(SimpleTestFloat) {
|
||||
const int output_elements_count = 10;
|
||||
float output_data[output_elements_count];
|
||||
tflite::testing::TestTanhFloat({2, 1, 5}, // Input shape.
|
||||
{
|
||||
1.0,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
93.0,
|
||||
-1.0,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-93.0,
|
||||
},
|
||||
{
|
||||
// Expected results.
|
||||
0.76159416,
|
||||
0.96402758,
|
||||
0.99505475,
|
||||
0.9993293,
|
||||
1.0,
|
||||
-0.76159416,
|
||||
-0.96402758,
|
||||
-0.99505475,
|
||||
-0.9993293,
|
||||
-1.0,
|
||||
},
|
||||
{2, 1, 5}, // Output shape.
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(SimpleTestInt8) {
|
||||
using tflite::testing::F2QS;
|
||||
|
||||
const float input_min = -31.75f;
|
||||
const float input_max = 32.0f;
|
||||
const float output_min = -1.0f;
|
||||
const float output_max = (127.0f / 128.0f);
|
||||
|
||||
const int output_elements_count = 10;
|
||||
int8_t output_data[output_elements_count];
|
||||
tflite::testing::TestTanhInt8(
|
||||
{2, 1, output_elements_count}, // Input shape.
|
||||
{F2QS(1.0, input_min, input_max), F2QS(2.0, input_min, input_max),
|
||||
F2QS(3.0, input_min, input_max), F2QS(4.0, input_min, input_max),
|
||||
F2QS(5.0, input_min, input_max), F2QS(-1.0, input_min, input_max),
|
||||
F2QS(-2.0, input_min, input_max), F2QS(-3.0, input_min, input_max),
|
||||
F2QS(-4.0, input_min, input_max), F2QS(-5.0, input_min, input_max)},
|
||||
input_min, input_max, // Input quantized range.
|
||||
{ // Expected results.
|
||||
F2QS(0.76159416, output_min, output_max),
|
||||
F2QS(0.96402758, output_min, output_max),
|
||||
F2QS(0.99505475, output_min, output_max),
|
||||
F2QS(0.9993293, output_min, output_max),
|
||||
F2QS(0.9999092, output_min, output_max),
|
||||
F2QS(-0.76159416, output_min, output_max),
|
||||
F2QS(-0.96402758, output_min, output_max),
|
||||
F2QS(-0.99505475, output_min, output_max),
|
||||
F2QS(-0.9993293, output_min, output_max),
|
||||
F2QS(-0.9999092, output_min, output_max)},
|
||||
{2, 1, output_elements_count}, // Output shape.
|
||||
output_min, output_max, // Output quantized range.
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TESTS_END
|
@ -164,6 +164,8 @@ tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h \
|
||||
tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h \
|
||||
tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h \
|
||||
tensorflow/lite/kernels/internal/reference/integer_ops/mul.h \
|
||||
tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h \
|
||||
tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h \
|
||||
tensorflow/lite/kernels/internal/reference/l2normalization.h \
|
||||
tensorflow/lite/kernels/internal/reference/maximum_minimum.h \
|
||||
tensorflow/lite/kernels/internal/reference/mul.h \
|
||||
@ -181,7 +183,7 @@ tensorflow/lite/kernels/internal/reference/softmax.h \
|
||||
tensorflow/lite/kernels/internal/reference/sub.h \
|
||||
tensorflow/lite/kernels/internal/reference/logistic.h \
|
||||
tensorflow/lite/kernels/internal/reference/strided_slice.h \
|
||||
tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h \
|
||||
tensorflow/lite/kernels/internal/reference/tanh.h \
|
||||
tensorflow/lite/kernels/internal/cppmath.h \
|
||||
tensorflow/lite/kernels/internal/strided_slice_logic.h \
|
||||
tensorflow/lite/kernels/internal/tensor.h \
|
||||
|
Loading…
Reference in New Issue
Block a user