diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 4e3fda7771e..535d1a9d2d9 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/core/platform/logging.h" @@ -565,7 +566,7 @@ void QuantizationDriver::PreprocessConstantOps() { // The user doesn't use this value as a bias operand or require same // scale, then this constant is considered to be a weight. if (biases.find(operand_num) == biases.end() && - !spec->requires_same_scale) { + !user->hasTrait()) { used_as_weight = true; } else { bias_users.push_back({user, operand_num}); @@ -593,8 +594,9 @@ void QuantizationDriver::SetupAllStates() { llvm::DenseMap value_to_state; fn_.walk([&](Operation *op) { - if (op->isKnownTerminator()) return; - if (!GetQuantSpec(op)->is_quantizable) return; + if (op->isKnownTerminator() || + op->hasTrait()) + return; work_list_.push_back(op); for (int i = 0, e = op->getNumOperands(); i != e; ++i) { @@ -653,12 +655,6 @@ bool QuantizationDriver::PropagateParams() { if (llvm::is_contained(quantized_, op)) continue; quantized_.insert(op); - auto spec = GetQuantSpec(op); - - // If the op has no quantizable result, the quantization parameters will not - // be propagated to the results. - if (!spec->is_quantizable) continue; - if (auto cst = llvm::dyn_cast(op)) { // If it isn't a weight or has been quantized, skip. if (!IsWeight(cst) || IsQuantized(op)) continue; @@ -669,7 +665,7 @@ bool QuantizationDriver::PropagateParams() { continue; } - if (spec->requires_same_scale) { + if (op->hasTrait()) { auto params = GetQuantParamsForSameScaleConstraint(op); // The quantization parameters haven't been propagated to any operands or // results. Skip this node for now. @@ -688,6 +684,7 @@ bool QuantizationDriver::PropagateParams() { } // TODO(fengliuai): make the bit width configurable. + auto spec = GetQuantSpec(op); auto key = std::make_pair(8, is_signed_); auto &restricted_outputs = spec->restricted_output_params[key]; for (int i = 0, e = restricted_outputs.size(); i != e; ++i) { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h index b64776ddee7..6d5d6555ff0 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h @@ -114,10 +114,7 @@ class AccumulatorUniformScale { // template class NoQuantizableResult - : public QuantizationSpecTraitBase { - public: - static bool IsQuantizable() { return false; } -}; + : public QuantizationSpecTraitBase {}; } // namespace quant } // namespace OpTrait diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index f06a7fee3d0..ade3feca855 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" namespace mlir { namespace TFL { @@ -40,14 +41,6 @@ using AccumulatorScaleFunc = // Quantization spec of an op, driving the quantization algorithm. struct OpQuantSpec { - // Whether the op has quantizable result. This flag is set to false if the op - // has "TFL::NoQuantizableResult" trait. - bool is_quantizable = true; - - // Whether it requires same inputs and result scale. This flag is set to true - // if the op has "TFL::SameOperandsAndResultScale" trait. - bool requires_same_scale = false; - // Maps the operand index of a bias input to its quantization specifications, // including the non-bias operand indexes and the method retrieving // quantization parameters from list of parameters of the non-bias operands. diff --git a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc index b381a5fa898..090f9713cc3 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc @@ -58,15 +58,6 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { OUT(2) << "if (auto tfl = llvm::dyn_cast<" << op.getQualCppClassName() << ">(op)) {\n"; - - // There is a "NoQuantizableResult" trait, set the flag. - if (trait.equals("NoQuantizableResult")) { - OUT(4) << "spec->is_quantizable = false;\n"; - } - // There is a "SameOperandsAndResultScale" trait, set the flag. - if (trait.equals("SameOperandsAndResultsScale")) { - OUT(4) << "spec->requires_same_scale = true;\n"; - } // There is a "FixedResultUniformScale" trait, set the type for result. auto trait_str = opTrait->getTrait().str(); if (fixed_uniform_trait_regex.match(trait_str, &matches)) {