diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/mul.cc b/tensorflow/lite/micro/kernels/cmsis-nn/mul.cc index b11fffefacf..d746166ebd9 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/mul.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/mul.cc @@ -50,14 +50,16 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, TF_LITE_ENSURE_EQ(context, input1->type, input2->type); - TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( - context, params->activation, output, &data->output_activation_min, - &data->output_activation_max)); + if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { + TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( + context, params->activation, output, &data->output_activation_min, + &data->output_activation_max)); - double real_multiplier = - input1->params.scale * input2->params.scale / output->params.scale; - QuantizeMultiplier(real_multiplier, &data->output_multiplier, - &data->output_shift); + double real_multiplier = + input1->params.scale * input2->params.scale / output->params.scale; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, + &data->output_shift); + } return kTfLiteOk; } diff --git a/tensorflow/lite/micro/kernels/mul.cc b/tensorflow/lite/micro/kernels/mul.cc index e89b58a3f62..fb47728a1a4 100644 --- a/tensorflow/lite/micro/kernels/mul.cc +++ b/tensorflow/lite/micro/kernels/mul.cc @@ -50,11 +50,11 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, TF_LITE_ENSURE_EQ(context, input1->type, input2->type); - TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( - context, params->activation, output, &data->output_activation_min, - &data->output_activation_max)); - if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { + TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( + context, params->activation, output, &data->output_activation_min, + &data->output_activation_max)); + double real_multiplier = static_cast(input1->params.scale) * static_cast(input2->params.scale) / static_cast(output->params.scale); @@ -138,7 +138,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInput2Tensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - CalculateOpData(context, node, params, &data); + TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, &data)); switch (input1->type) { case kTfLiteUInt8: