Set the calibrated quantization parameters to be volatile
The QDQ will be removed if there are not fused after quantization. PiperOrigin-RevId: 354638972 Change-Id: I1a0cd1514f302ec47d1f64fbba4232a354fc9fb6
This commit is contained in:
parent
23d9a2b49d
commit
c6de267c22
@ -152,6 +152,8 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
rewriter.setInsertionPointAfter(op.getOperation());
|
||||
Type result_type = quant_type.castFromExpressedType(op.getType());
|
||||
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg());
|
||||
q->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr());
|
||||
|
||||
auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
|
||||
op.getResult().replaceAllUsesWith(dq);
|
||||
q.getOperation()->replaceUsesOfWith(dq, op.arg());
|
||||
|
@ -112,7 +112,7 @@ func @QuantizeWithoutNorm(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {t
|
||||
// CHECK-SAME: %[[input_9]], %[[input_10]], %[[input_11]], %[[input_12]], %[[input_13]], %[[input_14]], %[[input_15]], %[[input_16]], %[[input_17]], %[[input_18]], %[[input_19]]
|
||||
// CHECK-SAME: effective_hidden_scale_intermediate = tensor<!quant.uniform<i8:f32, 0.0039215686274509803:-1>>
|
||||
|
||||
// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>}
|
||||
// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>, volatile}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeLstmCifg
|
||||
@ -194,7 +194,7 @@ func @QuantizeLstmCifg(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.e
|
||||
// CHECK-SAME: input_to_forget_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 4.8829615161595508E-4>>
|
||||
// CHECK-SAME: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>>
|
||||
|
||||
// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>}
|
||||
// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>, volatile}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeUnidirectionalLstmFull
|
||||
@ -285,7 +285,7 @@ func @QuantizeUnidirectionalLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> at
|
||||
// CHECK-SAME: input_to_input_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 9.7659230323191015E-4>>
|
||||
// CHECK-SAME: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>>
|
||||
|
||||
// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>}
|
||||
// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>, volatile}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeLstmFull
|
||||
@ -377,7 +377,7 @@ func @QuantizeLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.e
|
||||
// CHECK-SAME: input_to_input_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 9.7659230323191015E-4>>
|
||||
// CHECK-SAME: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>>
|
||||
|
||||
// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>}
|
||||
// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>, volatile}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeSVDF
|
||||
@ -398,7 +398,7 @@ func @QuantizeSVDF(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32> {
|
||||
// CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 1.3900876031311922E-5>>)
|
||||
// CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x4x!quant.uniform<i16<-32767:32767>:f32, 0.0037514108011770368>>)
|
||||
// CHECK: %[[svdf:.*]] = "tfl.svdf"(%[[input_0]], %[[input_1]], %[[input_2]], %[[input_3]], %[[input_4]])
|
||||
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[svdf]]) {qtype = tensor<1x2x!quant.uniform<i8:f32, 0.12954867493872549:-128>>}
|
||||
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[svdf]]) {qtype = tensor<1x2x!quant.uniform<i8:f32, 0.12954867493872549:-128>>, volatile}
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%11)
|
||||
// CHECK: return %[[dq]]
|
||||
}
|
||||
|
@ -49,9 +49,9 @@ func @prepareStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
return %1 : tensor<8x4x3xf32>
|
||||
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform<i8:f32, 0.0078431372549019607:-1>>}
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform<i8:f32, 0.0078431372549019607:-1>>, volatile}
|
||||
// CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]])
|
||||
// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform<i8:f32:2, {0.0078431372549019607:-1,0.062745098039215685:-1,0.0039215686274509803:-1}>>}
|
||||
// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform<i8:f32:2, {0.0078431372549019607:-1,0.062745098039215685:-1,0.0039215686274509803:-1}>>, volatile}
|
||||
// CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]])
|
||||
// CHECK: return %[[dq2]]
|
||||
}
|
||||
@ -71,9 +71,9 @@ func @prepareStatisticsNudge(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
return %1 : tensor<8x4x3xf32>
|
||||
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>}
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>, volatile}
|
||||
// CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]])
|
||||
// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform<i8:f32:2, {0.0078431372549019607:-1,0.031372549019607843:127,0.0039215686274509803:-1}>>}
|
||||
// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform<i8:f32:2, {0.0078431372549019607:-1,0.031372549019607843:127,0.0039215686274509803:-1}>>, volatile}
|
||||
// CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]])
|
||||
// CHECK: return %[[dq2]]
|
||||
}
|
||||
|
@ -48,9 +48,9 @@ func @prepareStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
|
||||
return %1 : tensor<8x4x3xf32>
|
||||
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform<u8:f32, 0.0078431372549019607:128>>, volatile}
|
||||
// CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]])
|
||||
// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform<u8:f32:2, {0.0078431372549019607:128,0.062745098039215685:128,0.0039215686274509803:128}>>}
|
||||
// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform<u8:f32:2, {0.0078431372549019607:128,0.062745098039215685:128,0.0039215686274509803:128}>>, volatile}
|
||||
// CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]])
|
||||
// CHECK: return %[[dq2]]
|
||||
}
|
||||
|
@ -132,6 +132,12 @@ struct RemoveVolatileOps : public OpRewritePattern<DequantizeOp> {
|
||||
if (auto q = llvm::dyn_cast_or_null<QuantizeOp>(input_op)) {
|
||||
if (!q->getAttr(mlir::quant::kVolatileOpAttrName)) return failure();
|
||||
|
||||
// Don't remove leading and tailing QDQ for PQT workflow, so the io
|
||||
// modifying lib can work correctly.
|
||||
if (!q.input().getDefiningOp()) return failure();
|
||||
if (op->hasOneUse() && op->user_begin()->isKnownTerminator())
|
||||
return failure();
|
||||
|
||||
op.replaceAllUsesWith(q.input());
|
||||
return success();
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user