Set the calibrated quantization parameters to be volatile
The QDQ will be removed if there are not fused after quantization. PiperOrigin-RevId: 354638972 Change-Id: I1a0cd1514f302ec47d1f64fbba4232a354fc9fb6
This commit is contained in:
		
							parent
							
								
									23d9a2b49d
								
							
						
					
					
						commit
						c6de267c22
					
				| @ -152,6 +152,8 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> { | |||||||
|     rewriter.setInsertionPointAfter(op.getOperation()); |     rewriter.setInsertionPointAfter(op.getOperation()); | ||||||
|     Type result_type = quant_type.castFromExpressedType(op.getType()); |     Type result_type = quant_type.castFromExpressedType(op.getType()); | ||||||
|     auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg()); |     auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg()); | ||||||
|  |     q->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr()); | ||||||
|  | 
 | ||||||
|     auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q); |     auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q); | ||||||
|     op.getResult().replaceAllUsesWith(dq); |     op.getResult().replaceAllUsesWith(dq); | ||||||
|     q.getOperation()->replaceUsesOfWith(dq, op.arg()); |     q.getOperation()->replaceUsesOfWith(dq, op.arg()); | ||||||
|  | |||||||
| @ -112,7 +112,7 @@ func @QuantizeWithoutNorm(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {t | |||||||
| // CHECK-SAME: %[[input_9]], %[[input_10]], %[[input_11]], %[[input_12]], %[[input_13]], %[[input_14]], %[[input_15]], %[[input_16]], %[[input_17]], %[[input_18]], %[[input_19]] | // CHECK-SAME: %[[input_9]], %[[input_10]], %[[input_11]], %[[input_12]], %[[input_13]], %[[input_14]], %[[input_15]], %[[input_16]], %[[input_17]], %[[input_18]], %[[input_19]] | ||||||
| // CHECK-SAME: effective_hidden_scale_intermediate = tensor<!quant.uniform<i8:f32, 0.0039215686274509803:-1>> | // CHECK-SAME: effective_hidden_scale_intermediate = tensor<!quant.uniform<i8:f32, 0.0039215686274509803:-1>> | ||||||
| 
 | 
 | ||||||
| // CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>} | // CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>, volatile} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // CHECK-LABEL: QuantizeLstmCifg | // CHECK-LABEL: QuantizeLstmCifg | ||||||
| @ -194,7 +194,7 @@ func @QuantizeLstmCifg(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.e | |||||||
| // CHECK-SAME: input_to_forget_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 4.8829615161595508E-4>> | // CHECK-SAME: input_to_forget_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 4.8829615161595508E-4>> | ||||||
| // CHECK-SAME: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>> | // CHECK-SAME: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>> | ||||||
| 
 | 
 | ||||||
| // CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>} | // CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>, volatile} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // CHECK-LABEL: QuantizeUnidirectionalLstmFull | // CHECK-LABEL: QuantizeUnidirectionalLstmFull | ||||||
| @ -285,7 +285,7 @@ func @QuantizeUnidirectionalLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> at | |||||||
| // CHECK-SAME: input_to_input_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 9.7659230323191015E-4>> | // CHECK-SAME: input_to_input_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 9.7659230323191015E-4>> | ||||||
| // CHECK-SAME: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>> | // CHECK-SAME: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>> | ||||||
| 
 | 
 | ||||||
| // CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>} | // CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>, volatile} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // CHECK-LABEL: QuantizeLstmFull | // CHECK-LABEL: QuantizeLstmFull | ||||||
| @ -377,7 +377,7 @@ func @QuantizeLstmFull(%arg0: tensor<1x5xf32>) -> tensor<*xf32> attributes {tf.e | |||||||
| // CHECK-SAME: input_to_input_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 9.7659230323191015E-4>> | // CHECK-SAME: input_to_input_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 9.7659230323191015E-4>> | ||||||
| // CHECK-SAME: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>> | // CHECK-SAME: input_to_output_intermediate = tensor<!quant.uniform<i16<-32767:32767>:f32, 3.0518509475997192E-5>> | ||||||
| 
 | 
 | ||||||
| // CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>} | // CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<*x!quant.uniform<i8:f32, 0.015686274509803921:-1>>, volatile} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // CHECK-LABEL: QuantizeSVDF | // CHECK-LABEL: QuantizeSVDF | ||||||
| @ -398,7 +398,7 @@ func @QuantizeSVDF(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32>  { | |||||||
| // CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 1.3900876031311922E-5>>) | // CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x!quant.uniform<i32:f32, 1.3900876031311922E-5>>) | ||||||
| // CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x4x!quant.uniform<i16<-32767:32767>:f32, 0.0037514108011770368>>) | // CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x4x!quant.uniform<i16<-32767:32767>:f32, 0.0037514108011770368>>) | ||||||
| // CHECK: %[[svdf:.*]] = "tfl.svdf"(%[[input_0]], %[[input_1]], %[[input_2]], %[[input_3]], %[[input_4]]) | // CHECK: %[[svdf:.*]] = "tfl.svdf"(%[[input_0]], %[[input_1]], %[[input_2]], %[[input_3]], %[[input_4]]) | ||||||
| // CHECK: %[[q:.*]] = "tfl.quantize"(%[[svdf]]) {qtype = tensor<1x2x!quant.uniform<i8:f32, 0.12954867493872549:-128>>} | // CHECK: %[[q:.*]] = "tfl.quantize"(%[[svdf]]) {qtype = tensor<1x2x!quant.uniform<i8:f32, 0.12954867493872549:-128>>, volatile} | ||||||
| // CHECK: %[[dq:.*]] = "tfl.dequantize"(%11) | // CHECK: %[[dq:.*]] = "tfl.dequantize"(%11) | ||||||
| // CHECK: return %[[dq]] | // CHECK: return %[[dq]] | ||||||
| } | } | ||||||
|  | |||||||
| @ -49,9 +49,9 @@ func @prepareStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { | |||||||
|   } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> |   } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> | ||||||
|   return %1 : tensor<8x4x3xf32> |   return %1 : tensor<8x4x3xf32> | ||||||
| 
 | 
 | ||||||
| // CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform<i8:f32, 0.0078431372549019607:-1>>} | // CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform<i8:f32, 0.0078431372549019607:-1>>, volatile} | ||||||
| // CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) | // CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) | ||||||
| // CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform<i8:f32:2, {0.0078431372549019607:-1,0.062745098039215685:-1,0.0039215686274509803:-1}>>} | // CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform<i8:f32:2, {0.0078431372549019607:-1,0.062745098039215685:-1,0.0039215686274509803:-1}>>, volatile} | ||||||
| // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) | // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) | ||||||
| // CHECK: return %[[dq2]] | // CHECK: return %[[dq2]] | ||||||
| } | } | ||||||
| @ -71,9 +71,9 @@ func @prepareStatisticsNudge(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { | |||||||
|   } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> |   } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> | ||||||
|   return %1 : tensor<8x4x3xf32> |   return %1 : tensor<8x4x3xf32> | ||||||
| 
 | 
 | ||||||
| // CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>} | // CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>, volatile} | ||||||
| // CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) | // CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) | ||||||
| // CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform<i8:f32:2, {0.0078431372549019607:-1,0.031372549019607843:127,0.0039215686274509803:-1}>>} | // CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform<i8:f32:2, {0.0078431372549019607:-1,0.031372549019607843:127,0.0039215686274509803:-1}>>, volatile} | ||||||
| // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) | // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) | ||||||
| // CHECK: return %[[dq2]] | // CHECK: return %[[dq2]] | ||||||
| } | } | ||||||
|  | |||||||
| @ -48,9 +48,9 @@ func @prepareStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { | |||||||
|   } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> |   } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> | ||||||
|   return %1 : tensor<8x4x3xf32> |   return %1 : tensor<8x4x3xf32> | ||||||
| 
 | 
 | ||||||
| // CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform<u8:f32, 0.0078431372549019607:128>>} | // CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform<u8:f32, 0.0078431372549019607:128>>, volatile} | ||||||
| // CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) | // CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) | ||||||
| // CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform<u8:f32:2, {0.0078431372549019607:128,0.062745098039215685:128,0.0039215686274509803:128}>>} | // CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform<u8:f32:2, {0.0078431372549019607:128,0.062745098039215685:128,0.0039215686274509803:128}>>, volatile} | ||||||
| // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) | // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) | ||||||
| // CHECK: return %[[dq2]] | // CHECK: return %[[dq2]] | ||||||
| } | } | ||||||
|  | |||||||
| @ -132,6 +132,12 @@ struct RemoveVolatileOps : public OpRewritePattern<DequantizeOp> { | |||||||
|     if (auto q = llvm::dyn_cast_or_null<QuantizeOp>(input_op)) { |     if (auto q = llvm::dyn_cast_or_null<QuantizeOp>(input_op)) { | ||||||
|       if (!q->getAttr(mlir::quant::kVolatileOpAttrName)) return failure(); |       if (!q->getAttr(mlir::quant::kVolatileOpAttrName)) return failure(); | ||||||
| 
 | 
 | ||||||
|  |       // Don't remove leading and tailing QDQ for PQT workflow, so the io
 | ||||||
|  |       // modifying lib can work correctly.
 | ||||||
|  |       if (!q.input().getDefiningOp()) return failure(); | ||||||
|  |       if (op->hasOneUse() && op->user_begin()->isKnownTerminator()) | ||||||
|  |         return failure(); | ||||||
|  | 
 | ||||||
|       op.replaceAllUsesWith(q.input()); |       op.replaceAllUsesWith(q.input()); | ||||||
|       return success(); |       return success(); | ||||||
|     } |     } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user