Added post-calibration recording of the max and min values for all input tensors.

PiperOrigin-RevId: 275090636
Change-Id: If3f1aa0f48ba24fe9041b397624cdb5dcb8b81ce
This commit is contained in:
A. Unique TensorFlower 2019-10-16 12:37:40 -07:00 committed by TensorFlower Gardener
parent 7b32b175fe
commit 265126ca30
5 changed files with 342 additions and 84 deletions

View File

@ -17,8 +17,11 @@ cc_library(
srcs = ["quantization_utils.cc"],
hdrs = ["quantization_utils.h"],
deps = [
":model_utils",
"//tensorflow/lite:framework",
"//tensorflow/lite:minimal_logging",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels/internal:quantization_util",
"//tensorflow/lite/kernels/internal:round",
"//tensorflow/lite/kernels/internal:tensor_utils",
@ -93,6 +96,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util",
"@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
"@flatbuffers",

View File

@ -16,14 +16,20 @@ limitations under the License.
#include <cmath>
#include <cstdint>
#include <memory>
#include <string>
#include "absl/memory/memory.h"
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/round.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/minimal_logging.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/tools/optimize/model_utils.h"
namespace tflite {
namespace optimize {
@ -76,16 +82,40 @@ void GetAsymmetricQuantizationParams(
quantization_params->zero_point = std::vector<int64_t>(1, zero_point);
}
// Per-channel quantize a tensor at the given index and returns both scales and
// quantized values.
void SymmetricPerChannelQuantization(const float* const input,
const std::vector<int>& dimension,
int32_t channel_dim_index,
std::vector<float>* output_scales,
std::vector<int8_t>* output_value) {
// Set the max and min quantization parameter for a single tensor given its
// values.
void FillSingleMinMax(const float* const input, const uint64_t input_size,
QuantizationParametersT* quantization_params) {
const auto minmax = std::minmax_element(input, input + input_size);
quantization_params->min.assign(1, *minmax.first);
quantization_params->max.assign(1, *minmax.second);
}
TfLiteStatus FillPerChannelMinMax(const float* const input,
const std::vector<int32_t>& dimension,
int32_t channel_dim_index,
QuantizationParametersT* quantization_params,
ErrorReporter* error_reporter) {
if (!quantization_params->min.empty() || !quantization_params->max.empty()) {
error_reporter->Report(
"Min or max already present in tensor quantization params.");
return kTfLiteError;
}
if (dimension.size() != 4) {
error_reporter->Report("Expected tensor with four dimensions, but got %d.",
dimension.size());
return kTfLiteError;
}
if (channel_dim_index > 3) {
error_reporter->Report(
"Expected channel_dim_index to be less than four, but got %d.",
channel_dim_index);
return kTfLiteError;
}
const int32_t channel_dim_size = dimension[channel_dim_index];
std::vector<float> min_vals(channel_dim_size);
std::vector<float> max_vals(channel_dim_size);
quantization_params->quantized_dimension = channel_dim_index;
quantization_params->min = std::vector<float>(channel_dim_size);
quantization_params->max = std::vector<float>(channel_dim_size);
std::vector<bool> has_min_max_value(channel_dim_size, false);
int indices[4];
RuntimeShape tensor_dims{dimension[0], dimension[1], dimension[2],
@ -99,27 +129,53 @@ void SymmetricPerChannelQuantization(const float* const input,
int channel_idx = indices[channel_dim_index];
const float val = input[Offset(tensor_dims, indices)];
if (has_min_max_value[channel_idx]) {
if (min_vals[channel_idx] > val) {
min_vals[channel_idx] = val;
} else if (max_vals[channel_idx] < val) {
max_vals[channel_idx] = val;
if (quantization_params->min[channel_idx] > val) {
quantization_params->min[channel_idx] = val;
} else if (quantization_params->max[channel_idx] < val) {
quantization_params->max[channel_idx] = val;
}
} else {
min_vals[channel_idx] = val;
max_vals[channel_idx] = val;
quantization_params->min[channel_idx] = val;
quantization_params->max[channel_idx] = val;
has_min_max_value[channel_idx] = true;
}
}
}
}
}
return kTfLiteOk;
}
// Calculate scales per channel
// Per-channel quantize a tensor at the given index and fills both scales and
// quantized values.
TfLiteStatus SymmetricPerChannelQuantization(TensorT* tensor,
const float* const input,
int32_t channel_dim_index,
std::vector<float>* output_scales,
std::vector<int8_t>* output_value,
ErrorReporter* error_reporter) {
const int32_t channel_dim_size = tensor->shape[channel_dim_index];
if (tensor == nullptr) {
error_reporter->Report("Cannot quantize. Tensor is null.");
return kTfLiteError;
}
// Fill per channel max and min values if needed
if (tensor->quantization == nullptr) {
tensor->quantization = absl::make_unique<QuantizationParametersT>();
}
if (!HasMinMax(tensor)) {
TF_LITE_ENSURE_STATUS(
FillPerChannelMinMax(input, tensor->shape, channel_dim_index,
tensor->quantization.get(), error_reporter));
}
// Calculate scales per channel using max and min values from tensor.
std::vector<float> scale_invs(channel_dim_size);
const float half_scale = kMaxQuantizedValue;
for (int channel_idx = 0; channel_idx < channel_dim_size; channel_idx++) {
const float half_range = std::max(std::abs(min_vals[channel_idx]),
std::abs(max_vals[channel_idx]));
const float half_range =
std::max(std::abs(tensor->quantization->min[channel_idx]),
std::abs(tensor->quantization->max[channel_idx]));
output_scales->at(channel_idx) = half_range / half_scale;
if (half_range == 0) {
scale_invs[channel_idx] = 0;
@ -128,14 +184,16 @@ void SymmetricPerChannelQuantization(const float* const input,
}
}
// Quantize the values.
SymmetricPerChannelQuantizeValues(input, scale_invs, dimension,
// Quantize the input values.
SymmetricPerChannelQuantizeValues(input, scale_invs, tensor->shape,
channel_dim_index, output_value);
return kTfLiteOk;
}
TfLiteStatus SymmetricQuantizeFloatsToInt16(ModelT* model, TensorT* tensor,
float input_scale,
float weight_scale) {
float weight_scale,
ErrorReporter* error_reporter) {
// Compute scale and inverse of scale.
const float scaling_factor = input_scale * weight_scale;
const float scaling_factor_inv =
@ -161,12 +219,13 @@ TfLiteStatus SymmetricQuantizeFloatsToInt16(ModelT* model, TensorT* tensor,
std::vector<float> scales(1, scaling_factor);
std::vector<int64_t> zero_points(1, 0);
return AddQuantizationParams(scales, zero_points, 0, uint8_buffer,
buffer_size, TensorType_INT16, model, tensor);
buffer_size, TensorType_INT16, model, tensor,
error_reporter);
}
void SymmetricPerChannelQuantizeValues(const float* const input,
const std::vector<float>& scales_inv,
const std::vector<int>& dimension,
const std::vector<int32_t>& dimension,
int32_t channel_dim_index,
std::vector<int8_t>* output_value) {
// Quantize the values.
@ -193,11 +252,13 @@ void SymmetricPerChannelQuantizeValues(const float* const input,
TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
if (model == nullptr || tensor == nullptr) {
TFLITE_LOG(TFLITE_LOG_ERROR, "No tensor to quantize.");
return kTfLiteError;
}
BufferT* buffer = model->buffers[tensor->buffer].get();
if (buffer == nullptr) {
TFLITE_LOG(TFLITE_LOG_ERROR, "Missing buffer.");
return kTfLiteError;
}
float* float_data = reinterpret_cast<float*>(buffer->data.data());
@ -230,11 +291,13 @@ TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
TfLiteStatus QuantizeTensorFloat16(ModelT* model, TensorT* tensor) {
if (model == nullptr || tensor == nullptr) {
TFLITE_LOG(TFLITE_LOG_ERROR, "No tensor to quantize.");
return kTfLiteError;
}
BufferT* buffer = model->buffers[tensor->buffer].get();
if (buffer == nullptr) {
TFLITE_LOG(TFLITE_LOG_ERROR, "Missing buffer.");
return kTfLiteError;
}
@ -269,25 +332,34 @@ TfLiteStatus AddQuantizationParams(const std::vector<float>& scales,
int quantized_dimension,
const uint8_t* buffer_data,
size_t buffer_size, TensorType output_type,
ModelT* model, TensorT* tensor) {
ModelT* model, TensorT* tensor,
ErrorReporter* error_reporter) {
tensor->quantization = absl::make_unique<QuantizationParametersT>();
tensor->quantization->scale.assign(scales.begin(), scales.end());
if (zero_point.size() != scales.size()) {
error_reporter->Report(
"Received zero_point of size %d and scales of size %d. "
"These sizes should match.",
zero_point.size(), scales.size());
return kTfLiteError;
}
tensor->quantization->zero_point.assign(zero_point.begin(), zero_point.end());
tensor->quantization->quantized_dimension = quantized_dimension;
model->buffers[tensor->buffer]->data.assign(buffer_data,
buffer_data + buffer_size);
// Update the tensor type.
tensor->type = output_type;
return kTfLiteOk;
}
TfLiteStatus SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor,
int32_t channel_dim_index) {
int32_t channel_dim_index,
ErrorReporter* error_reporter) {
if (tensor->shape.size() != 4) {
error_reporter->Report(
"SymmetricQuantizeTensorPerChannel requires tensor with four "
"dimensions, but got %d dimension(s).",
tensor->shape.size());
return kTfLiteError;
}
@ -307,8 +379,9 @@ TfLiteStatus SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor,
// Quantize the input data with respect to channel_dim_index.
const std::vector<int> tensor_dims = {tensor->shape[0], tensor->shape[1],
tensor->shape[2], tensor->shape[3]};
SymmetricPerChannelQuantization(float_input_data, tensor_dims,
channel_dim_index, &scales, &final_buffer);
TF_LITE_ENSURE_STATUS(SymmetricPerChannelQuantization(
tensor, float_input_data, channel_dim_index, &scales, &final_buffer,
error_reporter));
// Set the buffers and output type.
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
@ -316,12 +389,13 @@ TfLiteStatus SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor,
std::vector<int64_t> zero_point(scales.size(), 0);
return AddQuantizationParams(scales, zero_point, channel_dim_index,
uint8_buffer, buffer_size, TensorType_INT8,
model, tensor);
model, tensor, error_reporter);
}
TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
float input_scale,
float weight_scale) {
float weight_scale,
ErrorReporter* error_reporter) {
// Compute scale and inverse of scale.
const float scaling_factor = input_scale * weight_scale;
const float scaling_factor_inv =
@ -347,13 +421,15 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
std::vector<float> scales(1, scaling_factor);
std::vector<int64_t> zero_points(1, 0);
return AddQuantizationParams(scales, zero_points, 0, uint8_buffer,
buffer_size, TensorType_INT32, model, tensor);
buffer_size, TensorType_INT32, model, tensor,
error_reporter);
}
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
float input_scale,
const float* weight_scales,
int number_of_dimension) {
int number_of_dimension,
ErrorReporter* error_reporter) {
// Compute scales.
std::vector<float> scales(number_of_dimension);
for (size_t i = 0; i < number_of_dimension; i++) {
@ -383,16 +459,17 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
size_t buffer_size = num_elements * sizeof(int32_t);
std::vector<int64_t> zero_point(scales.size(), 0);
return AddQuantizationParams(scales, zero_point, 0, uint8_buffer, buffer_size,
TensorType_INT32, model, tensor);
TensorType_INT32, model, tensor, error_reporter);
}
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
int per_axis_index) {
int per_axis_index, ErrorReporter* error_reporter) {
// TODO(suharshs): Currently we conflate quantizing weights and constants. Its
// possible that the right thing to do is asymmetric quantize the weight. Add
// support for this.
if (per_channel) {
return SymmetricQuantizeTensorPerChannel(model, tensor, per_axis_index);
return SymmetricQuantizeTensorPerChannel(model, tensor, per_axis_index,
error_reporter);
} else {
return SymmetricQuantizeTensor(model, tensor);
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <cstdint>
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
@ -35,28 +36,40 @@ void GetAsymmetricQuantizationParams(
float min, float max, const int quant_min, const int quant_max,
QuantizationParametersT* quantization_params);
// Populates the single total max and min values for a tensor.
void FillSingleMinMax(const float* const input, const uint64_t input_size,
QuantizationParametersT* quantization_params);
// Populates the max and min values for per channel quantization.
TfLiteStatus FillPerChannelMinMax(const float* const input,
const std::vector<int>& dimension,
int32_t channel_dim_index,
QuantizationParametersT* quantization_params,
ErrorReporter* error_reporter);
// Per-channel quantize a tensor at the given index and returns both scales and
// quantized values.
// Parameters:
// - tensor is the tensor to be quantized, needed to access associated
// quantization parameters
// - input is the float input data to be quantized.
// - dimension is the dimension of the input data. Only supports dimension of
// size 4.
// - channel_dim_index is the channel index within "dimension".
// dimension[channel_dim_index] gives the number of channels.
// - output_scale is the output scale, the size of which equals the number of
// channels.
// - output_value is the output data, the size of which equals the number of
// inputs.
void SymmetricPerChannelQuantization(const float* const input,
const std::vector<int>& dimension,
int32_t channel_dim_index,
std::vector<float>* output_scales,
std::vector<int8_t>* output_value);
TfLiteStatus SymmetricPerChannelQuantization(TensorT* tensor,
const float* const input,
int32_t channel_dim_index,
std::vector<float>* output_scales,
std::vector<int8_t>* output_value,
ErrorReporter* error_reporter);
// Quantize the values given an array of scales.
void SymmetricPerChannelQuantizeValues(const float* const input,
const std::vector<float>& scales_inv,
const std::vector<int>& dimension,
const std::vector<int32_t>& dimension,
int32_t channel_dim_index,
std::vector<int8_t>* output_value);
@ -73,32 +86,37 @@ TfLiteStatus AddQuantizationParams(const std::vector<float>& scales,
int quantized_dimension,
const uint8_t* buffer_data,
size_t buffer_size, TensorType output_type,
ModelT* model, TensorT* tensor);
ModelT* model, TensorT* tensor,
ErrorReporter* error_reporter);
// Quantize tensor with per channel.
TfLiteStatus SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor,
int32_t channel_dim_index);
int32_t channel_dim_index,
ErrorReporter* error_reporter);
// Symmetrically quantized float to 16bits.
TfLiteStatus SymmetricQuantizeFloatsToInt16(ModelT* model, TensorT* tensor,
float input_scale,
float weight_scale);
float weight_scale,
ErrorReporter* error_reporter);
// Symmetrically quantized the bias for per-layer ops (i.e. FullyConnected).
TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
float input_scale,
float weight_scale);
float weight_scale,
ErrorReporter* error_reporter);
// Symmetrically quantizes the bias for ops like Conv and DepthwiseConv.
// The scale of bias if weight_per_channel_scale[channel] * input_scale
// The scale of bias if weight_per_channel_scale[channel] * input_scale.
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
float input_scale,
const float* weight_scales,
int number_of_dimension);
int number_of_dimension,
ErrorReporter* error_reporter);
// Quantize weight with or without per channel.
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
int per_axis_index);
int per_axis_index, ErrorReporter* error_reporter);
// Quantize activation.
void QuantizeActivation(TensorT* tensor);

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h"
#include "tensorflow/lite/tools/optimize/test_util.h"
namespace {
@ -44,7 +45,12 @@ std::unique_ptr<FlatBufferModel> ReadConvModel() {
using ::testing::ElementsAreArray;
TEST(QuantizationUtilsTest, NumElements) {
class QuantizationUtilsTest : public testing::Test {
protected:
tflite::TestErrorReporter error_reporter_;
};
TEST_F(QuantizationUtilsTest, NumElements) {
TensorT tensor;
tensor.shape = {1, 2, 3, 4};
uint64_t num_elements;
@ -60,7 +66,7 @@ TEST(QuantizationUtilsTest, NumElements) {
EXPECT_EQ(num_elements, 1);
}
TEST(QuantizationUtilsTest, GetAsymmetricQuantizationParamsUnitRange) {
TEST_F(QuantizationUtilsTest, GetAsymmetricQuantizationParamsUnitRange) {
const float float_min = -128.0;
const float float_max = 127.0;
const int quant_min = -128;
@ -82,7 +88,8 @@ TEST(QuantizationUtilsTest, GetAsymmetricQuantizationParamsUnitRange) {
EXPECT_NEAR(scale, 1, eps);
}
TEST(QuantizationUtilsTest, AsymmetricQuantizationParamsWithAllPositiveRange) {
TEST_F(QuantizationUtilsTest,
AsymmetricQuantizationParamsWithAllPositiveRange) {
// The min should get nudged to include 0, so the effective range is [0, 6].
const float float_min = 1.0;
const float float_max = 6.0;
@ -104,7 +111,8 @@ TEST(QuantizationUtilsTest, AsymmetricQuantizationParamsWithAllPositiveRange) {
EXPECT_NEAR(scale, 6 / 255.0f, eps);
}
TEST(QuantizationUtilsTest, AsymmetricQuantizationParamsWithAllNegativeRange) {
TEST_F(QuantizationUtilsTest,
AsymmetricQuantizationParamsWithAllNegativeRange) {
// The min should get nudged to include 0, so the effective range is [-6, 0].
const float float_min = -6.0;
const float float_max = -1.0;
@ -126,7 +134,7 @@ TEST(QuantizationUtilsTest, AsymmetricQuantizationParamsWithAllNegativeRange) {
EXPECT_NEAR(scale, 6 / 255.0f, eps);
}
TEST(QuantizationUtilsTest, AsymmetricQuantizationParamsWithZeroInRange) {
TEST_F(QuantizationUtilsTest, AsymmetricQuantizationParamsWithZeroInRange) {
const float float_min = -5.0;
const float float_max = 1.0;
const int quant_min = -128;
@ -148,7 +156,7 @@ TEST(QuantizationUtilsTest, AsymmetricQuantizationParamsWithZeroInRange) {
EXPECT_LT(zero_point, quant_max);
}
TEST(QuantizationUtilsTest, AsymmetricQuantizationParamsWithZeroMinMax) {
TEST_F(QuantizationUtilsTest, AsymmetricQuantizationParamsWithZeroMinMax) {
const float float_min = 0;
const float float_max = 0;
const int quant_min = -128;
@ -170,23 +178,27 @@ TEST(QuantizationUtilsTest, AsymmetricQuantizationParamsWithZeroMinMax) {
EXPECT_LT(zero_point, quant_max);
}
TEST(QuantizationUtilsTest, SymmetricPerChannelQuantization) {
TEST_F(QuantizationUtilsTest, SymmetricPerChannelQuantizationWithNullQParams) {
// Set up an input with [3, 2, 2, 2] size and 0 is the channel index.
const std::vector<float> input = {
3.0, 2.0, 5.0, -2.0, 3.0, 2.0, 5.0, -2.0, // Channel 1.
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // Channel 2.
1.0, 0.0, -1.0, -2.0, -3.0, -4.0, -5.0, -6.0, // Channel 3.
};
const std::vector<int32_t> dimension = {3, 2, 2, 2};
const int channel_index = 0;
// Create holder for output scale and data.
std::vector<float> output_scales(3);
std::vector<int8_t> output_data(3 * 2 * 2 * 2);
// Call SymmetricPerChannelQuantization and verify the result.
SymmetricPerChannelQuantization(input.data(), dimension, channel_index,
&output_scales, &output_data);
// Call SymmetricPerChannelQuantization with quant_params as a null pointer
// and verify the result.
TensorT tensor = TensorT();
tensor.quantization = nullptr;
tensor.shape = {3, 2, 2, 2};
SymmetricPerChannelQuantization(&tensor, input.data(), channel_index,
&output_scales, &output_data,
&error_reporter_);
const std::vector<float> expected_output_scales = {0.0393700786, 0.0629921257,
0.0472440943};
const std::vector<int8_t> expected_output_data = {
@ -198,7 +210,49 @@ TEST(QuantizationUtilsTest, SymmetricPerChannelQuantization) {
EXPECT_THAT(output_data, ElementsAreArray(expected_output_data));
}
TEST(QuantizationUtilsTest, SymmetricPerChannelQuantizeValues) {
TEST_F(QuantizationUtilsTest, SymmetricPerChannelQuantization) {
// Set up an input with [3, 2, 2, 2] size and 0 is the channel index.
const std::vector<float> input = {
3.0, 2.0, 5.0, -2.0, 3.0, 2.0, 5.0, -2.0, // Channel 1.
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // Channel 2.
1.0, 0.0, -1.0, -2.0, -3.0, -4.0, -5.0, -6.0, // Channel 3.
};
const int32_t channel_index = 0;
// Create holder for output scale and data.
std::vector<float> output_scales(3);
std::vector<int8_t> output_data(3 * 2 * 2 * 2);
// Initialize pointer to quantization parameters
TensorT tensor = TensorT();
tensor.quantization = absl::make_unique<QuantizationParametersT>();
tensor.shape = {3, 2, 2, 2};
FillPerChannelMinMax(input.data(), tensor.shape, channel_index,
tensor.quantization.get(), &error_reporter_);
// Test that FillPerChanneMinMax worked
const std::vector<float> expected_mins = {-2.0, 1.0, -6.0};
const std::vector<float> expected_maxs = {5.0, 8.0, 1.0};
EXPECT_THAT(tensor.quantization->min, ElementsAreArray(expected_mins));
EXPECT_THAT(tensor.quantization->max, ElementsAreArray(expected_maxs));
// Call SymmetricPerChannelQuantization with quant_params as a null pointer
// and verify the result.
SymmetricPerChannelQuantization(&tensor, input.data(), channel_index,
&output_scales, &output_data,
&error_reporter_);
const std::vector<float> expected_output_scales = {0.0393700786, 0.0629921257,
0.0472440943};
const std::vector<int8_t> expected_output_data = {
76, 51, 127, -51, 76, 51, 127, -51, // Channel 1.
16, 32, 48, 64, 79, 95, 111, 127, // Channel 2.
21, 0, -21, -42, -64, -85, -106, -127, // Channel 3.
};
EXPECT_THAT(output_scales, ElementsAreArray(expected_output_scales));
EXPECT_THAT(output_data, ElementsAreArray(expected_output_data));
}
TEST_F(QuantizationUtilsTest, SymmetricPerChannelQuantizeValues) {
// Set up an input with [3, 1, 1, 2] size and 0 is the channel index.
const std::vector<float> input = {
13.0, 21.0, // Channel 1.
@ -223,11 +277,34 @@ TEST(QuantizationUtilsTest, SymmetricPerChannelQuantizeValues) {
EXPECT_THAT(output_data, ElementsAreArray(expected_output_data));
}
TEST(QuantizationUtilsTest, SymmetricQuantizeTensorNullInputs) {
TEST_F(QuantizationUtilsTest, FillSingleMinMax) {
// Set up an input with [3, 1, 1, 2] size
const std::vector<float> input = {
13.0, 21.0, // Channel 1.
21.0, 22.0, // Channel 2.
31.0, 40.0, // Channel 3.
};
const uint32_t input_size = input.size();
// Initialize pointer to quantization parameters
QuantizationParametersT quantization_params = QuantizationParametersT();
FillSingleMinMax(input.data(), input_size, &quantization_params);
const std::vector<float> expected_min_max = {
13, 40, // min max
};
EXPECT_EQ(quantization_params.min.size(), 1);
EXPECT_EQ(quantization_params.max.size(), 1);
EXPECT_EQ(quantization_params.min[0], expected_min_max[0]);
EXPECT_EQ(quantization_params.max[0], expected_min_max[1]);
}
TEST_F(QuantizationUtilsTest, SymmetricQuantizeTensorNullInputs) {
tflite::TestErrorReporter error_reporter_;
EXPECT_EQ(SymmetricQuantizeTensor(nullptr, nullptr), kTfLiteError);
}
TEST(QuantizationUtilsTest, SymmetricQuantizeTensor) {
TEST_F(QuantizationUtilsTest, SymmetricQuantizeTensor) {
// Conv model has weights between 0 and 10.
// Quantize the weights tensor.
ASSERT_TRUE(g_test_model_dir);
@ -259,7 +336,7 @@ TEST(QuantizationUtilsTest, SymmetricQuantizeTensor) {
EXPECT_EQ(quant_buffer_size * 4, float_buffer_size);
}
TEST(QuantizationUtilsTest, QuantizeFloat16) {
TEST_F(QuantizationUtilsTest, QuantizeFloat16) {
// Conv model has weights between 0 and 10.
// Quantize the weights tensor.
ASSERT_TRUE(g_test_model_dir != nullptr);
@ -291,7 +368,7 @@ TEST(QuantizationUtilsTest, QuantizeFloat16) {
EXPECT_EQ(quant_buffer_size * 2, float_buffer_size);
}
TEST(QuantizationUtilsTest, AddQuantizationParams) {
TEST_F(QuantizationUtilsTest, AddQuantizationParams) {
// Create data.
auto model = absl::make_unique<ModelT>();
auto subgraph = absl::make_unique<tflite::SubGraphT>();
@ -310,11 +387,11 @@ TEST(QuantizationUtilsTest, AddQuantizationParams) {
model->buffers.push_back(std::move(buffer));
// Call and verify.
EXPECT_EQ(
AddQuantizationParams(scales, zero_points, quantizated_dimension,
buffer_data.data(), buffer_size, TensorType_INT8,
model.get(), model->subgraphs[0]->tensors[0].get()),
kTfLiteOk);
EXPECT_EQ(AddQuantizationParams(
scales, zero_points, quantizated_dimension, buffer_data.data(),
buffer_size, TensorType_INT8, model.get(),
model->subgraphs[0]->tensors[0].get(), &error_reporter_),
kTfLiteOk);
EXPECT_THAT(model->subgraphs[0]->tensors[0]->quantization->scale,
ElementsAreArray(scales));
EXPECT_THAT(model->subgraphs[0]->tensors[0]->quantization->zero_point,
@ -324,7 +401,7 @@ TEST(QuantizationUtilsTest, AddQuantizationParams) {
EXPECT_EQ(model->subgraphs[0]->tensors[0]->type, TensorType_INT8);
}
TEST(QuantizationUtilsTest, SymmetricQuantizeFloatsToInt16Test) {
TEST_F(QuantizationUtilsTest, SymmetricQuantizeFloatsToInt16Test) {
// Create data.
auto model = absl::make_unique<ModelT>();
auto subgraph = absl::make_unique<tflite::SubGraphT>();
@ -350,7 +427,7 @@ TEST(QuantizationUtilsTest, SymmetricQuantizeFloatsToInt16Test) {
// Call and verify.
EXPECT_EQ(SymmetricQuantizeFloatsToInt16(
model.get(), model->subgraphs[0]->tensors[0].get(), input_scale,
weight_scale),
weight_scale, &error_reporter_),
kTfLiteOk);
EXPECT_THAT(model->subgraphs[0]->tensors[0]->quantization->scale[0],
@ -366,7 +443,7 @@ TEST(QuantizationUtilsTest, SymmetricQuantizeFloatsToInt16Test) {
EXPECT_EQ(model->subgraphs[0]->tensors[0]->type, TensorType_INT16);
}
TEST(QuantizationUtilsTest, SymmetricPerLayerBiasQuantize) {
TEST_F(QuantizationUtilsTest, SymmetricPerLayerBiasQuantize) {
// Create data.
auto model = absl::make_unique<ModelT>();
auto subgraph = absl::make_unique<tflite::SubGraphT>();
@ -389,9 +466,9 @@ TEST(QuantizationUtilsTest, SymmetricPerLayerBiasQuantize) {
model->buffers.push_back(std::move(buffer));
// Call and verify.
EXPECT_EQ(SymmetricPerLayerBiasQuantize(model.get(),
model->subgraphs[0]->tensors[0].get(),
input_scale, weight_scale),
EXPECT_EQ(SymmetricPerLayerBiasQuantize(
model.get(), model->subgraphs[0]->tensors[0].get(), input_scale,
weight_scale, &error_reporter_),
kTfLiteOk);
EXPECT_THAT(model->subgraphs[0]->tensors[0]->quantization->scale[0],
@ -403,7 +480,7 @@ TEST(QuantizationUtilsTest, SymmetricPerLayerBiasQuantize) {
EXPECT_EQ(model->subgraphs[0]->tensors[0]->type, TensorType_INT32);
}
TEST(QuantizationUtilsTest, SymmetricPerChannelBiasQuantize) {
TEST_F(QuantizationUtilsTest, SymmetricPerChannelBiasQuantize) {
// Create data.
auto model = absl::make_unique<ModelT>();
auto subgraph = absl::make_unique<tflite::SubGraphT>();
@ -428,7 +505,7 @@ TEST(QuantizationUtilsTest, SymmetricPerChannelBiasQuantize) {
// Call and verify.
EXPECT_EQ(SymmetricPerChannelBiasQuantize(
model.get(), model->subgraphs[0]->tensors[0].get(), input_scale,
weight_scales.data(), 2),
weight_scales.data(), 2, &error_reporter_),
kTfLiteOk);
EXPECT_THAT(model->buffers[model->subgraphs[0]->tensors[0]->buffer]->data,
ElementsAreArray({16, 0, 0, 0, 2, 0, 0, 0}));

View File

@ -89,7 +89,7 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
}
return utils::SymmetricPerChannelBiasQuantize(
model, bias_tensor, input_tensor->quantization->scale[0],
weight_scales.data(), channel_dim_size);
weight_scales.data(), channel_dim_size, error_reporter);
} else {
if (weight_scales.size() != 1) {
error_reporter->Report(
@ -99,7 +99,7 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
}
return utils::SymmetricPerLayerBiasQuantize(
model, bias_tensor, input_tensor->quantization->scale[0],
weight_scales[0]);
weight_scales[0], error_reporter);
}
return kTfLiteError;
}
@ -417,15 +417,16 @@ TfLiteStatus QuantizeOpInput(
}
const int32_t tensor_idx = op->inputs[input_idx];
TensorT* tensor = subgraph->tensors[tensor_idx].get();
const bool is_input_quantized = utils::QuantizationParametersExist(tensor);
// Assumes op is quantized to int8.
const bool is_input_quantized = (tensor->type == TensorType_INT8);
if (property.quantizable && !is_input_quantized) {
// The operation is quantizable, but the input isn't yet quantized.
if (utils::HasBuffer(model, subgraph, tensor_idx)) {
// TODO(suharshs): Look at consumers, throw error if one consumer is
// per-channel and one per-layer.
if (utils::QuantizeWeight(model, tensor, tensor_property.per_axis,
tensor_property.per_axis_index) ==
kTfLiteError) {
tensor_property.per_axis_index,
error_reporter) == kTfLiteError) {
error_reporter->Report(
"Unable to quantize buffer or min/max value for input %d "
"in op %s in subgraph %d, node: %d",
@ -680,6 +681,85 @@ std::unordered_set<string> GetAllOperatorOutputs(ModelT* model) {
}
return operator_names;
}
// Populate the quantization parameters max and min for input tensors.
// Assumes that dynamic tensors already have stored min, max values and throw
// an error if a tensor does not have min, max quantization parameter or a
// buffer.
// If any static tensors are not inputs to an operation, their max, min values
// will not be filled by this function.
TfLiteStatus FillQuantizationParams(
ModelT* model, const std::unordered_set<string>& operator_names,
ErrorReporter* error_reporter) {
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
subgraph_idx++) {
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
for (size_t op_idx = 0; op_idx < subgraph->operators.size(); op_idx++) {
OperatorT* op = subgraph->operators[op_idx].get();
const BuiltinOperator op_code =
model->operator_codes[op->opcode_index]->builtin_code;
operator_property::OperatorProperty property = GetOperatorProperty(
operator_names, op_code, subgraph->tensors[op->outputs[0]]->name);
// Populate max, min for each input tensor.
for (const std::pair<int, operator_property::TensorProperty>& input :
property.inputs) {
// Get tensor.
const int32_t input_idx = input.first;
const int32_t tensor_idx = op->inputs[input_idx];
TensorT* tensor = subgraph->tensors[tensor_idx].get();
// Static tensor.
if (!utils::HasMinMax(tensor) &&
utils::HasBuffer(model, subgraph, tensor_idx)) {
// Get input float data and tensor dimensions.
BufferT* buffer = model->buffers[tensor->buffer].get();
float* float_input_data =
reinterpret_cast<float*>(buffer->data.data());
// Fill per channel max and min with respect to channel_dim_index.
if (input.second.per_axis) {
if (tensor->shape.size() == 4) {
int32_t channel_dim_index = input.second.per_axis_index;
TF_LITE_ENSURE_STATUS(utils::FillPerChannelMinMax(
float_input_data, tensor->shape, channel_dim_index,
tensor->quantization.get(), error_reporter));
} else {
error_reporter->Report(
"Could not fill max min for tensor as the dimension is %d "
"and not 4 as expected.",
tensor->shape.size());
}
// Fill per layer max and min.
} else if (!utils::HasMinMax(tensor) && !input.second.per_axis &&
utils::HasBuffer(model, subgraph, tensor_idx)) {
uint64_t input_size;
TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &input_size));
utils::FillSingleMinMax(float_input_data, input_size,
tensor->quantization.get());
}
if (tensor->quantization->quantized_dimension !=
input.second.per_axis_index) {
error_reporter->Report(
"Quantized dimension for tensor property and quantization "
"parameters do not match. Got %d and %d respectively.",
input.second.per_axis_index,
tensor->quantization->quantized_dimension);
return kTfLiteError;
}
// Dynamic tensor.
} else if (!utils::HasMinMax(tensor) &&
!utils::HasBuffer(model, subgraph, tensor_idx)) {
error_reporter->Report(
"Max and min for dynamic tensors should be"
" recorded during calibration");
return kTfLiteError;
}
} // loop over op inputs
} // loop over ops
} // loop over subgraphs
return kTfLiteOk;
}
} // namespace
@ -689,6 +769,8 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
const TensorType& output_type, bool allow_float,
const std::unordered_set<string>& operator_names,
ErrorReporter* error_reporter) {
TF_LITE_ENSURE_STATUS(
FillQuantizationParams(model, operator_names, error_reporter));
TF_LITE_ENSURE_STATUS(QuantizeWeightsInputOutput(
model, allow_float, operator_names, error_reporter));
TF_LITE_ENSURE_STATUS(