Merge pull request #44618 from foss-for-synopsys-dwc-arc-processors:arc_mli_evaltensor_porting_depthwise_conv

PiperOrigin-RevId: 342178413
Change-Id: Id8aeedaeefb0db15148db43b2662a5f2615fb423
This commit is contained in:
TensorFlower Gardener 2020-11-12 19:03:31 -08:00
commit d9a52463b3

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/lite/micro/kernels/arc_mli/mli_tf_utils.h"
#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h"
#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace {
@ -58,10 +59,21 @@ struct OpData {
// Per channel output multiplier and shift.
int32_t* per_channel_output_multiplier;
int32_t* per_channel_output_shift;
// The range of the fused activation layer. For example for kNone and
// uint8_t these would be 0 and 255.
int32_t output_activation_min;
int32_t output_activation_max;
// The result of checking if MLI optimized version of tensors can be used.
bool is_mli_applicable;
// Tensors in MLI format.
mli_tensor* mli_in;
mli_tensor* mli_weights;
mli_tensor* mli_bias;
mli_tensor* mli_out;
mli_conv2d_cfg* cfg;
};
bool IsMliApplicable(TfLiteContext* context, const TfLiteTensor* input,
@ -109,8 +121,7 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (data_type != kTfLiteFloat32 &&
!IsMliApplicable(context, input, filter, bias, params)) {
if (data_type != kTfLiteFloat32 && !data->is_mli_applicable) {
int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
return tflite::PopulateConvolutionQuantizationParams(
@ -140,6 +151,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
const TfLiteType data_type = input->type;
int width = SizeOfDimension(input, 2);
@ -158,6 +170,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
context, num_channels * sizeof(int32_t)));
data->is_mli_applicable =
IsMliApplicable(context, input, filter, bias, params);
// All per-channel quantized tensors need valid zero point and scale arrays.
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
@ -185,13 +200,67 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
data->filter_zero_point = filter->params.zero_point;
data->output_zero_point = output->params.zero_point;
if (data->is_mli_applicable) {
data->mli_in = static_cast<mli_tensor*>(
context->AllocatePersistentBuffer(context, sizeof(mli_tensor)));
data->mli_weights = static_cast<mli_tensor*>(
context->AllocatePersistentBuffer(context, sizeof(mli_tensor)));
data->mli_bias = static_cast<mli_tensor*>(
context->AllocatePersistentBuffer(context, sizeof(mli_tensor)));
data->mli_out = static_cast<mli_tensor*>(
context->AllocatePersistentBuffer(context, sizeof(mli_tensor)));
data->cfg = static_cast<mli_conv2d_cfg*>(
context->AllocatePersistentBuffer(context, sizeof(mli_conv2d_cfg)));
// reuse space allocated for OpData parameters
data->mli_weights->el_params.asym.scale.pi32 =
static_cast<int32_t*>(data->per_channel_output_multiplier);
data->mli_bias->el_params.asym.scale.pi32 =
static_cast<int32_t*>(data->per_channel_output_shift);
data->mli_weights->el_params.asym.zero_point.pi16 =
reinterpret_cast<int16_t*>(&data->filter_zero_point);
data->mli_bias->el_params.asym.zero_point.pi16 =
reinterpret_cast<int16_t*>(&data->filter_zero_point) + sizeof(int16_t);
ops::micro::ConvertToMliTensor(input, data->mli_in);
ops::micro::ConvertToMliTensorPerChannel(filter, data->mli_weights);
ops::micro::ConvertToMliTensorPerChannel(bias, data->mli_bias);
ops::micro::ConvertToMliTensor(output, data->mli_out);
if (params->activation == kTfLiteActRelu) {
data->cfg->relu.type = MLI_RELU_GEN;
} else if (params->activation == kTfLiteActRelu6) {
data->cfg->relu.type = MLI_RELU_6;
} else if (params->activation == kTfLiteActRelu1) {
data->cfg->relu.type = MLI_RELU_1;
} else {
data->cfg->relu.type = MLI_RELU_NONE;
}
data->cfg->stride_width = params->stride_width;
data->cfg->stride_height = params->stride_height;
if (params->padding == kTfLitePaddingValid) {
data->cfg->padding_left = 0;
data->cfg->padding_right = 0;
data->cfg->padding_top = 0;
data->cfg->padding_bottom = 0;
} else {
data->cfg->padding_left = data->padding.width;
data->cfg->padding_right =
data->padding.width + data->padding.width_offset;
data->cfg->padding_top = data->padding.height;
data->cfg->padding_bottom =
data->padding.height + data->padding.height_offset;
}
}
return kTfLiteOk;
}
void EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteDepthwiseConvParams* params, const OpData& data,
const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) {
#if !defined(TF_LITE_STRIP_REFERENCE_IMPL)
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
@ -211,10 +280,14 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
op_params.float_activation_max = output_activation_max;
tflite::reference_ops::DepthwiseConv(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(filter), GetTensorData<float>(filter),
GetTensorShape(bias), GetTensorData<float>(bias), GetTensorShape(output),
GetTensorData<float>(output));
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
#else
TF_LITE_KERNEL_LOG(context,
"Type %s (%d) is not supported by ARC MLI Library.",
@ -223,188 +296,161 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
}
TfLiteStatus EvalMliQuantizedPerChannel(
TfLiteContext* context, TfLiteNode* node, TfLiteDepthwiseConvParams* params,
const OpData& data, const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
const OpData& data, const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
// Run Depthwise Conv MLI kernel
mli_tensor mli_in = {};
mli_tensor mli_weights = {};
mli_tensor mli_bias = {};
mli_tensor mli_out = {};
mli_conv2d_cfg cfg = {};
// MLI optimized version only supports int8_t dataype and dilation factor of 1
if (data.is_mli_applicable) {
// Copy configuration data from external to local memory
mli_conv2d_cfg cfg_local = *data.cfg;
// reuse space allocated for OpData parameters
mli_weights.el_params.asym.scale.pi32 =
(int32_t*)data.per_channel_output_multiplier;
mli_bias.el_params.asym.scale.pi32 = (int32_t*)data.per_channel_output_shift;
ops::micro::MliTensorAttachBuffer<int8_t>(input, data.mli_in);
ops::micro::MliTensorAttachBuffer<int8_t>(filter, data.mli_weights);
ops::micro::MliTensorAttachBuffer<int32_t>(bias, data.mli_bias);
ops::micro::MliTensorAttachBuffer<int8_t>(output, data.mli_out);
int16_t filter_zero_point = 0;
int16_t bias_zero_point = 0;
mli_weights.el_params.asym.zero_point.pi16 = &filter_zero_point;
mli_bias.el_params.asym.zero_point.pi16 = &bias_zero_point;
// for height slicing
const int heightDimension = 1;
int inSliceHeight = 0;
int outSliceHeight = 0;
const int kernelHeight =
static_cast<int>(data.mli_weights->shape[KRNL_DW_H_DIM_HWC]);
const int overlap = kernelHeight - cfg_local.stride_height;
ops::micro::ConvertToMliTensor<int8_t>(input, &mli_in);
ops::micro::ConvertToMliTensorPerChannel<int8_t>(filter, &mli_weights);
ops::micro::ConvertToMliTensorPerChannel<int32_t>(bias, &mli_bias);
ops::micro::ConvertToMliTensor<int8_t>(output, &mli_out);
// for weight slicing (on output channels)
// HWCN layout for weights, output channel dimension is the first dimension.
const int weight_out_ch_dimension = 3;
// bias has only 1 dimension
const int bias_out_ch_dimension = 0;
// Batch-Height-Width-Channel layout means last dimension is output
// channels.
const int out_tensor_ch_dimension = 3;
const int32_t in_channels = data.mli_in->shape[out_tensor_ch_dimension];
const int32_t out_channels = data.mli_out->shape[out_tensor_ch_dimension];
int slice_channels =
static_cast<int>(data.mli_weights->shape[weight_out_ch_dimension]);
if (params->activation == kTfLiteActRelu) {
cfg.relu.type = MLI_RELU_GEN;
} else if (params->activation == kTfLiteActRelu6) {
cfg.relu.type = MLI_RELU_6;
} else if (params->activation == kTfLiteActRelu1) {
cfg.relu.type = MLI_RELU_1;
} else {
cfg.relu.type = MLI_RELU_NONE;
}
// Tensors for data in fast (local) memory
// and config to copy data from external to local memory
mli_tensor weights_local = *data.mli_weights;
mli_tensor bias_local = *data.mli_bias;
mli_tensor in_local = *data.mli_in;
mli_tensor out_local =
*data.mli_out; // this assumes that output shape
// is already filled in the tensor struct.
mli_mov_cfg_t copy_config;
mli_mov_cfg_for_copy(&copy_config);
cfg.stride_width = params->stride_width;
cfg.stride_height = params->stride_height;
if (params->padding == kTfLitePaddingValid) {
cfg.padding_left = 0;
cfg.padding_right = 0;
cfg.padding_top = 0;
cfg.padding_bottom = 0;
} else {
cfg.padding_left = data.padding.width;
cfg.padding_right = data.padding.width + data.padding.width_offset;
cfg.padding_top = data.padding.height;
cfg.padding_bottom = data.padding.height + data.padding.height_offset;
}
// for height slicing
const int heightDimension = 1;
int inSliceHeight = 0;
int outSliceHeight = 0;
const int kernelHeight =
static_cast<int>(mli_weights.shape[KRNL_DW_H_DIM_HWC]);
const int overlap = kernelHeight - cfg.stride_height;
// for weight slicing (on output channels)
// HWCN layout for weights, output channel dimension is the first dimension.
const int weight_out_ch_dimension = 3;
// bias has only 1 dimension
const int bias_out_ch_dimension = 0;
// Batch-Height-Width-Channel layout means last dimension is output channels.
const int out_tensor_ch_dimension = 3;
const int32_t in_channels = mli_in.shape[out_tensor_ch_dimension];
const int32_t out_channels = mli_out.shape[out_tensor_ch_dimension];
int slice_channels =
static_cast<int>(mli_weights.shape[weight_out_ch_dimension]);
// Tensors for data in fast (local) memory
// and config to copy data from external to local memory
mli_tensor weights_local = mli_weights;
mli_tensor bias_local = mli_bias;
mli_tensor in_local = mli_in;
mli_tensor out_local = mli_out; // this assumes that output shape
// is already filled in the tensor struct.
mli_mov_cfg_t copy_config;
mli_mov_cfg_for_copy(&copy_config);
TF_LITE_ENSURE_STATUS(ops::micro::get_arc_scratch_buffer_for_conv_tensors(
context, &in_local, &weights_local, &bias_local, &out_local));
/* is_local indicates that the tensor is already in local memory,
TF_LITE_ENSURE_STATUS(ops::micro::get_arc_scratch_buffer_for_conv_tensors(
context, &in_local, &weights_local, &bias_local, &out_local));
/* is_local indicates that the tensor is already in local memory,
so in that case the original tensor can be used,
and there is no need to copy it to the local tensor*/
const bool in_is_local = in_local.data == mli_in.data;
const bool out_is_local = out_local.data == mli_out.data;
const bool w_is_local = weights_local.data == mli_weights.data;
const bool b_is_local = bias_local.data == mli_bias.data;
const bool in_is_local = in_local.data == data.mli_in->data;
const bool out_is_local = out_local.data == data.mli_out->data;
const bool w_is_local = weights_local.data == data.mli_weights->data;
const bool b_is_local = bias_local.data == data.mli_bias->data;
TF_LITE_ENSURE_STATUS(ops::micro::arc_scratch_buffer_calc_slice_size_io(
&in_local, &out_local, kernelHeight, cfg.stride_height, cfg.padding_top,
cfg.padding_bottom, &inSliceHeight, &outSliceHeight));
TF_LITE_ENSURE_STATUS(ops::micro::arc_scratch_buffer_calc_slice_size_weights(
&weights_local, &bias_local, weight_out_ch_dimension, &slice_channels));
TF_LITE_ENSURE_STATUS(ops::micro::arc_scratch_buffer_calc_slice_size_io(
&in_local, &out_local, kernelHeight, cfg_local.stride_height,
cfg_local.padding_top, cfg_local.padding_bottom, &inSliceHeight,
&outSliceHeight));
TF_LITE_ENSURE_STATUS(
ops::micro::arc_scratch_buffer_calc_slice_size_weights(
&weights_local, &bias_local, weight_out_ch_dimension,
&slice_channels));
/* if input channels is not equal to output channels, a channel multiplier
is used. in this case the slice channels needs to be rounded down to a
multiple of the input channels */
if (in_channels != out_channels) {
slice_channels = (slice_channels / in_channels) * in_channels;
}
ops::micro::TensorSlicer b_slice(&mli_bias, bias_out_ch_dimension,
slice_channels);
ops::micro::TensorSlicer w_slice(&mli_weights, weight_out_ch_dimension,
slice_channels, 0, 0, 0, true);
ops::micro::TensorSlicer out_ch_slice(&mli_out, out_tensor_ch_dimension,
slice_channels, 0, 0, 0, true);
ops::micro::TensorSlicer in_ch_slice(&mli_in, out_tensor_ch_dimension,
slice_channels, 0, 0, 0, true);
mli_tensor* w_ptr = w_is_local ? w_slice.Sub() : &weights_local;
mli_tensor* b_ptr = b_is_local ? b_slice.Sub() : &bias_local;
void* input_buffer_ptr = NULL;
uint32_t input_buffer_size = 0;
int padding_top = cfg.padding_top;
int padding_bottom = cfg.padding_bottom;
while (!w_slice.Done()) {
mli_mov_tensor_sync(w_slice.Sub(), &copy_config, w_ptr);
mli_mov_tensor_sync(b_slice.Sub(), &copy_config, b_ptr);
/* input tensor is already sliced in the channel dimension.
out_ch_slice.Sub() is the tensor for the amount of channels of this
iteration of the weight slice loop. This tensor needs to be further
sliced over the batch and height dimension. in_ch_slice.Sub() tensor
contains batches of HWC tensors. so it is a 4 dimensional tensor. because
the mli kernel will process one HWC tensor at a time, the 4 dimensional
tensor needs to be sliced into nBatch 3 dimensional tensors. on top of
that there could be a need to also slice in the Height dimension. for that
the sliceHeight has been calculated. The tensor slicer is configured that
it will completely slice the nBatch dimension (0) and slice the height
dimension (1) in chunks of 'sliceHeight' */
ops::micro::TensorSlicer in_slice(in_ch_slice.Sub(), heightDimension,
inSliceHeight, padding_top,
padding_bottom, overlap);
/* output tensor is already sliced in the output channel dimension.
out_ch_slice.Sub() is the tensor for the amount of output channels of this
iteration of the weight slice loop. This tensor needs to be further
sliced over the batch and height dimension. */
ops::micro::TensorSlicer out_slice(out_ch_slice.Sub(), heightDimension,
outSliceHeight);
/* setup the pointers to the local or remote tensor to make the code
* inside the loop easier. */
mli_tensor* in_ptr = in_is_local ? in_slice.Sub() : &in_local;
mli_tensor* out_ptr = out_is_local ? out_slice.Sub() : &out_local;
while (!out_slice.Done()) {
TF_LITE_ENSURE(context, !in_slice.Done());
cfg.padding_top = in_slice.GetPaddingPre();
cfg.padding_bottom = in_slice.GetPaddingPost();
// if same input copy as previous iteration, skip the copy of input
if ((in_slice.Sub()->data != input_buffer_ptr) ||
(mli_hlp_count_elem_num(in_slice.Sub(), 0) != input_buffer_size)) {
mli_mov_tensor_sync(in_slice.Sub(), &copy_config, in_ptr);
input_buffer_ptr = in_slice.Sub()->data;
input_buffer_size = mli_hlp_count_elem_num(in_slice.Sub(), 0);
}
mli_krn_depthwise_conv2d_hwcn_sa8_sa8_sa32(in_ptr, w_ptr, b_ptr, &cfg,
out_ptr);
mli_mov_tensor_sync(out_ptr, &copy_config, out_slice.Sub());
in_slice.Next();
out_slice.Next();
/* if input channels is not equal to output channels, a channel multiplier
is used. in this case the slice channels needs to be rounded down to a
multiple of the input channels */
if (in_channels != out_channels) {
slice_channels = (slice_channels / in_channels) * in_channels;
}
ops::micro::TensorSlicer b_slice(data.mli_bias, bias_out_ch_dimension,
slice_channels);
ops::micro::TensorSlicer w_slice(data.mli_weights, weight_out_ch_dimension,
slice_channels, 0, 0, 0, true);
ops::micro::TensorSlicer out_ch_slice(data.mli_out, out_tensor_ch_dimension,
slice_channels, 0, 0, 0, true);
ops::micro::TensorSlicer in_ch_slice(data.mli_in, out_tensor_ch_dimension,
slice_channels, 0, 0, 0, true);
mli_tensor* w_ptr = w_is_local ? w_slice.Sub() : &weights_local;
mli_tensor* b_ptr = b_is_local ? b_slice.Sub() : &bias_local;
void* input_buffer_ptr = NULL;
uint32_t input_buffer_size = 0;
int padding_top = cfg_local.padding_top;
int padding_bottom = cfg_local.padding_bottom;
while (!w_slice.Done()) {
mli_mov_tensor_sync(w_slice.Sub(), &copy_config, w_ptr);
mli_mov_tensor_sync(b_slice.Sub(), &copy_config, b_ptr);
/* input tensor is already sliced in the channel dimension.
out_ch_slice.Sub() is the tensor for the amount of channels of this
iteration of the weight slice loop. This tensor needs to be further
sliced over the batch and height dimension. in_ch_slice.Sub() tensor
contains batches of HWC tensors. so it is a 4 dimensional tensor. because
the mli kernel will process one HWC tensor at a time, the 4 dimensional
tensor needs to be sliced into nBatch 3 dimensional tensors. on top of
that there could be a need to also slice in the Height dimension. for that
the sliceHeight has been calculated. The tensor slicer is configured that
it will completely slice the nBatch dimension (0) and slice the height
dimension (1) in chunks of 'sliceHeight' */
ops::micro::TensorSlicer in_slice(in_ch_slice.Sub(), heightDimension,
inSliceHeight, padding_top,
padding_bottom, overlap);
/* output tensor is already sliced in the output channel dimension.
out_ch_slice.Sub() is the tensor for the amount of output channels of this
iteration of the weight slice loop. This tensor needs to be further
sliced over the batch and height dimension. */
ops::micro::TensorSlicer out_slice(out_ch_slice.Sub(), heightDimension,
outSliceHeight);
/* setup the pointers to the local or remote tensor to make the code
* inside the loop easier. */
mli_tensor* in_ptr = in_is_local ? in_slice.Sub() : &in_local;
mli_tensor* out_ptr = out_is_local ? out_slice.Sub() : &out_local;
while (!out_slice.Done()) {
TF_LITE_ENSURE(context, !in_slice.Done());
cfg_local.padding_top = in_slice.GetPaddingPre();
cfg_local.padding_bottom = in_slice.GetPaddingPost();
// if same input copy as previous iteration, skip the copy of input
if ((in_slice.Sub()->data != input_buffer_ptr) ||
(mli_hlp_count_elem_num(in_slice.Sub(), 0) != input_buffer_size)) {
mli_mov_tensor_sync(in_slice.Sub(), &copy_config, in_ptr);
input_buffer_ptr = in_slice.Sub()->data;
input_buffer_size = mli_hlp_count_elem_num(in_slice.Sub(), 0);
}
mli_krn_depthwise_conv2d_hwcn_sa8_sa8_sa32(in_ptr, w_ptr, b_ptr,
&cfg_local, out_ptr);
mli_mov_tensor_sync(out_ptr, &copy_config, out_slice.Sub());
in_slice.Next();
out_slice.Next();
}
w_slice.Next();
b_slice.Next();
out_ch_slice.Next();
in_ch_slice.Next();
TF_LITE_ENSURE(context, in_slice.Done());
}
w_slice.Next();
b_slice.Next();
out_ch_slice.Next();
in_ch_slice.Next();
TF_LITE_ENSURE(context, in_slice.Done());
}
return kTfLiteOk;
}
void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
TfLiteDepthwiseConvParams* params,
const OpData& data, const TfLiteTensor* input,
const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
const OpData& data, const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
#if !defined(TF_LITE_STRIP_REFERENCE_IMPL)
DepthwiseParams op_params;
op_params.padding_type = PaddingType::kSame;
@ -423,11 +469,14 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
reference_integer_ops::DepthwiseConvPerChannel(
op_params, data.per_channel_output_multiplier,
data.per_channel_output_shift, GetTensorShape(input),
GetTensorData<int8_t>(input), GetTensorShape(filter),
GetTensorData<int8_t>(filter), GetTensorShape(bias),
GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorData<int8_t>(output));
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
#else
TF_LITE_KERNEL_LOG(context,
"Node configuration is not supported by ARC MLI Library.");
@ -436,8 +485,9 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteDepthwiseConvParams* params, const OpData& data,
const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
#if !defined(TF_LITE_STRIP_REFERENCE_IMPL)
const int32_t input_offset = -data.input_zero_point;
const int32_t filter_offset = -data.filter_zero_point;
@ -463,10 +513,14 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
op_params.output_shift = -data.output_shift;
tflite::reference_ops::DepthwiseConv(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<uint8_t>(output));
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<uint8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output));
#else
TF_LITE_KERNEL_LOG(context,
"Type %s (%d) is not supported by ARC MLI Library.",
@ -482,18 +536,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
const TfLiteTensor* bias =
(NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr;
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
const TfLiteEvalTensor* filter =
tflite::micro::GetEvalInput(context, node, kFilterTensor);
const TfLiteEvalTensor* bias =
(NumInputs(node) == 3)
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
: nullptr;
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
EvalFloat(context, node, params, data, input, filter, bias, output);
break;
case kTfLiteInt8:
if (IsMliApplicable(context, input, filter, bias, params)) {
if (data.is_mli_applicable) {
EvalMliQuantizedPerChannel(context, node, params, data, input, filter,
bias, output);
} else {