Fixed a bug where RemoveQuantizationAdapterOps would remove a dequantize which fed a return even if that op had other uses.

PiperOrigin-RevId: 319074130
Change-Id: I2b682bcfdfdc4b496711da3ab53bea5975cc8c54
This commit is contained in:
A. Unique TensorFlower 2020-06-30 12:23:20 -07:00 committed by TensorFlower Gardener
parent 78e1d0f299
commit 4e004c5e4a
2 changed files with 17 additions and 2 deletions

View File

@ -63,3 +63,17 @@ func @main2(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x4xf32> {
// CHECK-NEXT: %[[add:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
// CHECK-NEXT: return %[[add]] : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
// CHECK-NEXT:}
// CHECK-LABEL: HandleReturnedDequantizeWithAnotherUse
func @HandleReturnedDequantizeWithAnotherUse(%arg0: tensor<128x16xf32>) -> (tensor<128x16xf32>, tensor<128xi32>) {
// CHECK-NEXT: %[[cst:.*]] = constant dense<1> : tensor<i32>
%cst = constant dense<1> : tensor<i32>
// CHECK-NEXT: %[[softmax:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<128x16xf32>) -> tensor<128x16xf32>
%0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<128x16xf32>) -> tensor<128x16xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<128x16x!quant.uniform<u8:f32, 3.906250e-03>>, volatile} : (tensor<128x16xf32>) -> tensor<128x16x!quant.uniform<u8:f32, 3.906250e-03>>
%2 = "tfl.dequantize"(%1) : (tensor<128x16x!quant.uniform<u8:f32, 3.906250e-03>>) -> tensor<128x16xf32>
// CHECK-NEXT: %[[argmax:.*]] = "tfl.arg_max"(%[[softmax]], %[[cst]]) : (tensor<128x16xf32>, tensor<i32>) -> tensor<128xi32>
%3 = "tfl.arg_max"(%2, %cst) : (tensor<128x16xf32>, tensor<i32>) -> tensor<128xi32>
// CHECK-NEXT: return %[[softmax]], %[[argmax]] : tensor<128x16xf32>, tensor<128xi32>
return %2, %3 : tensor<128x16xf32>, tensor<128xi32>
}

View File

@ -53,7 +53,6 @@ class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
void RemoveQuantizationAdaptorOps(FuncOp func) {
mlir::OpBuilder builder(func.getBody());
auto& bb = func.front();
auto* terminator = bb.getTerminator();
int num_args = bb.getNumArguments();
llvm::SmallVector<Type, 4> input_types;
@ -99,13 +98,15 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
}
// Edit the return ops and remove the dequantize ops in place.
auto* terminator = bb.getTerminator();
int num_return_operands = terminator->getNumOperands();
llvm::SmallVector<Type, 4> output_types;
output_types.reserve(num_return_operands);
for (int i = 0; i != num_return_operands; ++i) {
auto returned_value = terminator->getOperand(i);
Operation* returned_op = returned_value.getDefiningOp();
if (returned_op && llvm::isa<DequantizeOp>(returned_op)) {
if (returned_op && returned_op->hasOneUse() &&
llvm::isa<DequantizeOp>(returned_op)) {
auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
Value dequantized_result = dequantize_op.input();
output_types.push_back(dequantized_result.getType());