Refactor Symmetric quantize tensor to quantization utils.
PiperOrigin-RevId: 236367083
This commit is contained in:
parent
132d1d4c76
commit
d73090d9fd
@ -19,20 +19,31 @@ cc_library(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels/internal:round",
|
||||
"//tensorflow/lite/kernels/internal:tensor_utils",
|
||||
"//tensorflow/lite/kernels/internal:types",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "quantization_utils_test",
|
||||
srcs = ["quantization_utils_test.cc"],
|
||||
args = [
|
||||
"--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
|
||||
],
|
||||
tags = [
|
||||
"tflite_not_portable_android",
|
||||
"tflite_not_portable_ios",
|
||||
],
|
||||
deps = [
|
||||
":quantization_utils",
|
||||
":test_util",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_googletest//:gtest",
|
||||
|
@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/round.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
#include <cmath>
|
||||
@ -158,6 +160,43 @@ void SymmetricPerChannelQuantizeValues(const float* const input,
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
|
||||
if (model == nullptr || tensor == nullptr) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
BufferT* buffer = model->buffers[tensor->buffer].get();
|
||||
if (buffer == nullptr) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
float* float_data = reinterpret_cast<float*>(buffer->data.data());
|
||||
uint64_t num_elements;
|
||||
TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &num_elements));
|
||||
|
||||
std::vector<int8_t> quantized_buffer;
|
||||
quantized_buffer.resize(num_elements);
|
||||
|
||||
float min_value, max_value, scaling_factor;
|
||||
tensor_utils::SymmetricQuantizeFloats(float_data, num_elements,
|
||||
quantized_buffer.data(), &min_value,
|
||||
&max_value, &scaling_factor);
|
||||
|
||||
if (tensor->quantization == nullptr) {
|
||||
tensor->quantization = absl::make_unique<QuantizationParametersT>();
|
||||
}
|
||||
tensor->quantization->scale = std::vector<float>(1, scaling_factor);
|
||||
tensor->quantization->zero_point = std::vector<int64_t>(1, 0);
|
||||
|
||||
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(quantized_buffer.data());
|
||||
model->buffers[tensor->buffer]->data.assign(uint8_buffer,
|
||||
uint8_buffer + num_elements);
|
||||
|
||||
// Update the tensor type.
|
||||
tensor->type = TensorType_INT8;
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace optimize
|
||||
} // namespace tflite
|
||||
|
@ -60,6 +60,10 @@ void SymmetricPerChannelQuantizeValues(const float* const input,
|
||||
int32_t channel_dim_index,
|
||||
std::vector<int8_t>* output_value);
|
||||
|
||||
// Quantizes tensor using symmetric quantization with the min and max elements
|
||||
// of the tensor.
|
||||
TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor);
|
||||
|
||||
} // namespace utils
|
||||
} // namespace optimize
|
||||
} // namespace tflite
|
||||
|
@ -15,12 +15,31 @@ limitations under the License.
|
||||
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/tools/optimize/test_util.h"
|
||||
|
||||
namespace {
|
||||
tensorflow::string* g_test_model_dir = nullptr;
|
||||
} // namespace
|
||||
|
||||
namespace tflite {
|
||||
namespace optimize {
|
||||
namespace utils {
|
||||
namespace {
|
||||
|
||||
std::unique_ptr<FlatBufferModel> ReadModel(const char* model) {
|
||||
auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, model);
|
||||
return FlatBufferModel::BuildFromFile(model_path.c_str());
|
||||
}
|
||||
|
||||
std::unique_ptr<FlatBufferModel> ReadConvModel() {
|
||||
return ReadModel(internal::kConvModelWith0Plus10Weights);
|
||||
}
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
TEST(QuantizationUtilsTest, NumElements) {
|
||||
@ -201,12 +220,61 @@ TEST(QuantizationUtilsTest, SymmetricPerChannelQuantizeValues) {
|
||||
EXPECT_THAT(output_data, ElementsAreArray(expected_output_data));
|
||||
}
|
||||
|
||||
TEST(QuantizationUtilsTest, SymmetricQuantizeTensorNullInputs) {
|
||||
EXPECT_EQ(SymmetricQuantizeTensor(nullptr, nullptr), kTfLiteError);
|
||||
}
|
||||
|
||||
TEST(QuantizationUtilsTest, SymmetricQuantizeTensor) {
|
||||
// Conv model has weights between 0 and 10.
|
||||
// Quantize the weights tensor.
|
||||
ASSERT_TRUE(g_test_model_dir);
|
||||
ASSERT_FALSE(g_test_model_dir->empty());
|
||||
auto test_model = ReadConvModel();
|
||||
ASSERT_TRUE(test_model);
|
||||
auto readonly_model = test_model->GetModel();
|
||||
ASSERT_TRUE(readonly_model);
|
||||
ASSERT_TRUE(readonly_model->subgraphs());
|
||||
ASSERT_GE(readonly_model->subgraphs()->size(), 1);
|
||||
tflite::ModelT model;
|
||||
readonly_model->UnPackTo(&model);
|
||||
auto subgraph = model.subgraphs[0].get();
|
||||
auto conv_op = subgraph->operators.at(0).get();
|
||||
ASSERT_EQ(model.operator_codes.at(conv_op->opcode_index)->builtin_code,
|
||||
BuiltinOperator_CONV_2D);
|
||||
int32_t weights_tensor_idx = conv_op->inputs[1];
|
||||
TensorT* weights_tensor = subgraph->tensors.at(weights_tensor_idx).get();
|
||||
|
||||
EXPECT_EQ(weights_tensor->type, TensorType_FLOAT32);
|
||||
size_t float_buffer_size =
|
||||
model.buffers.at(weights_tensor->buffer)->data.size();
|
||||
|
||||
EXPECT_EQ(SymmetricQuantizeTensor(&model, weights_tensor), kTfLiteOk);
|
||||
|
||||
size_t quant_buffer_size =
|
||||
model.buffers.at(weights_tensor->buffer)->data.size();
|
||||
EXPECT_EQ(weights_tensor->type, TensorType_INT8);
|
||||
EXPECT_EQ(quant_buffer_size * 4, float_buffer_size);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace utils
|
||||
} // namespace optimize
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
tensorflow::string model_file;
|
||||
const std::vector<tensorflow::Flag> flag_list = {
|
||||
tensorflow::Flag("test_model_file", &model_file,
|
||||
"Path to test tflite model file."),
|
||||
};
|
||||
|
||||
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
if (!parse_result) {
|
||||
std::cerr << "Required test_model_file\n";
|
||||
std::abort();
|
||||
}
|
||||
g_test_model_dir =
|
||||
new tensorflow::string(tensorflow::io::Dirname(model_file));
|
||||
::tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
@ -190,41 +190,6 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator(
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Quantizes tensor using symmetric quantization with the min and max elements
|
||||
// of the tensor. This is need for operations with hybrid evaluation
|
||||
// implemented.
|
||||
TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
|
||||
BufferT* buffer = model->buffers[tensor->buffer].get();
|
||||
float* float_data = reinterpret_cast<float*>(buffer->data.data());
|
||||
uint64_t num_elements;
|
||||
TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &num_elements));
|
||||
LOG(INFO) << "Quantizing tensor " << tensor->name << " with " << num_elements
|
||||
<< " elements.";
|
||||
|
||||
std::vector<int8_t> quantized_buffer;
|
||||
quantized_buffer.resize(num_elements);
|
||||
|
||||
float min_value, max_value, scaling_factor;
|
||||
tensor_utils::SymmetricQuantizeFloats(float_data, num_elements,
|
||||
quantized_buffer.data(), &min_value,
|
||||
&max_value, &scaling_factor);
|
||||
|
||||
if (tensor->quantization == nullptr) {
|
||||
tensor->quantization = absl::make_unique<QuantizationParametersT>();
|
||||
}
|
||||
tensor->quantization->scale = std::vector<float>(1, scaling_factor);
|
||||
tensor->quantization->zero_point = std::vector<int64_t>(1, 0);
|
||||
|
||||
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(quantized_buffer.data());
|
||||
model->buffers[tensor->buffer]->data.assign(uint8_buffer,
|
||||
uint8_buffer + num_elements);
|
||||
|
||||
// Update the tensor type.
|
||||
tensor->type = TensorType_INT8;
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Returns the index of the Dequantize op_code.
|
||||
// If a Dequantize op_code doesn't exist, adds it and returns its index.
|
||||
int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) {
|
||||
@ -314,7 +279,7 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
||||
for (std::pair<int32_t, TensorT*> tensor_pair : tensor_map) {
|
||||
// Quantize the tensor.
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
SymmetricQuantizeTensor(model.get(), tensor_pair.second));
|
||||
utils::SymmetricQuantizeTensor(model.get(), tensor_pair.second));
|
||||
}
|
||||
|
||||
// Examine the tensor consumers to determine which require dequantize ops.
|
||||
|
Loading…
Reference in New Issue
Block a user