Refactor conv to share code between reference and optimized kernels

This commit is contained in:
Nat Jeffries 2021-02-11 18:17:55 -08:00
parent dce9ee5e26
commit 3aa4d9952c
9 changed files with 445 additions and 534 deletions

View File

@ -107,6 +107,7 @@ cc_library(
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:op_macros",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/micro/kernels:conv",
"//tensorflow/lite/micro/kernels:ethosu",
"//tensorflow/lite/micro/kernels:fully_connected",
"//tensorflow/lite/micro/kernels:micro_ops",

View File

@ -136,6 +136,43 @@ cc_library(
}),
)
cc_library(
name = "conv",
srcs = [
"conv_common.cc",
] + select({
"//conditions:default": [
"conv.cc",
],
":xtensa_hifimini": [
"xtensa/conv.cc",
],
}),
hdrs = ["conv.h"],
copts = micro_copts(),
visibility = [
# Kernel variants need to be visible to the examples and benchmarks.
":micro",
],
deps = [
":fixedpoint_utils",
":kernel_util",
":xtensa",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels/internal:common",
"//tensorflow/lite/kernels/internal:quantization_util",
"//tensorflow/lite/kernels/internal:reference_base",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:padding",
] + select({
"//conditions:default": [],
":xtensa_hifimini": [
#"//third_party/xtensa/cstub64s:hifi_mini",
],
}),
)
cc_library(
name = "kernel_runner",
srcs = [
@ -211,14 +248,12 @@ cc_library(
"zeros_like.cc",
] + select({
"//conditions:default": [
"conv.cc",
"depthwise_conv.cc",
"quantize.cc",
"softmax.cc",
"svdf.cc",
],
":xtensa_hifimini": [
"xtensa/conv.cc",
"xtensa/depthwise_conv.cc",
"xtensa/quantize.cc",
"xtensa/softmax.cc",

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/micro/kernels/conv.h"
#include "CMSIS/NN/Include/arm_nn_types.h"
#include "CMSIS/NN/Include/arm_nnfunctions.h"
@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@ -30,93 +31,9 @@ limitations under the License.
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kFilterTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kOutputTensor = 0;
// Conv is quantized along dimension 0:
// https://www.tensorflow.org/lite/performance/quantization_spec
constexpr int kConvQuantizedDimension = 0;
struct OpData {
TfLitePaddingValues padding;
// Cached tensor zero point values for quantized operations.
int32_t input_zero_point;
int32_t filter_zero_point;
int32_t output_zero_point;
// The scaling factor from input to output (aka the 'real multiplier') can
// be represented as a fixed point multiplier plus a left shift.
int32_t output_multiplier;
int output_shift;
// Per channel output multiplier and shift.
int32_t* per_channel_output_multiplier;
int32_t* per_channel_output_shift;
// The range of the fused activation layer. For example for kNone and
// uint8_t these would be 0 and 255.
int32_t output_activation_min;
int32_t output_activation_max;
// Index to buffer for optimizations if applicable.
int buffer_idx;
};
inline PaddingType RuntimePaddingType(TfLitePadding padding) {
switch (padding) {
case TfLitePadding::kTfLitePaddingSame:
return PaddingType::kSame;
case TfLitePadding::kTfLitePaddingValid:
return PaddingType::kValid;
case TfLitePadding::kTfLitePaddingUnknown:
default:
return PaddingType::kNone;
}
}
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
const TfLiteConvParams* params, int width,
int height, int filter_width, int filter_height,
int out_width, int out_height,
const TfLiteType data_type, OpData* data) {
bool has_bias = node->inputs->size == 3;
// Check number of inputs/outputs
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
data->padding = ComputePaddingHeightWidth(
params->stride_height, params->stride_width,
params->dilation_height_factor, params->dilation_width_factor, height,
width, filter_height, filter_width, padding, &out_height, &out_width);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
if (data_type != kTfLiteFloat32) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
const TfLiteTensor* bias =
GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
int num_channels = filter->dims->data[kConvQuantizedDimension];
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
context, input, filter, bias, output, params->activation,
&data->output_multiplier, &data->output_shift,
&data->output_activation_min, &data->output_activation_max,
data->per_channel_output_multiplier,
reinterpret_cast<int*>(data->per_channel_output_shift), num_channels));
}
return kTfLiteOk;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
return context->AllocatePersistentBuffer(context, sizeof(OpDataConv));
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
@ -125,11 +42,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int32_t buf_size = 0;
const auto params = static_cast<const TfLiteConvParams*>(node->builtin_data);
OpData* data = static_cast<OpData*>(node->user_data);
OpDataConv* data = static_cast<OpDataConv*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
const TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
const TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
RuntimeShape input_shape = GetTensorShape(input);
RuntimeShape output_shape = GetTensorShape(output);
@ -168,7 +85,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
static_cast<int32_t*>(context->AllocatePersistentBuffer(
context, num_channels * sizeof(int32_t)));
TF_LITE_ENSURE_STATUS(CalculateOpData(
TF_LITE_ENSURE_STATUS(CalculateOpDataConv(
context, node, params, input_dims.w, input_dims.h, filter_dims.w,
filter_dims.h, output_dims.w, output_dims.h, input->type, data));
@ -203,49 +120,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, const OpData& data,
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* im2col,
TfLiteEvalTensor* hwcn_weights,
TfLiteEvalTensor* output) {
const int32_t input_offset = -data.input_zero_point;
const int32_t filter_offset = -data.filter_zero_point;
const int32_t output_offset = data.output_zero_point;
ConvParams op_params;
op_params.padding_type = RuntimePaddingType(params->padding);
op_params.padding_values.width = data.padding.width;
op_params.padding_values.height = data.padding.height;
op_params.stride_width = params->stride_width;
op_params.stride_height = params->stride_height;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.input_offset = input_offset;
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = data.output_multiplier;
op_params.output_shift = -data.output_shift;
op_params.quantized_activation_min = data.output_activation_min;
op_params.quantized_activation_max = data.output_activation_max;
reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<uint8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output),
tflite::micro::GetTensorShape(im2col),
tflite::micro::GetTensorData<uint8_t>(im2col), nullptr);
return kTfLiteOk;
}
TfLiteStatus EvalQuantizedPerChannel(
TfLiteContext* context, TfLiteNode* node, TfLiteConvParams* params,
const OpData& data, const TfLiteEvalTensor* input,
const OpDataConv& data, const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output, TfLiteEvalTensor* im2col) {
cmsis_nn_conv_params conv_params;
@ -340,21 +217,8 @@ TfLiteStatus EvalQuantizedPerChannel(
tflite::micro::GetTensorData<int8_t>(output)),
ARM_MATH_SUCCESS);
} else {
// TODO(b/154032858): Investigate removing extra copies.
ConvParams op_params;
op_params.input_offset = -data.input_zero_point;
op_params.output_offset = data.output_zero_point;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.padding_values.height = data.padding.height;
op_params.padding_values.width = data.padding.width;
op_params.quantized_activation_min = data.output_activation_min;
op_params.quantized_activation_max = data.output_activation_max;
reference_integer_ops::ConvPerChannel(
op_params, data.per_channel_output_multiplier,
ConvParamsQuantized(params, data), data.per_channel_output_multiplier,
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
@ -367,75 +231,59 @@ TfLiteStatus EvalQuantizedPerChannel(
return kTfLiteOk;
}
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, const OpData& data,
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias, TfLiteEvalTensor* im2col,
TfLiteEvalTensor* hwcn_weights,
TfLiteEvalTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
// TODO(b/154032858): Investigate removing extra copies.
ConvParams op_params;
op_params.padding_type = RuntimePaddingType(params->padding);
op_params.padding_values.width = data.padding.width;
op_params.padding_values.height = data.padding.height;
op_params.stride_width = params->stride_width;
op_params.stride_height = params->stride_height;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.float_activation_min = output_activation_min;
op_params.float_activation_max = output_activation_max;
reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output),
tflite::micro::GetTensorShape(im2col),
tflite::micro::GetTensorData<float>(im2col));
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
tflite::micro::GetEvalInput(context, node, kConvInputTensor);
const TfLiteEvalTensor* filter =
tflite::micro::GetEvalInput(context, node, kFilterTensor);
tflite::micro::GetEvalInput(context, node, kConvWeightsTensor);
const TfLiteEvalTensor* bias =
(NumInputs(node) == 3)
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
? tflite::micro::GetEvalInput(context, node, kConvBiasTensor)
: nullptr;
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
const OpDataConv& data = *(static_cast<const OpDataConv*>(node->user_data));
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
"Hybrid models are not supported on TFLite Micro.");
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
EvalFloat(context, node, params, data, input, filter, bias, nullptr,
nullptr, output);
case kTfLiteFloat32: {
tflite::reference_ops::Conv(
ConvParamsFloat(params, data), tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output),
tflite::micro::GetTensorShape(nullptr), nullptr);
break;
}
case kTfLiteInt8:
return EvalQuantizedPerChannel(context, node, params, data, input, filter,
bias, output, nullptr);
break;
case kTfLiteUInt8:
return EvalQuantized(context, node, params, data, input, filter, bias,
nullptr, nullptr, output);
case kTfLiteUInt8: {
reference_ops::Conv(ConvParamsQuantized(params, data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<uint8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output),
tflite::micro::GetTensorShape(nullptr), nullptr,
nullptr);
break;
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);

View File

@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/micro/kernels/conv.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@ -28,110 +29,23 @@ limitations under the License.
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kFilterTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kOutputTensor = 0;
// Conv is quantized along dimension 0:
// https://www.tensorflow.org/lite/performance/quantization_spec
constexpr int kConvQuantizedDimension = 0;
// This file has 2 implementation of Conv.
struct OpData {
TfLitePaddingValues padding;
// Cached tensor zero point values for quantized operations.
int32_t input_zero_point;
int32_t filter_zero_point;
int32_t output_zero_point;
// The scaling factor from input to output (aka the 'real multiplier') can
// be represented as a fixed point multiplier plus a left shift.
int32_t output_multiplier;
int output_shift;
// Per channel output multiplier and shift.
int32_t* per_channel_output_multiplier;
int32_t* per_channel_output_shift;
// The range of the fused activation layer. For example for kNone and
// uint8_t these would be 0 and 255.
int32_t output_activation_min;
int32_t output_activation_max;
};
inline PaddingType RuntimePaddingType(TfLitePadding padding) {
switch (padding) {
case TfLitePadding::kTfLitePaddingSame:
return PaddingType::kSame;
case TfLitePadding::kTfLitePaddingValid:
return PaddingType::kValid;
case TfLitePadding::kTfLitePaddingUnknown:
default:
return PaddingType::kNone;
}
}
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
const TfLiteConvParams* params, int width,
int height, int filter_width, int filter_height,
int out_width, int out_height,
const TfLiteType data_type, OpData* data) {
bool has_bias = node->inputs->size == 3;
// Check number of inputs/outputs
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
data->padding = ComputePaddingHeightWidth(
params->stride_height, params->stride_width,
params->dilation_height_factor, params->dilation_width_factor, height,
width, filter_height, filter_width, padding, &out_height, &out_width);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
if (data_type != kTfLiteFloat32) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
TF_LITE_ENSURE(context, filter != nullptr);
const TfLiteTensor* bias =
GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
int output_channels = filter->dims->data[kConvQuantizedDimension];
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
context, input, filter, bias, output, params->activation,
&data->output_multiplier, &data->output_shift,
&data->output_activation_min, &data->output_activation_max,
data->per_channel_output_multiplier,
reinterpret_cast<int*>(data->per_channel_output_shift),
output_channels));
}
return kTfLiteOk;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
return context->AllocatePersistentBuffer(context, sizeof(OpDataConv));
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
OpDataConv* data = static_cast<OpDataConv*>(node->user_data);
const auto params = static_cast<const TfLiteConvParams*>(node->builtin_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
TF_LITE_ENSURE(context, filter != nullptr);
int input_width = input->dims->data[2];
@ -169,7 +83,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
affine_quantization->zero_point->size);
}
TF_LITE_ENSURE_STATUS(CalculateOpData(
TF_LITE_ENSURE_STATUS(CalculateOpDataConv(
context, node, params, input_width, input_height, filter_width,
filter_height, output_width, output_height, input->type, data));
@ -178,144 +92,70 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
data->output_zero_point = output->params.zero_point;
return kTfLiteOk;
} // namespace conv
void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, const OpData& data,
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
TfLiteEvalTensor* im2col, TfLiteEvalTensor* hwcn_weights,
TfLiteEvalTensor* output) {
const int32_t input_offset = -data.input_zero_point;
const int32_t filter_offset = -data.filter_zero_point;
const int32_t output_offset = data.output_zero_point;
// TODO(b/154032858): Investigate removing extra copies.
ConvParams op_params;
op_params.padding_type = RuntimePaddingType(params->padding);
op_params.padding_values.width = data.padding.width;
op_params.padding_values.height = data.padding.height;
op_params.stride_width = params->stride_width;
op_params.stride_height = params->stride_height;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.input_offset = input_offset;
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = data.output_multiplier;
op_params.output_shift = -data.output_shift;
op_params.quantized_activation_min = data.output_activation_min;
op_params.quantized_activation_max = data.output_activation_max;
reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<uint8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output),
tflite::micro::GetTensorShape(im2col),
tflite::micro::GetTensorData<uint8_t>(im2col), nullptr);
}
void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, const OpData& data,
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output,
TfLiteEvalTensor* im2col) {
// TODO(b/154032858): Investigate removing extra copies.
ConvParams op_params;
op_params.input_offset = -data.input_zero_point;
op_params.output_offset = data.output_zero_point;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.padding_values.height = data.padding.height;
op_params.padding_values.width = data.padding.width;
op_params.quantized_activation_min = data.output_activation_min;
op_params.quantized_activation_max = data.output_activation_max;
reference_integer_ops::ConvPerChannel(
op_params, data.per_channel_output_multiplier,
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
}
void EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, const OpData& data,
const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias, TfLiteEvalTensor* im2col,
TfLiteEvalTensor* hwcn_weights, TfLiteEvalTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
// TODO(b/154032858): Investigate removing extra copies.
ConvParams op_params;
op_params.padding_type = RuntimePaddingType(params->padding);
op_params.padding_values.width = data.padding.width;
op_params.padding_values.height = data.padding.height;
op_params.stride_width = params->stride_width;
op_params.stride_height = params->stride_height;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.float_activation_min = output_activation_min;
op_params.float_activation_max = output_activation_max;
reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output),
tflite::micro::GetTensorShape(im2col),
tflite::micro::GetTensorData<float>(im2col));
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
tflite::micro::GetEvalInput(context, node, kConvInputTensor);
const TfLiteEvalTensor* filter =
tflite::micro::GetEvalInput(context, node, kFilterTensor);
tflite::micro::GetEvalInput(context, node, kConvWeightsTensor);
const TfLiteEvalTensor* bias =
(NumInputs(node) == 3)
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
? tflite::micro::GetEvalInput(context, node, kConvBiasTensor)
: nullptr;
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
const auto& data = *(static_cast<const OpDataConv*>(node->user_data));
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
"Hybrid models are not supported on TFLite Micro.");
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
EvalFloat(context, node, params, data, input, filter, bias, nullptr,
nullptr, output);
case kTfLiteFloat32: {
tflite::reference_ops::Conv(
ConvParamsFloat(params, data), tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output),
tflite::micro::GetTensorShape(nullptr), nullptr);
break;
case kTfLiteInt8:
EvalQuantizedPerChannel(context, node, params, data, input, filter, bias,
output, nullptr);
}
case kTfLiteInt8: {
reference_integer_ops::ConvPerChannel(
ConvParamsQuantized(params, data), data.per_channel_output_multiplier,
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
case kTfLiteUInt8:
EvalQuantized(context, node, params, data, input, filter, bias, nullptr,
nullptr, output);
}
case kTfLiteUInt8: {
reference_ops::Conv(ConvParamsQuantized(params, data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<uint8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output),
tflite::micro::GetTensorShape(nullptr), nullptr,
nullptr);
break;
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);

View File

@ -0,0 +1,77 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_
#include <cstdint>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
struct OpDataConv {
TfLitePaddingValues padding;
// Cached tensor zero point values for quantized operations.
int32_t input_zero_point;
int32_t filter_zero_point;
int32_t output_zero_point;
// The scaling factor from input to output (aka the 'real multiplier') can
// be represented as a fixed point multiplier plus a left shift.
int32_t output_multiplier;
int output_shift;
// Per channel output multiplier and shift.
int32_t* per_channel_output_multiplier;
int32_t* per_channel_output_shift;
// The range of the fused activation layer. For example for kNone and
// uint8_t these would be 0 and 255.
int32_t output_activation_min;
int32_t output_activation_max;
// Index to buffer for optimizations if applicable.
int buffer_idx;
};
extern const int kConvInputTensor;
extern const int kConvWeightsTensor;
extern const int kConvBiasTensor;
extern const int kConvOutputTensor;
extern const int kConvQuantizedDimension;
// Returns a ConvParams struct with all the parameters needed for a
// float computation.
ConvParams ConvParamsFloat(TfLiteConvParams* params, const OpDataConv& data);
// Returns a ConvParams struct with all the parameters needed for a
// quantized computation.
ConvParams ConvParamsQuantized(TfLiteConvParams* params,
const OpDataConv& data);
TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node,
const TfLiteConvParams* params, int width,
int height, int filter_width,
int filter_height, int out_width,
int out_height, const TfLiteType data_type,
OpDataConv* data);
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_

View File

@ -0,0 +1,184 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/micro/kernels/conv.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
const int kConvInputTensor = 0;
const int kConvWeightsTensor = 1;
const int kConvBiasTensor = 2;
const int kConvOutputTensor = 0;
// Conv is quantized along dimension 0:
// https://www.tensorflow.org/lite/performance/quantization_spec
const int kConvQuantizedDimension = 0;
// Returns a ConvParams struct with all the parameters needed for a
// float computation.
ConvParams ConvParamsFloat(TfLiteConvParams* params, const OpDataConv& data) {
ConvParams op_params;
CalculateActivationRange(params->activation, &op_params.float_activation_min,
&op_params.float_activation_max);
op_params.padding_type = tflite::micro::RuntimePaddingType(params->padding);
op_params.padding_values.width = data.padding.width;
op_params.padding_values.height = data.padding.height;
op_params.stride_width = params->stride_width;
op_params.stride_height = params->stride_height;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.dilation_height_factor = params->dilation_height_factor;
return op_params;
}
// Returns a ConvParams struct with all the parameters needed for a
// quantized computation.
ConvParams ConvParamsQuantized(TfLiteConvParams* params,
const OpDataConv& data) {
ConvParams op_params;
op_params.input_offset = -data.input_zero_point;
op_params.weights_offset = -data.filter_zero_point;
op_params.output_offset = data.output_zero_point;
op_params.output_multiplier = data.output_multiplier;
op_params.output_shift = -data.output_shift;
op_params.padding_type = tflite::micro::RuntimePaddingType(params->padding);
op_params.padding_values.height = data.padding.height;
op_params.padding_values.width = data.padding.width;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.quantized_activation_min = data.output_activation_min;
op_params.quantized_activation_max = data.output_activation_max;
return op_params;
}
TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node,
const TfLiteConvParams* params, int width,
int height, int filter_width,
int filter_height, int out_width,
int out_height, const TfLiteType data_type,
OpDataConv* data) {
bool has_bias = node->inputs->size == 3;
// Check number of inputs/outputs
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
data->padding = ComputePaddingHeightWidth(
params->stride_height, params->stride_width,
params->dilation_height_factor, params->dilation_width_factor, height,
width, filter_height, filter_width, padding, &out_height, &out_width);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
if (data_type != kTfLiteFloat32) {
const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
TF_LITE_ENSURE(context, filter != nullptr);
const TfLiteTensor* bias =
GetOptionalInputTensor(context, node, kConvBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
int output_channels = filter->dims->data[kConvQuantizedDimension];
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
context, input, filter, bias, output, params->activation,
&data->output_multiplier, &data->output_shift,
&data->output_activation_min, &data->output_activation_max,
data->per_channel_output_multiplier,
reinterpret_cast<int*>(data->per_channel_output_shift),
output_channels));
}
return kTfLiteOk;
}
void* InitConv(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpDataConv));
}
TfLiteStatus PrepareConv(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
OpDataConv* data = static_cast<OpDataConv*>(node->user_data);
const auto params = static_cast<const TfLiteConvParams*>(node->builtin_data);
TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
TF_LITE_ENSURE(context, filter != nullptr);
int input_width = input->dims->data[2];
int input_height = input->dims->data[1];
int filter_width = filter->dims->data[2];
int filter_height = filter->dims->data[1];
int output_width = output->dims->data[2];
int output_height = output->dims->data[1];
// Dynamically allocate per-channel quantization parameters.
const int num_channels = filter->dims->data[kConvQuantizedDimension];
data->per_channel_output_multiplier =
static_cast<int32_t*>(context->AllocatePersistentBuffer(
context, num_channels * sizeof(int32_t)));
data->per_channel_output_shift =
static_cast<int32_t*>(context->AllocatePersistentBuffer(
context, num_channels * sizeof(int32_t)));
// All per-channel quantized tensors need valid zero point and scale arrays.
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
kTfLiteAffineQuantization);
const auto* affine_quantization =
static_cast<TfLiteAffineQuantization*>(filter->quantization.params);
TF_LITE_ENSURE(context, affine_quantization);
TF_LITE_ENSURE(context, affine_quantization->scale);
TF_LITE_ENSURE(context, affine_quantization->zero_point);
TF_LITE_ENSURE(context,
affine_quantization->scale->size == 1 ||
affine_quantization->scale->size ==
filter->dims->data[kConvQuantizedDimension]);
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
affine_quantization->zero_point->size);
}
TF_LITE_ENSURE_STATUS(CalculateOpDataConv(
context, node, params, input_width, input_height, filter_width,
filter_height, output_width, output_height, input->type, data));
data->input_zero_point = input->params.zero_point;
data->filter_zero_point = filter->params.zero_point;
data->output_zero_point = output->params.zero_point;
return kTfLiteOk;
}
} // namespace tflite

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <cstdint>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/types.h"
@ -69,6 +70,18 @@ const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor);
bool HaveSameShapes(const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2);
inline PaddingType RuntimePaddingType(TfLitePadding padding) {
switch (padding) {
case TfLitePadding::kTfLitePaddingSame:
return PaddingType::kSame;
case TfLitePadding::kTfLitePaddingValid:
return PaddingType::kValid;
case TfLitePadding::kTfLitePaddingUnknown:
default:
return PaddingType::kNone;
}
}
} // namespace micro
} // namespace tflite

View File

@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/micro/kernels/conv.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@ -30,36 +31,6 @@ limitations under the License.
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kFilterTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kOutputTensor = 0;
// Conv is quantized along dimension 0:
// https://www.tensorflow.org/lite/performance/quantization_spec
constexpr int kConvQuantizedDimension = 0;
struct OpData {
TfLitePaddingValues padding;
// The scaling factor from input to output (aka the 'real multiplier') can
// be represented as a fixed point multiplier plus a left shift.
int32_t output_multiplier;
int output_shift;
// Cached tensor zero point values for quantized operations.
int32_t input_zero_point;
int32_t output_zero_point;
// Per channel output multiplier and shift.
int32_t* per_channel_output_multiplier;
int32_t* per_channel_output_shift;
// The range of the fused activation layer. For example for kNone and
// uint8_t these would be 0 and 255.
int32_t output_activation_min;
int32_t output_activation_max;
};
#if defined(HIFIMINI)
void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier,
const int32_t* output_shift,
@ -263,47 +234,9 @@ inline void Conv1x32Input32x32Filter(
}
#endif
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, int width, int height,
int filter_width, int filter_height, int out_width,
int out_height, const TfLiteType data_type,
OpData* data) {
bool has_bias = node->inputs->size == 3;
// Check number of inputs/outputs
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
data->padding = ComputePaddingHeightWidth(
params->stride_height, params->stride_width,
params->dilation_height_factor, params->dilation_width_factor, height,
width, filter_height, filter_width, padding, &out_height, &out_width);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
if (data_type != kTfLiteFloat32) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
const TfLiteTensor* bias =
GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
int output_channels = filter->dims->data[kConvQuantizedDimension];
return tflite::PopulateConvolutionQuantizationParams(
context, input, filter, bias, output, params->activation,
&data->output_multiplier, &data->output_shift,
&data->output_activation_min, &data->output_activation_max,
data->per_channel_output_multiplier,
reinterpret_cast<int*>(data->per_channel_output_shift),
output_channels);
}
return kTfLiteOk;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
return context->AllocatePersistentBuffer(context, sizeof(OpDataConv));
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
@ -311,11 +244,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor);
const TfLiteTensor* input = GetInput(context, node, kConvInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor);
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
auto* op_data = reinterpret_cast<OpDataConv*>(node->user_data);
int input_width = input->dims->data[2];
int input_height = input->dims->data[1];
@ -356,71 +289,26 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
affine_quantization->zero_point->size);
}
return CalculateOpData(context, node, params, input_width, input_height,
filter_width, filter_height, output_width,
output_height, input->type, op_data);
}
void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data,
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output,
TfLiteEvalTensor* im2col) {
// TODO(b/154032858): Investigate removing extra copies.
ConvParams op_params;
op_params.input_offset = -data->input_zero_point;
op_params.output_offset = data->output_zero_point;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.padding_values.height = data->padding.height;
op_params.padding_values.width = data->padding.width;
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
#if defined(HIFIMINI)
ConvPerChannel(op_params, data->per_channel_output_multiplier,
data->per_channel_output_shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
#else
reference_integer_ops::ConvPerChannel(
op_params, data->per_channel_output_multiplier,
data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
#endif
return CalculateOpDataConv(context, node, params, input_width, input_height,
filter_width, filter_height, output_width,
output_height, input->type, op_data);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
auto* op_data = reinterpret_cast<OpDataConv*>(node->user_data);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
tflite::micro::GetEvalInput(context, node, kConvInputTensor);
const TfLiteEvalTensor* filter =
tflite::micro::GetEvalInput(context, node, kFilterTensor);
tflite::micro::GetEvalInput(context, node, kConvWeightsTensor);
const TfLiteEvalTensor* bias =
(NumInputs(node) == 3)
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
? tflite::micro::GetEvalInput(context, node, kConvBiasTensor)
: nullptr;
#if defined(HIFIMINI)
@ -446,10 +334,34 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
#endif
switch (input->type) {
case kTfLiteInt8:
EvalQuantizedPerChannel(context, node, params, op_data, input, filter,
bias, output, nullptr);
case kTfLiteInt8: {
#if defined(HIFIMINI)
ConvPerChannel(ConvParamsQuantized(params, *op_data),
op_data->per_channel_output_multiplier,
op_data->per_channel_output_shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
#else
reference_integer_ops::ConvPerChannel(
ConvParamsQuantized(params, op_data),
data->per_channel_output_multiplier, data->per_channel_output_shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
#endif
break;
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);

View File

@ -316,6 +316,7 @@ tensorflow/lite/micro/kernels/circular_buffer.cc \
tensorflow/lite/micro/kernels/comparisons.cc \
tensorflow/lite/micro/kernels/concatenation.cc \
tensorflow/lite/micro/kernels/conv.cc \
tensorflow/lite/micro/kernels/conv_common.cc \
tensorflow/lite/micro/kernels/conv_test_common.cc \
tensorflow/lite/micro/kernels/depthwise_conv.cc \
tensorflow/lite/micro/kernels/dequantize.cc \