Merge pull request #39031 from andrewstevens-infineon:IFX/PR_tfl_converter_QAT_fixes
PiperOrigin-RevId: 314248957 Change-Id: Ieafc9d85b6c24e0f19feb95661384ecfd527357a
This commit is contained in:
commit
ef41a8e100
File diff suppressed because it is too large
Load Diff
@ -180,7 +180,7 @@ int main(int argc, char **argv) {
|
||||
if (!module.ok()) return kTrFailure;
|
||||
|
||||
mlir::PassManager pm(&context);
|
||||
applyPassManagerCLOptions(pm);
|
||||
mlir::applyPassManagerCLOptions(pm);
|
||||
|
||||
// Set the quantization specifications from the command line flags.
|
||||
mlir::TFL::QuantizationSpecs quant_specs;
|
||||
|
||||
@ -206,6 +206,28 @@ DenseElementsAttr GetShape(Value output_val) {
|
||||
llvm::makeArrayRef(shape));
|
||||
}
|
||||
|
||||
static Type GetShapeStrippedType(TypeAttr type_attr) {
|
||||
auto type = type_attr.getValue();
|
||||
auto shaped_type = type.dyn_cast<ShapedType>();
|
||||
if (shaped_type) {
|
||||
return shaped_type.getElementType();
|
||||
} else {
|
||||
return type;
|
||||
}
|
||||
}
|
||||
|
||||
bool NotFromQuantOpDifferentQuant(Value val, TypeAttr qtype_attr) {
|
||||
auto val_defn_op = val.getDefiningOp();
|
||||
TFL::QuantizeOp q_op = llvm::dyn_cast_or_null<TFL::QuantizeOp>(val_defn_op);
|
||||
if (!q_op) return true;
|
||||
|
||||
// Ignore shape details - weŕe really only trying to
|
||||
// check if quantization is the same.
|
||||
auto stripped_src_qtype = GetShapeStrippedType(q_op.qtypeAttr());
|
||||
auto stripped_qtype = GetShapeStrippedType(qtype_attr);
|
||||
return stripped_src_qtype == stripped_qtype;
|
||||
}
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
|
||||
|
||||
// Fuse Add with proceeding FullyConnected.
|
||||
|
||||
@ -33,6 +33,10 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
class HasRankAtMost<int n> : Constraint<
|
||||
CPred<"$0.getType().cast<ShapedType>().getRank() <= " # n>>;
|
||||
|
||||
// Checks value is not produce by a TLF_QUant with
|
||||
// different quantization attribute
|
||||
def NotFromQuantOpDifferentQuant : Constraint<CPred<"NotFromQuantOpDifferentQuant($0,$1)">>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ternary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -164,7 +168,10 @@ foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in
|
||||
// This pattern applies when the same quantize/dequantize have been used twice
|
||||
// with the same scale. We want to remove the redundancy.
|
||||
// TODO(fengliuai): move this to the sanity check of pre-quantize pass.
|
||||
def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
|
||||
def eliminate_dq_q_pairs : Pat<
|
||||
(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt),
|
||||
(replaceWithValue $in),
|
||||
[(NotFromQuantOpDifferentQuant $in, $qt)]>;
|
||||
|
||||
|
||||
// Constraint that makes sure both operands are the same operands.
|
||||
|
||||
@ -210,6 +210,7 @@ void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) {
|
||||
// one is returned directly, we decide to return the quantized result instead,
|
||||
// so this op can be quantized. This is only applied on the returned result
|
||||
// because the error will not be accumulated.
|
||||
|
||||
func.walk([&](ReturnOp ret) {
|
||||
int i = 0;
|
||||
for (Value returned : ret.operands()) {
|
||||
@ -237,6 +238,51 @@ void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) {
|
||||
"Missing quantization parameter on the output might introduce "
|
||||
"quantization error!");
|
||||
});
|
||||
|
||||
// Check for (Quant (Dequant $in), $qA) "qdq" pairs that couldn't be
|
||||
// eliminated at this point. This only occurs for the pattern
|
||||
// (Quant (Dequant (Quant $in, $qB)), $qA) $qB != $qA
|
||||
// where the qdq pair denotes a non-trivial requantiziion of an
|
||||
// alreadyquantized value. Since this makes little sense (directly quantizing
|
||||
// (Quant $in, $qA) would introduce less quantization noise) the likley cause
|
||||
// is an minor error in constructing the original network model that
|
||||
// introduced back-to-back Fake Quantization operations. Hence: emit a
|
||||
// warning. N.b. at this point weŕe (teporarility) in the quantization dialect
|
||||
// (presuambly enalbe re-use in xla etc) quant::*QuantizeCastOp weŕe matching
|
||||
// here.
|
||||
//
|
||||
func.walk([&](quant::QuantizeCastOp q_op) {
|
||||
// If up with end up with
|
||||
auto dq_op = dyn_cast_or_null<quant::DequantizeCastOp>(
|
||||
q_op.getOperand().getDefiningOp());
|
||||
if (!dq_op) {
|
||||
return;
|
||||
}
|
||||
auto dq_arg = dq_op.getOperand();
|
||||
|
||||
if (!dq_arg.hasOneUse()) {
|
||||
// The initial quanization is used sompleace else ... so it might be
|
||||
// reasonable for it to requantized for another purpose.
|
||||
// TODO: ideally would want to still check whether requanization narrows
|
||||
// rather than widens the representation
|
||||
return;
|
||||
}
|
||||
|
||||
// Invariant:
|
||||
// isa<quant::QuantizeCastOp>(dq_arg.getDefiningOp()) -->
|
||||
// getdq_arg.getType() != q_op.getResult().getType()
|
||||
//
|
||||
// as otherwise qdq pair would have been optimized away.
|
||||
auto qd_arg_def_q_op =
|
||||
dyn_cast_or_null<quant::QuantizeCastOp>(dq_arg.getDefiningOp());
|
||||
if (!qd_arg_def_q_op) {
|
||||
return;
|
||||
}
|
||||
|
||||
qd_arg_def_q_op.emitWarning()
|
||||
<< " quantizer's output has another quantizer (" << q_op.getLoc()
|
||||
<< ") as consumer - intentional?";
|
||||
});
|
||||
}
|
||||
|
||||
using PrepareQuantStats =
|
||||
|
||||
@ -82,13 +82,48 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
|
||||
bool unfold_batch_matmul_;
|
||||
};
|
||||
|
||||
template <class TFFakeQuantOp>
|
||||
struct FetchConstantMinMaxInputs {
|
||||
using AttrType = DenseFPElementsAttr;
|
||||
bool operator()(TFFakeQuantOp tf_op, AttrType &min_value,
|
||||
AttrType &max_value) const {
|
||||
Value min = tf_op.min(), max = tf_op.max();
|
||||
|
||||
// TODO: incomplete neither IdentityN ops
|
||||
// nor chains of Identity* (not rare) are handled
|
||||
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp()))
|
||||
min = id1.input();
|
||||
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp()))
|
||||
max = id2.input();
|
||||
if (!matchPattern(min, m_Constant(&min_value))) {
|
||||
return false;
|
||||
}
|
||||
if (!matchPattern(max, m_Constant(&max_value))) {
|
||||
return false;
|
||||
}
|
||||
return true; // Succesfully matched and fetched.
|
||||
}
|
||||
};
|
||||
|
||||
template <class TFFakeQuantOp>
|
||||
struct FetchMinMaxAttrs {
|
||||
using AttrType = FloatAttr;
|
||||
bool operator()(TFFakeQuantOp tf_op, AttrType &min_value,
|
||||
AttrType &max_value) const {
|
||||
min_value = tf_op.minAttr();
|
||||
max_value = tf_op.maxAttr();
|
||||
return true; // Succesfully matched and fetched.
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(fengliuai): move this rule to PreparePatterns.td
|
||||
// TODO(fengliuai): reuse the quantization/tensorflow/tf_to_quant pass.
|
||||
// TODO(b/140968741): propagate the sign from the command line. Currently all
|
||||
// the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
|
||||
// actually INT8.
|
||||
// Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the
|
||||
// "tf.FakeQuantWithMinMaxVarsOp" to be constant folded. Since the constant
|
||||
// tf.FakeQyantWithMinMax{Vars|VarsPerChannel|Args}Op
|
||||
// to be constant folded. Since the constant
|
||||
// folding logic will use a "std.constant" op to replace the
|
||||
// "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
|
||||
// the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
|
||||
@ -112,33 +147,49 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
|
||||
// |
|
||||
// tf.dequantize
|
||||
// |
|
||||
template <typename TFFakeQuantOp, bool PerAxis>
|
||||
//
|
||||
//
|
||||
// Warns if the (most likely unwanted, currently not quite correctly handled)
|
||||
// case of back-to-back tf.FakeQuant occurs
|
||||
//
|
||||
// tf.FakeQuant*
|
||||
// |
|
||||
// tf.FakeQuant*
|
||||
//
|
||||
// tf.identity / tf.IdentityN between the tf.FakeQuant* ops
|
||||
// need no special treatment are already eliminated before the rewrites / check
|
||||
// is applied.
|
||||
//
|
||||
|
||||
template <typename TFFakeQuantOp, bool PerAxis, class FetchMinMax>
|
||||
struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
: public OpRewritePattern<TFFakeQuantOp> {
|
||||
using BaseType = InsertTFLQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>;
|
||||
using BaseType =
|
||||
InsertTFLQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis, FetchMinMax>;
|
||||
|
||||
explicit InsertTFLQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>(
|
||||
MLIRContext *ctx)
|
||||
explicit InsertTFLQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis,
|
||||
FetchMinMax>(MLIRContext *ctx)
|
||||
: OpRewritePattern<TFFakeQuantOp>(ctx) {}
|
||||
|
||||
FetchMinMax fetchMinMax;
|
||||
|
||||
using FetchAttrType = typename FetchMinMax::AttrType;
|
||||
LogicalResult matchAndRewrite(TFFakeQuantOp tf_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// We don't want to insert quantize/dequantize if the quantize op exists.
|
||||
auto res = tf_op.outputs();
|
||||
if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin()))
|
||||
if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin())) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Extract the min/max constant values from the operands. We also consider
|
||||
// a special case that there are tf.Identity ops between the min/max
|
||||
// constants and the tf.FakeQuantWithMinMaxVarsOp.
|
||||
Value min = tf_op.min(), max = tf_op.max();
|
||||
DenseFPElementsAttr min_value, max_value;
|
||||
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp()))
|
||||
min = id1.input();
|
||||
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp()))
|
||||
max = id2.input();
|
||||
if (!matchPattern(min, m_Constant(&min_value))) return failure();
|
||||
if (!matchPattern(max, m_Constant(&max_value))) return failure();
|
||||
|
||||
FetchAttrType min_value, max_value;
|
||||
if (!fetchMinMax(tf_op, min_value, max_value)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
int quant_dim = -1;
|
||||
if (PerAxis) {
|
||||
@ -155,7 +206,9 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
TypeAttr qtype = quant::GetQuantizedTypeAttr(
|
||||
rewriter, res_type, min_value, max_value, quant_dim, num_bits,
|
||||
narrow_range, /*is_signed=*/false);
|
||||
if (!qtype) failure();
|
||||
if (!qtype) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Finally, use the quantization parameter to create the quantize and
|
||||
// dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
|
||||
@ -172,12 +225,22 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
}
|
||||
};
|
||||
|
||||
using PreparePerTensorFakeQuant =
|
||||
InsertTFLQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsOp, false>;
|
||||
//
|
||||
// Three instances of the rule to cover the three different types of
|
||||
// TF::FakeQuant operators
|
||||
//
|
||||
using PreparePerTensorFakeQuant = InsertTFLQuantOpsAfterTFFakeQuantOp<
|
||||
TF::FakeQuantWithMinMaxVarsOp, /*PerAxis=*/false,
|
||||
FetchConstantMinMaxInputs<TF::FakeQuantWithMinMaxVarsOp>>;
|
||||
|
||||
using PreparePerChannelFakeQuant =
|
||||
InsertTFLQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsPerChannelOp,
|
||||
true>;
|
||||
using PreparePerChannelFakeQuant = InsertTFLQuantOpsAfterTFFakeQuantOp<
|
||||
TF::FakeQuantWithMinMaxVarsPerChannelOp, /*PerAxis=*/true,
|
||||
FetchConstantMinMaxInputs<TF::FakeQuantWithMinMaxVarsPerChannelOp>>;
|
||||
|
||||
using PreparePerTensorFakeQuantWithMinMaxArgs =
|
||||
InsertTFLQuantOpsAfterTFFakeQuantOp<
|
||||
TF::FakeQuantWithMinMaxArgsOp, /*PerAxis=*/false,
|
||||
FetchMinMaxAttrs<TF::FakeQuantWithMinMaxArgsOp>>;
|
||||
|
||||
// Templated class for declaring a converter from some TensorFlow convolution
|
||||
// op into its counterpart in TensorFlow Lite.
|
||||
@ -644,9 +707,10 @@ void PrepareTFPass::runOnFunction() {
|
||||
|
||||
// This pattern was intented to uses TFL QDQs to preserve the quantization
|
||||
// parameters from the TF Quant ops, thus this pattern should run with the
|
||||
// first `applyPatternsAndFoldGreedily` method, which would otherwise removes
|
||||
// the TF FakeQuant ops by the constant folding.
|
||||
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
||||
// first `applyPatternsGreedily` method, which would otherwise removes the
|
||||
// TF FakeQuant ops by the constant folding.
|
||||
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant,
|
||||
PreparePerTensorFakeQuantWithMinMaxArgs>(ctx);
|
||||
|
||||
// This pattern will try to identify and optimize for dilated convolution.
|
||||
// e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user