Allow rounding errors in the numeric_verify check

PiperOrigin-RevId: 285905819
Change-Id: I0b5fd98d0f76e93e4ccc399b6ebe84f229a7364a
This commit is contained in:
Feng Liu 2019-12-16 20:48:39 -08:00 committed by TensorFlower Gardener
parent 8e42b57fc4
commit 43fe1b1493

View File

@ -142,15 +142,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
float dequant = GetTensorData<float>(dequantized)[i];
float reference = GetTensorData<float>(op_context.ref)[i];
if (std::abs(reference - dequant) / (reference + 1e-8) >
op_data->tolerance) {
float diff = std::abs(reference - dequant);
float error = diff / (reference + 1e-8);
// It is fine if the error is introduced by rounding so the diff will be
// smaller than `scale`.
if (diff > op_context.input->params.scale && error > op_data->tolerance) {
context->ReportError(context,
"Mismatch: %f is quantized to %d with (%f, %d). "
"abs((%f - %f) / %f) > %f (tolerance).\n",
"abs((%f - %f) / %f) = %f > %f (tolerance).\n",
reference, value, op_context.input->params.scale,
op_context.input->params.zero_point, reference,
dequant, reference, op_data->tolerance);
dequant, reference, error, op_data->tolerance);
return kTfLiteError;
}
}