diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index d6a96efdbf7..51b58f92de1 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -481,6 +481,7 @@ cc_library( "reference/strided_slice.h", "reference/sub.h", "reference/svdf.h", + "reference/tanh.h", ], build_for_embedded = True, copts = tflite_copts(), @@ -551,6 +552,7 @@ cc_library( "reference/softmax.h", "reference/strided_slice.h", "reference/sub.h", + "reference/tanh.h", ], copts = tflite_copts(), deps = [ diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 1a6c6d0d80e..e991a21e3bd 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -59,6 +59,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/softmax.h" #include "tensorflow/lite/kernels/internal/reference/strided_slice.h" #include "tensorflow/lite/kernels/internal/reference/sub.h" +#include "tensorflow/lite/kernels/internal/reference/tanh.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -1343,59 +1344,6 @@ inline void LogSoftmax(const SoftmaxParams& params, } } -inline void Tanh(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { - const int flat_size = MatchingFlatSize(input_shape, output_shape); - - for (int i = 0; i < flat_size; i++) { - float val = input_data[i]; - float result = std::tanh(val); - output_data[i] = result; - } -} - -// Convenience version that allows, for example, generated-code calls to be -// uniform between data types. -inline void Tanh(const TanhParams&, const RuntimeShape& input_shape, - const float* input_data, const RuntimeShape& output_shape, - float* output_data) { - // Drop params: not needed. - Tanh(input_shape, input_data, output_shape, output_data); -} - -inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, - const int16* input_data, const RuntimeShape& output_shape, - int16* output_data) { - const int input_left_shift = params.input_left_shift; - // Support for shifts is limited until we have a parameterized version of - // SaturatingRoundingMultiplyByPOT(). - TFLITE_DCHECK_GE(input_left_shift, 0); - TFLITE_DCHECK_LE(input_left_shift, 1); - - const int flat_size = MatchingFlatSize(input_shape, output_shape); - - // F0 uses 0 integer bits, range [-1, 1]. - // This is the return type of math functions such as tanh, logistic, - // whose range is in [-1, 1]. - using F0 = gemmlowp::FixedPoint; - // F3 uses 3 integer bits, range [-8, 8], the input range expected here. - using F3 = gemmlowp::FixedPoint; - - if (input_left_shift == 0) { - for (int i = 0; i < flat_size; i++) { - F3 input = F3::FromRaw(input_data[i]); - F0 output = gemmlowp::tanh(input); - output_data[i] = output.raw(); - } - } else { - for (int i = 0; i < flat_size; i++) { - F3 input = F3::FromRaw( - gemmlowp::SaturatingRoundingMultiplyByPOT<1>(input_data[i])); - F0 output = gemmlowp::tanh(input); - output_data[i] = output.raw(); - } - } -} inline void Dequantize(const RuntimeShape& input_shape, const Eigen::half* input_data, diff --git a/tensorflow/lite/kernels/internal/reference/tanh.h b/tensorflow/lite/kernels/internal/reference/tanh.h new file mode 100644 index 00000000000..0f31d4ddeef --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/tanh.h @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_TANH_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_TANH_H_ + +#include + +#include "fixedpoint/fixedpoint.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/op_macros.h" + +namespace tflite { +namespace reference_ops { + +inline void Tanh(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + for (int i = 0; i < flat_size; i++) { + float val = input_data[i]; + float result = std::tanh(val); + output_data[i] = result; + } +} + +// Convenience version that allows, for example, generated-code calls to be +// uniform between data types. +inline void Tanh(const TanhParams&, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { + // Drop params: not needed. + Tanh(input_shape, input_data, output_shape, output_data); +} + +inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, + const int16* input_data, const RuntimeShape& output_shape, + int16* output_data) { + const int input_left_shift = params.input_left_shift; + // Support for shifts is limited until we have a parameterized version of + // SaturatingRoundingMultiplyByPOT(). + TFLITE_DCHECK_GE(input_left_shift, 0); + TFLITE_DCHECK_LE(input_left_shift, 1); + + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint; + // F3 uses 3 integer bits, range [-8, 8], the input range expected here. + using F3 = gemmlowp::FixedPoint; + + if (input_left_shift == 0) { + for (int i = 0; i < flat_size; i++) { + F3 input = F3::FromRaw(input_data[i]); + F0 output = gemmlowp::tanh(input); + output_data[i] = output.raw(); + } + } else { + for (int i = 0; i < flat_size; i++) { + F3 input = F3::FromRaw( + gemmlowp::SaturatingRoundingMultiplyByPOT<1>(input_data[i])); + F0 output = gemmlowp::tanh(input); + output_data[i] = output.raw(); + } + } +} + +} // namespace reference_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_TANH_H_ diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index b6c6054d604..bbb5c67d9e5 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -51,6 +51,7 @@ cc_library( "split.cc", "strided_slice.cc", "sub.cc", + "tanh.cc", "unpack.cc", ] + select({ "//conditions:default": [ @@ -153,6 +154,7 @@ cc_library( "strided_slice.cc", "sub.cc", "svdf.cc", + "tanh.cc", "unpack.cc", ], hdrs = ["micro_ops.h"], @@ -656,3 +658,14 @@ tflite_micro_cc_test( "//tensorflow/lite/micro/testing:micro_test", ], ) + +tflite_micro_cc_test( + name = "tanh_test", + srcs = ["tanh_test.cc"], + deps = [ + ":all_ops_resolver", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro/testing:micro_test", + ], +) diff --git a/tensorflow/lite/micro/kernels/elementwise.cc b/tensorflow/lite/micro/kernels/elementwise.cc index 93fc4ec0d88..b69d260a826 100644 --- a/tensorflow/lite/micro/kernels/elementwise.cc +++ b/tensorflow/lite/micro/kernels/elementwise.cc @@ -106,9 +106,6 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) { return EvalLogical(context, node, [](bool v) { return !v; }); } -TfLiteStatus TANHEval(TfLiteContext* context, TfLiteNode* node) { - return EvalNumeric(context, node, std::tanh); -} } // namespace } // namespace elementwise @@ -225,20 +222,6 @@ TfLiteRegistration* Register_LOGICAL_NOT() { return &r; } -TfLiteRegistration* Register_TANH() { - static TfLiteRegistration r = { - /*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/ - elementwise::GenericPrepare, - /*invoke=*/elementwise::TANHEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; - return &r; -} - } // namespace micro } // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/tanh.cc b/tensorflow/lite/micro/kernels/tanh.cc new file mode 100644 index 00000000000..9ee5b74bde4 --- /dev/null +++ b/tensorflow/lite/micro/kernels/tanh.cc @@ -0,0 +1,128 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/tanh.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace activations { +namespace { +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +struct OpData { + int32_t input_zero_point; + int32_t input_range_radius; + int32_t input_multiplier; + int input_left_shift; +}; + +TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node, + OpData* data) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, input->type, output->type); + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + + // The number if input integer bits is set to be consistent with the + // required value in reference_integer_ops::Tanh + static constexpr int kInputIntegerBits = 4; + const double input_real_multiplier = + static_cast(input->params.scale) * + static_cast(1 << (31 - kInputIntegerBits)); + + const double q = std::frexp(input_real_multiplier, &data->input_left_shift); + data->input_multiplier = static_cast(TfLiteRound(q * (1ll << 31))); + + data->input_range_radius = + CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 31); + } + return kTfLiteOk; +} +} // namespace + +TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + OpData data; + CalculateArithmeticOpData(context, node, &data); + + if (input->type == kTfLiteFloat32) { + switch (output->type) { + case kTfLiteFloat32: { + reference_ops::Tanh(GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), + GetTensorData(output)); + return kTfLiteOk; + } + default: + TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + } else if (input->type == kTfLiteInt8) { + switch (output->type) { + case kTfLiteInt8: { + reference_integer_ops::Tanh( + input->params.zero_point, data.input_range_radius, + data.input_multiplier, data.input_left_shift, + NumElements(input->dims), GetTensorData(input), + GetTensorData(output)); + return kTfLiteOk; + } + default: + TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + } else { + TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace activations + +TfLiteRegistration* Register_TANH() { + static TfLiteRegistration r = {/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/nullptr, + /*invoke=*/activations::TanhEval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; + return &r; +} +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/tanh_test.cc b/tensorflow/lite/micro/kernels/tanh_test.cc new file mode 100644 index 00000000000..2a367107771 --- /dev/null +++ b/tensorflow/lite/micro/kernels/tanh_test.cc @@ -0,0 +1,220 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/all_ops_resolver.h" +#include "tensorflow/lite/micro/testing/micro_test.h" +#include "tensorflow/lite/micro/testing/test_utils.h" + +namespace tflite { +namespace testing { +namespace { + +void TestTanhFloat(std::initializer_list input_dims_data, + std::initializer_list input_data, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, + float* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_elements_count = ElementCount(*output_dims); + + constexpr int inputs_size = 1; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateFloatTensor(input_data, input_dims, "input_tensor"), + CreateFloatTensor(output_data, output_dims, "output_tensor"), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, micro_test::reporter, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_TANH, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + const char* init_data = nullptr; + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = nullptr; + node.user_data = user_data; + node.builtin_data = nullptr; + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_elements_count; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i], + 1e-5f); + } +} + +void TestTanhInt8(std::initializer_list input_dims_data, + std::initializer_list input_data, float input_min, + float input_max, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, float output_min, + float output_max, int8_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_elements_count = ElementCount(*output_dims); + + constexpr int inputs_size = 1; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min, + input_max), + CreateQuantizedTensor(output_data, output_dims, "output_tensor", + output_min, output_max), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, micro_test::reporter, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_TANH, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + const char* init_data = nullptr; + size_t init_data_size = 1; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = nullptr; + node.user_data = user_data; + node.builtin_data = nullptr; + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_elements_count; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i], + 1); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(SimpleTestFloat) { + const int output_elements_count = 10; + float output_data[output_elements_count]; + tflite::testing::TestTanhFloat({2, 1, 5}, // Input shape. + { + 1.0, + 2.0, + 3.0, + 4.0, + 93.0, + -1.0, + -2.0, + -3.0, + -4.0, + -93.0, + }, + { + // Expected results. + 0.76159416, + 0.96402758, + 0.99505475, + 0.9993293, + 1.0, + -0.76159416, + -0.96402758, + -0.99505475, + -0.9993293, + -1.0, + }, + {2, 1, 5}, // Output shape. + output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestInt8) { + using tflite::testing::F2QS; + + const float input_min = -31.75f; + const float input_max = 32.0f; + const float output_min = -1.0f; + const float output_max = (127.0f / 128.0f); + + const int output_elements_count = 10; + int8_t output_data[output_elements_count]; + tflite::testing::TestTanhInt8( + {2, 1, output_elements_count}, // Input shape. + {F2QS(1.0, input_min, input_max), F2QS(2.0, input_min, input_max), + F2QS(3.0, input_min, input_max), F2QS(4.0, input_min, input_max), + F2QS(5.0, input_min, input_max), F2QS(-1.0, input_min, input_max), + F2QS(-2.0, input_min, input_max), F2QS(-3.0, input_min, input_max), + F2QS(-4.0, input_min, input_max), F2QS(-5.0, input_min, input_max)}, + input_min, input_max, // Input quantized range. + { // Expected results. + F2QS(0.76159416, output_min, output_max), + F2QS(0.96402758, output_min, output_max), + F2QS(0.99505475, output_min, output_max), + F2QS(0.9993293, output_min, output_max), + F2QS(0.9999092, output_min, output_max), + F2QS(-0.76159416, output_min, output_max), + F2QS(-0.96402758, output_min, output_max), + F2QS(-0.99505475, output_min, output_max), + F2QS(-0.9993293, output_min, output_max), + F2QS(-0.9999092, output_min, output_max)}, + {2, 1, output_elements_count}, // Output shape. + output_min, output_max, // Output quantized range. + output_data); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 13761cca28b..a94d643a3d0 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -164,6 +164,8 @@ tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h \ tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h \ tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h \ tensorflow/lite/kernels/internal/reference/integer_ops/mul.h \ +tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h \ +tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h \ tensorflow/lite/kernels/internal/reference/l2normalization.h \ tensorflow/lite/kernels/internal/reference/maximum_minimum.h \ tensorflow/lite/kernels/internal/reference/mul.h \ @@ -181,7 +183,7 @@ tensorflow/lite/kernels/internal/reference/softmax.h \ tensorflow/lite/kernels/internal/reference/sub.h \ tensorflow/lite/kernels/internal/reference/logistic.h \ tensorflow/lite/kernels/internal/reference/strided_slice.h \ -tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h \ +tensorflow/lite/kernels/internal/reference/tanh.h \ tensorflow/lite/kernels/internal/cppmath.h \ tensorflow/lite/kernels/internal/strided_slice_logic.h \ tensorflow/lite/kernels/internal/tensor.h \