Added reference kernel fallback for conv
I added the reference kernel fallback for the conv's EvalQuantizedPerChannel routine, as suggested by @freddan80
This commit is contained in:
parent
38a34e2dd3
commit
d4e4bbc1a9
@ -153,7 +153,6 @@ TfLiteStatus EvalQuantizedPerChannel(
|
||||
TfLiteContext* context, TfLiteNode* node, TfLiteConvParams* params,
|
||||
OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output, TfLiteTensor* im2col) {
|
||||
#if defined(ARM_MATH_DSP) && defined(ARM_MATH_LOOPUNROLL)
|
||||
ConvParams op_params;
|
||||
op_params.input_offset = -input->params.zero_point;
|
||||
op_params.output_offset = output->params.zero_point;
|
||||
@ -164,6 +163,8 @@ TfLiteStatus EvalQuantizedPerChannel(
|
||||
op_params.padding_values.height = data->padding.height;
|
||||
op_params.padding_values.width = data->padding.width;
|
||||
|
||||
#if defined(ARM_MATH_DSP) && defined(ARM_MATH_LOOPUNROLL)
|
||||
|
||||
RuntimeShape filter_shape = GetTensorShape(filter);
|
||||
RuntimeShape input_shape = GetTensorShape(input);
|
||||
RuntimeShape output_shape = GetTensorShape(output);
|
||||
@ -235,7 +236,15 @@ TfLiteStatus EvalQuantizedPerChannel(
|
||||
}
|
||||
}
|
||||
#else
|
||||
#error ARM_MATH_DSP and ARM_MATH_LOOPUNROLL must be set
|
||||
|
||||
reference_integer_ops::ConvPerChannel(
|
||||
op_params, data->per_channel_output_multiplier,
|
||||
data->per_channel_output_shift, GetTensorShape(input),
|
||||
GetTensorData<int8>(input), GetTensorShape(filter),
|
||||
GetTensorData<int8>(filter), GetTensorShape(bias),
|
||||
GetTensorData<int32>(bias), GetTensorShape(output),
|
||||
GetTensorData<int8>(output));
|
||||
|
||||
#endif
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user