NFC: Uses hasTrait interface to query the op quantization spec
Two of the spec can be migrated directly. There are two other op traits for quantization spec have more information and couldn't be migrated to hasTrait interface. PiperOrigin-RevId: 269724215
This commit is contained in:
parent
8ab4a0a7bb
commit
d30ba84437
@ -36,6 +36,7 @@ limitations under the License.
|
|||||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
||||||
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
@ -565,7 +566,7 @@ void QuantizationDriver::PreprocessConstantOps() {
|
|||||||
// The user doesn't use this value as a bias operand or require same
|
// The user doesn't use this value as a bias operand or require same
|
||||||
// scale, then this constant is considered to be a weight.
|
// scale, then this constant is considered to be a weight.
|
||||||
if (biases.find(operand_num) == biases.end() &&
|
if (biases.find(operand_num) == biases.end() &&
|
||||||
!spec->requires_same_scale) {
|
!user->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>()) {
|
||||||
used_as_weight = true;
|
used_as_weight = true;
|
||||||
} else {
|
} else {
|
||||||
bias_users.push_back({user, operand_num});
|
bias_users.push_back({user, operand_num});
|
||||||
@ -593,8 +594,9 @@ void QuantizationDriver::SetupAllStates() {
|
|||||||
llvm::DenseMap<Value *, int> value_to_state;
|
llvm::DenseMap<Value *, int> value_to_state;
|
||||||
|
|
||||||
fn_.walk([&](Operation *op) {
|
fn_.walk([&](Operation *op) {
|
||||||
if (op->isKnownTerminator()) return;
|
if (op->isKnownTerminator() ||
|
||||||
if (!GetQuantSpec(op)->is_quantizable) return;
|
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
|
||||||
|
return;
|
||||||
work_list_.push_back(op);
|
work_list_.push_back(op);
|
||||||
|
|
||||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||||
@ -653,12 +655,6 @@ bool QuantizationDriver::PropagateParams() {
|
|||||||
if (llvm::is_contained(quantized_, op)) continue;
|
if (llvm::is_contained(quantized_, op)) continue;
|
||||||
quantized_.insert(op);
|
quantized_.insert(op);
|
||||||
|
|
||||||
auto spec = GetQuantSpec(op);
|
|
||||||
|
|
||||||
// If the op has no quantizable result, the quantization parameters will not
|
|
||||||
// be propagated to the results.
|
|
||||||
if (!spec->is_quantizable) continue;
|
|
||||||
|
|
||||||
if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
|
if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
|
||||||
// If it isn't a weight or has been quantized, skip.
|
// If it isn't a weight or has been quantized, skip.
|
||||||
if (!IsWeight(cst) || IsQuantized(op)) continue;
|
if (!IsWeight(cst) || IsQuantized(op)) continue;
|
||||||
@ -669,7 +665,7 @@ bool QuantizationDriver::PropagateParams() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (spec->requires_same_scale) {
|
if (op->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>()) {
|
||||||
auto params = GetQuantParamsForSameScaleConstraint(op);
|
auto params = GetQuantParamsForSameScaleConstraint(op);
|
||||||
// The quantization parameters haven't been propagated to any operands or
|
// The quantization parameters haven't been propagated to any operands or
|
||||||
// results. Skip this node for now.
|
// results. Skip this node for now.
|
||||||
@ -688,6 +684,7 @@ bool QuantizationDriver::PropagateParams() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(fengliuai): make the bit width configurable.
|
// TODO(fengliuai): make the bit width configurable.
|
||||||
|
auto spec = GetQuantSpec(op);
|
||||||
auto key = std::make_pair(8, is_signed_);
|
auto key = std::make_pair(8, is_signed_);
|
||||||
auto &restricted_outputs = spec->restricted_output_params[key];
|
auto &restricted_outputs = spec->restricted_output_params[key];
|
||||||
for (int i = 0, e = restricted_outputs.size(); i != e; ++i) {
|
for (int i = 0, e = restricted_outputs.size(); i != e; ++i) {
|
||||||
|
@ -114,10 +114,7 @@ class AccumulatorUniformScale {
|
|||||||
//
|
//
|
||||||
template <typename ConcreteType>
|
template <typename ConcreteType>
|
||||||
class NoQuantizableResult
|
class NoQuantizableResult
|
||||||
: public QuantizationSpecTraitBase<ConcreteType, NoQuantizableResult> {
|
: public QuantizationSpecTraitBase<ConcreteType, NoQuantizableResult> {};
|
||||||
public:
|
|
||||||
static bool IsQuantizable() { return false; }
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace quant
|
} // namespace quant
|
||||||
} // namespace OpTrait
|
} // namespace OpTrait
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||||
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TFL {
|
namespace TFL {
|
||||||
@ -40,14 +41,6 @@ using AccumulatorScaleFunc =
|
|||||||
|
|
||||||
// Quantization spec of an op, driving the quantization algorithm.
|
// Quantization spec of an op, driving the quantization algorithm.
|
||||||
struct OpQuantSpec {
|
struct OpQuantSpec {
|
||||||
// Whether the op has quantizable result. This flag is set to false if the op
|
|
||||||
// has "TFL::NoQuantizableResult" trait.
|
|
||||||
bool is_quantizable = true;
|
|
||||||
|
|
||||||
// Whether it requires same inputs and result scale. This flag is set to true
|
|
||||||
// if the op has "TFL::SameOperandsAndResultScale" trait.
|
|
||||||
bool requires_same_scale = false;
|
|
||||||
|
|
||||||
// Maps the operand index of a bias input to its quantization specifications,
|
// Maps the operand index of a bias input to its quantization specifications,
|
||||||
// including the non-bias operand indexes and the method retrieving
|
// including the non-bias operand indexes and the method retrieving
|
||||||
// quantization parameters from list of parameters of the non-bias operands.
|
// quantization parameters from list of parameters of the non-bias operands.
|
||||||
|
@ -58,15 +58,6 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
|
|||||||
|
|
||||||
OUT(2) << "if (auto tfl = llvm::dyn_cast<" << op.getQualCppClassName()
|
OUT(2) << "if (auto tfl = llvm::dyn_cast<" << op.getQualCppClassName()
|
||||||
<< ">(op)) {\n";
|
<< ">(op)) {\n";
|
||||||
|
|
||||||
// There is a "NoQuantizableResult" trait, set the flag.
|
|
||||||
if (trait.equals("NoQuantizableResult")) {
|
|
||||||
OUT(4) << "spec->is_quantizable = false;\n";
|
|
||||||
}
|
|
||||||
// There is a "SameOperandsAndResultScale" trait, set the flag.
|
|
||||||
if (trait.equals("SameOperandsAndResultsScale")) {
|
|
||||||
OUT(4) << "spec->requires_same_scale = true;\n";
|
|
||||||
}
|
|
||||||
// There is a "FixedResultUniformScale" trait, set the type for result.
|
// There is a "FixedResultUniformScale" trait, set the type for result.
|
||||||
auto trait_str = opTrait->getTrait().str();
|
auto trait_str = opTrait->getTrait().str();
|
||||||
if (fixed_uniform_trait_regex.match(trait_str, &matches)) {
|
if (fixed_uniform_trait_regex.match(trait_str, &matches)) {
|
||||||
|
Loading…
Reference in New Issue
Block a user