Added reference kernel fallback for conv

I added the reference kernel fallback for the conv's EvalQuantizedPerChannel routine, as suggested by @freddan80
This commit is contained in:
Basit Ayantunde 2019-11-11 20:04:03 +01:00 committed by GitHub
parent 38a34e2dd3
commit d4e4bbc1a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -153,7 +153,6 @@ TfLiteStatus EvalQuantizedPerChannel(
TfLiteContext* context, TfLiteNode* node, TfLiteConvParams* params,
OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output, TfLiteTensor* im2col) {
#if defined(ARM_MATH_DSP) && defined(ARM_MATH_LOOPUNROLL)
ConvParams op_params;
op_params.input_offset = -input->params.zero_point;
op_params.output_offset = output->params.zero_point;
@ -164,6 +163,8 @@ TfLiteStatus EvalQuantizedPerChannel(
op_params.padding_values.height = data->padding.height;
op_params.padding_values.width = data->padding.width;
#if defined(ARM_MATH_DSP) && defined(ARM_MATH_LOOPUNROLL)
RuntimeShape filter_shape = GetTensorShape(filter);
RuntimeShape input_shape = GetTensorShape(input);
RuntimeShape output_shape = GetTensorShape(output);
@ -235,7 +236,15 @@ TfLiteStatus EvalQuantizedPerChannel(
}
}
#else
#error ARM_MATH_DSP and ARM_MATH_LOOPUNROLL must be set
reference_integer_ops::ConvPerChannel(
op_params, data->per_channel_output_multiplier,
data->per_channel_output_shift, GetTensorShape(input),
GetTensorData<int8>(input), GetTensorShape(filter),
GetTensorData<int8>(filter), GetTensorShape(bias),
GetTensorData<int32>(bias), GetTensorShape(output),
GetTensorData<int8>(output));
#endif
return kTfLiteOk;
}