Remove a wrong check.

PiperOrigin-RevId: 251688407
This commit is contained in:
Yunlu Li 2019-06-05 11:56:45 -07:00 committed by TensorFlower Gardener
parent 8b5b49f7d2
commit 799448646a
6 changed files with 27 additions and 14 deletions

View File

@ -327,7 +327,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes(
TfLiteStatus InterpreterBuilder::ParseQuantization( TfLiteStatus InterpreterBuilder::ParseQuantization(
const QuantizationParameters* src_quantization, const QuantizationParameters* src_quantization,
TfLiteQuantization* quantization) { TfLiteQuantization* quantization, const std::vector<int>& dims) {
quantization->type = kTfLiteNoQuantization; quantization->type = kTfLiteNoQuantization;
if (!src_quantization || !src_quantization->scale() || if (!src_quantization || !src_quantization->scale() ||
src_quantization->scale()->size() == 0) { src_quantization->scale()->size() == 0) {
@ -353,14 +353,29 @@ TfLiteStatus InterpreterBuilder::ParseQuantization(
// Affine-quantization. // Affine-quantization.
quantization->type = kTfLiteAffineQuantization; quantization->type = kTfLiteAffineQuantization;
const size_t num_scales = src_quantization->scale()->size(); const size_t num_scales = src_quantization->scale()->size();
// Ensure that the quantization dimension is valid.
if (src_quantization->quantized_dimension() < 0 || if (src_quantization->quantized_dimension() < 0 ||
src_quantization->quantized_dimension() >= num_scales) { (!dims.empty() &&
src_quantization->quantized_dimension() >= dims.size())) {
error_reporter_->Report( error_reporter_->Report(
"quantized_dimension must be in range [0, %d). Was %d.", num_scales, "quantized_dimension must be in range [0, %d). Was %d.", dims.size(),
src_quantization->quantized_dimension()); src_quantization->quantized_dimension());
return kTfLiteError; return kTfLiteError;
} }
// Ensure that the number of scales is 1 for per-layer quantization, and
// matches number of quantization dimensions for per-axis quantization.
if (num_scales != 1 &&
(!dims.empty() &&
num_scales != dims[src_quantization->quantized_dimension()])) {
error_reporter_->Report(
"num_scales must be 1 for per-layer quantization, or %d for per-axis "
"quantization, but got %d.",
dims[src_quantization->quantized_dimension()], num_scales);
return kTfLiteError;
}
auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>( auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
malloc(sizeof(TfLiteAffineQuantization))); malloc(sizeof(TfLiteAffineQuantization)));
affine_quantization->scale = TfLiteFloatArrayCreate(num_scales); affine_quantization->scale = TfLiteFloatArrayCreate(num_scales);
@ -429,7 +444,7 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
const auto* src_quantization = tensor->quantization(); const auto* src_quantization = tensor->quantization();
TfLiteQuantization quantization; TfLiteQuantization quantization;
if (ParseQuantization(src_quantization, &quantization) != kTfLiteOk) { if (ParseQuantization(src_quantization, &quantization, dims) != kTfLiteOk) {
status = kTfLiteError; status = kTfLiteError;
continue; continue;
} }

View File

@ -210,7 +210,8 @@ class InterpreterBuilder {
Subgraph* subgraph); Subgraph* subgraph);
TfLiteStatus ApplyDelegates(Interpreter* interpreter); TfLiteStatus ApplyDelegates(Interpreter* interpreter);
TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization, TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization,
TfLiteQuantization* quantization); TfLiteQuantization* quantization,
const std::vector<int>& dims);
const ::tflite::Model* model_; const ::tflite::Model* model_;
const OpResolver& op_resolver_; const OpResolver& op_resolver_;

View File

@ -322,8 +322,7 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor, TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
float input_scale, float input_scale,
const float* weight_scales, const float* weight_scales,
int number_of_dimension, int number_of_dimension) {
int dimension_index) {
// Compute scales. // Compute scales.
std::vector<float> scales(number_of_dimension); std::vector<float> scales(number_of_dimension);
for (size_t i = 0; i < number_of_dimension; i++) { for (size_t i = 0; i < number_of_dimension; i++) {
@ -352,9 +351,8 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data()); uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
size_t buffer_size = num_elements * sizeof(int32_t); size_t buffer_size = num_elements * sizeof(int32_t);
std::vector<int64_t> zero_point(scales.size(), 0); std::vector<int64_t> zero_point(scales.size(), 0);
return AddQuantizationParams(scales, zero_point, dimension_index, return AddQuantizationParams(scales, zero_point, 0, uint8_buffer, buffer_size,
uint8_buffer, buffer_size, TensorType_INT32, TensorType_INT32, model, tensor);
model, tensor);
} }
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel, TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,

View File

@ -89,8 +89,7 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor, TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
float input_scale, float input_scale,
const float* weight_scales, const float* weight_scales,
int number_of_dimension, int number_of_dimension);
int dimension_index);
// Quantize weight with or without per channel. // Quantize weight with or without per channel.
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel, TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,

View File

@ -386,7 +386,7 @@ TEST(QuantizationUtilsTest, SymmetricPerChannelBiasQuantize) {
// Call and verify. // Call and verify.
EXPECT_EQ(SymmetricPerChannelBiasQuantize( EXPECT_EQ(SymmetricPerChannelBiasQuantize(
model.get(), model->subgraphs[0]->tensors[0].get(), input_scale, model.get(), model->subgraphs[0]->tensors[0].get(), input_scale,
weight_scales.data(), 2, 0), weight_scales.data(), 2),
kTfLiteOk); kTfLiteOk);
EXPECT_THAT(model->buffers[model->subgraphs[0]->tensors[0]->buffer]->data, EXPECT_THAT(model->buffers[model->subgraphs[0]->tensors[0]->buffer]->data,
ElementsAreArray({16, 0, 0, 0, 2, 0, 0, 0})); ElementsAreArray({16, 0, 0, 0, 2, 0, 0, 0}));

View File

@ -69,7 +69,7 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
} }
return utils::SymmetricPerChannelBiasQuantize( return utils::SymmetricPerChannelBiasQuantize(
model, bias_tensor, input_tensor->quantization->scale[0], model, bias_tensor, input_tensor->quantization->scale[0],
weight_scales.data(), channel_dim_size, channel_dim_index); weight_scales.data(), channel_dim_size);
} else { } else {
if (weight_scales.size() != 1) { if (weight_scales.size() != 1) {
error_reporter->Report( error_reporter->Report(