Stop CalculateActivationRangeQuantized being called for float MUL op.

This error was being masked by not using TF_LITE_ENSURE_STATUS on the output of mul::CalculateOpData.

PiperOrigin-RevId: 314959685
Change-Id: I458787eac0548a76590b1885c90eaa4e6517fc74
This commit is contained in:
A. Unique TensorFlower 2020-06-05 11:06:57 -07:00 committed by TensorFlower Gardener
parent faa0eac8a8
commit 1c7fb569b4
2 changed files with 14 additions and 12 deletions

View File

@ -50,14 +50,16 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
context, params->activation, output, &data->output_activation_min,
&data->output_activation_max));
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
context, params->activation, output, &data->output_activation_min,
&data->output_activation_max));
double real_multiplier =
input1->params.scale * input2->params.scale / output->params.scale;
QuantizeMultiplier(real_multiplier, &data->output_multiplier,
&data->output_shift);
double real_multiplier =
input1->params.scale * input2->params.scale / output->params.scale;
QuantizeMultiplier(real_multiplier, &data->output_multiplier,
&data->output_shift);
}
return kTfLiteOk;
}

View File

@ -50,11 +50,11 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
context, params->activation, output, &data->output_activation_min,
&data->output_activation_max));
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
context, params->activation, output, &data->output_activation_min,
&data->output_activation_max));
double real_multiplier = static_cast<double>(input1->params.scale) *
static_cast<double>(input2->params.scale) /
static_cast<double>(output->params.scale);
@ -138,7 +138,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInput2Tensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
CalculateOpData(context, node, params, &data);
TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, &data));
switch (input1->type) {
case kTfLiteUInt8: