diff --git a/tensorflow/lite/experimental/micro/BUILD b/tensorflow/lite/experimental/micro/BUILD index 0a1287d1122..f6d249e6275 100644 --- a/tensorflow/lite/experimental/micro/BUILD +++ b/tensorflow/lite/experimental/micro/BUILD @@ -39,6 +39,7 @@ cc_library( "-Wsign-compare", ], deps = [ + ":micro_utils", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/core/api", "//tensorflow/lite/experimental/micro/memory_planner:greedy_memory_planner", @@ -46,6 +47,19 @@ cc_library( ], ) +cc_library( + name = "micro_utils", + srcs = [ + "micro_utils.cc", + ], + hdrs = [ + "micro_utils.h", + ], + deps = [ + "//tensorflow/lite/c:c_api_internal", + ], +) + tflite_micro_cc_test( name = "micro_error_reporter_test", srcs = [ @@ -110,3 +124,25 @@ tflite_micro_cc_test( "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) + +tflite_micro_cc_test( + name = "testing_helpers_test", + srcs = [ + "testing_helpers_test.cc", + ], + deps = [ + ":micro_framework", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "micro_utils_test", + srcs = [ + "micro_utils_test.cc", + ], + deps = [ + ":micro_utils", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) diff --git a/tensorflow/lite/experimental/micro/kernels/BUILD b/tensorflow/lite/experimental/micro/kernels/BUILD index 7116aee59d2..799bc995bf7 100644 --- a/tensorflow/lite/experimental/micro/kernels/BUILD +++ b/tensorflow/lite/experimental/micro/kernels/BUILD @@ -45,6 +45,7 @@ cc_library( ":activation_utils", ":micro_utils", "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_utils", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/kernels:padding", @@ -104,6 +105,7 @@ cc_library( ":activation_utils", ":micro_utils", "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_utils", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/kernels:padding", @@ -245,6 +247,7 @@ tflite_micro_cc_test( ":all_ops_resolver", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro:micro_utils", "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) diff --git a/tensorflow/lite/experimental/micro/kernels/svdf.cc b/tensorflow/lite/experimental/micro/kernels/svdf.cc index 756c3c7ccd3..7c25b51ddb7 100644 --- a/tensorflow/lite/experimental/micro/kernels/svdf.cc +++ b/tensorflow/lite/experimental/micro/kernels/svdf.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/experimental/micro/kernels/activation_utils.h" +#include "tensorflow/lite/experimental/micro/micro_utils.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" @@ -41,55 +42,6 @@ namespace { * resizing. */ -// TODO(kreeger): Create a uint8-specific version of this when refactoring. -// TODO(kreeger): Remove these quantization methods when tensor_utils is ready -// for micro (b/140272187). -void SymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float* scaling_factor) { - // First, find min/max in values - float min_value = values[0]; - float max_value = values[0]; - for (int i = 1; i < size; ++i) { - if (values[i] < min_value) { - min_value = values[i]; - } - if (values[i] > max_value) { - max_value = values[i]; - } - } - - const float range = fmaxf(fabsf(min_value), fabsf(max_value)); - if (range == 0.0f) { - for (int i = 0; i < size; ++i) { - quantized_values[i] = 0; - } - *scaling_factor = 1; - return; - } - - const int kScale = 127; - *scaling_factor = range / kScale; - const float scaling_factor_inv = kScale / range; - for (int i = 0; i < size; ++i) { - const int32_t quantized_value = - static_cast(roundf(values[i] * scaling_factor_inv)); - // Clamp: just in case some odd numeric offset. - quantized_values[i] = fminf(kScale, fmaxf(-kScale, quantized_value)); - } -} - -// TODO(kreeger): Port this to a micro-utils file. -// TODO(kreeger): Than main difference between svdf.h in the reference kernel is -// the use of tensor_utils/portable_tensor_utils. Those utility methods are not -// currently ready for use in tflite-micro (see b/140272187). -void SymmetricDequantizeFloats(const int8_t* values, const int size, - const float dequantization_scale, - float* dequantized_values) { - for (int i = 0; i < size; ++i) { - dequantized_values[i] = values[i] * dequantization_scale; - } -} - // TODO(kreeger): upstream these reference methods into // `lite/kernels/reference/svdf.h` @@ -289,12 +241,12 @@ inline void EvalHybridSVDF( } if (!is_zero_vector) { + SignedSymmetricPerChannelQuantize(input_ptr_batch, input->dims, 0, + quantized_input_ptr_batch, + scaling_factors_ptr); + // Quantize input from float to int8. for (int b = 0; b < batch_size; ++b) { - const int offset = b * input_size; - SymmetricQuantizeFloats(input_ptr_batch + offset, input_size, - quantized_input_ptr_batch + offset, - &scaling_factors_ptr[b]); scaling_factors_ptr[b] *= weights_feature_scale; } @@ -475,10 +427,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } else { weights_time_ptr = GetTensorData(weights_time); } - SymmetricDequantizeFloats(weights_time_ptr, - NumElements(float_weights_time_scratch), - weights_time->params.scale, - GetTensorData(float_weights_time_scratch)); + SymmetricDequantize(weights_time_ptr, + NumElements(float_weights_time_scratch), + weights_time->params.scale, + GetTensorData(float_weights_time_scratch)); } else { // Validate Input Tensor dtypes: TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32); diff --git a/tensorflow/lite/experimental/micro/memory_helpers_test.cc b/tensorflow/lite/experimental/micro/memory_helpers_test.cc index b522bb901df..a5579424d73 100644 --- a/tensorflow/lite/experimental/micro/memory_helpers_test.cc +++ b/tensorflow/lite/experimental/micro/memory_helpers_test.cc @@ -146,7 +146,8 @@ TF_LITE_MICRO_TEST(TestTypeSizeOf) { } TF_LITE_MICRO_TEST(TestBytesRequiredForTensor) { - const tflite::Tensor* tensor100 = tflite::Create1dFlatbufferTensor(100); + const tflite::Tensor* tensor100 = + tflite::testing::Create1dFlatbufferTensor(100); size_t bytes; size_t type_size; TF_LITE_MICRO_EXPECT_EQ( @@ -155,7 +156,8 @@ TF_LITE_MICRO_TEST(TestBytesRequiredForTensor) { TF_LITE_MICRO_EXPECT_EQ(400, bytes); TF_LITE_MICRO_EXPECT_EQ(4, type_size); - const tflite::Tensor* tensor200 = tflite::Create1dFlatbufferTensor(200); + const tflite::Tensor* tensor200 = + tflite::testing::Create1dFlatbufferTensor(200); TF_LITE_MICRO_EXPECT_EQ( kTfLiteOk, tflite::BytesRequiredForTensor(*tensor200, &bytes, &type_size, micro_test::reporter)); diff --git a/tensorflow/lite/experimental/micro/micro_allocator_test.cc b/tensorflow/lite/experimental/micro/micro_allocator_test.cc index 9d15fe8a8da..dfff66bb4ae 100644 --- a/tensorflow/lite/experimental/micro/micro_allocator_test.cc +++ b/tensorflow/lite/experimental/micro/micro_allocator_test.cc @@ -21,16 +21,16 @@ limitations under the License. TF_LITE_MICRO_TESTS_BEGIN TF_LITE_MICRO_TEST(TestInitializeRuntimeTensor) { - const tflite::Model* model = tflite::GetMockModel(); + const tflite::Model* model = tflite::testing::GetMockModel(); TfLiteContext context; constexpr size_t arena_size = 1024; uint8_t arena[arena_size]; tflite::MicroAllocator allocator(&context, model, arena, arena_size, micro_test::reporter); - const tflite::Tensor* tensor = tflite::Create1dFlatbufferTensor(100); + const tflite::Tensor* tensor = tflite::testing::Create1dFlatbufferTensor(100); const flatbuffers::Vector>* buffers = - tflite::CreateFlatbufferBuffers(); + tflite::testing::CreateFlatbufferBuffers(); TfLiteTensor allocated_tensor; TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, allocator.InitializeRuntimeTensor( @@ -44,7 +44,7 @@ TF_LITE_MICRO_TEST(TestInitializeRuntimeTensor) { } TF_LITE_MICRO_TEST(TestMissingQuantization) { - const tflite::Model* model = tflite::GetMockModel(); + const tflite::Model* model = tflite::testing::GetMockModel(); TfLiteContext context; constexpr size_t arena_size = 1024; uint8_t arena[arena_size]; @@ -52,9 +52,9 @@ TF_LITE_MICRO_TEST(TestMissingQuantization) { micro_test::reporter); const tflite::Tensor* tensor = - tflite::CreateMissingQuantizationFlatbufferTensor(100); + tflite::testing::CreateMissingQuantizationFlatbufferTensor(100); const flatbuffers::Vector>* buffers = - tflite::CreateFlatbufferBuffers(); + tflite::testing::CreateFlatbufferBuffers(); TfLiteTensor allocated_tensor; TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, allocator.InitializeRuntimeTensor( @@ -68,7 +68,7 @@ TF_LITE_MICRO_TEST(TestMissingQuantization) { } TF_LITE_MICRO_TEST(TestAllocateTensors) { - const tflite::Model* model = tflite::GetMockModel(); + const tflite::Model* model = tflite::testing::GetMockModel(); TfLiteContext context; constexpr size_t arena_size = 1024; uint8_t arena[arena_size]; @@ -106,7 +106,7 @@ TF_LITE_MICRO_TEST(TestAllocateTensors) { } TF_LITE_MICRO_TEST(TestPreallocatedInput) { - const tflite::Model* model = tflite::GetMockModel(); + const tflite::Model* model = tflite::testing::GetMockModel(); TfLiteContext context; constexpr size_t arena_size = 1024; uint8_t arena[arena_size]; diff --git a/tensorflow/lite/experimental/micro/micro_interpreter_test.cc b/tensorflow/lite/experimental/micro/micro_interpreter_test.cc index e4cce8406fa..ebd12625c1b 100644 --- a/tensorflow/lite/experimental/micro/micro_interpreter_test.cc +++ b/tensorflow/lite/experimental/micro/micro_interpreter_test.cc @@ -71,7 +71,7 @@ class MockOpResolver : public OpResolver { TF_LITE_MICRO_TESTS_BEGIN TF_LITE_MICRO_TEST(TestInterpreter) { - const tflite::Model* model = tflite::GetMockModel(); + const tflite::Model* model = tflite::testing::GetMockModel(); TF_LITE_MICRO_EXPECT_NE(nullptr, model); tflite::MockOpResolver mock_resolver; constexpr size_t allocator_buffer_size = 1024; @@ -105,7 +105,7 @@ TF_LITE_MICRO_TEST(TestInterpreter) { } TF_LITE_MICRO_TEST(TestInterpreterProvideInputBuffer) { - const tflite::Model* model = tflite::GetMockModel(); + const tflite::Model* model = tflite::testing::GetMockModel(); TF_LITE_MICRO_EXPECT_NE(nullptr, model); tflite::MockOpResolver mock_resolver; int32_t input_buffer = 21; diff --git a/tensorflow/lite/experimental/micro/micro_utils.cc b/tensorflow/lite/experimental/micro/micro_utils.cc new file mode 100644 index 00000000000..cdeb7403569 --- /dev/null +++ b/tensorflow/lite/experimental/micro/micro_utils.cc @@ -0,0 +1,199 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/micro_utils.h" + +#include +#include +#include + +#include "tensorflow/lite/c/c_api_internal.h" + +namespace tflite { + +namespace { + +static const uint8_t kAsymmetricUInt8Min = 0; +static const uint8_t kAsymmetricUInt8Max = 255; +static const uint8_t kSymmetricUInt8Min = 1; +static const uint8_t kSymmetricUInt8Max = 255; +static const int8_t kAsymmetricInt8Min = -128; +static const int8_t kAsymmetricInt8Max = 127; +static const int kSymmetricInt8Scale = kAsymmetricInt8Max; + +} // namespace + +int ElementCount(const TfLiteIntArray& dims) { + int result = 1; + for (int i = 0; i < dims.size; ++i) { + result *= dims.data[i]; + } + return result; +} + +// Converts a float value into an unsigned eight-bit quantized value. +uint8_t FloatToAsymmetricQuantizedUInt8(const float value, const float scale, + const int zero_point) { + int32_t result = round(value / scale) + zero_point; + if (result < kAsymmetricUInt8Min) { + result = kAsymmetricUInt8Min; + } + if (result > kAsymmetricUInt8Max) { + result = kAsymmetricUInt8Max; + } + return result; +} + +uint8_t FloatToSymmetricQuantizedUInt8(const float value, const float scale, + const int zero_point) { + int32_t result = round(value / scale) + zero_point; + if (result < kSymmetricUInt8Min) { + result = kSymmetricUInt8Min; + } + if (result > kSymmetricUInt8Max) { + result = kSymmetricUInt8Max; + } + return result; +} + +int8_t FloatToAsymmetricQuantizedInt8(const float value, const float scale, + const int zero_point) { + return FloatToAsymmetricQuantizedUInt8(value, scale, + zero_point - kAsymmetricInt8Min) + + kAsymmetricInt8Min; +} + +int8_t FloatToSymmetricQuantizedInt8(const float value, const float scale, + const int zero_point) { + return FloatToSymmetricQuantizedUInt8(value, scale, + zero_point - kAsymmetricInt8Min) + + kAsymmetricInt8Min; +} + +int32_t FloatToSymmetricQuantizedInt32(const float value, const float scale) { + float quantized = round(value / scale); + if (quantized > INT_MAX) { + quantized = INT_MAX; + } else if (quantized < INT_MIN) { + quantized = INT_MIN; + } + + return static_cast(quantized); +} + +void AsymmetricQuantize(const float* input, int8_t* output, int num_elements, + float scale, int zero_point) { + for (int i = 0; i < num_elements; i++) { + output[i] = FloatToAsymmetricQuantizedInt8(input[i], scale, zero_point); + } +} + +void AsymmetricQuantize(const float* input, uint8_t* output, int num_elements, + float scale, int zero_point) { + for (int i = 0; i < num_elements; i++) { + output[i] = FloatToAsymmetricQuantizedUInt8(input[i], scale, zero_point); + } +} + +void SymmetricQuantize(const float* input, int32_t* output, int num_elements, + float scale) { + for (int i = 0; i < num_elements; i++) { + output[i] = FloatToSymmetricQuantizedInt32(input[i], scale); + } +} + +void SymmetricPerChannelQuantize(const float* input, int32_t* output, + int num_elements, int num_channels, + float* scales) { + int elements_per_channel = num_elements / num_channels; + for (int i = 0; i < num_channels; i++) { + for (int j = 0; j < elements_per_channel; j++) { + output[i * elements_per_channel + j] = FloatToSymmetricQuantizedInt32( + input[i * elements_per_channel + j], scales[i]); + } + } +} + +void SignedSymmetricPerChannelQuantize(const float* values, + TfLiteIntArray* dims, + int quantized_dimension, + int8_t* quantized_values, + float* scaling_factors) { + int input_size = ElementCount(*dims); + int channel_count = dims->data[quantized_dimension]; + int per_channel_size = input_size / channel_count; + for (int channel = 0; channel < channel_count; channel++) { + float min = 0; + float max = 0; + int stride = 1; + for (int i = 0; i < quantized_dimension; i++) { + stride *= dims->data[i]; + } + int channel_stride = per_channel_size / stride; + // Calculate scales for each channel. + for (int i = 0; i < per_channel_size; i++) { + int idx = channel * channel_stride + i * stride; + min = fminf(min, values[idx]); + max = fmaxf(max, values[idx]); + } + scaling_factors[channel] = + fmaxf(fabs(min), fabs(max)) / kSymmetricInt8Scale; + for (int i = 0; i < per_channel_size; i++) { + int idx = channel * channel_stride + i * stride; + const int32_t quantized_value = + static_cast(roundf(values[idx] / scaling_factors[channel])); + // Clamp: just in case some odd numeric offset. + quantized_values[idx] = fminf( + kSymmetricInt8Scale, fmaxf(-kSymmetricInt8Scale, quantized_value)); + } + } +} + +void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims, + int8_t* quantized_values, float* scaling_factor) { + int input_size = ElementCount(*dims); + + float min = 0; + float max = 0; + for (int i = 0; i < input_size; i++) { + min = fminf(min, values[i]); + max = fmaxf(max, values[i]); + } + *scaling_factor = fmaxf(fabs(min), fabs(max)) / kSymmetricInt8Scale; + for (int i = 0; i < input_size; i++) { + const int32_t quantized_value = + static_cast(roundf(values[i] / *scaling_factor)); + // Clamp: just in case some odd numeric offset. + quantized_values[i] = fminf(kSymmetricInt8Scale, + fmaxf(-kSymmetricInt8Scale, quantized_value)); + } +} + +void SymmetricQuantize(const float* values, TfLiteIntArray* dims, + uint8_t* quantized_values, float* scaling_factor) { + SignedSymmetricQuantize(values, dims, + reinterpret_cast(quantized_values), + scaling_factor); +} + +void SymmetricDequantize(const int8_t* values, const int size, + const float dequantization_scale, + float* dequantized_values) { + for (int i = 0; i < size; ++i) { + dequantized_values[i] = values[i] * dequantization_scale; + } +} + +} // namespace tflite diff --git a/tensorflow/lite/experimental/micro/micro_utils.h b/tensorflow/lite/experimental/micro/micro_utils.h new file mode 100644 index 00000000000..8008ee95199 --- /dev/null +++ b/tensorflow/lite/experimental/micro/micro_utils.h @@ -0,0 +1,88 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_MICRO_UTILS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_MICRO_UTILS_H_ + +#include + +#include "tensorflow/lite/c/c_api_internal.h" + +namespace tflite { + +// Returns number of elements in the shape array. + +int ElementCount(const TfLiteIntArray& dims); + +uint8_t FloatToAsymmetricQuantizedUInt8(const float value, const float scale, + const int zero_point); + +uint8_t FloatToSymmetricQuantizedUInt8(const float value, const float scale, + const int zero_point); + +int8_t FloatToAsymmetricQuantizedInt8(const float value, const float scale, + const int zero_point); + +int8_t FloatToSymmetricQuantizedInt8(const float value, const float scale, + const int zero_point); + +// Converts a float value into a signed thirty-two-bit quantized value. Note +// that values close to max int and min int may see significant error due to +// a lack of floating point granularity for large values. +int32_t FloatToSymmetricQuantizedInt32(const float value, const float scale); + +// Helper methods to quantize arrays of floats to the desired format. +// +// There are several key flavors of quantization in TfLite: +// asymmetric symmetric per channel +// int8 | X | X | X | +// uint8 | X | X | | +// int32 | | X | X | +// +// The per-op quantizaiton spec can be found here: +// https://www.tensorflow.org/lite/performance/quantization_spec + +void AsymmetricQuantize(const float* input, int8_t* output, int num_elements, + float scale, int zero_point = 0); + +void AsymmetricQuantize(const float* input, uint8_t* output, int num_elements, + float scale, int zero_point = 128); + +void SymmetricQuantize(const float* input, int32_t* output, int num_elements, + float scale); + +void SymmetricPerChannelQuantize(const float* input, int32_t* output, + int num_elements, int num_channels, + float* scales); + +void SignedSymmetricPerChannelQuantize(const float* values, + TfLiteIntArray* dims, + int quantized_dimension, + int8_t* quantized_values, + float* scaling_factor); + +void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims, + int8_t* quantized_values, float* scaling_factor); + +void SymmetricQuantize(const float* values, TfLiteIntArray* dims, + uint8_t* quantized_values, float* scaling_factor); + +void SymmetricDequantize(const int8_t* values, const int size, + const float dequantization_scale, + float* dequantized_values); + +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_MICRO_UTILS_H_ diff --git a/tensorflow/lite/experimental/micro/micro_utils_test.cc b/tensorflow/lite/experimental/micro/micro_utils_test.cc new file mode 100644 index 00000000000..d2d9105b9f7 --- /dev/null +++ b/tensorflow/lite/experimental/micro/micro_utils_test.cc @@ -0,0 +1,97 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/micro_utils.h" + +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(FloatToAsymmetricQuantizedUInt8Test) { + using tflite::FloatToAsymmetricQuantizedUInt8; + // [0, 127.5] -> zero_point=0, scale=0.5 + TF_LITE_MICRO_EXPECT_EQ(0, FloatToAsymmetricQuantizedUInt8(0, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(254, FloatToAsymmetricQuantizedUInt8(127, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(255, FloatToAsymmetricQuantizedUInt8(127.5, 0.5, 0)); + // [-10, 245] -> zero_point=10, scale=1.0 + TF_LITE_MICRO_EXPECT_EQ(0, FloatToAsymmetricQuantizedUInt8(-10, 1.0, 10)); + TF_LITE_MICRO_EXPECT_EQ(1, FloatToAsymmetricQuantizedUInt8(-9, 1.0, 10)); + TF_LITE_MICRO_EXPECT_EQ(128, FloatToAsymmetricQuantizedUInt8(118, 1.0, 10)); + TF_LITE_MICRO_EXPECT_EQ(253, FloatToAsymmetricQuantizedUInt8(243, 1.0, 10)); + TF_LITE_MICRO_EXPECT_EQ(254, FloatToAsymmetricQuantizedUInt8(244, 1.0, 10)); + TF_LITE_MICRO_EXPECT_EQ(255, FloatToAsymmetricQuantizedUInt8(245, 1.0, 10)); +} + +TF_LITE_MICRO_TEST(FloatToAsymmetricQuantizedInt8Test) { + using tflite::FloatToAsymmetricQuantizedInt8; + // [-64, 63.5] -> zero_point=0, scale=0.5 + TF_LITE_MICRO_EXPECT_EQ(2, FloatToAsymmetricQuantizedInt8(1, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(4, FloatToAsymmetricQuantizedInt8(2, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(6, FloatToAsymmetricQuantizedInt8(3, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(-10, FloatToAsymmetricQuantizedInt8(-5, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(-128, FloatToAsymmetricQuantizedInt8(-64, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(127, FloatToAsymmetricQuantizedInt8(63.5, 0.5, 0)); + // [-127, 128] -> zero_point=-1, scale=1.0 + TF_LITE_MICRO_EXPECT_EQ(0, FloatToAsymmetricQuantizedInt8(1, 1.0, -1)); + TF_LITE_MICRO_EXPECT_EQ(-1, FloatToAsymmetricQuantizedInt8(0, 1.0, -1)); + TF_LITE_MICRO_EXPECT_EQ(126, FloatToAsymmetricQuantizedInt8(127, 1.0, -1)); + TF_LITE_MICRO_EXPECT_EQ(127, FloatToAsymmetricQuantizedInt8(128, 1.0, -1)); + TF_LITE_MICRO_EXPECT_EQ(-127, FloatToAsymmetricQuantizedInt8(-126, 1.0, -1)); + TF_LITE_MICRO_EXPECT_EQ(-128, FloatToAsymmetricQuantizedInt8(-127, 1.0, -1)); +} + +TF_LITE_MICRO_TEST(FloatToAsymmetricQuantizedInt32Test) { + using tflite::FloatToSymmetricQuantizedInt32; + TF_LITE_MICRO_EXPECT_EQ(0, FloatToSymmetricQuantizedInt32(0, 0.5)); + TF_LITE_MICRO_EXPECT_EQ(2, FloatToSymmetricQuantizedInt32(1, 0.5)); + TF_LITE_MICRO_EXPECT_EQ(-2, FloatToSymmetricQuantizedInt32(-1, 0.5)); + TF_LITE_MICRO_EXPECT_EQ(-100, FloatToSymmetricQuantizedInt32(-50, 0.5)); + TF_LITE_MICRO_EXPECT_EQ(100, FloatToSymmetricQuantizedInt32(50, 0.5)); +} + +TF_LITE_MICRO_TEST(AsymmetricQuantizeInt8) { + float values[] = {-10.3, -3.1, -2.1, -1.9, -0.9, 0.1, 0.9, 1.85, 2.9, 4.1}; + int8_t goldens[] = {-20, -5, -3, -3, -1, 1, 3, 5, 7, 9}; + const int length = sizeof(values) / sizeof(float); + int8_t quantized[length]; + tflite::AsymmetricQuantize(values, quantized, length, 0.5, 1); + for (int i = 0; i < length; i++) { + TF_LITE_MICRO_EXPECT_EQ(quantized[i], goldens[i]); + } +} + +TF_LITE_MICRO_TEST(AsymmetricQuantizeUInt8) { + float values[] = {-10.3, -3.1, -2.1, -1.9, -0.9, 0.1, 0.9, 1.85, 2.9, 4.1}; + uint8_t goldens[] = {106, 121, 123, 123, 125, 127, 129, 131, 133, 135}; + const int length = sizeof(values) / sizeof(float); + uint8_t quantized[length]; + tflite::AsymmetricQuantize(values, quantized, length, 0.5, 127); + for (int i = 0; i < length; i++) { + TF_LITE_MICRO_EXPECT_EQ(quantized[i], goldens[i]); + } +} + +TF_LITE_MICRO_TEST(SymmetricQuantizeInt32) { + float values[] = {-10.3, -3.1, -2.1, -1.9, -0.9, 0.1, 0.9, 1.85, 2.9, 4.1}; + int32_t goldens[] = {-21, -6, -4, -4, -2, 0, 2, 4, 6, 8}; + const int length = sizeof(values) / sizeof(float); + int32_t quantized[length]; + tflite::SymmetricQuantize(values, quantized, length, 0.5); + for (int i = 0; i < length; i++) { + TF_LITE_MICRO_EXPECT_EQ(quantized[i], goldens[i]); + } +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/test_helpers.cc b/tensorflow/lite/experimental/micro/test_helpers.cc index 28d732427c2..19b4da41717 100644 --- a/tensorflow/lite/experimental/micro/test_helpers.cc +++ b/tensorflow/lite/experimental/micro/test_helpers.cc @@ -15,7 +15,12 @@ limitations under the License. #include "tensorflow/lite/experimental/micro/test_helpers.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/api/tensor_utils.h" +#include "tensorflow/lite/experimental/micro/micro_utils.h" + namespace tflite { +namespace testing { namespace { class StackAllocator : public flatbuffers::Allocator { @@ -181,4 +186,216 @@ CreateFlatbufferBuffers() { return result; } +int TestStrcmp(const char* a, const char* b) { + if ((a == nullptr) || (b == nullptr)) { + return -1; + } + while ((*a != 0) && (*a == *b)) { + a++; + b++; + } + return *reinterpret_cast(a) - + *reinterpret_cast(b); +} + +// Wrapper to forward kernel errors to the interpreter's error reporter. +void ReportOpError(struct TfLiteContext* context, const char* format, ...) { + ErrorReporter* error_reporter = static_cast(context->impl_); + va_list args; + va_start(args, format); + error_reporter->Report(format, args); + va_end(args); +} + +// Create a TfLiteIntArray from an array of ints. The first element in the +// supplied array must be the size of the array expressed as an int. +TfLiteIntArray* IntArrayFromInts(const int* int_array) { + return const_cast( + reinterpret_cast(int_array)); +} + +// Create a TfLiteFloatArray from an array of floats. The first element in the +// supplied array must be the size of the array expressed as a float. +TfLiteFloatArray* FloatArrayFromFloats(const float* floats) { + static_assert(sizeof(float) == sizeof(int), + "assumes sizeof(float) == sizeof(int) to perform casting"); + int size = static_cast(floats[0]); + *reinterpret_cast(const_cast(floats)) = size; + return reinterpret_cast(const_cast(floats)); +} + +TfLiteTensor CreateTensor(TfLiteIntArray* dims, const char* name, + bool is_variable) { + TfLiteTensor result; + result.dims = dims; + result.name = name; + result.params = {}; + result.quantization = {kTfLiteNoQuantization, nullptr}; + result.is_variable = is_variable; + result.allocation_type = kTfLiteMemNone; + result.allocation = nullptr; + return result; +} + +TfLiteTensor CreateFloatTensor(const float* data, TfLiteIntArray* dims, + const char* name, bool is_variable) { + TfLiteTensor result = CreateTensor(dims, name, is_variable); + result.type = kTfLiteFloat32; + result.data.f = const_cast(data); + result.bytes = ElementCount(*dims) * sizeof(float); + return result; +} + +void PopulateFloatTensor(TfLiteTensor* tensor, float* begin, float* end) { + float* p = begin; + float* v = tensor->data.f; + while (p != end) { + *v++ = *p++; + } +} + +TfLiteTensor CreateBoolTensor(const bool* data, TfLiteIntArray* dims, + const char* name, bool is_variable) { + TfLiteTensor result = CreateTensor(dims, name, is_variable); + result.type = kTfLiteBool; + result.data.b = const_cast(data); + result.bytes = ElementCount(*dims) * sizeof(bool); + return result; +} + +TfLiteTensor CreateInt32Tensor(const int32_t* data, TfLiteIntArray* dims, + const char* name, bool is_variable) { + TfLiteTensor result = CreateTensor(dims, name, is_variable); + result.type = kTfLiteInt32; + result.data.i32 = const_cast(data); + result.bytes = ElementCount(*dims) * sizeof(int32_t); + return result; +} + +TfLiteTensor CreateQuantizedTensor(const uint8_t* data, TfLiteIntArray* dims, + float scale, int zero_point, + const char* name, bool is_variable) { + TfLiteTensor result = CreateTensor(dims, name, is_variable); + result.type = kTfLiteUInt8; + result.data.uint8 = const_cast(data); + result.params = {scale, zero_point}; + result.quantization = {kTfLiteAffineQuantization, nullptr}; + result.bytes = ElementCount(*dims) * sizeof(uint8_t); + return result; +} + +// Create Quantized tensor which contains a quantized version of the supplied +// buffer. +TfLiteTensor CreateQuantizedTensor(const float* input, uint8_t* quantized, + TfLiteIntArray* dims, float scale, + int zero_point, const char* name, + bool is_variable) { + int input_size = ElementCount(*dims); + tflite::AsymmetricQuantize(input, quantized, input_size, scale, zero_point); + return CreateQuantizedTensor(quantized, dims, scale, zero_point, name); +} + +TfLiteTensor CreateQuantizedTensor(const int8_t* data, TfLiteIntArray* dims, + float scale, int zero_point, + const char* name, bool is_variable) { + TfLiteTensor result = CreateTensor(dims, name, is_variable); + result.type = kTfLiteInt8; + result.data.int8 = const_cast(data); + result.params = {scale, zero_point}; + result.quantization = {kTfLiteAffineQuantization, nullptr}; + result.bytes = ElementCount(*dims) * sizeof(int8_t); + return result; +} + +TfLiteTensor CreateQuantizedTensor(const float* input, int8_t* quantized, + TfLiteIntArray* dims, float scale, + int zero_point, const char* name, + bool is_variable) { + int input_size = ElementCount(*dims); + tflite::AsymmetricQuantize(input, quantized, input_size, scale, zero_point); + return CreateQuantizedTensor(quantized, dims, scale, zero_point, name); +} + +TfLiteTensor CreateQuantized32Tensor(const int32_t* data, TfLiteIntArray* dims, + float scale, const char* name, + bool is_variable) { + TfLiteTensor result = CreateTensor(dims, name, is_variable); + result.type = kTfLiteInt32; + result.data.i32 = const_cast(data); + // Quantized int32 tensors always have a zero point of 0, since the range of + // int32 values is large, and because zero point costs extra cycles during + // processing. + result.params = {scale, 0}; + result.quantization = {kTfLiteAffineQuantization, nullptr}; + result.bytes = ElementCount(*dims) * sizeof(int32_t); + return result; +} + +TfLiteTensor CreateQuantizedBiasTensor(const float* data, int32_t* quantized, + TfLiteIntArray* dims, float input_scale, + float weights_scale, const char* name, + bool is_variable) { + float bias_scale = input_scale * weights_scale; + tflite::SymmetricQuantize(data, quantized, ElementCount(*dims), bias_scale); + return CreateQuantized32Tensor(quantized, dims, bias_scale, name, + is_variable); +} + +// Quantizes int32 bias tensor with per-channel weights determined by input +// scale multiplied by weight scale for each channel. +TfLiteTensor CreatePerChannelQuantizedBiasTensor( + const float* input, int32_t* quantized, TfLiteIntArray* dims, + float input_scale, float* weight_scales, float* scales, int* zero_points, + TfLiteAffineQuantization* affine_quant, int quantized_dimension, + const char* name, bool is_variable) { + int input_size = ElementCount(*dims); + int num_channels = dims->data[quantized_dimension]; + // First element is reserved for array length + zero_points[0] = num_channels; + scales[0] = static_cast(num_channels); + float* scales_array = &scales[1]; + for (int i = 0; i < num_channels; i++) { + scales_array[i] = input_scale * weight_scales[i]; + zero_points[i + 1] = 0; + } + + SymmetricPerChannelQuantize(input, quantized, input_size, num_channels, + scales_array); + + affine_quant->scale = FloatArrayFromFloats(scales); + affine_quant->zero_point = IntArrayFromInts(zero_points); + affine_quant->quantized_dimension = quantized_dimension; + + TfLiteTensor result = CreateTensor(dims, name, is_variable); + result.type = kTfLiteInt32; + result.data.i32 = const_cast(quantized); + result.quantization = {kTfLiteAffineQuantization, affine_quant}; + result.bytes = ElementCount(*dims) * sizeof(int32_t); + return result; +} + +TfLiteTensor CreateSymmetricPerChannelQuantizedTensor( + const float* input, int8_t* quantized, TfLiteIntArray* dims, float* scales, + int* zero_points, TfLiteAffineQuantization* affine_quant, + int quantized_dimension, const char* name, bool is_variable) { + int channel_count = dims->data[quantized_dimension]; + scales[0] = static_cast(channel_count); + zero_points[0] = channel_count; + + SignedSymmetricPerChannelQuantize(input, dims, quantized_dimension, quantized, + &scales[1]); + + affine_quant->scale = FloatArrayFromFloats(scales); + affine_quant->zero_point = IntArrayFromInts(zero_points); + affine_quant->quantized_dimension = quantized_dimension; + + TfLiteTensor result = CreateTensor(dims, name, is_variable); + result.type = kTfLiteInt8; + result.data.int8 = const_cast(quantized); + result.quantization = {kTfLiteAffineQuantization, affine_quant}; + result.bytes = ElementCount(*dims) * sizeof(int8_t); + return result; +} + +} // namespace testing } // namespace tflite diff --git a/tensorflow/lite/experimental/micro/test_helpers.h b/tensorflow/lite/experimental/micro/test_helpers.h index 32eb11f8452..173e462fecf 100644 --- a/tensorflow/lite/experimental/micro/test_helpers.h +++ b/tensorflow/lite/experimental/micro/test_helpers.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" namespace tflite { +namespace testing { // Returns an example flatbuffer TensorFlow Lite model. const Model* GetMockModel(); @@ -37,6 +38,71 @@ const Tensor* CreateMissingQuantizationFlatbufferTensor(int size); const flatbuffers::Vector>* CreateFlatbufferBuffers(); +// Performs a simple string comparison without requiring standard C library. +int TestStrcmp(const char* a, const char* b); + +// Wrapper to forward kernel errors to the interpreter's error reporter. +void ReportOpError(struct TfLiteContext* context, const char* format, ...); + +void PopulateContext(TfLiteTensor* tensors, int tensors_size, + TfLiteContext* context); + +// Create a TfLiteIntArray from an array of ints. The first element in the +// supplied array must be the size of the array expressed as an int. +TfLiteIntArray* IntArrayFromInts(const int* int_array); + +// Create a TfLiteFloatArray from an array of floats. The first element in the +// supplied array must be the size of the array expressed as a float. +TfLiteFloatArray* FloatArrayFromFloats(const float* floats); + +TfLiteTensor CreateFloatTensor(const float* data, TfLiteIntArray* dims, + const char* name, bool is_variable = false); + +void PopulateFloatTensor(TfLiteTensor* tensor, float* begin, float* end); + +TfLiteTensor CreateBoolTensor(const bool* data, TfLiteIntArray* dims, + const char* name, bool is_variable = false); + +TfLiteTensor CreateInt32Tensor(const int32_t*, TfLiteIntArray* dims, + const char* name, bool is_variable = false); + +TfLiteTensor CreateQuantizedTensor(const uint8_t* data, TfLiteIntArray* dims, + float scale, int zero_point, + const char* name, bool is_variable = false); + +TfLiteTensor CreateQuantizedTensor(const float* input, uint8_t* quantized, + TfLiteIntArray* dims, float scale, + int zero_point, const char* name, + bool is_variable = false); + +TfLiteTensor CreateQuantizedTensor(const int8_t* data, TfLiteIntArray* dims, + float scale, int zero_point, + const char* name, bool is_variable = false); + +TfLiteTensor CreateQuantizedTensor(const float* input, int8_t* quantized, + TfLiteIntArray* dims, float scale, + int zero_point, const char* name, + bool is_variable = false); + +TfLiteTensor CreateQuantizedBiasTensor(const float* data, int32_t* quantized, + TfLiteIntArray* dims, float input_scale, + float weights_scale, const char* name, + bool is_variable = false); + +// Quantizes int32 bias tensor with per-channel weights determined by input +// scale multiplied by weight scale for each channel. +TfLiteTensor CreatePerChannelQuantizedBiasTensor( + const float* input, int32_t* quantized, TfLiteIntArray* dims, + float input_scale, float* weight_scales, float* scales, int* zero_points, + TfLiteAffineQuantization* affine_quant, int quantized_dimension, + const char* name, bool is_variable = false); + +TfLiteTensor CreateSymmetricPerChannelQuantizedTensor( + const float* input, int8_t* quantized, TfLiteIntArray* dims, float* scales, + int* zero_points, TfLiteAffineQuantization* affine_quant, + int quantized_dimension, const char* name, bool is_variable = false); + +} // namespace testing } // namespace tflite #endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TEST_HELPERS_H_ diff --git a/tensorflow/lite/experimental/micro/testing/BUILD b/tensorflow/lite/experimental/micro/testing/BUILD index bf7b6fa30e9..dcdb9945532 100644 --- a/tensorflow/lite/experimental/micro/testing/BUILD +++ b/tensorflow/lite/experimental/micro/testing/BUILD @@ -20,15 +20,6 @@ cc_library( "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/core/api", "//tensorflow/lite/experimental/micro:micro_framework", - ], -) - -tflite_micro_cc_test( - name = "test_utils_test", - srcs = [ - "test_utils_test.cc", - ], - deps = [ - ":micro_test", + "//tensorflow/lite/experimental/micro:micro_utils", ], ) diff --git a/tensorflow/lite/experimental/micro/testing/micro_test.h b/tensorflow/lite/experimental/micro/testing/micro_test.h index 32e9a57f76e..17941f38d03 100644 --- a/tensorflow/lite/experimental/micro/testing/micro_test.h +++ b/tensorflow/lite/experimental/micro/testing/micro_test.h @@ -125,14 +125,15 @@ extern tflite::ErrorReporter* reporter; } \ } while (false) -#define TF_LITE_MICRO_EXPECT_NEAR(x, y, epsilon) \ - do { \ - auto delta = ((x) > (y)) ? ((x) - (y)) : ((y) - (x)); \ - if (delta > epsilon) { \ - micro_test::reporter->Report(#x " near " #y " failed at %s:%d", \ - __FILE__, __LINE__); \ - micro_test::did_test_fail = true; \ - } \ +#define TF_LITE_MICRO_EXPECT_NEAR(x, y, epsilon) \ + do { \ + auto delta = ((x) > (y)) ? ((x) - (y)) : ((y) - (x)); \ + if (delta > epsilon) { \ + micro_test::reporter->Report( \ + #x " (%f) near " #y " (%f) failed at %s:%d", static_cast(x), \ + static_cast(y), __FILE__, __LINE__); \ + micro_test::did_test_fail = true; \ + } \ } while (false) #define TF_LITE_MICRO_EXPECT_GT(x, y) \ diff --git a/tensorflow/lite/experimental/micro/testing/test_utils.h b/tensorflow/lite/experimental/micro/testing/test_utils.h index 74482102b9a..b883c9194db 100644 --- a/tensorflow/lite/experimental/micro/testing/test_utils.h +++ b/tensorflow/lite/experimental/micro/testing/test_utils.h @@ -16,82 +16,26 @@ limitations under the License. #define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_TEST_UTILS_H_ #include -#include #include #include #include -#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/core/api/tensor_utils.h" -#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" +#include "tensorflow/lite/experimental/micro/micro_utils.h" +#include "tensorflow/lite/experimental/micro/test_helpers.h" #include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace tflite { namespace testing { -// TODO(kreeger): Move to common code! -inline void SignedSymmetricQuantizeFloats(const float* values, const int size, - float* min_value, float* max_value, - int8_t* quantized_values, - float* scaling_factor) { - // First, find min/max in values - *min_value = values[0]; - *max_value = values[0]; - for (int i = 1; i < size; ++i) { - if (values[i] < *min_value) { - *min_value = values[i]; - } - if (values[i] > *max_value) { - *max_value = values[i]; - } - } - const float range = fmaxf(fabsf(*min_value), fabsf(*max_value)); - if (range == 0.0f) { - for (int i = 0; i < size; ++i) { - quantized_values[i] = 0; - } - *scaling_factor = 1; - return; - } +// Note: These methods are deprecated, do not use. See b/141332970. - const int kScale = 127; - *scaling_factor = range / kScale; - const float scaling_factor_inv = kScale / range; - for (int i = 0; i < size; ++i) { - const int32_t quantized_value = - static_cast(roundf(values[i] * scaling_factor_inv)); - // Clamp: just in case some odd numeric offset. - quantized_values[i] = fminf(kScale, fmaxf(-kScale, quantized_value)); - } -} - -inline void SymmetricQuantizeFloats(const float* values, const int size, - float* min_value, float* max_value, - uint8_t* quantized_values, - float* scaling_factor) { - SignedSymmetricQuantizeFloats(values, size, min_value, max_value, - reinterpret_cast(quantized_values), - scaling_factor); -} - -// How many elements are in the array with this shape. -inline int ElementCount(const TfLiteIntArray& dims) { - int result = 1; - for (int i = 0; i < dims.size; ++i) { - result *= dims.data[i]; - } - return result; -} - -// Wrapper to forward kernel errors to the interpreter's error reporter. -inline void ReportOpError(struct TfLiteContext* context, const char* format, - ...) { - ErrorReporter* error_reporter = static_cast(context->impl_); - va_list args; - va_start(args, format); - error_reporter->Report(format, args); - va_end(args); +// TODO(kreeger): Don't use this anymore in our tests. Optimized compiler +// settings can play with pointer placement on the stack (b/140130236). +inline TfLiteIntArray* IntArrayFromInitializer( + std::initializer_list int_initializer) { + return IntArrayFromInts(int_initializer.begin()); } // Derives the quantization range max from scaling factor and zero point. @@ -151,6 +95,7 @@ inline int32_t F2Q32(const float value, const float scale) { return static_cast(quantized); } +// TODO(b/141330728): Move this method elsewhere as part clean up. inline void PopulateContext(TfLiteTensor* tensors, int tensors_size, TfLiteContext* context) { context->tensors_size = tensors_size; @@ -168,70 +113,16 @@ inline void PopulateContext(TfLiteTensor* tensors, int tensors_size, for (int i = 0; i < tensors_size; ++i) { if (context->tensors[i].is_variable) { - tflite::ResetVariableTensor(&context->tensors[i]); + ResetVariableTensor(&context->tensors[i]); } } } - -inline TfLiteIntArray* IntArrayFromInts(const int* int_array) { - return const_cast( - reinterpret_cast(int_array)); -} - -// TODO(kreeger): Don't use this anymore in our tests. Optimized compiler -// settings can play with pointer placement on the stack (b/140130236). -inline TfLiteIntArray* IntArrayFromInitializer( - std::initializer_list int_initializer) { - return IntArrayFromInts(int_initializer.begin()); -} - -inline TfLiteTensor CreateFloatTensor(const float* data, TfLiteIntArray* dims, - const char* name, - bool is_variable = false) { - TfLiteTensor result; - result.type = kTfLiteFloat32; - result.data.f = const_cast(data); - result.dims = dims; - result.params = {}; - result.allocation_type = kTfLiteMemNone; - result.bytes = ElementCount(*dims) * sizeof(float); - result.allocation = nullptr; - result.name = name; - result.is_variable = is_variable; - return result; -} - inline TfLiteTensor CreateFloatTensor(std::initializer_list data, TfLiteIntArray* dims, const char* name, bool is_variable = false) { return CreateFloatTensor(data.begin(), dims, name, is_variable); } -inline void PopulateFloatTensor(TfLiteTensor* tensor, float* begin, - float* end) { - float* p = begin; - float* v = tensor->data.f; - while (p != end) { - *v++ = *p++; - } -} - -inline TfLiteTensor CreateBoolTensor(const bool* data, TfLiteIntArray* dims, - const char* name, - bool is_variable = false) { - TfLiteTensor result; - result.type = kTfLiteBool; - result.data.b = const_cast(data); - result.dims = dims; - result.params = {}; - result.allocation_type = kTfLiteMemNone; - result.bytes = ElementCount(*dims) * sizeof(bool); - result.allocation = nullptr; - result.name = name; - result.is_variable = is_variable; - return result; -} - inline TfLiteTensor CreateBoolTensor(std::initializer_list data, TfLiteIntArray* dims, const char* name, bool is_variable = false) { @@ -293,9 +184,7 @@ inline TfLiteTensor CreateQuantizedTensor(float* data, uint8_t* quantized_data, const char* name, bool is_variable = false) { TfLiteTensor result; - float min, max; - SymmetricQuantizeFloats(data, ElementCount(*dims), &min, &max, quantized_data, - &result.params.scale); + SymmetricQuantize(data, dims, quantized_data, &result.params.scale); result.data.uint8 = quantized_data; result.type = kTfLiteUInt8; result.dims = dims; @@ -313,9 +202,7 @@ inline TfLiteTensor CreateQuantizedTensor(float* data, int8_t* quantized_data, const char* name, bool is_variable = false) { TfLiteTensor result; - float min, max; - SignedSymmetricQuantizeFloats(data, ElementCount(*dims), &min, &max, - quantized_data, &result.params.scale); + SignedSymmetricQuantize(data, dims, quantized_data, &result.params.scale); result.data.int8 = quantized_data; result.type = kTfLiteInt8; result.dims = dims; @@ -380,19 +267,6 @@ inline TfLiteTensor CreateTensor(std::initializer_list data, is_variable); } -// Do a simple string comparison for testing purposes, without requiring the -// standard C library. -inline int TestStrcmp(const char* a, const char* b) { - if ((a == nullptr) || (b == nullptr)) { - return -1; - } - while ((*a != 0) && (*a == *b)) { - a++; - b++; - } - return *(const unsigned char*)a - *(const unsigned char*)b; -} - } // namespace testing } // namespace tflite diff --git a/tensorflow/lite/experimental/micro/testing/test_utils_test.cc b/tensorflow/lite/experimental/micro/testing/test_utils_test.cc deleted file mode 100644 index a65c55452c9..00000000000 --- a/tensorflow/lite/experimental/micro/testing/test_utils_test.cc +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/micro/testing/test_utils.h" - -#include "tensorflow/lite/experimental/micro/testing/micro_test.h" - -TF_LITE_MICRO_TESTS_BEGIN - -TF_LITE_MICRO_TEST(F2QTest) { - using tflite::testing::F2Q; - // [0, 127.5] -> zero_point=0, scale=0.5 - TF_LITE_MICRO_EXPECT_EQ(0, F2Q(0, 0, 127.5)); - TF_LITE_MICRO_EXPECT_EQ(254, F2Q(127, 0, 127.5)); - TF_LITE_MICRO_EXPECT_EQ(255, F2Q(127.5, 0, 127.5)); - // [-10, 245] -> zero_point=-10, scale=1.0 - TF_LITE_MICRO_EXPECT_EQ(0, F2Q(-10, -10, 245)); - TF_LITE_MICRO_EXPECT_EQ(1, F2Q(-9, -10, 245)); - TF_LITE_MICRO_EXPECT_EQ(128, F2Q(118, -10, 245)); - TF_LITE_MICRO_EXPECT_EQ(253, F2Q(243, -10, 245)); - TF_LITE_MICRO_EXPECT_EQ(254, F2Q(244, -10, 245)); - TF_LITE_MICRO_EXPECT_EQ(255, F2Q(245, -10, 245)); -} - -TF_LITE_MICRO_TEST(F2QSTest) { - using tflite::testing::F2QS; - // [-64, 63.5] -> zero_point=0, scale=0.5 - TF_LITE_MICRO_EXPECT_EQ(2, F2QS(1, -64, 63.5)); - TF_LITE_MICRO_EXPECT_EQ(4, F2QS(2, -64, 63.5)); - TF_LITE_MICRO_EXPECT_EQ(6, F2QS(3, -64, 63.5)); - TF_LITE_MICRO_EXPECT_EQ(-10, F2QS(-5, -64, 63.5)); - TF_LITE_MICRO_EXPECT_EQ(-128, F2QS(-64, -64, 63.5)); - TF_LITE_MICRO_EXPECT_EQ(127, F2QS(63.5, -64, 63.5)); - // [-127, 128] -> zero_point=1, scale=1.0 - TF_LITE_MICRO_EXPECT_EQ(0, F2QS(1, -127, 128)); - TF_LITE_MICRO_EXPECT_EQ(-1, F2QS(0, -127, 128)); - TF_LITE_MICRO_EXPECT_EQ(126, F2QS(127, -127, 128)); - TF_LITE_MICRO_EXPECT_EQ(127, F2QS(128, -127, 128)); - TF_LITE_MICRO_EXPECT_EQ(-127, F2QS(-126, -127, 128)); - TF_LITE_MICRO_EXPECT_EQ(-128, F2QS(-127, -127, 128)); -} - -TF_LITE_MICRO_TEST(F2Q32Test) { - using tflite::testing::F2Q32; - TF_LITE_MICRO_EXPECT_EQ(0, F2Q32(0, 0.5)); - TF_LITE_MICRO_EXPECT_EQ(2, F2Q32(1, 0.5)); - TF_LITE_MICRO_EXPECT_EQ(-2, F2Q32(-1, 0.5)); - TF_LITE_MICRO_EXPECT_EQ(-100, F2Q32(-50, 0.5)); - TF_LITE_MICRO_EXPECT_EQ(100, F2Q32(50, 0.5)); -} - -TF_LITE_MICRO_TEST(ZeroPointTest) { - TF_LITE_MICRO_EXPECT_EQ( - 10, tflite::testing::ZeroPointFromMinMax(-69, 58.5)); - TF_LITE_MICRO_EXPECT_EQ( - -10, tflite::testing::ZeroPointFromMinMax(-59, 68.5)); - TF_LITE_MICRO_EXPECT_EQ( - 0, tflite::testing::ZeroPointFromMinMax(0, 255)); - TF_LITE_MICRO_EXPECT_EQ( - 64, tflite::testing::ZeroPointFromMinMax(-32, 95.5)); -} - -TF_LITE_MICRO_TEST(ZeroPointRoundingTest) { - TF_LITE_MICRO_EXPECT_EQ( - -1, tflite::testing::ZeroPointFromMinMax(-126.51, 128.49)); - TF_LITE_MICRO_EXPECT_EQ( - -1, tflite::testing::ZeroPointFromMinMax(-127.49, 127.51)); - TF_LITE_MICRO_EXPECT_EQ( - 0, tflite::testing::ZeroPointFromMinMax(-127.51, 127.49)); - TF_LITE_MICRO_EXPECT_EQ( - 0, tflite::testing::ZeroPointFromMinMax(-128.49, 126.51)); - TF_LITE_MICRO_EXPECT_EQ( - 1, tflite::testing::ZeroPointFromMinMax(-128.51, 126.49)); - TF_LITE_MICRO_EXPECT_EQ( - 1, tflite::testing::ZeroPointFromMinMax(-129.49, 125.51)); -} - -TF_LITE_MICRO_TEST(ScaleTest) { - int min_int = std::numeric_limits::min(); - int max_int = std::numeric_limits::max(); - TF_LITE_MICRO_EXPECT_EQ( - 0.5, tflite::testing::ScaleFromMinMax(-0.5, max_int)); - TF_LITE_MICRO_EXPECT_EQ( - 1.0, tflite::testing::ScaleFromMinMax(min_int, max_int)); - TF_LITE_MICRO_EXPECT_EQ(0.25, tflite::testing::ScaleFromMinMax( - min_int / 4, max_int / 4)); - TF_LITE_MICRO_EXPECT_EQ(0.5, - tflite::testing::ScaleFromMinMax(-64, 63.5)); - TF_LITE_MICRO_EXPECT_EQ(0.25, - tflite::testing::ScaleFromMinMax(0, 63.75)); - TF_LITE_MICRO_EXPECT_EQ(0.5, - tflite::testing::ScaleFromMinMax(0, 127.5)); - TF_LITE_MICRO_EXPECT_EQ( - 0.25, tflite::testing::ScaleFromMinMax(63.75, 127.5)); -} - -TF_LITE_MICRO_TEST(MinMaxTest) { - TF_LITE_MICRO_EXPECT_EQ( - -128, tflite::testing::MinFromZeroPointScale(0, 1.0)); - TF_LITE_MICRO_EXPECT_EQ( - 127, tflite::testing::MaxFromZeroPointScale(0, 1.0)); - TF_LITE_MICRO_EXPECT_EQ( - -64, tflite::testing::MinFromZeroPointScale(0, 0.5)); - TF_LITE_MICRO_EXPECT_EQ( - 63.5, tflite::testing::MaxFromZeroPointScale(0, 0.5)); - TF_LITE_MICRO_EXPECT_EQ( - -65, tflite::testing::MinFromZeroPointScale(2, 0.5)); - TF_LITE_MICRO_EXPECT_EQ( - 62.5, tflite::testing::MaxFromZeroPointScale(2, 0.5)); -} - -TF_LITE_MICRO_TEST(ZeroPointScaleMinMaxSanityTest) { - float min = -150.0f; - float max = 105.0f; - float scale = tflite::testing::ScaleFromMinMax(min, max); - int zero_point = tflite::testing::ZeroPointFromMinMax(min, max); - float min_test = - tflite::testing::MinFromZeroPointScale(zero_point, scale); - float max_test = - tflite::testing::MaxFromZeroPointScale(zero_point, scale); - TF_LITE_MICRO_EXPECT_EQ(min, min_test); - TF_LITE_MICRO_EXPECT_EQ(max, max_test); -} - -TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/testing_helpers_test.cc b/tensorflow/lite/experimental/micro/testing_helpers_test.cc new file mode 100644 index 00000000000..48f46c64496 --- /dev/null +++ b/tensorflow/lite/experimental/micro/testing_helpers_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/experimental/micro/testing/test_utils.h" + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(CreateQuantizedBiasTensor) { + float input_scale = 0.5; + float weight_scale = 0.5; + const int tensor_size = 12; + int dims_arr[] = {4, 2, 3, 2, 1}; + const char* tensor_name = "test_tensor"; + int32_t quantized[tensor_size]; + float pre_quantized[] = {-10, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 10}; + int32_t expected_quantized_values[] = {-40, -20, -16, -12, -8, -4, + 0, 4, 8, 12, 16, 40}; + TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(dims_arr); + + TfLiteTensor result = tflite::testing::CreateQuantizedBiasTensor( + pre_quantized, quantized, dims, input_scale, weight_scale, tensor_name); + + TF_LITE_MICRO_EXPECT_EQ(result.bytes, tensor_size * sizeof(int32_t)); + TF_LITE_MICRO_EXPECT_EQ(result.dims, dims); + TF_LITE_MICRO_EXPECT_EQ(result.name, tensor_name); + TF_LITE_MICRO_EXPECT_EQ(result.params.scale, input_scale * weight_scale); + for (int i = 0; i < tensor_size; i++) { + TF_LITE_MICRO_EXPECT_EQ(expected_quantized_values[i], result.data.i32[i]); + } +} + +TF_LITE_MICRO_TEST(CreatePerChannelQuantizedBiasTensor) { + float input_scale = 0.5; + float weight_scales[] = {0.5, 1, 2, 4}; + const int tensor_size = 12; + const int channels = 4; + int dims_arr[] = {4, 4, 3, 1, 1}; + const char* tensor_name = "test_tensor"; + int32_t quantized[tensor_size]; + float scales[channels + 1]; + int zero_points[] = {4, 0, 0, 0, 0}; + float pre_quantized[] = {-10, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 10}; + int32_t expected_quantized_values[] = {-40, -20, -16, -6, -4, -2, + 0, 1, 2, 2, 2, 5}; + TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(dims_arr); + + TfLiteAffineQuantization quant; + TfLiteTensor result = tflite::testing::CreatePerChannelQuantizedBiasTensor( + pre_quantized, quantized, dims, input_scale, weight_scales, scales, + zero_points, &quant, 0, tensor_name); + + // Values in scales array start at index 1 since index 0 is dedicated to + // tracking the tensor size. + for (int i = 0; i < channels; i++) { + TF_LITE_MICRO_EXPECT_EQ(scales[i + 1], input_scale * weight_scales[i]); + } + + TF_LITE_MICRO_EXPECT_EQ(result.bytes, tensor_size * sizeof(int32_t)); + TF_LITE_MICRO_EXPECT_EQ(result.dims, dims); + TF_LITE_MICRO_EXPECT_EQ(result.name, tensor_name); + for (int i = 0; i < tensor_size; i++) { + TF_LITE_MICRO_EXPECT_EQ(expected_quantized_values[i], result.data.i32[i]); + } +} + +TF_LITE_MICRO_TEST(CreateSymmetricPerChannelQuantizedTensor) { + const int tensor_size = 12; + const int channels = 2; + const int dims_arr[] = {4, channels, 3, 2, 1}; + const char* tensor_name = "test_tensor"; + int8_t quantized[12]; + const float pre_quantized[] = {-127, -55, -4, -3, -2, -1, + 0, 1, 2, 3, 4, 63.5}; + const int8_t expected_quantized_values[] = {-127, -55, -4, -3, -2, -1, + 0, 2, 4, 6, 8, 127}; + float expected_scales[] = {1.0, 0.5}; + TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(dims_arr); + + int zero_points[channels + 1]; + float scales[channels + 1]; + TfLiteAffineQuantization quant; + TfLiteTensor result = + tflite::testing::CreateSymmetricPerChannelQuantizedTensor( + pre_quantized, quantized, dims, scales, zero_points, &quant, 0, + "test_tensor"); + + TF_LITE_MICRO_EXPECT_EQ(result.bytes, tensor_size * sizeof(int8_t)); + TF_LITE_MICRO_EXPECT_EQ(result.dims, dims); + TF_LITE_MICRO_EXPECT_EQ(result.name, tensor_name); + TfLiteFloatArray* result_scales = + static_cast(result.quantization.params)->scale; + for (int i = 0; i < channels; i++) { + TF_LITE_MICRO_EXPECT_EQ(result_scales->data[i], expected_scales[i]); + } + for (int i = 0; i < tensor_size; i++) { + TF_LITE_MICRO_EXPECT_EQ(expected_quantized_values[i], result.data.int8[i]); + } +} + +TF_LITE_MICRO_TESTS_END