Remove trivial quantize op
PiperOrigin-RevId: 312221307 Change-Id: Ibed5b449cedf5268f675a9fb09807e429f8a254a
This commit is contained in:
parent
3da4ead13d
commit
97aed8f72e
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
@ -35,6 +36,7 @@ limitations under the License.
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -363,6 +365,54 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
}
|
||||
};
|
||||
|
||||
// Fold Extra Requantize ops if the preceding ops has free scale requirement.
|
||||
template <typename RQ>
|
||||
struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
|
||||
explicit FoldTrivalRequantizeOp(MLIRContext* context)
|
||||
: OpRewritePattern<RQ>(context, 1) {}
|
||||
|
||||
LogicalResult matchAndRewrite(RQ op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Value pre_quantized = op.input();
|
||||
auto pre_quantized_type =
|
||||
quant::QuantizedType::getQuantizedElementType(pre_quantized.getType());
|
||||
if (!pre_quantized_type) return failure();
|
||||
|
||||
Operation* def = pre_quantized.getDefiningOp();
|
||||
if (!def) return failure();
|
||||
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() ||
|
||||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
op.emitWarning("Remove trivial `rescale` op. Please fix the source graph.");
|
||||
|
||||
llvm::SmallVector<Type, 4> new_output_types;
|
||||
for (auto result : def->getResults()) {
|
||||
result.getUsers().begin()->dump();
|
||||
op.dump();
|
||||
if (result.hasOneUse() && *result.getUsers().begin() == op) {
|
||||
new_output_types.push_back(op.qtype());
|
||||
} else {
|
||||
new_output_types.push_back(result.getType());
|
||||
}
|
||||
}
|
||||
|
||||
// Remove this rescale op.
|
||||
rewriter.replaceOp(op, {pre_quantized});
|
||||
|
||||
// Replace the output scale of the preceding op.
|
||||
rewriter.setInsertionPointAfter(def);
|
||||
OperationState new_state(def->getLoc(), def->getName().getStringRef(),
|
||||
def->getOperands(), new_output_types,
|
||||
def->getAttrs());
|
||||
Operation* new_op = rewriter.createOperation(new_state);
|
||||
|
||||
rewriter.replaceOp(def, new_op->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Given a quantized type `input`, magnifying its scales by the factor stored in
|
||||
// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
|
||||
// dimension size of `input` or isn't floating-point, nullptr will be returned.
|
||||
|
@ -19,6 +19,16 @@ func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> (tensor<2xf32>,t
|
||||
// CHECK-NEXT: return %[[split]]#0, %[[split]]#1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: RemoveTrival
|
||||
func @RemoveTrival(%arg0: tensor<384x512x!quant.uniform<i8:f32, 1.0:-128>>, %arg1: tensor<128x512x!quant.uniform<i8<-127:127>:f32, 1.0>>, %arg2: none) -> tensor<384x128x!quant.uniform<i8:f32, 2.0>> {
|
||||
%1 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<384x512x!quant.uniform<i8:f32, 1.0:-128>>, tensor<128x512x!quant.uniform<i8<-127:127>:f32, 1.0>>, none) -> tensor<384x128x!quant.uniform<i8:f32, 1.0>>
|
||||
%2 = "tfl.quantize"(%1) {qtype = tensor<384x128x!quant.uniform<i8:f32, 2.0>>} : (tensor<384x128x!quant.uniform<i8:f32, 1.0>>) -> tensor<384x128x!quant.uniform<i8:f32, 2.0>>
|
||||
return %2 : tensor<384x128x!quant.uniform<i8:f32, 2.0>>
|
||||
|
||||
// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"{{.*}} -> tensor<384x128x!quant.uniform<i8:f32, 2.000000e+00>>
|
||||
// CHECK-NEXT: return %[[fc]]
|
||||
}
|
||||
|
||||
func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
|
||||
%cst = constant dense<[1, 1001]> : tensor<2xi32>
|
||||
%0 = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>
|
||||
|
@ -125,6 +125,7 @@ void PostQuantizePass::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
auto* ctx = func.getContext();
|
||||
TFL::populateWithGenerated(ctx, &patterns);
|
||||
patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
|
||||
if (!emit_quant_adaptor_ops_) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user