From c6de267c224890673c1e9300ed718834c335830a Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Fri, 29 Jan 2021 17:02:48 -0800 Subject: [PATCH] Set the calibrated quantization parameters to be volatile The QDQ will be removed if there are not fused after quantization. PiperOrigin-RevId: 354638972 Change-Id: I1a0cd1514f302ec47d1f64fbba4232a354fc9fb6 --- .../mlir/lite/quantization/quantization_utils.h | 2 ++ .../lite/tests/prepare-quantize-post-training.mlir | 10 +++++----- .../mlir/lite/tests/prepare-quantize-signed.mlir | 8 ++++---- .../compiler/mlir/lite/tests/prepare-quantize.mlir | 4 ++-- .../compiler/mlir/lite/transforms/post_quantize.cc | 6 ++++++ 5 files changed, 19 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 1d9c0f06178..01356a35b16 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -152,6 +152,8 @@ struct ConvertStatsToQDQs : public OpRewritePattern { rewriter.setInsertionPointAfter(op.getOperation()); Type result_type = quant_type.castFromExpressedType(op.getType()); auto q = rewriter.create(op.getLoc(), result_type, op.arg()); + q->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr()); + auto dq = rewriter.create(op.getLoc(), op.getType(), q); op.getResult().replaceAllUsesWith(dq); q.getOperation()->replaceUsesOfWith(dq, op.arg()); diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training.mlir index 46a9876cd03..a79b139639d 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training.mlir @@ -112,7 +112,7 @@ func @QuantizeWithoutNorm(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {t // CHECK-SAME: %[[input_9]], %[[input_10]], %[[input_11]], %[[input_12]], %[[input_13]], %[[input_14]], %[[input_15]], %[[input_16]], %[[input_17]], %[[input_18]], %[[input_19]] // CHECK-SAME: effective_hidden_scale_intermediate = tensor> -// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>} +// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>, volatile} } // CHECK-LABEL: QuantizeLstmCifg @@ -194,7 +194,7 @@ func @QuantizeLstmCifg(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.e // CHECK-SAME: input_to_forget_intermediate = tensor:f32, 4.8829615161595508E-4>> // CHECK-SAME: input_to_output_intermediate = tensor:f32, 3.0518509475997192E-5>> -// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>} +// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>, volatile} } // CHECK-LABEL: QuantizeUnidirectionalLstmFull @@ -285,7 +285,7 @@ func @QuantizeUnidirectionalLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> at // CHECK-SAME: input_to_input_intermediate = tensor:f32, 9.7659230323191015E-4>> // CHECK-SAME: input_to_output_intermediate = tensor:f32, 3.0518509475997192E-5>> -// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>} +// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>, volatile} } // CHECK-LABEL: QuantizeLstmFull @@ -377,7 +377,7 @@ func @QuantizeLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.e // CHECK-SAME: input_to_input_intermediate = tensor:f32, 9.7659230323191015E-4>> // CHECK-SAME: input_to_output_intermediate = tensor:f32, 3.0518509475997192E-5>> -// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>} +// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform>, volatile} } // CHECK-LABEL: QuantizeSVDF @@ -398,7 +398,7 @@ func @QuantizeSVDF(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32> { // CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform>) // CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x4x!quant.uniform:f32, 0.0037514108011770368>>) // CHECK: %[[svdf:.*]] = "tfl.svdf"(%[[input_0]], %[[input_1]], %[[input_2]], %[[input_3]], %[[input_4]]) -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[svdf]]) {qtype = tensor<1x2x!quant.uniform>} +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[svdf]]) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%11) // CHECK: return %[[dq]] } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir index 6288bd1213c..23c837f9dd0 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir @@ -49,9 +49,9 @@ func @prepareStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %1 : tensor<8x4x3xf32> -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform>} +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform>, volatile} // CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) -// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform>} +// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform>, volatile} // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) // CHECK: return %[[dq2]] } @@ -71,9 +71,9 @@ func @prepareStatisticsNudge(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %1 : tensor<8x4x3xf32> -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform>} +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform>, volatile} // CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) -// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform>} +// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform>, volatile} // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) // CHECK: return %[[dq2]] } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index c134192f8d2..34db4cf7028 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -48,9 +48,9 @@ func @prepareStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %1 : tensor<8x4x3xf32> -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform>} +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform>, volatile} // CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) -// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform>} +// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform>, volatile} // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) // CHECK: return %[[dq2]] } diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 8ee5556afb1..649c4332a7c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -132,6 +132,12 @@ struct RemoveVolatileOps : public OpRewritePattern { if (auto q = llvm::dyn_cast_or_null(input_op)) { if (!q->getAttr(mlir::quant::kVolatileOpAttrName)) return failure(); + // Don't remove leading and tailing QDQ for PQT workflow, so the io + // modifying lib can work correctly. + if (!q.input().getDefiningOp()) return failure(); + if (op->hasOneUse() && op->user_begin()->isKnownTerminator()) + return failure(); + op.replaceAllUsesWith(q.input()); return success(); }