Fixed a bug where RemoveQuantizationAdapterOps would remove a dequantize which fed a return even if that op had other uses.
PiperOrigin-RevId: 319074130 Change-Id: I2b682bcfdfdc4b496711da3ab53bea5975cc8c54
This commit is contained in:
parent
78e1d0f299
commit
4e004c5e4a
@ -63,3 +63,17 @@ func @main2(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
// CHECK-NEXT: %[[add:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
// CHECK-NEXT: return %[[add]] : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
// CHECK-NEXT:}
|
||||
|
||||
// CHECK-LABEL: HandleReturnedDequantizeWithAnotherUse
|
||||
func @HandleReturnedDequantizeWithAnotherUse(%arg0: tensor<128x16xf32>) -> (tensor<128x16xf32>, tensor<128xi32>) {
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant dense<1> : tensor<i32>
|
||||
%cst = constant dense<1> : tensor<i32>
|
||||
// CHECK-NEXT: %[[softmax:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<128x16xf32>) -> tensor<128x16xf32>
|
||||
%0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<128x16xf32>) -> tensor<128x16xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<128x16x!quant.uniform<u8:f32, 3.906250e-03>>, volatile} : (tensor<128x16xf32>) -> tensor<128x16x!quant.uniform<u8:f32, 3.906250e-03>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<128x16x!quant.uniform<u8:f32, 3.906250e-03>>) -> tensor<128x16xf32>
|
||||
// CHECK-NEXT: %[[argmax:.*]] = "tfl.arg_max"(%[[softmax]], %[[cst]]) : (tensor<128x16xf32>, tensor<i32>) -> tensor<128xi32>
|
||||
%3 = "tfl.arg_max"(%2, %cst) : (tensor<128x16xf32>, tensor<i32>) -> tensor<128xi32>
|
||||
// CHECK-NEXT: return %[[softmax]], %[[argmax]] : tensor<128x16xf32>, tensor<128xi32>
|
||||
return %2, %3 : tensor<128x16xf32>, tensor<128xi32>
|
||||
}
|
||||
|
@ -53,7 +53,6 @@ class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
|
||||
void RemoveQuantizationAdaptorOps(FuncOp func) {
|
||||
mlir::OpBuilder builder(func.getBody());
|
||||
auto& bb = func.front();
|
||||
auto* terminator = bb.getTerminator();
|
||||
|
||||
int num_args = bb.getNumArguments();
|
||||
llvm::SmallVector<Type, 4> input_types;
|
||||
@ -99,13 +98,15 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
|
||||
}
|
||||
|
||||
// Edit the return ops and remove the dequantize ops in place.
|
||||
auto* terminator = bb.getTerminator();
|
||||
int num_return_operands = terminator->getNumOperands();
|
||||
llvm::SmallVector<Type, 4> output_types;
|
||||
output_types.reserve(num_return_operands);
|
||||
for (int i = 0; i != num_return_operands; ++i) {
|
||||
auto returned_value = terminator->getOperand(i);
|
||||
Operation* returned_op = returned_value.getDefiningOp();
|
||||
if (returned_op && llvm::isa<DequantizeOp>(returned_op)) {
|
||||
if (returned_op && returned_op->hasOneUse() &&
|
||||
llvm::isa<DequantizeOp>(returned_op)) {
|
||||
auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
|
||||
Value dequantized_result = dequantize_op.input();
|
||||
output_types.push_back(dequantized_result.getType());
|
||||
|
Loading…
x
Reference in New Issue
Block a user