diff --git a/tensorflow/lite/model.cc b/tensorflow/lite/model.cc index 9f55d25f8b7..890e5118409 100644 --- a/tensorflow/lite/model.cc +++ b/tensorflow/lite/model.cc @@ -327,7 +327,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes( TfLiteStatus InterpreterBuilder::ParseQuantization( const QuantizationParameters* src_quantization, - TfLiteQuantization* quantization) { + TfLiteQuantization* quantization, const std::vector& dims) { quantization->type = kTfLiteNoQuantization; if (!src_quantization || !src_quantization->scale() || src_quantization->scale()->size() == 0) { @@ -353,14 +353,29 @@ TfLiteStatus InterpreterBuilder::ParseQuantization( // Affine-quantization. quantization->type = kTfLiteAffineQuantization; const size_t num_scales = src_quantization->scale()->size(); + + // Ensure that the quantization dimension is valid. if (src_quantization->quantized_dimension() < 0 || - src_quantization->quantized_dimension() >= num_scales) { + (!dims.empty() && + src_quantization->quantized_dimension() >= dims.size())) { error_reporter_->Report( - "quantized_dimension must be in range [0, %d). Was %d.", num_scales, + "quantized_dimension must be in range [0, %d). Was %d.", dims.size(), src_quantization->quantized_dimension()); return kTfLiteError; } + // Ensure that the number of scales is 1 for per-layer quantization, and + // matches number of quantization dimensions for per-axis quantization. + if (num_scales != 1 && + (!dims.empty() && + num_scales != dims[src_quantization->quantized_dimension()])) { + error_reporter_->Report( + "num_scales must be 1 for per-layer quantization, or %d for per-axis " + "quantization, but got %d.", + dims[src_quantization->quantized_dimension()], num_scales); + return kTfLiteError; + } + auto* affine_quantization = reinterpret_cast( malloc(sizeof(TfLiteAffineQuantization))); affine_quantization->scale = TfLiteFloatArrayCreate(num_scales); @@ -429,7 +444,7 @@ TfLiteStatus InterpreterBuilder::ParseTensors( const auto* src_quantization = tensor->quantization(); TfLiteQuantization quantization; - if (ParseQuantization(src_quantization, &quantization) != kTfLiteOk) { + if (ParseQuantization(src_quantization, &quantization, dims) != kTfLiteOk) { status = kTfLiteError; continue; } diff --git a/tensorflow/lite/model.h b/tensorflow/lite/model.h index e936010f191..6c569470f34 100644 --- a/tensorflow/lite/model.h +++ b/tensorflow/lite/model.h @@ -210,7 +210,8 @@ class InterpreterBuilder { Subgraph* subgraph); TfLiteStatus ApplyDelegates(Interpreter* interpreter); TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization, - TfLiteQuantization* quantization); + TfLiteQuantization* quantization, + const std::vector& dims); const ::tflite::Model* model_; const OpResolver& op_resolver_; diff --git a/tensorflow/lite/tools/optimize/quantization_utils.cc b/tensorflow/lite/tools/optimize/quantization_utils.cc index f26a1b5838e..f8012b024ae 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils.cc @@ -322,8 +322,7 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor, TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor, float input_scale, const float* weight_scales, - int number_of_dimension, - int dimension_index) { + int number_of_dimension) { // Compute scales. std::vector scales(number_of_dimension); for (size_t i = 0; i < number_of_dimension; i++) { @@ -352,9 +351,8 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor, uint8_t* uint8_buffer = reinterpret_cast(final_buffer.data()); size_t buffer_size = num_elements * sizeof(int32_t); std::vector zero_point(scales.size(), 0); - return AddQuantizationParams(scales, zero_point, dimension_index, - uint8_buffer, buffer_size, TensorType_INT32, - model, tensor); + return AddQuantizationParams(scales, zero_point, 0, uint8_buffer, buffer_size, + TensorType_INT32, model, tensor); } TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel, diff --git a/tensorflow/lite/tools/optimize/quantization_utils.h b/tensorflow/lite/tools/optimize/quantization_utils.h index 4cc67cfe40a..7a37614d990 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils.h +++ b/tensorflow/lite/tools/optimize/quantization_utils.h @@ -89,8 +89,7 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor, TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor, float input_scale, const float* weight_scales, - int number_of_dimension, - int dimension_index); + int number_of_dimension); // Quantize weight with or without per channel. TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel, diff --git a/tensorflow/lite/tools/optimize/quantization_utils_test.cc b/tensorflow/lite/tools/optimize/quantization_utils_test.cc index d30df1a47ff..c50566ab8e8 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils_test.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils_test.cc @@ -386,7 +386,7 @@ TEST(QuantizationUtilsTest, SymmetricPerChannelBiasQuantize) { // Call and verify. EXPECT_EQ(SymmetricPerChannelBiasQuantize( model.get(), model->subgraphs[0]->tensors[0].get(), input_scale, - weight_scales.data(), 2, 0), + weight_scales.data(), 2), kTfLiteOk); EXPECT_THAT(model->buffers[model->subgraphs[0]->tensors[0]->buffer]->data, ElementsAreArray({16, 0, 0, 0, 2, 0, 0, 0})); diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc index 6b6ed6cd9dd..7966f3a0763 100644 --- a/tensorflow/lite/tools/optimize/quantize_model.cc +++ b/tensorflow/lite/tools/optimize/quantize_model.cc @@ -69,7 +69,7 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor, } return utils::SymmetricPerChannelBiasQuantize( model, bias_tensor, input_tensor->quantization->scale[0], - weight_scales.data(), channel_dim_size, channel_dim_index); + weight_scales.data(), channel_dim_size); } else { if (weight_scales.size() != 1) { error_reporter->Report(