diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index 5629188938c..3011de4b399 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -19,20 +19,31 @@ cc_library( "//tensorflow/lite:framework", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/kernels/internal:round", + "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/kernels/internal:types", "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/memory", ], ) tf_cc_test( name = "quantization_utils_test", srcs = ["quantization_utils_test.cc"], + args = [ + "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + ], + data = [ + "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", + ], tags = [ "tflite_not_portable_android", "tflite_not_portable_ios", ], deps = [ ":quantization_utils", + ":test_util", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", diff --git a/tensorflow/lite/tools/optimize/quantization_utils.cc b/tensorflow/lite/tools/optimize/quantization_utils.cc index 445fffb8dd4..a5b9b00b8a9 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/lite/tools/optimize/quantization_utils.h" +#include "absl/memory/memory.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/internal/types.h" #include @@ -158,6 +160,43 @@ void SymmetricPerChannelQuantizeValues(const float* const input, } } +TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) { + if (model == nullptr || tensor == nullptr) { + return kTfLiteError; + } + + BufferT* buffer = model->buffers[tensor->buffer].get(); + if (buffer == nullptr) { + return kTfLiteError; + } + float* float_data = reinterpret_cast(buffer->data.data()); + uint64_t num_elements; + TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &num_elements)); + + std::vector quantized_buffer; + quantized_buffer.resize(num_elements); + + float min_value, max_value, scaling_factor; + tensor_utils::SymmetricQuantizeFloats(float_data, num_elements, + quantized_buffer.data(), &min_value, + &max_value, &scaling_factor); + + if (tensor->quantization == nullptr) { + tensor->quantization = absl::make_unique(); + } + tensor->quantization->scale = std::vector(1, scaling_factor); + tensor->quantization->zero_point = std::vector(1, 0); + + uint8_t* uint8_buffer = reinterpret_cast(quantized_buffer.data()); + model->buffers[tensor->buffer]->data.assign(uint8_buffer, + uint8_buffer + num_elements); + + // Update the tensor type. + tensor->type = TensorType_INT8; + + return kTfLiteOk; +} + } // namespace utils } // namespace optimize } // namespace tflite diff --git a/tensorflow/lite/tools/optimize/quantization_utils.h b/tensorflow/lite/tools/optimize/quantization_utils.h index d20b3176bf3..010bcb931fb 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils.h +++ b/tensorflow/lite/tools/optimize/quantization_utils.h @@ -60,6 +60,10 @@ void SymmetricPerChannelQuantizeValues(const float* const input, int32_t channel_dim_index, std::vector* output_value); +// Quantizes tensor using symmetric quantization with the min and max elements +// of the tensor. +TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor); + } // namespace utils } // namespace optimize } // namespace tflite diff --git a/tensorflow/lite/tools/optimize/quantization_utils_test.cc b/tensorflow/lite/tools/optimize/quantization_utils_test.cc index ecad09ed612..1562309a9c8 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils_test.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils_test.cc @@ -15,12 +15,31 @@ limitations under the License. #include "tensorflow/lite/tools/optimize/quantization_utils.h" #include #include +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/tools/optimize/test_util.h" + +namespace { +tensorflow::string* g_test_model_dir = nullptr; +} // namespace namespace tflite { namespace optimize { namespace utils { namespace { +std::unique_ptr ReadModel(const char* model) { + auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, model); + return FlatBufferModel::BuildFromFile(model_path.c_str()); +} + +std::unique_ptr ReadConvModel() { + return ReadModel(internal::kConvModelWith0Plus10Weights); +} + using ::testing::ElementsAreArray; TEST(QuantizationUtilsTest, NumElements) { @@ -201,12 +220,61 @@ TEST(QuantizationUtilsTest, SymmetricPerChannelQuantizeValues) { EXPECT_THAT(output_data, ElementsAreArray(expected_output_data)); } +TEST(QuantizationUtilsTest, SymmetricQuantizeTensorNullInputs) { + EXPECT_EQ(SymmetricQuantizeTensor(nullptr, nullptr), kTfLiteError); +} + +TEST(QuantizationUtilsTest, SymmetricQuantizeTensor) { + // Conv model has weights between 0 and 10. + // Quantize the weights tensor. + ASSERT_TRUE(g_test_model_dir); + ASSERT_FALSE(g_test_model_dir->empty()); + auto test_model = ReadConvModel(); + ASSERT_TRUE(test_model); + auto readonly_model = test_model->GetModel(); + ASSERT_TRUE(readonly_model); + ASSERT_TRUE(readonly_model->subgraphs()); + ASSERT_GE(readonly_model->subgraphs()->size(), 1); + tflite::ModelT model; + readonly_model->UnPackTo(&model); + auto subgraph = model.subgraphs[0].get(); + auto conv_op = subgraph->operators.at(0).get(); + ASSERT_EQ(model.operator_codes.at(conv_op->opcode_index)->builtin_code, + BuiltinOperator_CONV_2D); + int32_t weights_tensor_idx = conv_op->inputs[1]; + TensorT* weights_tensor = subgraph->tensors.at(weights_tensor_idx).get(); + + EXPECT_EQ(weights_tensor->type, TensorType_FLOAT32); + size_t float_buffer_size = + model.buffers.at(weights_tensor->buffer)->data.size(); + + EXPECT_EQ(SymmetricQuantizeTensor(&model, weights_tensor), kTfLiteOk); + + size_t quant_buffer_size = + model.buffers.at(weights_tensor->buffer)->data.size(); + EXPECT_EQ(weights_tensor->type, TensorType_INT8); + EXPECT_EQ(quant_buffer_size * 4, float_buffer_size); +} + } // namespace } // namespace utils } // namespace optimize } // namespace tflite int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); + tensorflow::string model_file; + const std::vector flag_list = { + tensorflow::Flag("test_model_file", &model_file, + "Path to test tflite model file."), + }; + + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + std::cerr << "Required test_model_file\n"; + std::abort(); + } + g_test_model_dir = + new tensorflow::string(tensorflow::io::Dirname(model_file)); + ::tensorflow::port::InitMain(argv[0], &argc, &argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/lite/tools/optimize/quantize_weights.cc b/tensorflow/lite/tools/optimize/quantize_weights.cc index f0a280f1c1f..da96c586375 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights.cc @@ -190,41 +190,6 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator( return kTfLiteOk; } -// Quantizes tensor using symmetric quantization with the min and max elements -// of the tensor. This is need for operations with hybrid evaluation -// implemented. -TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) { - BufferT* buffer = model->buffers[tensor->buffer].get(); - float* float_data = reinterpret_cast(buffer->data.data()); - uint64_t num_elements; - TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &num_elements)); - LOG(INFO) << "Quantizing tensor " << tensor->name << " with " << num_elements - << " elements."; - - std::vector quantized_buffer; - quantized_buffer.resize(num_elements); - - float min_value, max_value, scaling_factor; - tensor_utils::SymmetricQuantizeFloats(float_data, num_elements, - quantized_buffer.data(), &min_value, - &max_value, &scaling_factor); - - if (tensor->quantization == nullptr) { - tensor->quantization = absl::make_unique(); - } - tensor->quantization->scale = std::vector(1, scaling_factor); - tensor->quantization->zero_point = std::vector(1, 0); - - uint8_t* uint8_buffer = reinterpret_cast(quantized_buffer.data()); - model->buffers[tensor->buffer]->data.assign(uint8_buffer, - uint8_buffer + num_elements); - - // Update the tensor type. - tensor->type = TensorType_INT8; - - return kTfLiteOk; -} - // Returns the index of the Dequantize op_code. // If a Dequantize op_code doesn't exist, adds it and returns its index. int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) { @@ -314,7 +279,7 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder, for (std::pair tensor_pair : tensor_map) { // Quantize the tensor. TF_LITE_ENSURE_STATUS( - SymmetricQuantizeTensor(model.get(), tensor_pair.second)); + utils::SymmetricQuantizeTensor(model.get(), tensor_pair.second)); } // Examine the tensor consumers to determine which require dequantize ops.