Merge pull request #39031 from andrewstevens-infineon:IFX/PR_tfl_converter_QAT_fixes

PiperOrigin-RevId: 314248957
Change-Id: Ieafc9d85b6c24e0f19feb95661384ecfd527357a
This commit is contained in:
TensorFlower Gardener 2020-06-01 19:12:33 -07:00
commit ef41a8e100
6 changed files with 1350 additions and 25 deletions

File diff suppressed because it is too large Load Diff

View File

@ -180,7 +180,7 @@ int main(int argc, char **argv) {
if (!module.ok()) return kTrFailure;
mlir::PassManager pm(&context);
applyPassManagerCLOptions(pm);
mlir::applyPassManagerCLOptions(pm);
// Set the quantization specifications from the command line flags.
mlir::TFL::QuantizationSpecs quant_specs;

View File

@ -206,6 +206,28 @@ DenseElementsAttr GetShape(Value output_val) {
llvm::makeArrayRef(shape));
}
static Type GetShapeStrippedType(TypeAttr type_attr) {
auto type = type_attr.getValue();
auto shaped_type = type.dyn_cast<ShapedType>();
if (shaped_type) {
return shaped_type.getElementType();
} else {
return type;
}
}
bool NotFromQuantOpDifferentQuant(Value val, TypeAttr qtype_attr) {
auto val_defn_op = val.getDefiningOp();
TFL::QuantizeOp q_op = llvm::dyn_cast_or_null<TFL::QuantizeOp>(val_defn_op);
if (!q_op) return true;
// Ignore shape details - weŕe really only trying to
// check if quantization is the same.
auto stripped_src_qtype = GetShapeStrippedType(q_op.qtypeAttr());
auto stripped_qtype = GetShapeStrippedType(qtype_attr);
return stripped_src_qtype == stripped_qtype;
}
#include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
// Fuse Add with proceeding FullyConnected.

View File

@ -33,6 +33,10 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
class HasRankAtMost<int n> : Constraint<
CPred<"$0.getType().cast<ShapedType>().getRank() <= " # n>>;
// Checks value is not produce by a TLF_QUant with
// different quantization attribute
def NotFromQuantOpDifferentQuant : Constraint<CPred<"NotFromQuantOpDifferentQuant($0,$1)">>;
//===----------------------------------------------------------------------===//
// Ternary ops patterns.
//===----------------------------------------------------------------------===//
@ -164,7 +168,10 @@ foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in
// This pattern applies when the same quantize/dequantize have been used twice
// with the same scale. We want to remove the redundancy.
// TODO(fengliuai): move this to the sanity check of pre-quantize pass.
def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
def eliminate_dq_q_pairs : Pat<
(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt),
(replaceWithValue $in),
[(NotFromQuantOpDifferentQuant $in, $qt)]>;
// Constraint that makes sure both operands are the same operands.

View File

@ -210,6 +210,7 @@ void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) {
// one is returned directly, we decide to return the quantized result instead,
// so this op can be quantized. This is only applied on the returned result
// because the error will not be accumulated.
func.walk([&](ReturnOp ret) {
int i = 0;
for (Value returned : ret.operands()) {
@ -237,6 +238,51 @@ void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) {
"Missing quantization parameter on the output might introduce "
"quantization error!");
});
// Check for (Quant (Dequant $in), $qA) "qdq" pairs that couldn't be
// eliminated at this point. This only occurs for the pattern
// (Quant (Dequant (Quant $in, $qB)), $qA) $qB != $qA
// where the qdq pair denotes a non-trivial requantiziion of an
// alreadyquantized value. Since this makes little sense (directly quantizing
// (Quant $in, $qA) would introduce less quantization noise) the likley cause
// is an minor error in constructing the original network model that
// introduced back-to-back Fake Quantization operations. Hence: emit a
// warning. N.b. at this point weŕe (teporarility) in the quantization dialect
// (presuambly enalbe re-use in xla etc) quant::*QuantizeCastOp weŕe matching
// here.
//
func.walk([&](quant::QuantizeCastOp q_op) {
// If up with end up with
auto dq_op = dyn_cast_or_null<quant::DequantizeCastOp>(
q_op.getOperand().getDefiningOp());
if (!dq_op) {
return;
}
auto dq_arg = dq_op.getOperand();
if (!dq_arg.hasOneUse()) {
// The initial quanization is used sompleace else ... so it might be
// reasonable for it to requantized for another purpose.
// TODO: ideally would want to still check whether requanization narrows
// rather than widens the representation
return;
}
// Invariant:
// isa<quant::QuantizeCastOp>(dq_arg.getDefiningOp()) -->
// getdq_arg.getType() != q_op.getResult().getType()
//
// as otherwise qdq pair would have been optimized away.
auto qd_arg_def_q_op =
dyn_cast_or_null<quant::QuantizeCastOp>(dq_arg.getDefiningOp());
if (!qd_arg_def_q_op) {
return;
}
qd_arg_def_q_op.emitWarning()
<< " quantizer's output has another quantizer (" << q_op.getLoc()
<< ") as consumer - intentional?";
});
}
using PrepareQuantStats =

View File

@ -82,13 +82,48 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
bool unfold_batch_matmul_;
};
template <class TFFakeQuantOp>
struct FetchConstantMinMaxInputs {
using AttrType = DenseFPElementsAttr;
bool operator()(TFFakeQuantOp tf_op, AttrType &min_value,
AttrType &max_value) const {
Value min = tf_op.min(), max = tf_op.max();
// TODO: incomplete neither IdentityN ops
// nor chains of Identity* (not rare) are handled
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp()))
min = id1.input();
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp()))
max = id2.input();
if (!matchPattern(min, m_Constant(&min_value))) {
return false;
}
if (!matchPattern(max, m_Constant(&max_value))) {
return false;
}
return true; // Succesfully matched and fetched.
}
};
template <class TFFakeQuantOp>
struct FetchMinMaxAttrs {
using AttrType = FloatAttr;
bool operator()(TFFakeQuantOp tf_op, AttrType &min_value,
AttrType &max_value) const {
min_value = tf_op.minAttr();
max_value = tf_op.maxAttr();
return true; // Succesfully matched and fetched.
}
};
// TODO(fengliuai): move this rule to PreparePatterns.td
// TODO(fengliuai): reuse the quantization/tensorflow/tf_to_quant pass.
// TODO(b/140968741): propagate the sign from the command line. Currently all
// the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
// actually INT8.
// Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the
// "tf.FakeQuantWithMinMaxVarsOp" to be constant folded. Since the constant
// tf.FakeQyantWithMinMax{Vars|VarsPerChannel|Args}Op
// to be constant folded. Since the constant
// folding logic will use a "std.constant" op to replace the
// "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
// the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
@ -112,33 +147,49 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
// |
// tf.dequantize
// |
template <typename TFFakeQuantOp, bool PerAxis>
//
//
// Warns if the (most likely unwanted, currently not quite correctly handled)
// case of back-to-back tf.FakeQuant occurs
//
// tf.FakeQuant*
// |
// tf.FakeQuant*
//
// tf.identity / tf.IdentityN between the tf.FakeQuant* ops
// need no special treatment are already eliminated before the rewrites / check
// is applied.
//
template <typename TFFakeQuantOp, bool PerAxis, class FetchMinMax>
struct InsertTFLQuantOpsAfterTFFakeQuantOp
: public OpRewritePattern<TFFakeQuantOp> {
using BaseType = InsertTFLQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>;
using BaseType =
InsertTFLQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis, FetchMinMax>;
explicit InsertTFLQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>(
MLIRContext *ctx)
explicit InsertTFLQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis,
FetchMinMax>(MLIRContext *ctx)
: OpRewritePattern<TFFakeQuantOp>(ctx) {}
FetchMinMax fetchMinMax;
using FetchAttrType = typename FetchMinMax::AttrType;
LogicalResult matchAndRewrite(TFFakeQuantOp tf_op,
PatternRewriter &rewriter) const override {
// We don't want to insert quantize/dequantize if the quantize op exists.
auto res = tf_op.outputs();
if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin()))
if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin())) {
return failure();
}
// Extract the min/max constant values from the operands. We also consider
// a special case that there are tf.Identity ops between the min/max
// constants and the tf.FakeQuantWithMinMaxVarsOp.
Value min = tf_op.min(), max = tf_op.max();
DenseFPElementsAttr min_value, max_value;
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp()))
min = id1.input();
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp()))
max = id2.input();
if (!matchPattern(min, m_Constant(&min_value))) return failure();
if (!matchPattern(max, m_Constant(&max_value))) return failure();
FetchAttrType min_value, max_value;
if (!fetchMinMax(tf_op, min_value, max_value)) {
return failure();
}
int quant_dim = -1;
if (PerAxis) {
@ -155,7 +206,9 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
TypeAttr qtype = quant::GetQuantizedTypeAttr(
rewriter, res_type, min_value, max_value, quant_dim, num_bits,
narrow_range, /*is_signed=*/false);
if (!qtype) failure();
if (!qtype) {
return failure();
}
// Finally, use the quantization parameter to create the quantize and
// dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
@ -172,12 +225,22 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
}
};
using PreparePerTensorFakeQuant =
InsertTFLQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsOp, false>;
//
// Three instances of the rule to cover the three different types of
// TF::FakeQuant operators
//
using PreparePerTensorFakeQuant = InsertTFLQuantOpsAfterTFFakeQuantOp<
TF::FakeQuantWithMinMaxVarsOp, /*PerAxis=*/false,
FetchConstantMinMaxInputs<TF::FakeQuantWithMinMaxVarsOp>>;
using PreparePerChannelFakeQuant =
InsertTFLQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsPerChannelOp,
true>;
using PreparePerChannelFakeQuant = InsertTFLQuantOpsAfterTFFakeQuantOp<
TF::FakeQuantWithMinMaxVarsPerChannelOp, /*PerAxis=*/true,
FetchConstantMinMaxInputs<TF::FakeQuantWithMinMaxVarsPerChannelOp>>;
using PreparePerTensorFakeQuantWithMinMaxArgs =
InsertTFLQuantOpsAfterTFFakeQuantOp<
TF::FakeQuantWithMinMaxArgsOp, /*PerAxis=*/false,
FetchMinMaxAttrs<TF::FakeQuantWithMinMaxArgsOp>>;
// Templated class for declaring a converter from some TensorFlow convolution
// op into its counterpart in TensorFlow Lite.
@ -644,9 +707,10 @@ void PrepareTFPass::runOnFunction() {
// This pattern was intented to uses TFL QDQs to preserve the quantization
// parameters from the TF Quant ops, thus this pattern should run with the
// first `applyPatternsAndFoldGreedily` method, which would otherwise removes
// the TF FakeQuant ops by the constant folding.
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
// first `applyPatternsGreedily` method, which would otherwise removes the
// TF FakeQuant ops by the constant folding.
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant,
PreparePerTensorFakeQuantWithMinMaxArgs>(ctx);
// This pattern will try to identify and optimize for dilated convolution.
// e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be