From ca3e7f55d9bb5aac89d6cdae3308d9f229a0f0db Mon Sep 17 00:00:00 2001 From: "Jae H. Yoo" Date: Tue, 22 Dec 2020 21:40:53 -0800 Subject: [PATCH] Add enable_numeric_verify in TFLite converter for quantization numeric verification. This CL adds experimental enable_numeric_verify in TFLite converter.mlir_quantize() for quantization numeric verification. This pass adds NumericVerify ops for debugging during full-integer quantization, and NumericVerify op will output the differences between the original float & the dequantized values of the quantized values. The main characteristic of NumericVerify op is the output tensor name has the following format: NumericVerify/{tensor_name}:{tensor_id} where tensor_name and tensor_id are from the original quantized op's activation tensor (the first input tensor of the NumericVerify op) for the purpose of debugging. It provides users the easy way to search for tensors in the debugging tools. If you turn on log_if_failed in NumericVerify op, then it will output logs & throw errors when there exist any errors > tolerance. PiperOrigin-RevId: 348740984 Change-Id: Ifb965d2237e827a747d1888cbd21af311d42e4a1 --- .../compiler/mlir/lite/flatbuffer_export.cc | 28 +++++- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 10 +- .../lite/quantization/lite/quantize_model.cc | 6 +- .../lite/quantization/lite/quantize_model.h | 2 +- .../lite/quantization/quantization_config.h | 4 + .../lite/quantization/quantization_utils.h | 14 ++- .../tests/mlir2flatbuffer/numeric_verify.mlir | 17 +++- .../compiler/mlir/lite/tests/quantize.mlir | 4 +- .../compiler/mlir/lite/tf_tfl_passes.cc | 3 +- .../compiler/mlir/lite/transforms/passes.h | 3 +- .../compiler/mlir/lite/transforms/quantize.cc | 30 ++++-- tensorflow/lite/kernels/numeric_verify.cc | 92 +++++++++---------- .../lite/kernels/numeric_verify_test.cc | 12 +-- tensorflow/lite/python/convert.py | 8 +- tensorflow/lite/python/lite_v2_test.py | 76 +++++++++++++++ tensorflow/lite/python/wrap_toco.py | 6 +- .../lite/toco/python/toco_python_api.cc | 5 +- tensorflow/lite/toco/python/toco_python_api.h | 3 +- .../python/lite/toco_python_api_wrapper.cc | 5 +- 19 files changed, 237 insertions(+), 91 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 5dc92382051..6192c988fd3 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -944,7 +944,22 @@ BufferOffset Translator::BuildNumericVerifyOperator( mlir::TFL::NumericVerifyOp op, const std::vector& operands, const std::vector& results) { float tolerance = op.tolerance().convertToFloat(); - return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results); + bool log_if_failed = op.log_if_failed(); + auto fbb = absl::make_unique(); + fbb->Map([&]() { + fbb->Float("tolerance", tolerance); + fbb->Bool("log_if_failed", log_if_failed); + }); + fbb->Finish(); + auto f = std::unique_ptr(fbb.release()); + auto custom_option = f->GetBuffer(); + auto opcode_index = + GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM); + return tflite::CreateOperator( + builder_, opcode_index, builder_.CreateVector(operands), + builder_.CreateVector(results), tflite::BuiltinOptions_NONE, + /*builtin_options=*/0, builder_.CreateVector(custom_option), + tflite::CustomOptionsFormat_FLEXBUFFERS); } BufferOffset Translator::BuildCustomOperator( @@ -1408,6 +1423,17 @@ Optional> Translator::BuildSubGraph( for (auto val : inst.getResults()) { std::string name = UniqueName(val); + // For "tfl.numeric_verify" op, the name is used to find out the original + // activation tensor rather than its own unique name in the visualization + // or debugging tools. + auto builtin_code = GetBuiltinOpCode(&inst); + if (!builtin_code && dyn_cast(&inst)) { + // The first operand is the quantized activation, the target of this + // NumericVerify op. + auto quantized_op_val = inst.getOperands().front(); + name = "NumericVerify/" + UniqueName(quantized_op_val) + ":" + + std::to_string(tensor_index_map[quantized_op_val]); + } if (!build_tensor_and_buffer(val, name)) return llvm::None; } diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 161c2e03ebd..25514345e65 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -4257,6 +4257,11 @@ def TFL_NumericVerifyOp : Op:$ref, // Attributes - DefaultValuedAttr:$tolerance + DefaultValuedAttr:$tolerance, + DefaultValuedAttr:$log_if_failed ); - let results = (outs); + let results = (outs TFL_FpTensor:$output); } // SVDF op. diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 15f1e7b2516..8b99c1d58e8 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -44,7 +44,7 @@ TfLiteStatus QuantizeModel( const std::unordered_set& operator_names, bool disable_per_channel, bool fully_quantize, flatbuffers::FlatBufferBuilder* builder, - tflite::ErrorReporter* error_reporter) { + tflite::ErrorReporter* error_reporter, bool verify_numeric) { // TODO(b/142502494): remove this restriction by improving the `emit_adaptor` // flag if (input_type != output_type) { @@ -91,8 +91,10 @@ TfLiteStatus QuantizeModel( quant_specs.inference_type = input_tf_type; } + quant_specs.verify_numeric = verify_numeric; + pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs)); - pm.addPass(TFL::CreateQuantizePass()); + pm.addPass(TFL::CreateQuantizePass(verify_numeric)); pm.addPass(TFL::CreatePostQuantizePass(emit_adaptor)); if (failed(pm.run(module.get()))) { diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index d60df56b473..50f41cc477e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -36,7 +36,7 @@ TfLiteStatus QuantizeModel( const std::unordered_set& operator_names, bool disable_per_channel, bool fully_quantize, flatbuffers::FlatBufferBuilder* builder, - tflite::ErrorReporter* error_reporter); + tflite::ErrorReporter* error_reporter, bool verify_numeric = false); } // namespace lite } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h index 0e2f4906a7a..50ddc4306c8 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h @@ -124,6 +124,10 @@ struct QuantizationSpecs { return 0; } } + + // Whether add the NumericVerify ops to verify numbers before and after + // quantization. + bool verify_numeric = false; }; // Parses the command line flag strings to the quantization specification for diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 12e1dc6ba74..0ee01b5ad45 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -175,12 +175,14 @@ struct QuantizationPattern : public RewritePattern { using BaseType = QuantizationPattern; explicit QuantizationPattern(MLIRContext* context, bool enable_verify, - float error_tolerance, bool single_layer_verify) + float error_tolerance, bool single_layer_verify, + bool log_if_failed = false) // Set the score to a large number so it is always preferred. : RewritePattern(DQ::getOperationName(), 300, context), enable_verify(enable_verify), error_tolerance(error_tolerance), - single_layer_verify(single_layer_verify) {} + single_layer_verify(single_layer_verify), + log_if_failed(log_if_failed) {} LogicalResult matchAndRewrite(Operation* op, PatternRewriter& rewriter) const override { @@ -312,10 +314,11 @@ struct QuantizationPattern : public RewritePattern { } rewriter.setInsertionPointAfter(new_op); FloatAttr tolerance = rewriter.getF32FloatAttr(error_tolerance); + BoolAttr log = rewriter.getBoolAttr(log_if_failed); // Verify the quantized value by sending the result to the verifier. - rewriter.create(quantized_op->getLoc(), - new_op->getResult(i), - quantized_op->getResult(i), tolerance); + rewriter.create( + quantized_op->getLoc(), new_op->getResult(i).getType(), + new_op->getResult(i), quantized_op->getResult(i), tolerance, log); if (single_layer_verify) continue; @@ -341,6 +344,7 @@ struct QuantizationPattern : public RewritePattern { bool enable_verify; float error_tolerance; bool single_layer_verify; + bool log_if_failed; }; // Converts quantize ops with unsigned quantized types to these with signed diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir index 60360c7ded6..f97959b1564 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir @@ -24,13 +24,20 @@ // CHECK-NEXT: scale: [ 0.1 ], // CHECK-NEXT: zero_point: [ 0 ] // CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "NumericVerify/arg1:1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } // CHECK-NEXT: } ], // CHECK-NEXT: inputs: [ 0, 1 ], // CHECK-NEXT: outputs: [ 0 ], // CHECK-NEXT: operators: [ { // CHECK-NEXT: inputs: [ 1, 0 ], -// CHECK-NEXT: outputs: [ ], -// CHECK-NEXT: custom_options: [ 205, 204, 204, 61 ] +// CHECK-NEXT: outputs: [ 2 ], +// CHECK-NEXT: custom_options: // CHECK-NEXT: } ], // CHECK-NEXT: name: "main" // CHECK-NEXT: } ], @@ -42,16 +49,18 @@ // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { // CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: } ], // CHECK-NEXT: metadata: [ { // CHECK-NEXT: name: "min_runtime_version", -// CHECK-NEXT: buffer: 3 +// CHECK-NEXT: buffer: 4 // CHECK-NEXT: } ] // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} func @main(%arg0: tensor<4xf32>, %arg1: tensor<4x!quant.uniform>) -> tensor<4xf32> { - "tfl.NumericVerify"(%arg1, %arg0) {tolerance = 0.1 : f32} : (tensor<4x!quant.uniform>, tensor<4xf32>) -> () + "tfl.NumericVerify"(%arg1, %arg0) {tolerance = 0.1 : f32} : (tensor<4x!quant.uniform>, tensor<4xf32>) -> (tensor<4xf32>) return %arg0 : tensor<4xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/quantize.mlir b/tensorflow/compiler/mlir/lite/tests/quantize.mlir index 6f42ae6293d..a9e5663ed11 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize.mlir @@ -1,5 +1,5 @@ // RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize | FileCheck %s -// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify | FileCheck --check-prefix=DEBUG %s +// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify -tfl-log-if-failed | FileCheck --check-prefix=DEBUG %s // CHECK-LABEL: QuantizeFloatConst func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform> { @@ -76,7 +76,7 @@ func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform // DEBUG: %[[act:.*]] = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> // DEBUG: %[[f_conv:.*]] = "tfl.conv_2d"(%[[act]], %[[wt]], %[[bias]]) // DEBUG: %[[q_conv:.*]] = "tfl.conv_2d" -// DEBUG: "tfl.NumericVerify"(%[[q_conv]], %[[f_conv]]) {tolerance = 5.000000e+00 : f32} +// DEBUG: "tfl.NumericVerify"(%[[q_conv]], %[[f_conv]]) {log_if_failed = true, tolerance = 5.000000e+00 : f32} // DEBUG: return %[[q_conv]] : tensor<1x112x112x32x!quant.uniform> } diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index a7a8af9e3bd..34cc014ec59 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -54,7 +54,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, quant_specs.default_ranges.second.getValueOr(0.0), quant_specs.IsSignedInferenceType())); } - pass_manager->addNestedPass(mlir::TFL::CreateQuantizePass()); + pass_manager->addNestedPass( + mlir::TFL::CreateQuantizePass(quant_specs.verify_numeric)); bool emit_quant_adaptor_ops = quant_specs.inference_type != quant_specs.inference_input_type; pass_manager->addNestedPass( diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 58e7c929b73..29a223e60b7 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -48,7 +48,8 @@ std::unique_ptr> CreatePrepareTFPass( std::unique_ptr> CreateLowerStaticTensorListPass(); // Creates an instance of the TensorFlow Lite dialect Quantize pass. -std::unique_ptr> CreateQuantizePass(); +std::unique_ptr> CreateQuantizePass( + bool verify_numeric = false); // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass. std::unique_ptr> CreatePrepareQuantizePass( diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index e2cce058f88..f8c686b5a7f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -55,6 +55,13 @@ static llvm::cl::opt enable_single_layer_verify( "`-tfl-numeric-verify` is set."), llvm::cl::init(true)); +// NOLINTNEXTLINE +static llvm::cl::opt enable_log_if_failed( + "tfl-log-if-failed", llvm::cl::value_desc("bool"), + llvm::cl::desc("Whether verify numericals with thresholding " + "tolerance. Valid when `-tfl-numeric-verify` is set."), + llvm::cl::init(false)); + namespace mlir { namespace TFL { @@ -67,16 +74,26 @@ namespace { struct TFLFullQuantization : public quant::QuantizationPattern { - explicit TFLFullQuantization(MLIRContext* ctx, bool verify_numeric, - float tolerance, bool verify_single_layer) - : BaseType(ctx, verify_numeric, tolerance, verify_single_layer) {} + explicit TFLFullQuantization(MLIRContext* ctx, bool verify_numeric_flag, + float tolerance, bool verify_single_layer, + bool log_if_failed_flag = false) + : BaseType(ctx, verify_numeric_flag, tolerance, verify_single_layer, + log_if_failed_flag) {} static bool AllowHybridOperand() { return false; } static bool AllowHybridResult() { return false; } }; // Applies quantization on the model in TFL dialect. struct QuantizePass : public PassWrapper { + public: + // Constructor used by manually creating the pass. + explicit QuantizePass(bool verify_numeric_flag = false) + : verify_numeric(verify_numeric_flag) {} + void runOnFunction() override; + + private: + bool verify_numeric; }; #include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc" @@ -87,14 +104,15 @@ void QuantizePass::runOnFunction() { auto* ctx = func.getContext(); TFL::populateWithGenerated(ctx, patterns); patterns.insert( - ctx, enable_numeric_verify, error_tolerance, enable_single_layer_verify); + ctx, enable_numeric_verify || verify_numeric, error_tolerance, + enable_single_layer_verify, enable_log_if_failed); applyPatternsAndFoldGreedily(func, std::move(patterns)); } } // namespace // Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass. -std::unique_ptr> CreateQuantizePass() { - return std::make_unique(); +std::unique_ptr> CreateQuantizePass(bool verify_numeric) { + return std::make_unique(verify_numeric); } static PassRegistration pass( diff --git a/tensorflow/lite/kernels/numeric_verify.cc b/tensorflow/lite/kernels/numeric_verify.cc index ce1e491b1d0..45771cbd9b4 100644 --- a/tensorflow/lite/kernels/numeric_verify.cc +++ b/tensorflow/lite/kernels/numeric_verify.cc @@ -38,8 +38,9 @@ namespace custom { namespace numeric_verify { static constexpr const char kToleranceStr[] = "tolerance"; -static constexpr const char kDebugModeStr[] = "debug_mode"; +static constexpr const char kLogIfFailedStr[] = "log_if_failed"; static constexpr const int kTemporaryDequantizedTensor = 0; +static constexpr const int kOutputTensor = 0; struct OpContext { OpContext(TfLiteContext* context, TfLiteNode* node) { @@ -61,7 +62,7 @@ struct OpData { bool float_input_initialized; int cache_tensor_id = kTensorNotAllocated; // This boolean value is for controlling the behavior of numeric verify op. - bool debug_mode; + bool log_if_failed; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -71,9 +72,9 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { const uint8_t* buffer_t = reinterpret_cast(buffer); const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); const float tolerance = m[kToleranceStr].AsFloat(); - const bool debug_mode = m[kDebugModeStr].AsBool(); + const bool log_if_failed = m[kLogIfFailedStr].AsBool(); op_data->tolerance = tolerance; - op_data->debug_mode = debug_mode; + op_data->log_if_failed = log_if_failed; return op_data; } @@ -84,13 +85,11 @@ void Free(TfLiteContext* context, void* buffer) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); OpData* op_data = reinterpret_cast(node->user_data); OpContext op_context(context, node); - const int num_output = (op_data->debug_mode) ? 1 : 0; - TF_LITE_ENSURE_EQ(context, NumOutputs(node), num_output); - TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 || op_context.input->type == kTfLiteInt8 || op_context.input->type == kTfLiteInt16 || @@ -118,15 +117,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context, dequantized, TfLiteIntArrayCopy(op_context.input->dims))); - if (op_data->debug_mode) { - TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, num_output - 1, - &op_context.output)); - op_context.output->type = kTfLiteFloat32; - op_context.output->allocation_type = kTfLiteArenaRwPersistent; - return context->ResizeTensor(context, op_context.output, - TfLiteIntArrayCopy(op_context.input->dims)); - } - return kTfLiteOk; + TF_LITE_ENSURE_OK( + context, GetOutputSafe(context, node, kOutputTensor, &op_context.output)); + op_context.output->type = kTfLiteFloat32; + op_context.output->allocation_type = kTfLiteArenaRwPersistent; + return context->ResizeTensor(context, op_context.output, + TfLiteIntArrayCopy(op_context.input->dims)); } static int32_t GetQuantizedValue(const OpContext& op_context, int index) { @@ -165,22 +161,37 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { op_data->float_input_initialized = true; } - // If the debug_mode is on, we don't throw any errors. - // We just calculate difference between float and quantized values, letting - // python debugger deal with the information. - if (op_data->debug_mode || op_data->tolerance < 0.1) { - const int num_output = (op_data->debug_mode) ? 1 : 0; - const int n = NumElements(dequantized); - if (op_data->debug_mode) { - TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, num_output - 1, - &op_context.output)); - auto output_data = GetTensorData(op_context.output); - for (int i = 0; i < n; ++i) { - float dequant = GetTensorData(dequantized)[i]; - float reference = GetTensorData(op_context.ref)[i]; - output_data[i] = dequant - reference; + TF_LITE_ENSURE_OK( + context, GetOutputSafe(context, node, kOutputTensor, &op_context.output)); + auto output_data = GetTensorData(op_context.output); + + // If log_if_failed is on, calculate differences between float and + // quantized values, their statistics and output logs. + // Throw errors if any diff greater than tolerance exists. + const int n = NumElements(dequantized); + if (op_data->log_if_failed && op_data->tolerance >= 0.1) { + // Verify the dequantized output. + auto max_diff = op_data->tolerance * op_context.input->params.scale; + for (int i = 0; i < n; ++i) { + int32_t value = GetQuantizedValue(op_context, i); + float dequant = GetTensorData(dequantized)[i]; + float reference = GetTensorData(op_context.ref)[i]; + output_data[i] = dequant - reference; + float diff = std::abs(output_data[i]); + if (diff > max_diff) { + TF_LITE_KERNEL_LOG( + context, + "Mismatch: %f is quantized to %d with (%f, %d). " + "abs(%f - %f) = %f > %f (tolerance) range percentage %f.\n", + reference, value, op_context.input->params.scale, + op_context.input->params.zero_point, reference, dequant, diff, + max_diff, op_data->tolerance); + return kTfLiteError; } } + } else { + // If tolerance is small or log_if_failed is off, then we only care about + // statistics. // These statistics logging was added to identify some errors in practice. std::vector diffs, temp; diffs.reserve(n); @@ -191,6 +202,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { float dequant = GetTensorData(dequantized)[i]; float reference = GetTensorData(op_context.ref)[i]; diffs[i] = static_cast(dequant - reference); + output_data[i] = dequant - reference; } double mean = std::accumulate(diffs.begin(), diffs.end(), 0.0) / diffs.size(); @@ -208,26 +220,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { "std: %f, mean: %f, max_diff: %f (scale: %f, zero_point: %d).\n", std, mean, max_diff, op_context.input->params.scale, op_context.input->params.zero_point); - return kTfLiteOk; - } else { - // Verify the dequantized output. - auto max_diff = op_data->tolerance * op_context.input->params.scale; - for (int i = 0; i < NumElements(op_context.ref); ++i) { - int32_t value = GetQuantizedValue(op_context, i); - float dequant = GetTensorData(dequantized)[i]; - float reference = GetTensorData(op_context.ref)[i]; - float diff = std::abs(reference - dequant); - if (diff > max_diff) { - TF_LITE_KERNEL_LOG( - context, - "Mismatch: %f is quantized to %d with (%f, %d). " - "abs(%f - %f) = %f > %f (tolerance) range percentage %f.\n", - reference, value, op_context.input->params.scale, - op_context.input->params.zero_point, reference, dequant, diff, - max_diff, op_data->tolerance); - return kTfLiteError; - } - } } return kTfLiteOk; } diff --git a/tensorflow/lite/kernels/numeric_verify_test.cc b/tensorflow/lite/kernels/numeric_verify_test.cc index e26f5607bb7..9e83000bef1 100644 --- a/tensorflow/lite/kernels/numeric_verify_test.cc +++ b/tensorflow/lite/kernels/numeric_verify_test.cc @@ -45,21 +45,19 @@ class NumericVerifyOpModel : public SingleOpModel { public: NumericVerifyOpModel(TensorType type, std::initializer_list shape, float scale, int32_t zero_point, int version, - float tolerance = 5.0, bool debug_mode = false) { + float tolerance = 5.0, bool log_if_failed = true) { const TensorData input_tensor_data = {type, shape, 0, 0, scale, zero_point}; input_ = AddInput(input_tensor_data); ref_ = AddInput({TensorType_FLOAT32, shape}); - if (debug_mode) { - // The output tensor has the same shape with that of the input tensor. - output_ = AddOutput({TensorType_FLOAT32, shape}); - } + // The output tensor has the same shape with that of the input tensor. + output_ = AddOutput({TensorType_FLOAT32, shape}); std::vector custom_options(sizeof(float)); flexbuffers::Builder fbb; fbb.Map([&]() { fbb.Float("tolerance", tolerance); - fbb.Bool("debug_mode", debug_mode); + fbb.Bool("log_if_failed", log_if_failed); }); fbb.Finish(); @@ -135,7 +133,7 @@ TEST(NumericVerifyOpFailedTest, Int8) { TEST(NumericVerifyOpDebugModeTest, Int8) { // [-63.5, 64] -> scale=0.5, zero_point=1 for INT8 - NumericVerifyOpModel m(TensorType_INT8, {2, 5}, 0.5, -1, 2, 5.0, true); + NumericVerifyOpModel m(TensorType_INT8, {2, 5}, 0.5, -1, 2, 5.0, false); // The 5th element is set to 0. m.SetInputs({-128, -127, -126, -125, -124, 0, 124, 125, 126, 127}, diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 300ce8434e3..4bfd2dc5792 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -126,7 +126,8 @@ class ConverterError(Exception): def mlir_quantize(input_data_str, disable_per_channel=False, fully_quantize=False, - inference_type=_types_pb2.INT8): + inference_type=_types_pb2.INT8, + enable_numeric_verify=False): """Quantize `input_data_str` with calibration results. Args: @@ -137,6 +138,8 @@ def mlir_quantize(input_data_str, fully_quantize: Bool indicating whether to fully quantize the model. Besides model body, the input/output will be quantized as well. inference_type: Data type for the activations. The default value is int8. + enable_numeric_verify: Experimental. Subject to change. Bool indicating + whether to add NumericVerify ops into the debug mode quantized model. Returns: Quantized model in serialized form (e.g. a TFLITE model) with floating-point @@ -145,7 +148,8 @@ def mlir_quantize(input_data_str, return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel, fully_quantize, - inference_type) + inference_type, + enable_numeric_verify) def mlir_sparsify(input_data_str): diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index b0cea7d1306..6d261c9b5ba 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -668,6 +668,82 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype) self.assertEqual(output_details[1]['dtype'], expected_dtype) + @test_util.run_v2_only + def testNewQuantizerNumericVerificationDebugMode(self): + """Test the model quantized by the new converter with numeric verify ops.""" + func, calibration_gen = self._getIntegerQuantizeModel() + + quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func]) + quantized_converter.target_spec.supported_ops = [ + lite.OpsSet.TFLITE_BUILTINS_INT8 + ] + quantized_converter.representative_dataset = calibration_gen + + # Create a TFLite model with new quantizer. + quantized_converter.optimizations = [lite.Optimize.DEFAULT] + quantized_converter._experimental_new_quantizer = True + production_tflite = quantized_converter.convert() + # Create a TFLite model with new quantizer and numeric verify ops. + quantized_converter._experimental_calibrate_only = True + calibrated = quantized_converter.convert() + debug_mode_tflite = mlir_quantize(calibrated, enable_numeric_verify=True) + + # Check if adding debug mode should output a different flatbuffer. + self.assertNotEqual(production_tflite, debug_mode_tflite) + + # Check if newly added ops are numeric verify ops. + input_data = tf.constant( + np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)) + + def examine_tflite_model(tflite_content, input_data): + interpreter = Interpreter(model_content=tflite_content) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + interpreter.set_tensor(input_details[0]['index'], input_data.numpy()) + interpreter.invoke() + tensor_details = interpreter.get_tensor_details() + return { + details['name']: interpreter.get_tensor(details['index']) + for details in interpreter.get_tensor_details() + }, tensor_details + + tflite_result, _ = examine_tflite_model(production_tflite, input_data) + debug_mode_tflite_result, debug_tensor_details = examine_tflite_model( + debug_mode_tflite, input_data) + + # MLIR-based quantizer should output flatbuffer model with `tfl.quantize`. + num_production_quantize_ops = len([ + None for output_tensor_name in tflite_result + if 'tfl.quantize' in output_tensor_name + ]) + self.assertEqual(num_production_quantize_ops, 1) + # MLIR-based quantizer should output flatbuffer model with `tfl.quantize`. + num_debug_quantize_ops = len([ + None for output_tensor_name in debug_mode_tflite_result + if 'tfl.quantize' in output_tensor_name + ]) + # Two numbers should be equal. + self.assertEqual(num_production_quantize_ops, num_debug_quantize_ops) + # DebugMode TFLite flatbuffer should have NumericVerifyOps more than zero. + # The name has the prefix "NumericVerify/{name}:{id} + # where {name} is the tensor name of the original quantized op's activation, + # and {id} is its tensor id. + num_debug_ops = 0 + for output_tensor_name in debug_mode_tflite_result: + if 'NumericVerify' in output_tensor_name: + pos_end_prefix = len('NumericVerify/') + pos_colon = output_tensor_name.rfind(':') + self.assertEqual('NumericVerify/', + output_tensor_name[:pos_end_prefix]) + tensor_id = int(output_tensor_name[pos_colon+1:]) + original_tensor_name = output_tensor_name[pos_end_prefix:pos_colon] + self.assertEqual(original_tensor_name, + debug_tensor_details[tensor_id]['name']) + num_debug_ops += 1 + self.assertEqual(num_debug_ops, 1) + # The number of debug ops should be equal to that of quantized ops. + self.assertEqual(num_debug_ops, num_debug_quantize_ops) + class FromSavedModelTest(lite_v2_test_util.ModelTest): diff --git a/tensorflow/lite/python/wrap_toco.py b/tensorflow/lite/python/wrap_toco.py index 60b33cea8fd..acba925c4c5 100644 --- a/tensorflow/lite/python/wrap_toco.py +++ b/tensorflow/lite/python/wrap_toco.py @@ -44,12 +44,14 @@ def wrapped_get_potentially_supported_ops(): def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel, - fully_quantize, inference_type): + fully_quantize, inference_type, + enable_numeric_verify): """Wraps experimental mlir quantize model.""" return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str, disable_per_channel, fully_quantize, - inference_type) + inference_type, + enable_numeric_verify) def wrapped_experimental_mlir_sparsify(input_data_str): diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index edcc1f805b4..98aaafdcb6c 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -236,7 +236,8 @@ PyObject* TocoGetPotentiallySupportedOps() { } PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, - bool fully_quantize, int inference_type) { + bool fully_quantize, int inference_type, + bool enable_numeric_verify) { using tflite::interpreter_wrapper::PythonErrorReporter; char* buf = nullptr; Py_ssize_t length; @@ -276,7 +277,7 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, auto status = mlir::lite::QuantizeModel( *tflite_model, inference_io_type, inference_io_type, inference_tensor_type, {}, disable_per_channel, fully_quantize, &builder, - error_reporter.get()); + error_reporter.get(), enable_numeric_verify); if (status != kTfLiteOk) { error_reporter->exception(); diff --git a/tensorflow/lite/toco/python/toco_python_api.h b/tensorflow/lite/toco/python/toco_python_api.h index df9d6e11bcf..14b80fcb1c0 100644 --- a/tensorflow/lite/toco/python/toco_python_api.h +++ b/tensorflow/lite/toco/python/toco_python_api.h @@ -44,7 +44,8 @@ PyObject* TocoGetPotentiallySupportedOps(); // is specified by the calibration data are not sufficient to quantize the // model. PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, - bool fully_quantize, int inference_type); + bool fully_quantize, int inference_type, + bool enable_numeric_verify = false); // Sparsifies model to encode sparse tensors with proper format. Throws error if // sparsification fails. diff --git a/tensorflow/python/lite/toco_python_api_wrapper.cc b/tensorflow/python/lite/toco_python_api_wrapper.cc index c5a5f63b2ac..8c2e889c127 100644 --- a/tensorflow/python/lite/toco_python_api_wrapper.cc +++ b/tensorflow/python/lite/toco_python_api_wrapper.cc @@ -57,13 +57,14 @@ PYBIND11_MODULE(_pywrap_toco_api, m) { m.def( "ExperimentalMlirQuantizeModel", [](py::object input_contents_txt_raw, bool disable_per_channel, - bool fully_quantize, int inference_type) { + bool fully_quantize, int inference_type, bool enable_numeric_verify) { return tensorflow::PyoOrThrow(toco::MlirQuantizeModel( input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize, - inference_type)); + inference_type, enable_numeric_verify)); }, py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false, py::arg("fully_quantize") = true, py::arg("inference_type") = 9, + py::arg("enable_numeric_verify") = false, R"pbdoc( Returns a quantized model. )pbdoc");