Remove a wrong check.
PiperOrigin-RevId: 251688407
This commit is contained in:
parent
8b5b49f7d2
commit
799448646a
@ -327,7 +327,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes(
|
|||||||
|
|
||||||
TfLiteStatus InterpreterBuilder::ParseQuantization(
|
TfLiteStatus InterpreterBuilder::ParseQuantization(
|
||||||
const QuantizationParameters* src_quantization,
|
const QuantizationParameters* src_quantization,
|
||||||
TfLiteQuantization* quantization) {
|
TfLiteQuantization* quantization, const std::vector<int>& dims) {
|
||||||
quantization->type = kTfLiteNoQuantization;
|
quantization->type = kTfLiteNoQuantization;
|
||||||
if (!src_quantization || !src_quantization->scale() ||
|
if (!src_quantization || !src_quantization->scale() ||
|
||||||
src_quantization->scale()->size() == 0) {
|
src_quantization->scale()->size() == 0) {
|
||||||
@ -353,14 +353,29 @@ TfLiteStatus InterpreterBuilder::ParseQuantization(
|
|||||||
// Affine-quantization.
|
// Affine-quantization.
|
||||||
quantization->type = kTfLiteAffineQuantization;
|
quantization->type = kTfLiteAffineQuantization;
|
||||||
const size_t num_scales = src_quantization->scale()->size();
|
const size_t num_scales = src_quantization->scale()->size();
|
||||||
|
|
||||||
|
// Ensure that the quantization dimension is valid.
|
||||||
if (src_quantization->quantized_dimension() < 0 ||
|
if (src_quantization->quantized_dimension() < 0 ||
|
||||||
src_quantization->quantized_dimension() >= num_scales) {
|
(!dims.empty() &&
|
||||||
|
src_quantization->quantized_dimension() >= dims.size())) {
|
||||||
error_reporter_->Report(
|
error_reporter_->Report(
|
||||||
"quantized_dimension must be in range [0, %d). Was %d.", num_scales,
|
"quantized_dimension must be in range [0, %d). Was %d.", dims.size(),
|
||||||
src_quantization->quantized_dimension());
|
src_quantization->quantized_dimension());
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure that the number of scales is 1 for per-layer quantization, and
|
||||||
|
// matches number of quantization dimensions for per-axis quantization.
|
||||||
|
if (num_scales != 1 &&
|
||||||
|
(!dims.empty() &&
|
||||||
|
num_scales != dims[src_quantization->quantized_dimension()])) {
|
||||||
|
error_reporter_->Report(
|
||||||
|
"num_scales must be 1 for per-layer quantization, or %d for per-axis "
|
||||||
|
"quantization, but got %d.",
|
||||||
|
dims[src_quantization->quantized_dimension()], num_scales);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
|
auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||||
malloc(sizeof(TfLiteAffineQuantization)));
|
malloc(sizeof(TfLiteAffineQuantization)));
|
||||||
affine_quantization->scale = TfLiteFloatArrayCreate(num_scales);
|
affine_quantization->scale = TfLiteFloatArrayCreate(num_scales);
|
||||||
@ -429,7 +444,7 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
|
|||||||
|
|
||||||
const auto* src_quantization = tensor->quantization();
|
const auto* src_quantization = tensor->quantization();
|
||||||
TfLiteQuantization quantization;
|
TfLiteQuantization quantization;
|
||||||
if (ParseQuantization(src_quantization, &quantization) != kTfLiteOk) {
|
if (ParseQuantization(src_quantization, &quantization, dims) != kTfLiteOk) {
|
||||||
status = kTfLiteError;
|
status = kTfLiteError;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -210,7 +210,8 @@ class InterpreterBuilder {
|
|||||||
Subgraph* subgraph);
|
Subgraph* subgraph);
|
||||||
TfLiteStatus ApplyDelegates(Interpreter* interpreter);
|
TfLiteStatus ApplyDelegates(Interpreter* interpreter);
|
||||||
TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization,
|
TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization,
|
||||||
TfLiteQuantization* quantization);
|
TfLiteQuantization* quantization,
|
||||||
|
const std::vector<int>& dims);
|
||||||
|
|
||||||
const ::tflite::Model* model_;
|
const ::tflite::Model* model_;
|
||||||
const OpResolver& op_resolver_;
|
const OpResolver& op_resolver_;
|
||||||
|
@ -322,8 +322,7 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
|||||||
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
||||||
float input_scale,
|
float input_scale,
|
||||||
const float* weight_scales,
|
const float* weight_scales,
|
||||||
int number_of_dimension,
|
int number_of_dimension) {
|
||||||
int dimension_index) {
|
|
||||||
// Compute scales.
|
// Compute scales.
|
||||||
std::vector<float> scales(number_of_dimension);
|
std::vector<float> scales(number_of_dimension);
|
||||||
for (size_t i = 0; i < number_of_dimension; i++) {
|
for (size_t i = 0; i < number_of_dimension; i++) {
|
||||||
@ -352,9 +351,8 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
|||||||
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
|
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
|
||||||
size_t buffer_size = num_elements * sizeof(int32_t);
|
size_t buffer_size = num_elements * sizeof(int32_t);
|
||||||
std::vector<int64_t> zero_point(scales.size(), 0);
|
std::vector<int64_t> zero_point(scales.size(), 0);
|
||||||
return AddQuantizationParams(scales, zero_point, dimension_index,
|
return AddQuantizationParams(scales, zero_point, 0, uint8_buffer, buffer_size,
|
||||||
uint8_buffer, buffer_size, TensorType_INT32,
|
TensorType_INT32, model, tensor);
|
||||||
model, tensor);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
|
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
|
||||||
|
@ -89,8 +89,7 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
|||||||
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
||||||
float input_scale,
|
float input_scale,
|
||||||
const float* weight_scales,
|
const float* weight_scales,
|
||||||
int number_of_dimension,
|
int number_of_dimension);
|
||||||
int dimension_index);
|
|
||||||
|
|
||||||
// Quantize weight with or without per channel.
|
// Quantize weight with or without per channel.
|
||||||
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
|
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
|
||||||
|
@ -386,7 +386,7 @@ TEST(QuantizationUtilsTest, SymmetricPerChannelBiasQuantize) {
|
|||||||
// Call and verify.
|
// Call and verify.
|
||||||
EXPECT_EQ(SymmetricPerChannelBiasQuantize(
|
EXPECT_EQ(SymmetricPerChannelBiasQuantize(
|
||||||
model.get(), model->subgraphs[0]->tensors[0].get(), input_scale,
|
model.get(), model->subgraphs[0]->tensors[0].get(), input_scale,
|
||||||
weight_scales.data(), 2, 0),
|
weight_scales.data(), 2),
|
||||||
kTfLiteOk);
|
kTfLiteOk);
|
||||||
EXPECT_THAT(model->buffers[model->subgraphs[0]->tensors[0]->buffer]->data,
|
EXPECT_THAT(model->buffers[model->subgraphs[0]->tensors[0]->buffer]->data,
|
||||||
ElementsAreArray({16, 0, 0, 0, 2, 0, 0, 0}));
|
ElementsAreArray({16, 0, 0, 0, 2, 0, 0, 0}));
|
||||||
|
@ -69,7 +69,7 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
|
|||||||
}
|
}
|
||||||
return utils::SymmetricPerChannelBiasQuantize(
|
return utils::SymmetricPerChannelBiasQuantize(
|
||||||
model, bias_tensor, input_tensor->quantization->scale[0],
|
model, bias_tensor, input_tensor->quantization->scale[0],
|
||||||
weight_scales.data(), channel_dim_size, channel_dim_index);
|
weight_scales.data(), channel_dim_size);
|
||||||
} else {
|
} else {
|
||||||
if (weight_scales.size() != 1) {
|
if (weight_scales.size() != 1) {
|
||||||
error_reporter->Report(
|
error_reporter->Report(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user