Strengthen IsQuantized rule to actually check quantization params rather than just looking at the type.

PiperOrigin-RevId: 253598425
This commit is contained in:
Suharsh Sivakumar 2019-06-17 09:41:00 -07:00 committed by TensorFlower Gardener
parent 568b95a874
commit ed21beb4a2
3 changed files with 3 additions and 18 deletions

View File

@ -114,10 +114,6 @@ bool HasBuffer(const ModelT* model, const SubGraphT* subgraph,
return true;
}
bool IsQuantized(const SubGraphT* subgraph, int tensor_index) {
return subgraph->tensors[tensor_index]->type != TensorType_FLOAT32;
}
bool HasMinMax(const TensorT* tensor) {
return tensor->quantization && !tensor->quantization->min.empty() &&
!tensor->quantization->max.empty();

View File

@ -55,16 +55,6 @@ TEST(ModelUtilsTest, HasBuffer) {
EXPECT_TRUE(HasBuffer(&model, model.subgraphs[0].get(), 0));
}
TEST(ModelUtilsTest, IsQuantized) {
tflite::SubGraphT subgraph;
auto tensor = absl::make_unique<tflite::TensorT>();
tensor->type = TensorType_UINT8;
subgraph.tensors.push_back(std::move(tensor));
EXPECT_TRUE(IsQuantized(&subgraph, 0));
subgraph.tensors[0]->type = TensorType_FLOAT32;
EXPECT_FALSE(IsQuantized(&subgraph, 0));
}
TEST(ModelUtilsTest, HasMinMax) {
TensorT tensor;
tensor.quantization = absl::make_unique<QuantizationParametersT>();

View File

@ -382,7 +382,7 @@ TfLiteStatus QuantizeOpInput(
}
const int32_t tensor_idx = op->inputs[input_idx];
TensorT* tensor = subgraph->tensors[tensor_idx].get();
const bool is_input_quantized = utils::IsQuantized(subgraph, tensor_idx);
const bool is_input_quantized = utils::QuantizationParametersExist(tensor);
if (property.quantizable && !is_input_quantized) {
// The operation is quantizable, but the input isn't yet quantized.
if (utils::HasBuffer(model, subgraph, tensor_idx)) {
@ -607,10 +607,9 @@ TfLiteStatus QuantizeBiases(ModelT* model, ErrorReporter* error_reporter) {
}
// Quantize if it is not quantized already as the
// output of another op or input of another op.
if (!utils::IsQuantized(subgraph, op->inputs[bias_idx])) {
TensorT* bias_tensor = subgraph->tensors[op->inputs[bias_idx]].get();
if (!utils::QuantizationParametersExist(bias_tensor)) {
if (utils::HasBuffer(model, subgraph, op->inputs[bias_idx])) {
TensorT* bias_tensor =
subgraph->tensors[op->inputs[bias_idx]].get();
if (property.inputs.size() != 2) {
error_reporter->Report(
"Expect the input length of "