Remove a wrong check.
PiperOrigin-RevId: 251688407
This commit is contained in:
parent
8b5b49f7d2
commit
799448646a
@ -327,7 +327,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes(
|
||||
|
||||
TfLiteStatus InterpreterBuilder::ParseQuantization(
|
||||
const QuantizationParameters* src_quantization,
|
||||
TfLiteQuantization* quantization) {
|
||||
TfLiteQuantization* quantization, const std::vector<int>& dims) {
|
||||
quantization->type = kTfLiteNoQuantization;
|
||||
if (!src_quantization || !src_quantization->scale() ||
|
||||
src_quantization->scale()->size() == 0) {
|
||||
@ -353,14 +353,29 @@ TfLiteStatus InterpreterBuilder::ParseQuantization(
|
||||
// Affine-quantization.
|
||||
quantization->type = kTfLiteAffineQuantization;
|
||||
const size_t num_scales = src_quantization->scale()->size();
|
||||
|
||||
// Ensure that the quantization dimension is valid.
|
||||
if (src_quantization->quantized_dimension() < 0 ||
|
||||
src_quantization->quantized_dimension() >= num_scales) {
|
||||
(!dims.empty() &&
|
||||
src_quantization->quantized_dimension() >= dims.size())) {
|
||||
error_reporter_->Report(
|
||||
"quantized_dimension must be in range [0, %d). Was %d.", num_scales,
|
||||
"quantized_dimension must be in range [0, %d). Was %d.", dims.size(),
|
||||
src_quantization->quantized_dimension());
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Ensure that the number of scales is 1 for per-layer quantization, and
|
||||
// matches number of quantization dimensions for per-axis quantization.
|
||||
if (num_scales != 1 &&
|
||||
(!dims.empty() &&
|
||||
num_scales != dims[src_quantization->quantized_dimension()])) {
|
||||
error_reporter_->Report(
|
||||
"num_scales must be 1 for per-layer quantization, or %d for per-axis "
|
||||
"quantization, but got %d.",
|
||||
dims[src_quantization->quantized_dimension()], num_scales);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
malloc(sizeof(TfLiteAffineQuantization)));
|
||||
affine_quantization->scale = TfLiteFloatArrayCreate(num_scales);
|
||||
@ -429,7 +444,7 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
|
||||
|
||||
const auto* src_quantization = tensor->quantization();
|
||||
TfLiteQuantization quantization;
|
||||
if (ParseQuantization(src_quantization, &quantization) != kTfLiteOk) {
|
||||
if (ParseQuantization(src_quantization, &quantization, dims) != kTfLiteOk) {
|
||||
status = kTfLiteError;
|
||||
continue;
|
||||
}
|
||||
|
@ -210,7 +210,8 @@ class InterpreterBuilder {
|
||||
Subgraph* subgraph);
|
||||
TfLiteStatus ApplyDelegates(Interpreter* interpreter);
|
||||
TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization,
|
||||
TfLiteQuantization* quantization);
|
||||
TfLiteQuantization* quantization,
|
||||
const std::vector<int>& dims);
|
||||
|
||||
const ::tflite::Model* model_;
|
||||
const OpResolver& op_resolver_;
|
||||
|
@ -322,8 +322,7 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
float input_scale,
|
||||
const float* weight_scales,
|
||||
int number_of_dimension,
|
||||
int dimension_index) {
|
||||
int number_of_dimension) {
|
||||
// Compute scales.
|
||||
std::vector<float> scales(number_of_dimension);
|
||||
for (size_t i = 0; i < number_of_dimension; i++) {
|
||||
@ -352,9 +351,8 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
|
||||
size_t buffer_size = num_elements * sizeof(int32_t);
|
||||
std::vector<int64_t> zero_point(scales.size(), 0);
|
||||
return AddQuantizationParams(scales, zero_point, dimension_index,
|
||||
uint8_buffer, buffer_size, TensorType_INT32,
|
||||
model, tensor);
|
||||
return AddQuantizationParams(scales, zero_point, 0, uint8_buffer, buffer_size,
|
||||
TensorType_INT32, model, tensor);
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
|
||||
|
@ -89,8 +89,7 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
|
||||
float input_scale,
|
||||
const float* weight_scales,
|
||||
int number_of_dimension,
|
||||
int dimension_index);
|
||||
int number_of_dimension);
|
||||
|
||||
// Quantize weight with or without per channel.
|
||||
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
|
||||
|
@ -386,7 +386,7 @@ TEST(QuantizationUtilsTest, SymmetricPerChannelBiasQuantize) {
|
||||
// Call and verify.
|
||||
EXPECT_EQ(SymmetricPerChannelBiasQuantize(
|
||||
model.get(), model->subgraphs[0]->tensors[0].get(), input_scale,
|
||||
weight_scales.data(), 2, 0),
|
||||
weight_scales.data(), 2),
|
||||
kTfLiteOk);
|
||||
EXPECT_THAT(model->buffers[model->subgraphs[0]->tensors[0]->buffer]->data,
|
||||
ElementsAreArray({16, 0, 0, 0, 2, 0, 0, 0}));
|
||||
|
@ -69,7 +69,7 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
|
||||
}
|
||||
return utils::SymmetricPerChannelBiasQuantize(
|
||||
model, bias_tensor, input_tensor->quantization->scale[0],
|
||||
weight_scales.data(), channel_dim_size, channel_dim_index);
|
||||
weight_scales.data(), channel_dim_size);
|
||||
} else {
|
||||
if (weight_scales.size() != 1) {
|
||||
error_reporter->Report(
|
||||
|
Loading…
x
Reference in New Issue
Block a user