Remove a wrong check.

PiperOrigin-RevId: 251688407
This commit is contained in:
Yunlu Li 2019-06-05 11:56:45 -07:00 committed by TensorFlower Gardener
parent 8b5b49f7d2
commit 799448646a
6 changed files with 27 additions and 14 deletions

View File

@ -327,7 +327,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes(
TfLiteStatus InterpreterBuilder::ParseQuantization(
const QuantizationParameters* src_quantization,
TfLiteQuantization* quantization) {
TfLiteQuantization* quantization, const std::vector<int>& dims) {
quantization->type = kTfLiteNoQuantization;
if (!src_quantization || !src_quantization->scale() ||
src_quantization->scale()->size() == 0) {
@ -353,14 +353,29 @@ TfLiteStatus InterpreterBuilder::ParseQuantization(
// Affine-quantization.
quantization->type = kTfLiteAffineQuantization;
const size_t num_scales = src_quantization->scale()->size();
// Ensure that the quantization dimension is valid.
if (src_quantization->quantized_dimension() < 0 ||
src_quantization->quantized_dimension() >= num_scales) {
(!dims.empty() &&
src_quantization->quantized_dimension() >= dims.size())) {
error_reporter_->Report(
"quantized_dimension must be in range [0, %d). Was %d.", num_scales,
"quantized_dimension must be in range [0, %d). Was %d.", dims.size(),
src_quantization->quantized_dimension());
return kTfLiteError;
}
// Ensure that the number of scales is 1 for per-layer quantization, and
// matches number of quantization dimensions for per-axis quantization.
if (num_scales != 1 &&
(!dims.empty() &&
num_scales != dims[src_quantization->quantized_dimension()])) {
error_reporter_->Report(
"num_scales must be 1 for per-layer quantization, or %d for per-axis "
"quantization, but got %d.",
dims[src_quantization->quantized_dimension()], num_scales);
return kTfLiteError;
}
auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
malloc(sizeof(TfLiteAffineQuantization)));
affine_quantization->scale = TfLiteFloatArrayCreate(num_scales);
@ -429,7 +444,7 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
const auto* src_quantization = tensor->quantization();
TfLiteQuantization quantization;
if (ParseQuantization(src_quantization, &quantization) != kTfLiteOk) {
if (ParseQuantization(src_quantization, &quantization, dims) != kTfLiteOk) {
status = kTfLiteError;
continue;
}

View File

@ -210,7 +210,8 @@ class InterpreterBuilder {
Subgraph* subgraph);
TfLiteStatus ApplyDelegates(Interpreter* interpreter);
TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization,
TfLiteQuantization* quantization);
TfLiteQuantization* quantization,
const std::vector<int>& dims);
const ::tflite::Model* model_;
const OpResolver& op_resolver_;

View File

@ -322,8 +322,7 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
float input_scale,
const float* weight_scales,
int number_of_dimension,
int dimension_index) {
int number_of_dimension) {
// Compute scales.
std::vector<float> scales(number_of_dimension);
for (size_t i = 0; i < number_of_dimension; i++) {
@ -352,9 +351,8 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(final_buffer.data());
size_t buffer_size = num_elements * sizeof(int32_t);
std::vector<int64_t> zero_point(scales.size(), 0);
return AddQuantizationParams(scales, zero_point, dimension_index,
uint8_buffer, buffer_size, TensorType_INT32,
model, tensor);
return AddQuantizationParams(scales, zero_point, 0, uint8_buffer, buffer_size,
TensorType_INT32, model, tensor);
}
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,

View File

@ -89,8 +89,7 @@ TfLiteStatus SymmetricPerLayerBiasQuantize(ModelT* model, TensorT* tensor,
TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
float input_scale,
const float* weight_scales,
int number_of_dimension,
int dimension_index);
int number_of_dimension);
// Quantize weight with or without per channel.
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,

View File

@ -386,7 +386,7 @@ TEST(QuantizationUtilsTest, SymmetricPerChannelBiasQuantize) {
// Call and verify.
EXPECT_EQ(SymmetricPerChannelBiasQuantize(
model.get(), model->subgraphs[0]->tensors[0].get(), input_scale,
weight_scales.data(), 2, 0),
weight_scales.data(), 2),
kTfLiteOk);
EXPECT_THAT(model->buffers[model->subgraphs[0]->tensors[0]->buffer]->data,
ElementsAreArray({16, 0, 0, 0, 2, 0, 0, 0}));

View File

@ -69,7 +69,7 @@ TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
}
return utils::SymmetricPerChannelBiasQuantize(
model, bias_tensor, input_tensor->quantization->scale[0],
weight_scales.data(), channel_dim_size, channel_dim_index);
weight_scales.data(), channel_dim_size);
} else {
if (weight_scales.size() != 1) {
error_reporter->Report(