Refactor Symmetric quantize tensor to quantization utils.

PiperOrigin-RevId: 236367083
This commit is contained in:
Shashi Shekhar 2019-03-01 13:25:27 -08:00 committed by TensorFlower Gardener
parent 132d1d4c76
commit d73090d9fd
5 changed files with 124 additions and 37 deletions

View File

@ -19,20 +19,31 @@ cc_library(
"//tensorflow/lite:framework",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels/internal:round",
"//tensorflow/lite/kernels/internal:tensor_utils",
"//tensorflow/lite/kernels/internal:types",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/memory",
],
)
tf_cc_test(
name = "quantization_utils_test",
srcs = ["quantization_utils_test.cc"],
args = [
"--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
],
data = [
"//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
],
tags = [
"tflite_not_portable_android",
"tflite_not_portable_ios",
],
deps = [
":quantization_utils",
":test_util",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_googletest//:gtest",

View File

@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
#include "absl/memory/memory.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/round.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include <cmath>
@ -158,6 +160,43 @@ void SymmetricPerChannelQuantizeValues(const float* const input,
}
}
TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
if (model == nullptr || tensor == nullptr) {
return kTfLiteError;
}
BufferT* buffer = model->buffers[tensor->buffer].get();
if (buffer == nullptr) {
return kTfLiteError;
}
float* float_data = reinterpret_cast<float*>(buffer->data.data());
uint64_t num_elements;
TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &num_elements));
std::vector<int8_t> quantized_buffer;
quantized_buffer.resize(num_elements);
float min_value, max_value, scaling_factor;
tensor_utils::SymmetricQuantizeFloats(float_data, num_elements,
quantized_buffer.data(), &min_value,
&max_value, &scaling_factor);
if (tensor->quantization == nullptr) {
tensor->quantization = absl::make_unique<QuantizationParametersT>();
}
tensor->quantization->scale = std::vector<float>(1, scaling_factor);
tensor->quantization->zero_point = std::vector<int64_t>(1, 0);
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(quantized_buffer.data());
model->buffers[tensor->buffer]->data.assign(uint8_buffer,
uint8_buffer + num_elements);
// Update the tensor type.
tensor->type = TensorType_INT8;
return kTfLiteOk;
}
} // namespace utils
} // namespace optimize
} // namespace tflite

View File

@ -60,6 +60,10 @@ void SymmetricPerChannelQuantizeValues(const float* const input,
int32_t channel_dim_index,
std::vector<int8_t>* output_value);
// Quantizes tensor using symmetric quantization with the min and max elements
// of the tensor.
TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor);
} // namespace utils
} // namespace optimize
} // namespace tflite

View File

@ -15,12 +15,31 @@ limitations under the License.
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/tools/optimize/test_util.h"
namespace {
tensorflow::string* g_test_model_dir = nullptr;
} // namespace
namespace tflite {
namespace optimize {
namespace utils {
namespace {
std::unique_ptr<FlatBufferModel> ReadModel(const char* model) {
auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, model);
return FlatBufferModel::BuildFromFile(model_path.c_str());
}
std::unique_ptr<FlatBufferModel> ReadConvModel() {
return ReadModel(internal::kConvModelWith0Plus10Weights);
}
using ::testing::ElementsAreArray;
TEST(QuantizationUtilsTest, NumElements) {
@ -201,12 +220,61 @@ TEST(QuantizationUtilsTest, SymmetricPerChannelQuantizeValues) {
EXPECT_THAT(output_data, ElementsAreArray(expected_output_data));
}
TEST(QuantizationUtilsTest, SymmetricQuantizeTensorNullInputs) {
EXPECT_EQ(SymmetricQuantizeTensor(nullptr, nullptr), kTfLiteError);
}
TEST(QuantizationUtilsTest, SymmetricQuantizeTensor) {
// Conv model has weights between 0 and 10.
// Quantize the weights tensor.
ASSERT_TRUE(g_test_model_dir);
ASSERT_FALSE(g_test_model_dir->empty());
auto test_model = ReadConvModel();
ASSERT_TRUE(test_model);
auto readonly_model = test_model->GetModel();
ASSERT_TRUE(readonly_model);
ASSERT_TRUE(readonly_model->subgraphs());
ASSERT_GE(readonly_model->subgraphs()->size(), 1);
tflite::ModelT model;
readonly_model->UnPackTo(&model);
auto subgraph = model.subgraphs[0].get();
auto conv_op = subgraph->operators.at(0).get();
ASSERT_EQ(model.operator_codes.at(conv_op->opcode_index)->builtin_code,
BuiltinOperator_CONV_2D);
int32_t weights_tensor_idx = conv_op->inputs[1];
TensorT* weights_tensor = subgraph->tensors.at(weights_tensor_idx).get();
EXPECT_EQ(weights_tensor->type, TensorType_FLOAT32);
size_t float_buffer_size =
model.buffers.at(weights_tensor->buffer)->data.size();
EXPECT_EQ(SymmetricQuantizeTensor(&model, weights_tensor), kTfLiteOk);
size_t quant_buffer_size =
model.buffers.at(weights_tensor->buffer)->data.size();
EXPECT_EQ(weights_tensor->type, TensorType_INT8);
EXPECT_EQ(quant_buffer_size * 4, float_buffer_size);
}
} // namespace
} // namespace utils
} // namespace optimize
} // namespace tflite
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
tensorflow::string model_file;
const std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("test_model_file", &model_file,
"Path to test tflite model file."),
};
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
std::cerr << "Required test_model_file\n";
std::abort();
}
g_test_model_dir =
new tensorflow::string(tensorflow::io::Dirname(model_file));
::tensorflow::port::InitMain(argv[0], &argc, &argv);
return RUN_ALL_TESTS();
}

View File

@ -190,41 +190,6 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator(
return kTfLiteOk;
}
// Quantizes tensor using symmetric quantization with the min and max elements
// of the tensor. This is need for operations with hybrid evaluation
// implemented.
TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
BufferT* buffer = model->buffers[tensor->buffer].get();
float* float_data = reinterpret_cast<float*>(buffer->data.data());
uint64_t num_elements;
TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &num_elements));
LOG(INFO) << "Quantizing tensor " << tensor->name << " with " << num_elements
<< " elements.";
std::vector<int8_t> quantized_buffer;
quantized_buffer.resize(num_elements);
float min_value, max_value, scaling_factor;
tensor_utils::SymmetricQuantizeFloats(float_data, num_elements,
quantized_buffer.data(), &min_value,
&max_value, &scaling_factor);
if (tensor->quantization == nullptr) {
tensor->quantization = absl::make_unique<QuantizationParametersT>();
}
tensor->quantization->scale = std::vector<float>(1, scaling_factor);
tensor->quantization->zero_point = std::vector<int64_t>(1, 0);
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(quantized_buffer.data());
model->buffers[tensor->buffer]->data.assign(uint8_buffer,
uint8_buffer + num_elements);
// Update the tensor type.
tensor->type = TensorType_INT8;
return kTfLiteOk;
}
// Returns the index of the Dequantize op_code.
// If a Dequantize op_code doesn't exist, adds it and returns its index.
int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) {
@ -314,7 +279,7 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
for (std::pair<int32_t, TensorT*> tensor_pair : tensor_map) {
// Quantize the tensor.
TF_LITE_ENSURE_STATUS(
SymmetricQuantizeTensor(model.get(), tensor_pair.second));
utils::SymmetricQuantizeTensor(model.get(), tensor_pair.second));
}
// Examine the tensor consumers to determine which require dequantize ops.