Add enable_numeric_verify in TFLite converter for quantization numeric verification.

This CL adds experimental enable_numeric_verify in TFLite converter.mlir_quantize() for quantization numeric verification. This pass adds NumericVerify ops for debugging during full-integer quantization, and NumericVerify op will output the differences between the original float & the dequantized values of the quantized values. The main characteristic of NumericVerify op is the output tensor name has the following format:

  NumericVerify/{tensor_name}:{tensor_id}

where tensor_name and tensor_id are from the original quantized op's activation tensor (the first input tensor of the NumericVerify op) for the purpose of debugging. It provides users the easy way to search for tensors in the debugging tools.

If you turn on log_if_failed in NumericVerify op, then it will output logs & throw errors when there exist any errors > tolerance.

PiperOrigin-RevId: 348740984
Change-Id: Ifb965d2237e827a747d1888cbd21af311d42e4a1
This commit is contained in:
Jae H. Yoo 2020-12-22 21:40:53 -08:00 committed by TensorFlower Gardener
parent 917ebe0008
commit ca3e7f55d9
19 changed files with 237 additions and 91 deletions

View File

@ -944,7 +944,22 @@ BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
float tolerance = op.tolerance().convertToFloat();
return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results);
bool log_if_failed = op.log_if_failed();
auto fbb = absl::make_unique<flexbuffers::Builder>();
fbb->Map([&]() {
fbb->Float("tolerance", tolerance);
fbb->Bool("log_if_failed", log_if_failed);
});
fbb->Finish();
auto f = std::unique_ptr<flexbuffers::Builder>(fbb.release());
auto custom_option = f->GetBuffer();
auto opcode_index =
GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM);
return tflite::CreateOperator(
builder_, opcode_index, builder_.CreateVector(operands),
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
/*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_option),
tflite::CustomOptionsFormat_FLEXBUFFERS);
}
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
@ -1408,6 +1423,17 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
for (auto val : inst.getResults()) {
std::string name = UniqueName(val);
// For "tfl.numeric_verify" op, the name is used to find out the original
// activation tensor rather than its own unique name in the visualization
// or debugging tools.
auto builtin_code = GetBuiltinOpCode(&inst);
if (!builtin_code && dyn_cast<mlir::TFL::NumericVerifyOp>(&inst)) {
// The first operand is the quantized activation, the target of this
// NumericVerify op.
auto quantized_op_val = inst.getOperands().front();
name = "NumericVerify/" + UniqueName(quantized_op_val) + ":" +
std::to_string(tensor_index_map[quantized_op_val]);
}
if (!build_tensor_and_buffer(val, name)) return llvm::None;
}

View File

@ -4257,6 +4257,11 @@ def TFL_NumericVerifyOp : Op<TFL_Dialect, "NumericVerify", [
let description = [{
The NumericVerify op is a debugging op to verify the numericals of the two
activations. It is a custom op in TFLite.
If log_if_failed is true, the NumericVerify op calculates statistics on
differences between float and quantized activations, output
logs, set differences to the output tensors, and throws an error if errors
above tolerance exist. If log_if_failed = false, then it doesn't care about
errors.
}];
let arguments = (ins
@ -4264,10 +4269,11 @@ def TFL_NumericVerifyOp : Op<TFL_Dialect, "NumericVerify", [
TFL_TensorOf<[F32]>:$ref,
// Attributes
DefaultValuedAttr<F32Attr, "0.1">:$tolerance
DefaultValuedAttr<F32Attr, "0.1">:$tolerance,
DefaultValuedAttr<BoolAttr, "false">:$log_if_failed
);
let results = (outs);
let results = (outs TFL_FpTensor:$output);
}
// SVDF op.

View File

@ -44,7 +44,7 @@ TfLiteStatus QuantizeModel(
const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter) {
tflite::ErrorReporter* error_reporter, bool verify_numeric) {
// TODO(b/142502494): remove this restriction by improving the `emit_adaptor`
// flag
if (input_type != output_type) {
@ -91,8 +91,10 @@ TfLiteStatus QuantizeModel(
quant_specs.inference_type = input_tf_type;
}
quant_specs.verify_numeric = verify_numeric;
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
pm.addPass(TFL::CreateQuantizePass());
pm.addPass(TFL::CreateQuantizePass(verify_numeric));
pm.addPass(TFL::CreatePostQuantizePass(emit_adaptor));
if (failed(pm.run(module.get()))) {

View File

@ -36,7 +36,7 @@ TfLiteStatus QuantizeModel(
const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter);
tflite::ErrorReporter* error_reporter, bool verify_numeric = false);
} // namespace lite
} // namespace mlir

View File

@ -124,6 +124,10 @@ struct QuantizationSpecs {
return 0;
}
}
// Whether add the NumericVerify ops to verify numbers before and after
// quantization.
bool verify_numeric = false;
};
// Parses the command line flag strings to the quantization specification for

View File

@ -175,12 +175,14 @@ struct QuantizationPattern : public RewritePattern {
using BaseType = QuantizationPattern<ConcretTy, Q, DQ, VERIFIER>;
explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
float error_tolerance, bool single_layer_verify)
float error_tolerance, bool single_layer_verify,
bool log_if_failed = false)
// Set the score to a large number so it is always preferred.
: RewritePattern(DQ::getOperationName(), 300, context),
enable_verify(enable_verify),
error_tolerance(error_tolerance),
single_layer_verify(single_layer_verify) {}
single_layer_verify(single_layer_verify),
log_if_failed(log_if_failed) {}
LogicalResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override {
@ -312,10 +314,11 @@ struct QuantizationPattern : public RewritePattern {
}
rewriter.setInsertionPointAfter(new_op);
FloatAttr tolerance = rewriter.getF32FloatAttr(error_tolerance);
BoolAttr log = rewriter.getBoolAttr(log_if_failed);
// Verify the quantized value by sending the result to the verifier.
rewriter.create<VERIFIER>(quantized_op->getLoc(),
new_op->getResult(i),
quantized_op->getResult(i), tolerance);
rewriter.create<VERIFIER>(
quantized_op->getLoc(), new_op->getResult(i).getType(),
new_op->getResult(i), quantized_op->getResult(i), tolerance, log);
if (single_layer_verify) continue;
@ -341,6 +344,7 @@ struct QuantizationPattern : public RewritePattern {
bool enable_verify;
float error_tolerance;
bool single_layer_verify;
bool log_if_failed;
};
// Converts quantize ops with unsigned quantized types to these with signed

View File

@ -24,13 +24,20 @@
// CHECK-NEXT: scale: [ 0.1 ],
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "NumericVerify/arg1:1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 0 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 1, 0 ],
// CHECK-NEXT: outputs: [ ],
// CHECK-NEXT: custom_options: [ 205, 204, 204, 61 ]
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: custom_options:
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
@ -42,16 +49,18 @@
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 3
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
func @main(%arg0: tensor<4xf32>, %arg1: tensor<4x!quant.uniform<u8:f32, 0.1>>) -> tensor<4xf32> {
"tfl.NumericVerify"(%arg1, %arg0) {tolerance = 0.1 : f32} : (tensor<4x!quant.uniform<u8:f32, 0.1>>, tensor<4xf32>) -> ()
"tfl.NumericVerify"(%arg1, %arg0) {tolerance = 0.1 : f32} : (tensor<4x!quant.uniform<u8:f32, 0.1>>, tensor<4xf32>) -> (tensor<4xf32>)
return %arg0 : tensor<4xf32>
}

View File

@ -1,5 +1,5 @@
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize | FileCheck %s
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify | FileCheck --check-prefix=DEBUG %s
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify -tfl-log-if-failed | FileCheck --check-prefix=DEBUG %s
// CHECK-LABEL: QuantizeFloatConst
func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
@ -76,7 +76,7 @@ func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>
// DEBUG: %[[act:.*]] = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
// DEBUG: %[[f_conv:.*]] = "tfl.conv_2d"(%[[act]], %[[wt]], %[[bias]])
// DEBUG: %[[q_conv:.*]] = "tfl.conv_2d"
// DEBUG: "tfl.NumericVerify"(%[[q_conv]], %[[f_conv]]) {tolerance = 5.000000e+00 : f32}
// DEBUG: "tfl.NumericVerify"(%[[q_conv]], %[[f_conv]]) {log_if_failed = true, tolerance = 5.000000e+00 : f32}
// DEBUG: return %[[q_conv]] : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
}

View File

@ -54,7 +54,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
quant_specs.default_ranges.second.getValueOr(0.0),
quant_specs.IsSignedInferenceType()));
}
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateQuantizePass());
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateQuantizePass(quant_specs.verify_numeric));
bool emit_quant_adaptor_ops =
quant_specs.inference_type != quant_specs.inference_input_type;
pass_manager->addNestedPass<mlir::FuncOp>(

View File

@ -48,7 +48,8 @@ std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass();
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass();
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass(
bool verify_numeric = false);
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(

View File

@ -55,6 +55,13 @@ static llvm::cl::opt<bool> enable_single_layer_verify(
"`-tfl-numeric-verify` is set."),
llvm::cl::init(true));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> enable_log_if_failed(
"tfl-log-if-failed", llvm::cl::value_desc("bool"),
llvm::cl::desc("Whether verify numericals with thresholding "
"tolerance. Valid when `-tfl-numeric-verify` is set."),
llvm::cl::init(false));
namespace mlir {
namespace TFL {
@ -67,16 +74,26 @@ namespace {
struct TFLFullQuantization
: public quant::QuantizationPattern<TFLFullQuantization, QuantizeOp,
DequantizeOp, NumericVerifyOp> {
explicit TFLFullQuantization(MLIRContext* ctx, bool verify_numeric,
float tolerance, bool verify_single_layer)
: BaseType(ctx, verify_numeric, tolerance, verify_single_layer) {}
explicit TFLFullQuantization(MLIRContext* ctx, bool verify_numeric_flag,
float tolerance, bool verify_single_layer,
bool log_if_failed_flag = false)
: BaseType(ctx, verify_numeric_flag, tolerance, verify_single_layer,
log_if_failed_flag) {}
static bool AllowHybridOperand() { return false; }
static bool AllowHybridResult() { return false; }
};
// Applies quantization on the model in TFL dialect.
struct QuantizePass : public PassWrapper<QuantizePass, FunctionPass> {
public:
// Constructor used by manually creating the pass.
explicit QuantizePass(bool verify_numeric_flag = false)
: verify_numeric(verify_numeric_flag) {}
void runOnFunction() override;
private:
bool verify_numeric;
};
#include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc"
@ -87,14 +104,15 @@ void QuantizePass::runOnFunction() {
auto* ctx = func.getContext();
TFL::populateWithGenerated(ctx, patterns);
patterns.insert<TFLFullQuantization>(
ctx, enable_numeric_verify, error_tolerance, enable_single_layer_verify);
ctx, enable_numeric_verify || verify_numeric, error_tolerance,
enable_single_layer_verify, enable_log_if_failed);
applyPatternsAndFoldGreedily(func, std::move(patterns));
}
} // namespace
// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass() {
return std::make_unique<QuantizePass>();
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass(bool verify_numeric) {
return std::make_unique<QuantizePass>(verify_numeric);
}
static PassRegistration<QuantizePass> pass(

View File

@ -38,8 +38,9 @@ namespace custom {
namespace numeric_verify {
static constexpr const char kToleranceStr[] = "tolerance";
static constexpr const char kDebugModeStr[] = "debug_mode";
static constexpr const char kLogIfFailedStr[] = "log_if_failed";
static constexpr const int kTemporaryDequantizedTensor = 0;
static constexpr const int kOutputTensor = 0;
struct OpContext {
OpContext(TfLiteContext* context, TfLiteNode* node) {
@ -61,7 +62,7 @@ struct OpData {
bool float_input_initialized;
int cache_tensor_id = kTensorNotAllocated;
// This boolean value is for controlling the behavior of numeric verify op.
bool debug_mode;
bool log_if_failed;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@ -71,9 +72,9 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
const float tolerance = m[kToleranceStr].AsFloat();
const bool debug_mode = m[kDebugModeStr].AsBool();
const bool log_if_failed = m[kLogIfFailedStr].AsBool();
op_data->tolerance = tolerance;
op_data->debug_mode = debug_mode;
op_data->log_if_failed = log_if_failed;
return op_data;
}
@ -84,13 +85,11 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
OpContext op_context(context, node);
const int num_output = (op_data->debug_mode) ? 1 : 0;
TF_LITE_ENSURE_EQ(context, NumOutputs(node), num_output);
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
op_context.input->type == kTfLiteInt8 ||
op_context.input->type == kTfLiteInt16 ||
@ -118,15 +117,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, dequantized,
TfLiteIntArrayCopy(op_context.input->dims)));
if (op_data->debug_mode) {
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, num_output - 1,
&op_context.output));
op_context.output->type = kTfLiteFloat32;
op_context.output->allocation_type = kTfLiteArenaRwPersistent;
return context->ResizeTensor(context, op_context.output,
TfLiteIntArrayCopy(op_context.input->dims));
}
return kTfLiteOk;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputTensor, &op_context.output));
op_context.output->type = kTfLiteFloat32;
op_context.output->allocation_type = kTfLiteArenaRwPersistent;
return context->ResizeTensor(context, op_context.output,
TfLiteIntArrayCopy(op_context.input->dims));
}
static int32_t GetQuantizedValue(const OpContext& op_context, int index) {
@ -165,22 +161,37 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
op_data->float_input_initialized = true;
}
// If the debug_mode is on, we don't throw any errors.
// We just calculate difference between float and quantized values, letting
// python debugger deal with the information.
if (op_data->debug_mode || op_data->tolerance < 0.1) {
const int num_output = (op_data->debug_mode) ? 1 : 0;
const int n = NumElements(dequantized);
if (op_data->debug_mode) {
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, num_output - 1,
&op_context.output));
auto output_data = GetTensorData<float>(op_context.output);
for (int i = 0; i < n; ++i) {
float dequant = GetTensorData<float>(dequantized)[i];
float reference = GetTensorData<float>(op_context.ref)[i];
output_data[i] = dequant - reference;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputTensor, &op_context.output));
auto output_data = GetTensorData<float>(op_context.output);
// If log_if_failed is on, calculate differences between float and
// quantized values, their statistics and output logs.
// Throw errors if any diff greater than tolerance exists.
const int n = NumElements(dequantized);
if (op_data->log_if_failed && op_data->tolerance >= 0.1) {
// Verify the dequantized output.
auto max_diff = op_data->tolerance * op_context.input->params.scale;
for (int i = 0; i < n; ++i) {
int32_t value = GetQuantizedValue(op_context, i);
float dequant = GetTensorData<float>(dequantized)[i];
float reference = GetTensorData<float>(op_context.ref)[i];
output_data[i] = dequant - reference;
float diff = std::abs(output_data[i]);
if (diff > max_diff) {
TF_LITE_KERNEL_LOG(
context,
"Mismatch: %f is quantized to %d with (%f, %d). "
"abs(%f - %f) = %f > %f (tolerance) range percentage %f.\n",
reference, value, op_context.input->params.scale,
op_context.input->params.zero_point, reference, dequant, diff,
max_diff, op_data->tolerance);
return kTfLiteError;
}
}
} else {
// If tolerance is small or log_if_failed is off, then we only care about
// statistics.
// These statistics logging was added to identify some errors in practice.
std::vector<double> diffs, temp;
diffs.reserve(n);
@ -191,6 +202,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
float dequant = GetTensorData<float>(dequantized)[i];
float reference = GetTensorData<float>(op_context.ref)[i];
diffs[i] = static_cast<double>(dequant - reference);
output_data[i] = dequant - reference;
}
double mean =
std::accumulate(diffs.begin(), diffs.end(), 0.0) / diffs.size();
@ -208,26 +220,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
"std: %f, mean: %f, max_diff: %f (scale: %f, zero_point: %d).\n", std,
mean, max_diff, op_context.input->params.scale,
op_context.input->params.zero_point);
return kTfLiteOk;
} else {
// Verify the dequantized output.
auto max_diff = op_data->tolerance * op_context.input->params.scale;
for (int i = 0; i < NumElements(op_context.ref); ++i) {
int32_t value = GetQuantizedValue(op_context, i);
float dequant = GetTensorData<float>(dequantized)[i];
float reference = GetTensorData<float>(op_context.ref)[i];
float diff = std::abs(reference - dequant);
if (diff > max_diff) {
TF_LITE_KERNEL_LOG(
context,
"Mismatch: %f is quantized to %d with (%f, %d). "
"abs(%f - %f) = %f > %f (tolerance) range percentage %f.\n",
reference, value, op_context.input->params.scale,
op_context.input->params.zero_point, reference, dequant, diff,
max_diff, op_data->tolerance);
return kTfLiteError;
}
}
}
return kTfLiteOk;
}

View File

@ -45,21 +45,19 @@ class NumericVerifyOpModel : public SingleOpModel {
public:
NumericVerifyOpModel(TensorType type, std::initializer_list<int> shape,
float scale, int32_t zero_point, int version,
float tolerance = 5.0, bool debug_mode = false) {
float tolerance = 5.0, bool log_if_failed = true) {
const TensorData input_tensor_data = {type, shape, 0, 0, scale, zero_point};
input_ = AddInput(input_tensor_data);
ref_ = AddInput({TensorType_FLOAT32, shape});
if (debug_mode) {
// The output tensor has the same shape with that of the input tensor.
output_ = AddOutput({TensorType_FLOAT32, shape});
}
// The output tensor has the same shape with that of the input tensor.
output_ = AddOutput({TensorType_FLOAT32, shape});
std::vector<uint8_t> custom_options(sizeof(float));
flexbuffers::Builder fbb;
fbb.Map([&]() {
fbb.Float("tolerance", tolerance);
fbb.Bool("debug_mode", debug_mode);
fbb.Bool("log_if_failed", log_if_failed);
});
fbb.Finish();
@ -135,7 +133,7 @@ TEST(NumericVerifyOpFailedTest, Int8) {
TEST(NumericVerifyOpDebugModeTest, Int8) {
// [-63.5, 64] -> scale=0.5, zero_point=1 for INT8
NumericVerifyOpModel m(TensorType_INT8, {2, 5}, 0.5, -1, 2, 5.0, true);
NumericVerifyOpModel m(TensorType_INT8, {2, 5}, 0.5, -1, 2, 5.0, false);
// The 5th element is set to 0.
m.SetInputs<int8_t>({-128, -127, -126, -125, -124, 0, 124, 125, 126, 127},

View File

@ -126,7 +126,8 @@ class ConverterError(Exception):
def mlir_quantize(input_data_str,
disable_per_channel=False,
fully_quantize=False,
inference_type=_types_pb2.INT8):
inference_type=_types_pb2.INT8,
enable_numeric_verify=False):
"""Quantize `input_data_str` with calibration results.
Args:
@ -137,6 +138,8 @@ def mlir_quantize(input_data_str,
fully_quantize: Bool indicating whether to fully quantize the model. Besides
model body, the input/output will be quantized as well.
inference_type: Data type for the activations. The default value is int8.
enable_numeric_verify: Experimental. Subject to change. Bool indicating
whether to add NumericVerify ops into the debug mode quantized model.
Returns:
Quantized model in serialized form (e.g. a TFLITE model) with floating-point
@ -145,7 +148,8 @@ def mlir_quantize(input_data_str,
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str,
disable_per_channel,
fully_quantize,
inference_type)
inference_type,
enable_numeric_verify)
def mlir_sparsify(input_data_str):

View File

@ -668,6 +668,82 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype)
self.assertEqual(output_details[1]['dtype'], expected_dtype)
@test_util.run_v2_only
def testNewQuantizerNumericVerificationDebugMode(self):
"""Test the model quantized by the new converter with numeric verify ops."""
func, calibration_gen = self._getIntegerQuantizeModel()
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
quantized_converter.target_spec.supported_ops = [
lite.OpsSet.TFLITE_BUILTINS_INT8
]
quantized_converter.representative_dataset = calibration_gen
# Create a TFLite model with new quantizer.
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter._experimental_new_quantizer = True
production_tflite = quantized_converter.convert()
# Create a TFLite model with new quantizer and numeric verify ops.
quantized_converter._experimental_calibrate_only = True
calibrated = quantized_converter.convert()
debug_mode_tflite = mlir_quantize(calibrated, enable_numeric_verify=True)
# Check if adding debug mode should output a different flatbuffer.
self.assertNotEqual(production_tflite, debug_mode_tflite)
# Check if newly added ops are numeric verify ops.
input_data = tf.constant(
np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32))
def examine_tflite_model(tflite_content, input_data):
interpreter = Interpreter(model_content=tflite_content)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
interpreter.set_tensor(input_details[0]['index'], input_data.numpy())
interpreter.invoke()
tensor_details = interpreter.get_tensor_details()
return {
details['name']: interpreter.get_tensor(details['index'])
for details in interpreter.get_tensor_details()
}, tensor_details
tflite_result, _ = examine_tflite_model(production_tflite, input_data)
debug_mode_tflite_result, debug_tensor_details = examine_tflite_model(
debug_mode_tflite, input_data)
# MLIR-based quantizer should output flatbuffer model with `tfl.quantize`.
num_production_quantize_ops = len([
None for output_tensor_name in tflite_result
if 'tfl.quantize' in output_tensor_name
])
self.assertEqual(num_production_quantize_ops, 1)
# MLIR-based quantizer should output flatbuffer model with `tfl.quantize`.
num_debug_quantize_ops = len([
None for output_tensor_name in debug_mode_tflite_result
if 'tfl.quantize' in output_tensor_name
])
# Two numbers should be equal.
self.assertEqual(num_production_quantize_ops, num_debug_quantize_ops)
# DebugMode TFLite flatbuffer should have NumericVerifyOps more than zero.
# The name has the prefix "NumericVerify/{name}:{id}
# where {name} is the tensor name of the original quantized op's activation,
# and {id} is its tensor id.
num_debug_ops = 0
for output_tensor_name in debug_mode_tflite_result:
if 'NumericVerify' in output_tensor_name:
pos_end_prefix = len('NumericVerify/')
pos_colon = output_tensor_name.rfind(':')
self.assertEqual('NumericVerify/',
output_tensor_name[:pos_end_prefix])
tensor_id = int(output_tensor_name[pos_colon+1:])
original_tensor_name = output_tensor_name[pos_end_prefix:pos_colon]
self.assertEqual(original_tensor_name,
debug_tensor_details[tensor_id]['name'])
num_debug_ops += 1
self.assertEqual(num_debug_ops, 1)
# The number of debug ops should be equal to that of quantized ops.
self.assertEqual(num_debug_ops, num_debug_quantize_ops)
class FromSavedModelTest(lite_v2_test_util.ModelTest):

View File

@ -44,12 +44,14 @@ def wrapped_get_potentially_supported_ops():
def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel,
fully_quantize, inference_type):
fully_quantize, inference_type,
enable_numeric_verify):
"""Wraps experimental mlir quantize model."""
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str,
disable_per_channel,
fully_quantize,
inference_type)
inference_type,
enable_numeric_verify)
def wrapped_experimental_mlir_sparsify(input_data_str):

View File

@ -236,7 +236,8 @@ PyObject* TocoGetPotentiallySupportedOps() {
}
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
bool fully_quantize, int inference_type) {
bool fully_quantize, int inference_type,
bool enable_numeric_verify) {
using tflite::interpreter_wrapper::PythonErrorReporter;
char* buf = nullptr;
Py_ssize_t length;
@ -276,7 +277,7 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
auto status = mlir::lite::QuantizeModel(
*tflite_model, inference_io_type, inference_io_type,
inference_tensor_type, {}, disable_per_channel, fully_quantize, &builder,
error_reporter.get());
error_reporter.get(), enable_numeric_verify);
if (status != kTfLiteOk) {
error_reporter->exception();

View File

@ -44,7 +44,8 @@ PyObject* TocoGetPotentiallySupportedOps();
// is specified by the calibration data are not sufficient to quantize the
// model.
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
bool fully_quantize, int inference_type);
bool fully_quantize, int inference_type,
bool enable_numeric_verify = false);
// Sparsifies model to encode sparse tensors with proper format. Throws error if
// sparsification fails.

View File

@ -57,13 +57,14 @@ PYBIND11_MODULE(_pywrap_toco_api, m) {
m.def(
"ExperimentalMlirQuantizeModel",
[](py::object input_contents_txt_raw, bool disable_per_channel,
bool fully_quantize, int inference_type) {
bool fully_quantize, int inference_type, bool enable_numeric_verify) {
return tensorflow::PyoOrThrow(toco::MlirQuantizeModel(
input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize,
inference_type));
inference_type, enable_numeric_verify));
},
py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false,
py::arg("fully_quantize") = true, py::arg("inference_type") = 9,
py::arg("enable_numeric_verify") = false,
R"pbdoc(
Returns a quantized model.
)pbdoc");