NFC: Uses hasTrait interface to query the op quantization spec

Two of the spec can be migrated directly. There are two other op traits for
quantization spec have more information and couldn't be migrated to hasTrait
interface.

PiperOrigin-RevId: 269724215
This commit is contained in:
Feng Liu 2019-09-17 22:01:28 -07:00 committed by TensorFlower Gardener
parent 8ab4a0a7bb
commit d30ba84437
4 changed files with 9 additions and 31 deletions

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/core/platform/logging.h"
@ -565,7 +566,7 @@ void QuantizationDriver::PreprocessConstantOps() {
// The user doesn't use this value as a bias operand or require same
// scale, then this constant is considered to be a weight.
if (biases.find(operand_num) == biases.end() &&
!spec->requires_same_scale) {
!user->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>()) {
used_as_weight = true;
} else {
bias_users.push_back({user, operand_num});
@ -593,8 +594,9 @@ void QuantizationDriver::SetupAllStates() {
llvm::DenseMap<Value *, int> value_to_state;
fn_.walk([&](Operation *op) {
if (op->isKnownTerminator()) return;
if (!GetQuantSpec(op)->is_quantizable) return;
if (op->isKnownTerminator() ||
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
return;
work_list_.push_back(op);
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
@ -653,12 +655,6 @@ bool QuantizationDriver::PropagateParams() {
if (llvm::is_contained(quantized_, op)) continue;
quantized_.insert(op);
auto spec = GetQuantSpec(op);
// If the op has no quantizable result, the quantization parameters will not
// be propagated to the results.
if (!spec->is_quantizable) continue;
if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
// If it isn't a weight or has been quantized, skip.
if (!IsWeight(cst) || IsQuantized(op)) continue;
@ -669,7 +665,7 @@ bool QuantizationDriver::PropagateParams() {
continue;
}
if (spec->requires_same_scale) {
if (op->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>()) {
auto params = GetQuantParamsForSameScaleConstraint(op);
// The quantization parameters haven't been propagated to any operands or
// results. Skip this node for now.
@ -688,6 +684,7 @@ bool QuantizationDriver::PropagateParams() {
}
// TODO(fengliuai): make the bit width configurable.
auto spec = GetQuantSpec(op);
auto key = std::make_pair(8, is_signed_);
auto &restricted_outputs = spec->restricted_output_params[key];
for (int i = 0, e = restricted_outputs.size(); i != e; ++i) {

View File

@ -114,10 +114,7 @@ class AccumulatorUniformScale {
//
template <typename ConcreteType>
class NoQuantizableResult
: public QuantizationSpecTraitBase<ConcreteType, NoQuantizableResult> {
public:
static bool IsQuantizable() { return false; }
};
: public QuantizationSpecTraitBase<ConcreteType, NoQuantizableResult> {};
} // namespace quant
} // namespace OpTrait

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
namespace mlir {
namespace TFL {
@ -40,14 +41,6 @@ using AccumulatorScaleFunc =
// Quantization spec of an op, driving the quantization algorithm.
struct OpQuantSpec {
// Whether the op has quantizable result. This flag is set to false if the op
// has "TFL::NoQuantizableResult" trait.
bool is_quantizable = true;
// Whether it requires same inputs and result scale. This flag is set to true
// if the op has "TFL::SameOperandsAndResultScale" trait.
bool requires_same_scale = false;
// Maps the operand index of a bias input to its quantization specifications,
// including the non-bias operand indexes and the method retrieving
// quantization parameters from list of parameters of the non-bias operands.

View File

@ -58,15 +58,6 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
OUT(2) << "if (auto tfl = llvm::dyn_cast<" << op.getQualCppClassName()
<< ">(op)) {\n";
// There is a "NoQuantizableResult" trait, set the flag.
if (trait.equals("NoQuantizableResult")) {
OUT(4) << "spec->is_quantizable = false;\n";
}
// There is a "SameOperandsAndResultScale" trait, set the flag.
if (trait.equals("SameOperandsAndResultsScale")) {
OUT(4) << "spec->requires_same_scale = true;\n";
}
// There is a "FixedResultUniformScale" trait, set the type for result.
auto trait_str = opTrait->getTrait().str();
if (fixed_uniform_trait_regex.match(trait_str, &matches)) {