Add enable_numeric_verify in TFLite converter for quantization numeric verification.
This CL adds experimental enable_numeric_verify in TFLite converter.mlir_quantize() for quantization numeric verification. This pass adds NumericVerify ops for debugging during full-integer quantization, and NumericVerify op will output the differences between the original float & the dequantized values of the quantized values. The main characteristic of NumericVerify op is the output tensor name has the following format:
NumericVerify/{tensor_name}:{tensor_id}
where tensor_name and tensor_id are from the original quantized op's activation tensor (the first input tensor of the NumericVerify op) for the purpose of debugging. It provides users the easy way to search for tensors in the debugging tools.
If you turn on log_if_failed in NumericVerify op, then it will output logs & throw errors when there exist any errors > tolerance.
PiperOrigin-RevId: 348740984
Change-Id: Ifb965d2237e827a747d1888cbd21af311d42e4a1
This commit is contained in:
parent
917ebe0008
commit
ca3e7f55d9
@ -944,7 +944,22 @@ BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
|
|||||||
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
|
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
|
||||||
const std::vector<int32_t>& results) {
|
const std::vector<int32_t>& results) {
|
||||||
float tolerance = op.tolerance().convertToFloat();
|
float tolerance = op.tolerance().convertToFloat();
|
||||||
return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results);
|
bool log_if_failed = op.log_if_failed();
|
||||||
|
auto fbb = absl::make_unique<flexbuffers::Builder>();
|
||||||
|
fbb->Map([&]() {
|
||||||
|
fbb->Float("tolerance", tolerance);
|
||||||
|
fbb->Bool("log_if_failed", log_if_failed);
|
||||||
|
});
|
||||||
|
fbb->Finish();
|
||||||
|
auto f = std::unique_ptr<flexbuffers::Builder>(fbb.release());
|
||||||
|
auto custom_option = f->GetBuffer();
|
||||||
|
auto opcode_index =
|
||||||
|
GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM);
|
||||||
|
return tflite::CreateOperator(
|
||||||
|
builder_, opcode_index, builder_.CreateVector(operands),
|
||||||
|
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
|
||||||
|
/*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_option),
|
||||||
|
tflite::CustomOptionsFormat_FLEXBUFFERS);
|
||||||
}
|
}
|
||||||
|
|
||||||
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
|
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
|
||||||
@ -1408,6 +1423,17 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
|
|||||||
|
|
||||||
for (auto val : inst.getResults()) {
|
for (auto val : inst.getResults()) {
|
||||||
std::string name = UniqueName(val);
|
std::string name = UniqueName(val);
|
||||||
|
// For "tfl.numeric_verify" op, the name is used to find out the original
|
||||||
|
// activation tensor rather than its own unique name in the visualization
|
||||||
|
// or debugging tools.
|
||||||
|
auto builtin_code = GetBuiltinOpCode(&inst);
|
||||||
|
if (!builtin_code && dyn_cast<mlir::TFL::NumericVerifyOp>(&inst)) {
|
||||||
|
// The first operand is the quantized activation, the target of this
|
||||||
|
// NumericVerify op.
|
||||||
|
auto quantized_op_val = inst.getOperands().front();
|
||||||
|
name = "NumericVerify/" + UniqueName(quantized_op_val) + ":" +
|
||||||
|
std::to_string(tensor_index_map[quantized_op_val]);
|
||||||
|
}
|
||||||
if (!build_tensor_and_buffer(val, name)) return llvm::None;
|
if (!build_tensor_and_buffer(val, name)) return llvm::None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -4257,6 +4257,11 @@ def TFL_NumericVerifyOp : Op<TFL_Dialect, "NumericVerify", [
|
|||||||
let description = [{
|
let description = [{
|
||||||
The NumericVerify op is a debugging op to verify the numericals of the two
|
The NumericVerify op is a debugging op to verify the numericals of the two
|
||||||
activations. It is a custom op in TFLite.
|
activations. It is a custom op in TFLite.
|
||||||
|
If log_if_failed is true, the NumericVerify op calculates statistics on
|
||||||
|
differences between float and quantized activations, output
|
||||||
|
logs, set differences to the output tensors, and throws an error if errors
|
||||||
|
above tolerance exist. If log_if_failed = false, then it doesn't care about
|
||||||
|
errors.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@ -4264,10 +4269,11 @@ def TFL_NumericVerifyOp : Op<TFL_Dialect, "NumericVerify", [
|
|||||||
TFL_TensorOf<[F32]>:$ref,
|
TFL_TensorOf<[F32]>:$ref,
|
||||||
|
|
||||||
// Attributes
|
// Attributes
|
||||||
DefaultValuedAttr<F32Attr, "0.1">:$tolerance
|
DefaultValuedAttr<F32Attr, "0.1">:$tolerance,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$log_if_failed
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs);
|
let results = (outs TFL_FpTensor:$output);
|
||||||
}
|
}
|
||||||
|
|
||||||
// SVDF op.
|
// SVDF op.
|
||||||
|
|||||||
@ -44,7 +44,7 @@ TfLiteStatus QuantizeModel(
|
|||||||
const std::unordered_set<std::string>& operator_names,
|
const std::unordered_set<std::string>& operator_names,
|
||||||
bool disable_per_channel, bool fully_quantize,
|
bool disable_per_channel, bool fully_quantize,
|
||||||
flatbuffers::FlatBufferBuilder* builder,
|
flatbuffers::FlatBufferBuilder* builder,
|
||||||
tflite::ErrorReporter* error_reporter) {
|
tflite::ErrorReporter* error_reporter, bool verify_numeric) {
|
||||||
// TODO(b/142502494): remove this restriction by improving the `emit_adaptor`
|
// TODO(b/142502494): remove this restriction by improving the `emit_adaptor`
|
||||||
// flag
|
// flag
|
||||||
if (input_type != output_type) {
|
if (input_type != output_type) {
|
||||||
@ -91,8 +91,10 @@ TfLiteStatus QuantizeModel(
|
|||||||
quant_specs.inference_type = input_tf_type;
|
quant_specs.inference_type = input_tf_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
quant_specs.verify_numeric = verify_numeric;
|
||||||
|
|
||||||
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
|
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
|
||||||
pm.addPass(TFL::CreateQuantizePass());
|
pm.addPass(TFL::CreateQuantizePass(verify_numeric));
|
||||||
pm.addPass(TFL::CreatePostQuantizePass(emit_adaptor));
|
pm.addPass(TFL::CreatePostQuantizePass(emit_adaptor));
|
||||||
|
|
||||||
if (failed(pm.run(module.get()))) {
|
if (failed(pm.run(module.get()))) {
|
||||||
|
|||||||
@ -36,7 +36,7 @@ TfLiteStatus QuantizeModel(
|
|||||||
const std::unordered_set<std::string>& operator_names,
|
const std::unordered_set<std::string>& operator_names,
|
||||||
bool disable_per_channel, bool fully_quantize,
|
bool disable_per_channel, bool fully_quantize,
|
||||||
flatbuffers::FlatBufferBuilder* builder,
|
flatbuffers::FlatBufferBuilder* builder,
|
||||||
tflite::ErrorReporter* error_reporter);
|
tflite::ErrorReporter* error_reporter, bool verify_numeric = false);
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|||||||
@ -124,6 +124,10 @@ struct QuantizationSpecs {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Whether add the NumericVerify ops to verify numbers before and after
|
||||||
|
// quantization.
|
||||||
|
bool verify_numeric = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Parses the command line flag strings to the quantization specification for
|
// Parses the command line flag strings to the quantization specification for
|
||||||
|
|||||||
@ -175,12 +175,14 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
using BaseType = QuantizationPattern<ConcretTy, Q, DQ, VERIFIER>;
|
using BaseType = QuantizationPattern<ConcretTy, Q, DQ, VERIFIER>;
|
||||||
|
|
||||||
explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
|
explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
|
||||||
float error_tolerance, bool single_layer_verify)
|
float error_tolerance, bool single_layer_verify,
|
||||||
|
bool log_if_failed = false)
|
||||||
// Set the score to a large number so it is always preferred.
|
// Set the score to a large number so it is always preferred.
|
||||||
: RewritePattern(DQ::getOperationName(), 300, context),
|
: RewritePattern(DQ::getOperationName(), 300, context),
|
||||||
enable_verify(enable_verify),
|
enable_verify(enable_verify),
|
||||||
error_tolerance(error_tolerance),
|
error_tolerance(error_tolerance),
|
||||||
single_layer_verify(single_layer_verify) {}
|
single_layer_verify(single_layer_verify),
|
||||||
|
log_if_failed(log_if_failed) {}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(Operation* op,
|
LogicalResult matchAndRewrite(Operation* op,
|
||||||
PatternRewriter& rewriter) const override {
|
PatternRewriter& rewriter) const override {
|
||||||
@ -312,10 +314,11 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
}
|
}
|
||||||
rewriter.setInsertionPointAfter(new_op);
|
rewriter.setInsertionPointAfter(new_op);
|
||||||
FloatAttr tolerance = rewriter.getF32FloatAttr(error_tolerance);
|
FloatAttr tolerance = rewriter.getF32FloatAttr(error_tolerance);
|
||||||
|
BoolAttr log = rewriter.getBoolAttr(log_if_failed);
|
||||||
// Verify the quantized value by sending the result to the verifier.
|
// Verify the quantized value by sending the result to the verifier.
|
||||||
rewriter.create<VERIFIER>(quantized_op->getLoc(),
|
rewriter.create<VERIFIER>(
|
||||||
new_op->getResult(i),
|
quantized_op->getLoc(), new_op->getResult(i).getType(),
|
||||||
quantized_op->getResult(i), tolerance);
|
new_op->getResult(i), quantized_op->getResult(i), tolerance, log);
|
||||||
|
|
||||||
if (single_layer_verify) continue;
|
if (single_layer_verify) continue;
|
||||||
|
|
||||||
@ -341,6 +344,7 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
bool enable_verify;
|
bool enable_verify;
|
||||||
float error_tolerance;
|
float error_tolerance;
|
||||||
bool single_layer_verify;
|
bool single_layer_verify;
|
||||||
|
bool log_if_failed;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Converts quantize ops with unsigned quantized types to these with signed
|
// Converts quantize ops with unsigned quantized types to these with signed
|
||||||
|
|||||||
@ -24,13 +24,20 @@
|
|||||||
// CHECK-NEXT: scale: [ 0.1 ],
|
// CHECK-NEXT: scale: [ 0.1 ],
|
||||||
// CHECK-NEXT: zero_point: [ 0 ]
|
// CHECK-NEXT: zero_point: [ 0 ]
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-NEXT: shape: [ 4 ],
|
||||||
|
// CHECK-NEXT: buffer: 3,
|
||||||
|
// CHECK-NEXT: name: "NumericVerify/arg1:1",
|
||||||
|
// CHECK-NEXT: quantization: {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||||
// CHECK-NEXT: outputs: [ 0 ],
|
// CHECK-NEXT: outputs: [ 0 ],
|
||||||
// CHECK-NEXT: operators: [ {
|
// CHECK-NEXT: operators: [ {
|
||||||
// CHECK-NEXT: inputs: [ 1, 0 ],
|
// CHECK-NEXT: inputs: [ 1, 0 ],
|
||||||
// CHECK-NEXT: outputs: [ ],
|
// CHECK-NEXT: outputs: [ 2 ],
|
||||||
// CHECK-NEXT: custom_options: [ 205, 204, 204, 61 ]
|
// CHECK-NEXT: custom_options:
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: name: "main"
|
// CHECK-NEXT: name: "main"
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
@ -42,16 +49,18 @@
|
|||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-EMPTY:
|
// CHECK-EMPTY:
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: metadata: [ {
|
// CHECK-NEXT: metadata: [ {
|
||||||
// CHECK-NEXT: name: "min_runtime_version",
|
// CHECK-NEXT: name: "min_runtime_version",
|
||||||
// CHECK-NEXT: buffer: 3
|
// CHECK-NEXT: buffer: 4
|
||||||
// CHECK-NEXT: } ]
|
// CHECK-NEXT: } ]
|
||||||
// CHECK-NEXT: signature_defs: [ ]
|
// CHECK-NEXT: signature_defs: [ ]
|
||||||
// CHECK-NEXT:}
|
// CHECK-NEXT:}
|
||||||
|
|
||||||
func @main(%arg0: tensor<4xf32>, %arg1: tensor<4x!quant.uniform<u8:f32, 0.1>>) -> tensor<4xf32> {
|
func @main(%arg0: tensor<4xf32>, %arg1: tensor<4x!quant.uniform<u8:f32, 0.1>>) -> tensor<4xf32> {
|
||||||
"tfl.NumericVerify"(%arg1, %arg0) {tolerance = 0.1 : f32} : (tensor<4x!quant.uniform<u8:f32, 0.1>>, tensor<4xf32>) -> ()
|
"tfl.NumericVerify"(%arg1, %arg0) {tolerance = 0.1 : f32} : (tensor<4x!quant.uniform<u8:f32, 0.1>>, tensor<4xf32>) -> (tensor<4xf32>)
|
||||||
return %arg0 : tensor<4xf32>
|
return %arg0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize | FileCheck %s
|
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize | FileCheck %s
|
||||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify | FileCheck --check-prefix=DEBUG %s
|
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify -tfl-log-if-failed | FileCheck --check-prefix=DEBUG %s
|
||||||
|
|
||||||
// CHECK-LABEL: QuantizeFloatConst
|
// CHECK-LABEL: QuantizeFloatConst
|
||||||
func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
||||||
@ -76,7 +76,7 @@ func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>
|
|||||||
// DEBUG: %[[act:.*]] = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
|
// DEBUG: %[[act:.*]] = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
|
||||||
// DEBUG: %[[f_conv:.*]] = "tfl.conv_2d"(%[[act]], %[[wt]], %[[bias]])
|
// DEBUG: %[[f_conv:.*]] = "tfl.conv_2d"(%[[act]], %[[wt]], %[[bias]])
|
||||||
// DEBUG: %[[q_conv:.*]] = "tfl.conv_2d"
|
// DEBUG: %[[q_conv:.*]] = "tfl.conv_2d"
|
||||||
// DEBUG: "tfl.NumericVerify"(%[[q_conv]], %[[f_conv]]) {tolerance = 5.000000e+00 : f32}
|
// DEBUG: "tfl.NumericVerify"(%[[q_conv]], %[[f_conv]]) {log_if_failed = true, tolerance = 5.000000e+00 : f32}
|
||||||
// DEBUG: return %[[q_conv]] : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
// DEBUG: return %[[q_conv]] : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -54,7 +54,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
|
|||||||
quant_specs.default_ranges.second.getValueOr(0.0),
|
quant_specs.default_ranges.second.getValueOr(0.0),
|
||||||
quant_specs.IsSignedInferenceType()));
|
quant_specs.IsSignedInferenceType()));
|
||||||
}
|
}
|
||||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateQuantizePass());
|
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||||
|
mlir::TFL::CreateQuantizePass(quant_specs.verify_numeric));
|
||||||
bool emit_quant_adaptor_ops =
|
bool emit_quant_adaptor_ops =
|
||||||
quant_specs.inference_type != quant_specs.inference_input_type;
|
quant_specs.inference_type != quant_specs.inference_input_type;
|
||||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||||
|
|||||||
@ -48,7 +48,8 @@ std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
|
|||||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass();
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
|
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass(
|
||||||
|
bool verify_numeric = false);
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
|
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(
|
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(
|
||||||
|
|||||||
@ -55,6 +55,13 @@ static llvm::cl::opt<bool> enable_single_layer_verify(
|
|||||||
"`-tfl-numeric-verify` is set."),
|
"`-tfl-numeric-verify` is set."),
|
||||||
llvm::cl::init(true));
|
llvm::cl::init(true));
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static llvm::cl::opt<bool> enable_log_if_failed(
|
||||||
|
"tfl-log-if-failed", llvm::cl::value_desc("bool"),
|
||||||
|
llvm::cl::desc("Whether verify numericals with thresholding "
|
||||||
|
"tolerance. Valid when `-tfl-numeric-verify` is set."),
|
||||||
|
llvm::cl::init(false));
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TFL {
|
namespace TFL {
|
||||||
|
|
||||||
@ -67,16 +74,26 @@ namespace {
|
|||||||
struct TFLFullQuantization
|
struct TFLFullQuantization
|
||||||
: public quant::QuantizationPattern<TFLFullQuantization, QuantizeOp,
|
: public quant::QuantizationPattern<TFLFullQuantization, QuantizeOp,
|
||||||
DequantizeOp, NumericVerifyOp> {
|
DequantizeOp, NumericVerifyOp> {
|
||||||
explicit TFLFullQuantization(MLIRContext* ctx, bool verify_numeric,
|
explicit TFLFullQuantization(MLIRContext* ctx, bool verify_numeric_flag,
|
||||||
float tolerance, bool verify_single_layer)
|
float tolerance, bool verify_single_layer,
|
||||||
: BaseType(ctx, verify_numeric, tolerance, verify_single_layer) {}
|
bool log_if_failed_flag = false)
|
||||||
|
: BaseType(ctx, verify_numeric_flag, tolerance, verify_single_layer,
|
||||||
|
log_if_failed_flag) {}
|
||||||
static bool AllowHybridOperand() { return false; }
|
static bool AllowHybridOperand() { return false; }
|
||||||
static bool AllowHybridResult() { return false; }
|
static bool AllowHybridResult() { return false; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// Applies quantization on the model in TFL dialect.
|
// Applies quantization on the model in TFL dialect.
|
||||||
struct QuantizePass : public PassWrapper<QuantizePass, FunctionPass> {
|
struct QuantizePass : public PassWrapper<QuantizePass, FunctionPass> {
|
||||||
|
public:
|
||||||
|
// Constructor used by manually creating the pass.
|
||||||
|
explicit QuantizePass(bool verify_numeric_flag = false)
|
||||||
|
: verify_numeric(verify_numeric_flag) {}
|
||||||
|
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool verify_numeric;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc"
|
#include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc"
|
||||||
@ -87,14 +104,15 @@ void QuantizePass::runOnFunction() {
|
|||||||
auto* ctx = func.getContext();
|
auto* ctx = func.getContext();
|
||||||
TFL::populateWithGenerated(ctx, patterns);
|
TFL::populateWithGenerated(ctx, patterns);
|
||||||
patterns.insert<TFLFullQuantization>(
|
patterns.insert<TFLFullQuantization>(
|
||||||
ctx, enable_numeric_verify, error_tolerance, enable_single_layer_verify);
|
ctx, enable_numeric_verify || verify_numeric, error_tolerance,
|
||||||
|
enable_single_layer_verify, enable_log_if_failed);
|
||||||
applyPatternsAndFoldGreedily(func, std::move(patterns));
|
applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
|
// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass(bool verify_numeric) {
|
||||||
return std::make_unique<QuantizePass>();
|
return std::make_unique<QuantizePass>(verify_numeric);
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<QuantizePass> pass(
|
static PassRegistration<QuantizePass> pass(
|
||||||
|
|||||||
@ -38,8 +38,9 @@ namespace custom {
|
|||||||
namespace numeric_verify {
|
namespace numeric_verify {
|
||||||
|
|
||||||
static constexpr const char kToleranceStr[] = "tolerance";
|
static constexpr const char kToleranceStr[] = "tolerance";
|
||||||
static constexpr const char kDebugModeStr[] = "debug_mode";
|
static constexpr const char kLogIfFailedStr[] = "log_if_failed";
|
||||||
static constexpr const int kTemporaryDequantizedTensor = 0;
|
static constexpr const int kTemporaryDequantizedTensor = 0;
|
||||||
|
static constexpr const int kOutputTensor = 0;
|
||||||
|
|
||||||
struct OpContext {
|
struct OpContext {
|
||||||
OpContext(TfLiteContext* context, TfLiteNode* node) {
|
OpContext(TfLiteContext* context, TfLiteNode* node) {
|
||||||
@ -61,7 +62,7 @@ struct OpData {
|
|||||||
bool float_input_initialized;
|
bool float_input_initialized;
|
||||||
int cache_tensor_id = kTensorNotAllocated;
|
int cache_tensor_id = kTensorNotAllocated;
|
||||||
// This boolean value is for controlling the behavior of numeric verify op.
|
// This boolean value is for controlling the behavior of numeric verify op.
|
||||||
bool debug_mode;
|
bool log_if_failed;
|
||||||
};
|
};
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
@ -71,9 +72,9 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
||||||
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
|
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
|
||||||
const float tolerance = m[kToleranceStr].AsFloat();
|
const float tolerance = m[kToleranceStr].AsFloat();
|
||||||
const bool debug_mode = m[kDebugModeStr].AsBool();
|
const bool log_if_failed = m[kLogIfFailedStr].AsBool();
|
||||||
op_data->tolerance = tolerance;
|
op_data->tolerance = tolerance;
|
||||||
op_data->debug_mode = debug_mode;
|
op_data->log_if_failed = log_if_failed;
|
||||||
|
|
||||||
return op_data;
|
return op_data;
|
||||||
}
|
}
|
||||||
@ -84,13 +85,11 @@ void Free(TfLiteContext* context, void* buffer) {
|
|||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
OpContext op_context(context, node);
|
OpContext op_context(context, node);
|
||||||
|
|
||||||
const int num_output = (op_data->debug_mode) ? 1 : 0;
|
|
||||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), num_output);
|
|
||||||
|
|
||||||
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
|
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
|
||||||
op_context.input->type == kTfLiteInt8 ||
|
op_context.input->type == kTfLiteInt8 ||
|
||||||
op_context.input->type == kTfLiteInt16 ||
|
op_context.input->type == kTfLiteInt16 ||
|
||||||
@ -118,15 +117,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
context, dequantized,
|
context, dequantized,
|
||||||
TfLiteIntArrayCopy(op_context.input->dims)));
|
TfLiteIntArrayCopy(op_context.input->dims)));
|
||||||
|
|
||||||
if (op_data->debug_mode) {
|
TF_LITE_ENSURE_OK(
|
||||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, num_output - 1,
|
context, GetOutputSafe(context, node, kOutputTensor, &op_context.output));
|
||||||
&op_context.output));
|
|
||||||
op_context.output->type = kTfLiteFloat32;
|
op_context.output->type = kTfLiteFloat32;
|
||||||
op_context.output->allocation_type = kTfLiteArenaRwPersistent;
|
op_context.output->allocation_type = kTfLiteArenaRwPersistent;
|
||||||
return context->ResizeTensor(context, op_context.output,
|
return context->ResizeTensor(context, op_context.output,
|
||||||
TfLiteIntArrayCopy(op_context.input->dims));
|
TfLiteIntArrayCopy(op_context.input->dims));
|
||||||
}
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static int32_t GetQuantizedValue(const OpContext& op_context, int index) {
|
static int32_t GetQuantizedValue(const OpContext& op_context, int index) {
|
||||||
@ -165,22 +161,37 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
op_data->float_input_initialized = true;
|
op_data->float_input_initialized = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the debug_mode is on, we don't throw any errors.
|
TF_LITE_ENSURE_OK(
|
||||||
// We just calculate difference between float and quantized values, letting
|
context, GetOutputSafe(context, node, kOutputTensor, &op_context.output));
|
||||||
// python debugger deal with the information.
|
|
||||||
if (op_data->debug_mode || op_data->tolerance < 0.1) {
|
|
||||||
const int num_output = (op_data->debug_mode) ? 1 : 0;
|
|
||||||
const int n = NumElements(dequantized);
|
|
||||||
if (op_data->debug_mode) {
|
|
||||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, num_output - 1,
|
|
||||||
&op_context.output));
|
|
||||||
auto output_data = GetTensorData<float>(op_context.output);
|
auto output_data = GetTensorData<float>(op_context.output);
|
||||||
|
|
||||||
|
// If log_if_failed is on, calculate differences between float and
|
||||||
|
// quantized values, their statistics and output logs.
|
||||||
|
// Throw errors if any diff greater than tolerance exists.
|
||||||
|
const int n = NumElements(dequantized);
|
||||||
|
if (op_data->log_if_failed && op_data->tolerance >= 0.1) {
|
||||||
|
// Verify the dequantized output.
|
||||||
|
auto max_diff = op_data->tolerance * op_context.input->params.scale;
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
|
int32_t value = GetQuantizedValue(op_context, i);
|
||||||
float dequant = GetTensorData<float>(dequantized)[i];
|
float dequant = GetTensorData<float>(dequantized)[i];
|
||||||
float reference = GetTensorData<float>(op_context.ref)[i];
|
float reference = GetTensorData<float>(op_context.ref)[i];
|
||||||
output_data[i] = dequant - reference;
|
output_data[i] = dequant - reference;
|
||||||
|
float diff = std::abs(output_data[i]);
|
||||||
|
if (diff > max_diff) {
|
||||||
|
TF_LITE_KERNEL_LOG(
|
||||||
|
context,
|
||||||
|
"Mismatch: %f is quantized to %d with (%f, %d). "
|
||||||
|
"abs(%f - %f) = %f > %f (tolerance) range percentage %f.\n",
|
||||||
|
reference, value, op_context.input->params.scale,
|
||||||
|
op_context.input->params.zero_point, reference, dequant, diff,
|
||||||
|
max_diff, op_data->tolerance);
|
||||||
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// If tolerance is small or log_if_failed is off, then we only care about
|
||||||
|
// statistics.
|
||||||
// These statistics logging was added to identify some errors in practice.
|
// These statistics logging was added to identify some errors in practice.
|
||||||
std::vector<double> diffs, temp;
|
std::vector<double> diffs, temp;
|
||||||
diffs.reserve(n);
|
diffs.reserve(n);
|
||||||
@ -191,6 +202,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
float dequant = GetTensorData<float>(dequantized)[i];
|
float dequant = GetTensorData<float>(dequantized)[i];
|
||||||
float reference = GetTensorData<float>(op_context.ref)[i];
|
float reference = GetTensorData<float>(op_context.ref)[i];
|
||||||
diffs[i] = static_cast<double>(dequant - reference);
|
diffs[i] = static_cast<double>(dequant - reference);
|
||||||
|
output_data[i] = dequant - reference;
|
||||||
}
|
}
|
||||||
double mean =
|
double mean =
|
||||||
std::accumulate(diffs.begin(), diffs.end(), 0.0) / diffs.size();
|
std::accumulate(diffs.begin(), diffs.end(), 0.0) / diffs.size();
|
||||||
@ -208,26 +220,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
"std: %f, mean: %f, max_diff: %f (scale: %f, zero_point: %d).\n", std,
|
"std: %f, mean: %f, max_diff: %f (scale: %f, zero_point: %d).\n", std,
|
||||||
mean, max_diff, op_context.input->params.scale,
|
mean, max_diff, op_context.input->params.scale,
|
||||||
op_context.input->params.zero_point);
|
op_context.input->params.zero_point);
|
||||||
return kTfLiteOk;
|
|
||||||
} else {
|
|
||||||
// Verify the dequantized output.
|
|
||||||
auto max_diff = op_data->tolerance * op_context.input->params.scale;
|
|
||||||
for (int i = 0; i < NumElements(op_context.ref); ++i) {
|
|
||||||
int32_t value = GetQuantizedValue(op_context, i);
|
|
||||||
float dequant = GetTensorData<float>(dequantized)[i];
|
|
||||||
float reference = GetTensorData<float>(op_context.ref)[i];
|
|
||||||
float diff = std::abs(reference - dequant);
|
|
||||||
if (diff > max_diff) {
|
|
||||||
TF_LITE_KERNEL_LOG(
|
|
||||||
context,
|
|
||||||
"Mismatch: %f is quantized to %d with (%f, %d). "
|
|
||||||
"abs(%f - %f) = %f > %f (tolerance) range percentage %f.\n",
|
|
||||||
reference, value, op_context.input->params.scale,
|
|
||||||
op_context.input->params.zero_point, reference, dequant, diff,
|
|
||||||
max_diff, op_data->tolerance);
|
|
||||||
return kTfLiteError;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -45,21 +45,19 @@ class NumericVerifyOpModel : public SingleOpModel {
|
|||||||
public:
|
public:
|
||||||
NumericVerifyOpModel(TensorType type, std::initializer_list<int> shape,
|
NumericVerifyOpModel(TensorType type, std::initializer_list<int> shape,
|
||||||
float scale, int32_t zero_point, int version,
|
float scale, int32_t zero_point, int version,
|
||||||
float tolerance = 5.0, bool debug_mode = false) {
|
float tolerance = 5.0, bool log_if_failed = true) {
|
||||||
const TensorData input_tensor_data = {type, shape, 0, 0, scale, zero_point};
|
const TensorData input_tensor_data = {type, shape, 0, 0, scale, zero_point};
|
||||||
input_ = AddInput(input_tensor_data);
|
input_ = AddInput(input_tensor_data);
|
||||||
ref_ = AddInput({TensorType_FLOAT32, shape});
|
ref_ = AddInput({TensorType_FLOAT32, shape});
|
||||||
if (debug_mode) {
|
|
||||||
// The output tensor has the same shape with that of the input tensor.
|
// The output tensor has the same shape with that of the input tensor.
|
||||||
output_ = AddOutput({TensorType_FLOAT32, shape});
|
output_ = AddOutput({TensorType_FLOAT32, shape});
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<uint8_t> custom_options(sizeof(float));
|
std::vector<uint8_t> custom_options(sizeof(float));
|
||||||
|
|
||||||
flexbuffers::Builder fbb;
|
flexbuffers::Builder fbb;
|
||||||
fbb.Map([&]() {
|
fbb.Map([&]() {
|
||||||
fbb.Float("tolerance", tolerance);
|
fbb.Float("tolerance", tolerance);
|
||||||
fbb.Bool("debug_mode", debug_mode);
|
fbb.Bool("log_if_failed", log_if_failed);
|
||||||
});
|
});
|
||||||
fbb.Finish();
|
fbb.Finish();
|
||||||
|
|
||||||
@ -135,7 +133,7 @@ TEST(NumericVerifyOpFailedTest, Int8) {
|
|||||||
|
|
||||||
TEST(NumericVerifyOpDebugModeTest, Int8) {
|
TEST(NumericVerifyOpDebugModeTest, Int8) {
|
||||||
// [-63.5, 64] -> scale=0.5, zero_point=1 for INT8
|
// [-63.5, 64] -> scale=0.5, zero_point=1 for INT8
|
||||||
NumericVerifyOpModel m(TensorType_INT8, {2, 5}, 0.5, -1, 2, 5.0, true);
|
NumericVerifyOpModel m(TensorType_INT8, {2, 5}, 0.5, -1, 2, 5.0, false);
|
||||||
|
|
||||||
// The 5th element is set to 0.
|
// The 5th element is set to 0.
|
||||||
m.SetInputs<int8_t>({-128, -127, -126, -125, -124, 0, 124, 125, 126, 127},
|
m.SetInputs<int8_t>({-128, -127, -126, -125, -124, 0, 124, 125, 126, 127},
|
||||||
|
|||||||
@ -126,7 +126,8 @@ class ConverterError(Exception):
|
|||||||
def mlir_quantize(input_data_str,
|
def mlir_quantize(input_data_str,
|
||||||
disable_per_channel=False,
|
disable_per_channel=False,
|
||||||
fully_quantize=False,
|
fully_quantize=False,
|
||||||
inference_type=_types_pb2.INT8):
|
inference_type=_types_pb2.INT8,
|
||||||
|
enable_numeric_verify=False):
|
||||||
"""Quantize `input_data_str` with calibration results.
|
"""Quantize `input_data_str` with calibration results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -137,6 +138,8 @@ def mlir_quantize(input_data_str,
|
|||||||
fully_quantize: Bool indicating whether to fully quantize the model. Besides
|
fully_quantize: Bool indicating whether to fully quantize the model. Besides
|
||||||
model body, the input/output will be quantized as well.
|
model body, the input/output will be quantized as well.
|
||||||
inference_type: Data type for the activations. The default value is int8.
|
inference_type: Data type for the activations. The default value is int8.
|
||||||
|
enable_numeric_verify: Experimental. Subject to change. Bool indicating
|
||||||
|
whether to add NumericVerify ops into the debug mode quantized model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Quantized model in serialized form (e.g. a TFLITE model) with floating-point
|
Quantized model in serialized form (e.g. a TFLITE model) with floating-point
|
||||||
@ -145,7 +148,8 @@ def mlir_quantize(input_data_str,
|
|||||||
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str,
|
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str,
|
||||||
disable_per_channel,
|
disable_per_channel,
|
||||||
fully_quantize,
|
fully_quantize,
|
||||||
inference_type)
|
inference_type,
|
||||||
|
enable_numeric_verify)
|
||||||
|
|
||||||
|
|
||||||
def mlir_sparsify(input_data_str):
|
def mlir_sparsify(input_data_str):
|
||||||
|
|||||||
@ -668,6 +668,82 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
|||||||
self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype)
|
self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype)
|
||||||
self.assertEqual(output_details[1]['dtype'], expected_dtype)
|
self.assertEqual(output_details[1]['dtype'], expected_dtype)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testNewQuantizerNumericVerificationDebugMode(self):
|
||||||
|
"""Test the model quantized by the new converter with numeric verify ops."""
|
||||||
|
func, calibration_gen = self._getIntegerQuantizeModel()
|
||||||
|
|
||||||
|
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
|
||||||
|
quantized_converter.target_spec.supported_ops = [
|
||||||
|
lite.OpsSet.TFLITE_BUILTINS_INT8
|
||||||
|
]
|
||||||
|
quantized_converter.representative_dataset = calibration_gen
|
||||||
|
|
||||||
|
# Create a TFLite model with new quantizer.
|
||||||
|
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||||
|
quantized_converter._experimental_new_quantizer = True
|
||||||
|
production_tflite = quantized_converter.convert()
|
||||||
|
# Create a TFLite model with new quantizer and numeric verify ops.
|
||||||
|
quantized_converter._experimental_calibrate_only = True
|
||||||
|
calibrated = quantized_converter.convert()
|
||||||
|
debug_mode_tflite = mlir_quantize(calibrated, enable_numeric_verify=True)
|
||||||
|
|
||||||
|
# Check if adding debug mode should output a different flatbuffer.
|
||||||
|
self.assertNotEqual(production_tflite, debug_mode_tflite)
|
||||||
|
|
||||||
|
# Check if newly added ops are numeric verify ops.
|
||||||
|
input_data = tf.constant(
|
||||||
|
np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32))
|
||||||
|
|
||||||
|
def examine_tflite_model(tflite_content, input_data):
|
||||||
|
interpreter = Interpreter(model_content=tflite_content)
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
input_details = interpreter.get_input_details()
|
||||||
|
interpreter.set_tensor(input_details[0]['index'], input_data.numpy())
|
||||||
|
interpreter.invoke()
|
||||||
|
tensor_details = interpreter.get_tensor_details()
|
||||||
|
return {
|
||||||
|
details['name']: interpreter.get_tensor(details['index'])
|
||||||
|
for details in interpreter.get_tensor_details()
|
||||||
|
}, tensor_details
|
||||||
|
|
||||||
|
tflite_result, _ = examine_tflite_model(production_tflite, input_data)
|
||||||
|
debug_mode_tflite_result, debug_tensor_details = examine_tflite_model(
|
||||||
|
debug_mode_tflite, input_data)
|
||||||
|
|
||||||
|
# MLIR-based quantizer should output flatbuffer model with `tfl.quantize`.
|
||||||
|
num_production_quantize_ops = len([
|
||||||
|
None for output_tensor_name in tflite_result
|
||||||
|
if 'tfl.quantize' in output_tensor_name
|
||||||
|
])
|
||||||
|
self.assertEqual(num_production_quantize_ops, 1)
|
||||||
|
# MLIR-based quantizer should output flatbuffer model with `tfl.quantize`.
|
||||||
|
num_debug_quantize_ops = len([
|
||||||
|
None for output_tensor_name in debug_mode_tflite_result
|
||||||
|
if 'tfl.quantize' in output_tensor_name
|
||||||
|
])
|
||||||
|
# Two numbers should be equal.
|
||||||
|
self.assertEqual(num_production_quantize_ops, num_debug_quantize_ops)
|
||||||
|
# DebugMode TFLite flatbuffer should have NumericVerifyOps more than zero.
|
||||||
|
# The name has the prefix "NumericVerify/{name}:{id}
|
||||||
|
# where {name} is the tensor name of the original quantized op's activation,
|
||||||
|
# and {id} is its tensor id.
|
||||||
|
num_debug_ops = 0
|
||||||
|
for output_tensor_name in debug_mode_tflite_result:
|
||||||
|
if 'NumericVerify' in output_tensor_name:
|
||||||
|
pos_end_prefix = len('NumericVerify/')
|
||||||
|
pos_colon = output_tensor_name.rfind(':')
|
||||||
|
self.assertEqual('NumericVerify/',
|
||||||
|
output_tensor_name[:pos_end_prefix])
|
||||||
|
tensor_id = int(output_tensor_name[pos_colon+1:])
|
||||||
|
original_tensor_name = output_tensor_name[pos_end_prefix:pos_colon]
|
||||||
|
self.assertEqual(original_tensor_name,
|
||||||
|
debug_tensor_details[tensor_id]['name'])
|
||||||
|
num_debug_ops += 1
|
||||||
|
self.assertEqual(num_debug_ops, 1)
|
||||||
|
# The number of debug ops should be equal to that of quantized ops.
|
||||||
|
self.assertEqual(num_debug_ops, num_debug_quantize_ops)
|
||||||
|
|
||||||
|
|
||||||
class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
||||||
|
|
||||||
|
|||||||
@ -44,12 +44,14 @@ def wrapped_get_potentially_supported_ops():
|
|||||||
|
|
||||||
|
|
||||||
def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel,
|
def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel,
|
||||||
fully_quantize, inference_type):
|
fully_quantize, inference_type,
|
||||||
|
enable_numeric_verify):
|
||||||
"""Wraps experimental mlir quantize model."""
|
"""Wraps experimental mlir quantize model."""
|
||||||
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str,
|
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str,
|
||||||
disable_per_channel,
|
disable_per_channel,
|
||||||
fully_quantize,
|
fully_quantize,
|
||||||
inference_type)
|
inference_type,
|
||||||
|
enable_numeric_verify)
|
||||||
|
|
||||||
|
|
||||||
def wrapped_experimental_mlir_sparsify(input_data_str):
|
def wrapped_experimental_mlir_sparsify(input_data_str):
|
||||||
|
|||||||
@ -236,7 +236,8 @@ PyObject* TocoGetPotentiallySupportedOps() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
||||||
bool fully_quantize, int inference_type) {
|
bool fully_quantize, int inference_type,
|
||||||
|
bool enable_numeric_verify) {
|
||||||
using tflite::interpreter_wrapper::PythonErrorReporter;
|
using tflite::interpreter_wrapper::PythonErrorReporter;
|
||||||
char* buf = nullptr;
|
char* buf = nullptr;
|
||||||
Py_ssize_t length;
|
Py_ssize_t length;
|
||||||
@ -276,7 +277,7 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
|||||||
auto status = mlir::lite::QuantizeModel(
|
auto status = mlir::lite::QuantizeModel(
|
||||||
*tflite_model, inference_io_type, inference_io_type,
|
*tflite_model, inference_io_type, inference_io_type,
|
||||||
inference_tensor_type, {}, disable_per_channel, fully_quantize, &builder,
|
inference_tensor_type, {}, disable_per_channel, fully_quantize, &builder,
|
||||||
error_reporter.get());
|
error_reporter.get(), enable_numeric_verify);
|
||||||
|
|
||||||
if (status != kTfLiteOk) {
|
if (status != kTfLiteOk) {
|
||||||
error_reporter->exception();
|
error_reporter->exception();
|
||||||
|
|||||||
@ -44,7 +44,8 @@ PyObject* TocoGetPotentiallySupportedOps();
|
|||||||
// is specified by the calibration data are not sufficient to quantize the
|
// is specified by the calibration data are not sufficient to quantize the
|
||||||
// model.
|
// model.
|
||||||
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
||||||
bool fully_quantize, int inference_type);
|
bool fully_quantize, int inference_type,
|
||||||
|
bool enable_numeric_verify = false);
|
||||||
|
|
||||||
// Sparsifies model to encode sparse tensors with proper format. Throws error if
|
// Sparsifies model to encode sparse tensors with proper format. Throws error if
|
||||||
// sparsification fails.
|
// sparsification fails.
|
||||||
|
|||||||
@ -57,13 +57,14 @@ PYBIND11_MODULE(_pywrap_toco_api, m) {
|
|||||||
m.def(
|
m.def(
|
||||||
"ExperimentalMlirQuantizeModel",
|
"ExperimentalMlirQuantizeModel",
|
||||||
[](py::object input_contents_txt_raw, bool disable_per_channel,
|
[](py::object input_contents_txt_raw, bool disable_per_channel,
|
||||||
bool fully_quantize, int inference_type) {
|
bool fully_quantize, int inference_type, bool enable_numeric_verify) {
|
||||||
return tensorflow::PyoOrThrow(toco::MlirQuantizeModel(
|
return tensorflow::PyoOrThrow(toco::MlirQuantizeModel(
|
||||||
input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize,
|
input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize,
|
||||||
inference_type));
|
inference_type, enable_numeric_verify));
|
||||||
},
|
},
|
||||||
py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false,
|
py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false,
|
||||||
py::arg("fully_quantize") = true, py::arg("inference_type") = 9,
|
py::arg("fully_quantize") = true, py::arg("inference_type") = 9,
|
||||||
|
py::arg("enable_numeric_verify") = false,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Returns a quantized model.
|
Returns a quantized model.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user