diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 27ccc7d2b22..d4512509f6b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project @@ -35,6 +36,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" namespace mlir { @@ -363,6 +365,54 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { } }; +// Fold Extra Requantize ops if the preceding ops has free scale requirement. +template +struct FoldTrivalRequantizeOp : public OpRewritePattern { + explicit FoldTrivalRequantizeOp(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(RQ op, + PatternRewriter& rewriter) const override { + Value pre_quantized = op.input(); + auto pre_quantized_type = + quant::QuantizedType::getQuantizedElementType(pre_quantized.getType()); + if (!pre_quantized_type) return failure(); + + Operation* def = pre_quantized.getDefiningOp(); + if (!def) return failure(); + if (def->hasTrait() || + def->hasTrait()) { + return failure(); + } + + op.emitWarning("Remove trivial `rescale` op. Please fix the source graph."); + + llvm::SmallVector new_output_types; + for (auto result : def->getResults()) { + result.getUsers().begin()->dump(); + op.dump(); + if (result.hasOneUse() && *result.getUsers().begin() == op) { + new_output_types.push_back(op.qtype()); + } else { + new_output_types.push_back(result.getType()); + } + } + + // Remove this rescale op. + rewriter.replaceOp(op, {pre_quantized}); + + // Replace the output scale of the preceding op. + rewriter.setInsertionPointAfter(def); + OperationState new_state(def->getLoc(), def->getName().getStringRef(), + def->getOperands(), new_output_types, + def->getAttrs()); + Operation* new_op = rewriter.createOperation(new_state); + + rewriter.replaceOp(def, new_op->getResults()); + return success(); + } +}; + // Given a quantized type `input`, magnifying its scales by the factor stored in // `factor`. If `input` isn't a quantized type or the `factor` doesn't match the // dimension size of `input` or isn't floating-point, nullptr will be returned. diff --git a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir index 5377c4fdb98..6573a2f1c36 100644 --- a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir @@ -19,6 +19,16 @@ func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor) -> (tensor<2xf32>,t // CHECK-NEXT: return %[[split]]#0, %[[split]]#1 } +// CHECK-LABEL: RemoveTrival +func @RemoveTrival(%arg0: tensor<384x512x!quant.uniform>, %arg1: tensor<128x512x!quant.uniform:f32, 1.0>>, %arg2: none) -> tensor<384x128x!quant.uniform> { + %1 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<384x512x!quant.uniform>, tensor<128x512x!quant.uniform:f32, 1.0>>, none) -> tensor<384x128x!quant.uniform> + %2 = "tfl.quantize"(%1) {qtype = tensor<384x128x!quant.uniform>} : (tensor<384x128x!quant.uniform>) -> tensor<384x128x!quant.uniform> + return %2 : tensor<384x128x!quant.uniform> + +// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"{{.*}} -> tensor<384x128x!quant.uniform> +// CHECK-NEXT: return %[[fc]] +} + func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> { %cst = constant dense<[1, 1001]> : tensor<2xi32> %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 97b7d57dbf4..7954f72046a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -125,6 +125,7 @@ void PostQuantizePass::runOnFunction() { auto func = getFunction(); auto* ctx = func.getContext(); TFL::populateWithGenerated(ctx, &patterns); + patterns.insert>(ctx); applyPatternsAndFoldGreedily(func, patterns); if (!emit_quant_adaptor_ops_) {