Remove the redundant tfl.quantizet/tfl.dequantize pair in the post-quantize

pass

This patch changes the propagation pass so a special attribute is attached to
the quantize op being added by propagation. These quantize and the paired
dequantize ops can be removed wihtout losing accuracy. So the post-quantize
pass removes these pairs if they are not fused to quantized ops.

PiperOrigin-RevId: 313703390
Change-Id: I002d001147839a7f180071ae7df2a56325e60094
This commit is contained in:
Feng Liu 2020-05-28 19:08:15 -07:00 committed by TensorFlower Gardener
parent b51979e15e
commit 6115edabbc
9 changed files with 86 additions and 43 deletions

View File

@ -23,6 +23,7 @@ package_group(
exports_files([
"quantization_traits.h",
"quantization_config.h",
"quantization_utils.h",
])
filegroup(

View File

@ -49,14 +49,15 @@ cc_library(
],
hdrs = [
"tfl_to_std.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_utils.h",
],
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
namespace mlir {
namespace TFL {
@ -47,12 +48,18 @@ void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func) {
auto dcast = b.create<DequantizeOp>(dq.getLoc(), dq.getResult().getType(),
dq.arg());
dq.getResult().replaceAllUsesWith(dcast);
if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) {
dcast.setAttr(mlir::quant::kVolatileOpAttrName, extra_attr);
}
dq.erase();
} else if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(op)) {
auto out_type = q.getResult().getType();
auto qcast = b.create<QuantizeOp>(q.getLoc(), out_type, q.arg(),
TypeAttr::get(out_type));
q.getResult().replaceAllUsesWith(qcast);
if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) {
qcast.setAttr(mlir::quant::kVolatileOpAttrName, extra_attr);
}
q.erase();
}
});

View File

@ -494,6 +494,13 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
auto dequantize = builder_.create<quant::DequantizeCastOp>(
loc, expressed_type, quantize.getResult());
// This attribute is set to distinguish the quantize ops being added by the
// quantization pass. These ops can be removed without losing original
// program accuracy.
// TODO(fengliuai): make the attribute being part of op definition.
quantize.setAttr(kVolatileOpAttrName, builder_.getUnitAttr());
// `original_result` has a use to `quantize`, so this will replace that use
// by the result of `dequantize`. Remember to reset that use afterwards
value.replaceAllUsesWith(dequantize);

View File

@ -42,6 +42,11 @@ limitations under the License.
namespace mlir {
namespace quant {
// A unit attribute can be attached to the quantize/dequantize ops which are
// added by the quantization passes. These ops can be removed erased without
// losing accuracy.
constexpr char kVolatileOpAttrName[] = "volatile";
using QuantParams = quant::QuantizedType;
using SignedInteger = std::pair<unsigned, unsigned>; // bitwidth and sign
using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>;

View File

@ -63,7 +63,7 @@ func @prepareAdd(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
return %add : tensor<2x2xf32>
// CHECK: %[[cst:.*]] = constant dense<[{{\[}}0.000000e+00, 1.000000e+00], [2.000000e+00, 2.550000e+02]]>
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<2x2x!quant.uniform<i8:f32, 1.000000e+00:-128>>}
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<2x2x!quant.uniform<i8:f32, 1.000000e+00:-128>>, volatile}
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
// CHECK: %[[add:.*]] = tfl.add %arg0, %[[dq]]
// CHECK: return %[[add]]
@ -83,7 +83,7 @@ func @prepareConv2DSplat(%arg0: tensor<1x5x5x3xf32>) -> tensor<1x5x5x3xf32> {
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq]]
// PerTensor: %[[cst:.*]] = constant dense<1.270000e+02> : tensor<3x3x3x3xf32>
// PerTensor: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<3x3x3x3x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>}
// PerTensor: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<3x3x3x3x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>, volatile}
// PerTensor: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
// PerTensor: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq]]
}
@ -97,7 +97,7 @@ func @prepareConv2D(%arg0: tensor<1x5x5x1xf32>) -> tensor<1x5x5x3xf32> {
// CHECK: %[[cst:.*]] = constant dense<[{{\[\[\[}}0.000000e+00]]], [{{\[\[}}1.270000e+02]]], [{{\[\[}}-1.270000e+02]]]]>
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<3x1x1x1x!quant.uniform<i8<-127:127>:f32:0,
// CHECK-SAME: {0.0078740157480314959,1.000000e+00,1.000000e+00}>>}
// CHECK-SAME: {0.0078740157480314959,1.000000e+00,1.000000e+00}>>, volatile}
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq]]
@ -134,12 +134,12 @@ func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112
return %fc : tensor<1x112x112x32xf32>
// CHECK: %[[cst:.*]] = constant dense<1.270000e+02> : tensor<32x12xf32>
// CHECK: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>} : (tensor<32x12xf32>) -> tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>
// CHECK: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>, volatile}
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<32x12xf32>
// CHECK: "tfl.fully_connected"(%arg0, %[[dq]]
// PerTensor: %[[cst:.*]] = constant dense<1.270000e+02> : tensor<32x12xf32>
// PerTensor: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>} : (tensor<32x12xf32>) -> tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>
// PerTensor: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>, volatile}
// PerTensor: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<32x12xf32>
// PerTensor: "tfl.fully_connected"(%arg0, %[[dq]]
}

View File

@ -67,7 +67,7 @@ func @QuantizeConv2DPerChannel(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32,
return %conv : tensor<1x112x112x32xf32>
// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform<i32:f32:0, {1.500000e+00,3.000000e+00,4.500000e+00}>>}
// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform<i32:f32:0, {1.500000e+00,3.000000e+00,4.500000e+00}>>, volatile}
// CHECK-NEXT: %[[bias:.*]] = "tfl.dequantize"(%[[qbias]])
// CHECK-NEXT: %[[in:.*]] = "tfl.dequantize"(%arg0)
// CHECK-NEXT: %[[w:.*]] = "tfl.dequantize"(%arg1)
@ -87,7 +87,7 @@ func @QuantizeConv2DPerChannels(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32:
return %conv : tensor<1x112x112x32xf32>
// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform<i32:f32:0, {1.000000e+00,4.000000e+00,9.000000e+00}>>}
// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform<i32:f32:0, {1.000000e+00,4.000000e+00,9.000000e+00}>>, volatile}
// CHECK-NEXT: %[[bias:.*]] = "tfl.dequantize"(%[[qbias]])
// CHECK-NEXT: %[[in:.*]] = "tfl.dequantize"(%arg0)
// CHECK-NEXT: %[[w:.*]] = "tfl.dequantize"(%arg1)
@ -107,7 +107,7 @@ func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>
return %6 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
// CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32>
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>}
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, volatile}
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>)
// CHECK: %2 = "tfl.dequantize"(%arg0)
// CHECK: %3 = "tfl.pseudo_qconst"()
@ -129,7 +129,7 @@ func @QuantizeFullyConnected(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e
return %6 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
// CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32>
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>}
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, volatile}
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>)
// CHECK: %2 = "tfl.dequantize"(%arg0)
// CHECK: %3 = "tfl.pseudo_qconst"()
@ -151,7 +151,7 @@ func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500
return %6 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
// CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32>
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>}
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, volatile}
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>)
// CHECK: %2 = "tfl.dequantize"(%arg0)
// CHECK: %3 = "tfl.pseudo_qconst"()
@ -232,7 +232,7 @@ func @QuantizeStridedSlice(tensor<12x2x2x5x!quant.uniform<u8:f32, 0.1>>, tensor<
// CHECK: %0 = "tfl.dequantize"(%arg0)
// CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg3)
// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x2x2x5x!quant.uniform<u8:f32, 1.000000e-01>>}
// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x2x2x5x!quant.uniform<u8:f32, 1.000000e-01>>, volatile}
// CHECK: %3 = "tfl.dequantize"(%2)
// CHECK: return %3 : tensor<1x2x2x5xf32>
}
@ -277,7 +277,7 @@ func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>
// CHECK: %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>)
// CHECK: %1 = "tfl.reshape"(%0, %{{.*}}) : (tensor<1x6x6x16xf32>, tensor<3xi32>) -> tensor<1x36x16xf32>
// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x36x16x!quant.uniform<u8:f32, 7.812500e-03:128>>}
// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x36x16x!quant.uniform<u8:f32, 7.812500e-03:128>>, volatile}
// CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x36x16x!quant.uniform<u8:f32, 7.812500e-03:128>>)
// CHECK: return %3 : tensor<1x36x16xf32>
}
@ -291,7 +291,7 @@ func @QuantizeSoftmax(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>)
// CHECK: %0 = "tfl.dequantize"(%arg0)
// CHECK: %1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32>
// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform<u8:f32, 3.906250e-03>>}
// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform<u8:f32, 3.906250e-03>>, volatile}
// CHECK: %3 = "tfl.dequantize"(%2)
// CHECK: return %3 : tensor<1x6x6x16xf32>
}
@ -305,7 +305,7 @@ func @QuantizeLogistic(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>
// CHECK: %0 = "tfl.dequantize"(%arg0)
// CHECK: %1 = "tfl.logistic"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32>
// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform<u8:f32, 3.906250e-03>>}
// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform<u8:f32, 3.906250e-03>>, volatile}
// CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x6x6x16x!quant.uniform<u8:f32, 3.906250e-03>>) -> tensor<1x6x6x16xf32>
// CHECK: return %3 : tensor<1x6x6x16xf32>
}
@ -327,7 +327,7 @@ func @QuantizeL2Norm(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 1.0>>) -> ten
// CHECK: %[[in:.*]] = "tfl.dequantize"(%arg0)
// CHECK: %[[l2:.*]] = "tfl.l2_normalization"(%[[in]])
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[l2]]) {qtype = tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>}
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[l2]]) {qtype = tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>, volatile}
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
// CHECK: return %[[dq]] : tensor<1x6x6x16xf32>
}
@ -350,13 +350,13 @@ func @QuantizeConcatOperand0ToAll(tensor<1x2x!quant.uniform<u8:f32, 0.1:128>>, t
%1 = "tfl.concatenation"(%0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>, volatile}
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %3 = "tfl.concatenation"(%2, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
// CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>, volatile}
// CHECK: %5 = "tfl.dequantize"(%4) : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<2x2xf32>
// CHeCK: return %5 : tensor<2x2xf32>
// CHECK: return %5 : tensor<2x2xf32>
}
// CHECK-LABEL: QuantizeConcatOperand1ToAll
@ -366,11 +366,11 @@ func @QuantizeConcatOperand1ToAll(tensor<1x2xf32>, tensor<1x2x!quant.uniform<u8:
%1 = "tfl.concatenation"(%arg0, %0) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>, volatile}
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %2 = "tfl.dequantize"(%arg1) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
// CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>, volatile}
// CHECK: %5 = "tfl.dequantize"(%4) : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<2x2xf32>
// CHECK: return %5 : tensor<2x2xf32>
}
@ -382,9 +382,9 @@ func @QuantizeConcatResToAll(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!qu
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>, volatile}
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %2 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %2 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>, volatile}
// CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %4 = "tfl.concatenation"(%3, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
// CHECK: %5 = "tfl.quantize"(%4) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
@ -399,7 +399,7 @@ func @QuantizeConcatResToAllNoRequantize(tensor<1x2x!quant.uniform<u8:f32, 0.1:1
%2 = "tfl.quantize"(%1) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
return %2 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>, volatile}
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %3 = "tfl.concatenation"(%2, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
@ -416,7 +416,7 @@ func @QuantizeConcatResToAllRequantize(tensor<1x2xf32>, tensor<1x2xf32>) -> tens
%3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
return %3 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>, volatile}
// CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %[[Q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>
// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%[[Q0]]) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
@ -434,7 +434,7 @@ func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform<u8:f32, 2.0:
%3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
return %3 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>, volatile}
// CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
@ -475,22 +475,22 @@ func @QuantizeChain(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>
return %10 : tensor<1x36x16xf32>
// CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32>
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>}
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, volatile}
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>)
// CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
// CHECK: %3 = "tfl.pseudo_qconst"()
// CHECK: %4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>)
// CHECK: %5 = "tfl.average_pool_2d"(%2)
// CHECK: %6 = "tfl.quantize"(%5) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>}
// CHECK: %6 = "tfl.quantize"(%5) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>, volatile}
// CHECK: %7 = "tfl.dequantize"(%6) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
// CHECK: %8 = "tfl.conv_2d"(%7, %4, %1)
// CHECK: %9 = "tfl.quantize"(%8) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>}
// CHECK: %10 = "tfl.dequantize"(%9) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>)
// CHECK: %11 = "tfl.reshape"(%10, %{{.*}})
// CHECK: %12 = "tfl.quantize"(%11) {qtype = tensor<1x36x16x!quant.uniform<u8:f32, 0.023528476789885875>>}
// CHECK: %12 = "tfl.quantize"(%11) {qtype = tensor<1x36x16x!quant.uniform<u8:f32, 0.023528476789885875>>, volatile}
// CHECK: %13 = "tfl.dequantize"(%12) : (tensor<1x36x16x!quant.uniform<u8:f32, 0.023528476789885875>>)
// CHECK: %14 = "tfl.softmax"(%13)
// CHECK: %15 = "tfl.quantize"(%14) {qtype = tensor<1x36x16x!quant.uniform<u8:f32, 3.906250e-03>>}
// CHECK: %15 = "tfl.quantize"(%14) {qtype = tensor<1x36x16x!quant.uniform<u8:f32, 3.906250e-03>>, volatile}
// CHECK: %16 = "tfl.dequantize"(%15) : (tensor<1x36x16x!quant.uniform<u8:f32, 3.906250e-03>>)
// CHECK: return %16 : tensor<1x36x16xf32>
}
@ -501,7 +501,7 @@ func @QuantizeConstant() -> tensor<2x3xf32> {
return %cst : tensor<2x3xf32>
// CHECK: %cst = constant dense{{.*}}tensor<2x3xf32>
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<2x3x!quant.uniform<u8:f32, 0.023529411764705882:128>>}
// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<2x3x!quant.uniform<u8:f32, 0.023529411764705882:128>>, volatile}
// CHECK: %1 = "tfl.dequantize"(%0)
// CHECK: return %1 : tensor<2x3xf32>
}
@ -521,7 +521,7 @@ func @QuantizeZeroSplat() -> tensor<2x3xf32> {
return %cst : tensor<2x3xf32>
// CHECK-NEXT: %[[cst:.*]] = constant dense<0.000000e+00> : tensor<2x3xf32>
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform<u8:f32, 1.000000e+00>>}
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform<u8:f32, 1.000000e+00>>, volatile}
}
// CHECK-LABEL: QuantizeZeroScalar
@ -530,7 +530,7 @@ func @QuantizeZeroScalar() -> tensor<f32> {
return %cst : tensor<f32>
// CHECK-NEXT: %[[cst:.*]] = constant dense<0.000000e+00> : tensor<f32>
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<!quant.uniform<u8:f32, 1.000000e+00>>}
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<!quant.uniform<u8:f32, 1.000000e+00>>, volatile}
}
// CHECK-LABEL: QuantizePositiveSplat
@ -539,7 +539,7 @@ func @QuantizePositiveSplat() -> tensor<2x3xf32> {
return %cst : tensor<2x3xf32>
// CHECK-NEXT: %[[cst:.*]] = constant dense<2.540000e+01> : tensor<2x3xf32>
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform<u8:f32, 0.099607841641295186>>}
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform<u8:f32, 0.099607841641295186>>, volatile}
}
// CHECK-LABEL: QuantizePositiveScalar
@ -548,7 +548,7 @@ func @QuantizePositiveScalar() -> tensor<f32> {
return %cst : tensor<f32>
// CHECK-NEXT: %[[cst:.*]] = constant dense<2.540000e+00> : tensor<f32>
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<!quant.uniform<u8:f32, 0.0099607841641295193>>}
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<!quant.uniform<u8:f32, 0.0099607841641295193>>, volatile}
}
// CHECK-LABEL: QuantizeNegativeSplat
@ -557,7 +557,7 @@ func @QuantizeNegativeSplat() -> tensor<2x3xf32> {
return %cst : tensor<2x3xf32>
// CHECK-NEXT: %[[cst:.*]] = constant dense<-2.540000e+00> : tensor<2x3xf32>
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform<u8:f32, 0.0099607841641295193:255>>}
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform<u8:f32, 0.0099607841641295193:255>>, volatile}
}
// CHECK-LABEL: QuantizeNegativeScalar
@ -566,7 +566,7 @@ func @QuantizeNegativeScalar() -> tensor<f32> {
return %cst : tensor<f32>
// CHECK-NEXT: %[[cst:.*]] = constant dense<-2.540000e+01> : tensor<f32>
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<!quant.uniform<u8:f32, 0.099607841641295186:255>>}
// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<!quant.uniform<u8:f32, 0.099607841641295186:255>>, volatile}
}
// CHECK-LABEL: QuantizeSharedBiases
@ -617,7 +617,7 @@ func @QuantizeSharedBiases2(
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]])
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
// CHECK: %[[cst_0:.*]] = constant dense<0.000000e+00> : tensor<32xf32>
// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<32x!quant.uniform<u8:f32, 1.000000e+00>>}
// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<32x!quant.uniform<u8:f32, 1.000000e+00>>, volatile}
// CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]])
// CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]]
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])

View File

@ -195,8 +195,8 @@ func @QuantizeConcat(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant.unif
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>}
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>}
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>, volatile}
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>, volatile}
// CHECK: %[[cc:.*]] = "tfl.concatenation"(%[[q1]], %[[q0]]) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: return %[[cc]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
}
@ -209,8 +209,8 @@ func @QuantizeConcatRequantize(tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>, tens
%3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
return %3 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>}
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>, volatile}
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>}
// CHECK: %[[cc:.*]] = "tfl.concatenation"(%[[q0]], %[[q1]]) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: return %[[cc]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -118,6 +119,24 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
func.setType(new_func_type);
}
// Remove the back-to-back quantize and dequantize ops with volatile attribute.
struct RemoveVolatileOps : public OpRewritePattern<DequantizeOp> {
explicit RemoveVolatileOps(MLIRContext* context)
: OpRewritePattern<DequantizeOp>(context, 1) {}
LogicalResult matchAndRewrite(DequantizeOp op,
PatternRewriter& rewriter) const override {
auto input_op = op.input().getDefiningOp();
if (auto q = llvm::dyn_cast_or_null<QuantizeOp>(input_op)) {
if (!q.getAttr(mlir::quant::kVolatileOpAttrName)) return failure();
op.replaceAllUsesWith(q.input());
return success();
}
return failure();
}
};
#include "tensorflow/compiler/mlir/lite/transforms/generated_post_quantize.inc"
void PostQuantizePass::runOnFunction() {
@ -131,6 +150,9 @@ void PostQuantizePass::runOnFunction() {
if (!emit_quant_adaptor_ops_) {
RemoveQuantizationAdaptorOps(getFunction());
}
patterns.insert<RemoveVolatileOps>(ctx);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace