Port the xtensa_hifimini optimized kernels to the new TfLiteEvalTensor API. Unify strategy to get optional bias tensor with new tensor API.
PiperOrigin-RevId: 324089711 Change-Id: Ic46a1e737cc2f9dc5ba9f186290c25a477fab9aa
This commit is contained in:
parent
8dd314781e
commit
481957da0f
@ -286,7 +286,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* filter =
|
||||
tflite::micro::GetEvalInput(context, node, kFilterTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
tflite::micro::GetEvalInput(context, node, kBiasTensor);
|
||||
(NumInputs(node) == 3)
|
||||
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
|
||||
: nullptr;
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
|
||||
|
@ -41,7 +41,8 @@ struct OpData {
|
||||
int effective_scale_2_b;
|
||||
int scratch_tensor_index;
|
||||
int scratch_output_tensor_index;
|
||||
bool bias_provided;
|
||||
|
||||
// Cached tensor zero point values for quantized operations.
|
||||
int input_zero_point;
|
||||
int output_zero_point;
|
||||
};
|
||||
@ -421,7 +422,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
data->bias_provided = (bias != nullptr);
|
||||
|
||||
if (input->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
|
||||
@ -498,7 +498,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* weights_time =
|
||||
tflite::micro::GetEvalInput(context, node, kWeightsTimeTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
data.bias_provided
|
||||
(NumInputs(node) == 5)
|
||||
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
|
||||
: nullptr;
|
||||
TfLiteEvalTensor* activation_state = tflite::micro::GetMutableEvalInput(
|
||||
|
@ -501,7 +501,13 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units,
|
||||
micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array,
|
||||
outputs_array, ¶ms, micro_test::reporter);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
|
||||
TfLiteStatus init_and_prepare_status = runner.InitAndPrepare();
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, init_and_prepare_status);
|
||||
|
||||
// Abort early to make it clear init and prepare failed.
|
||||
if (init_and_prepare_status != kTfLiteOk) {
|
||||
return;
|
||||
}
|
||||
|
||||
int num_inputs = input_sequences_len / (input_size * batch_size);
|
||||
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -253,6 +254,10 @@ struct OpData {
|
||||
int32_t output_multiplier;
|
||||
int output_shift;
|
||||
|
||||
// Cached tensor zero point values for quantized operations.
|
||||
int32_t input_zero_point;
|
||||
int32_t output_zero_point;
|
||||
|
||||
// Per channel output multiplier and shift.
|
||||
int32_t* per_channel_output_multiplier;
|
||||
int32_t* per_channel_output_shift;
|
||||
@ -334,7 +339,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
op_data->per_channel_output_shift =
|
||||
reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
|
||||
context, num_channels * sizeof(int32_t)));
|
||||
|
||||
op_data->input_zero_point = input->params.zero_point;
|
||||
op_data->output_zero_point = output->params.zero_point;
|
||||
// All per-channel quantized tensors need valid zero point and scale arrays.
|
||||
if (input->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
|
||||
@ -362,14 +368,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data,
|
||||
const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output,
|
||||
TfLiteTensor* im2col) {
|
||||
const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* filter,
|
||||
const TfLiteEvalTensor* bias,
|
||||
TfLiteEvalTensor* output,
|
||||
TfLiteEvalTensor* im2col) {
|
||||
// TODO(b/154032858): Investigate removing extra copies.
|
||||
ConvParams op_params;
|
||||
op_params.input_offset = -input->params.zero_point;
|
||||
op_params.output_offset = output->params.zero_point;
|
||||
op_params.input_offset = -data->input_zero_point;
|
||||
op_params.output_offset = data->output_zero_point;
|
||||
op_params.stride_height = params->stride_height;
|
||||
op_params.stride_width = params->stride_width;
|
||||
op_params.dilation_height_factor = params->dilation_height_factor;
|
||||
@ -381,11 +388,14 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
xtensa::hifimini::ConvPerChannel(
|
||||
op_params, data->per_channel_output_multiplier,
|
||||
data->per_channel_output_shift, GetTensorShape(input),
|
||||
GetTensorData<int8_t>(input), GetTensorShape(filter),
|
||||
GetTensorData<int8_t>(filter), GetTensorShape(bias),
|
||||
GetTensorData<int32_t>(bias), GetTensorShape(output),
|
||||
GetTensorData<int8_t>(output));
|
||||
data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
@ -394,10 +404,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
|
||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||
const TfLiteEvalTensor* filter =
|
||||
tflite::micro::GetEvalInput(context, node, kFilterTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
(NumInputs(node) == 3)
|
||||
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
|
||||
: nullptr;
|
||||
|
||||
int* input_dims = input->dims->data;
|
||||
int* filter_dims = filter->dims->data;
|
||||
@ -405,14 +421,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
input_dims[3] == 32 && filter_dims[0] == 32 && filter_dims[1] == 1 &&
|
||||
filter_dims[2] == 1 && filter_dims[3] == 32) {
|
||||
xtensa::hifimini::Conv1x32Input32x32Filter(
|
||||
-input->params.zero_point, output->params.zero_point,
|
||||
-op_data->input_zero_point, op_data->output_zero_point,
|
||||
op_data->output_activation_min, op_data->output_activation_max,
|
||||
op_data->per_channel_output_multiplier,
|
||||
op_data->per_channel_output_shift, GetTensorShape(input),
|
||||
GetTensorData<int8_t>(input), GetTensorShape(filter),
|
||||
GetTensorData<int8_t>(filter), GetTensorShape(bias),
|
||||
GetTensorData<int32_t>(bias), GetTensorShape(output),
|
||||
GetTensorData<int8_t>(output));
|
||||
op_data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -300,6 +301,10 @@ struct OpData {
|
||||
int32_t output_multiplier;
|
||||
int output_shift;
|
||||
|
||||
// Cached tensor zero point values for quantized operations.
|
||||
int32_t input_zero_point;
|
||||
int32_t output_zero_point;
|
||||
|
||||
// Per channel output multiplier and shift.
|
||||
// TODO(b/141139247): Allocate these dynamically when possible.
|
||||
int32_t* per_channel_output_multiplier;
|
||||
@ -363,6 +368,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
||||
const TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
@ -383,6 +389,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
|
||||
context, num_channels * sizeof(int32_t)));
|
||||
|
||||
op_data->input_zero_point = input->params.zero_point;
|
||||
op_data->output_zero_point = output->params.zero_point;
|
||||
|
||||
// All per-channel quantized tensors need valid zero point and scale arrays.
|
||||
if (input->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
|
||||
@ -408,9 +417,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteDepthwiseConvParams* params, OpData* data,
|
||||
const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output) {
|
||||
const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* filter,
|
||||
const TfLiteEvalTensor* bias,
|
||||
TfLiteEvalTensor* output) {
|
||||
DepthwiseParams op_params;
|
||||
op_params.padding_type = PaddingType::kSame;
|
||||
op_params.padding_values.width = data->padding.width;
|
||||
@ -420,20 +430,23 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
op_params.dilation_width_factor = params->dilation_width_factor;
|
||||
op_params.dilation_height_factor = params->dilation_height_factor;
|
||||
op_params.depth_multiplier = params->depth_multiplier;
|
||||
op_params.input_offset = -input->params.zero_point;
|
||||
op_params.input_offset = -data->input_zero_point;
|
||||
op_params.weights_offset = 0;
|
||||
op_params.output_offset = output->params.zero_point;
|
||||
op_params.output_offset = data->output_zero_point;
|
||||
// TODO(b/130439627): Use calculated value for clamping.
|
||||
op_params.quantized_activation_min = std::numeric_limits<int8_t>::min();
|
||||
op_params.quantized_activation_max = std::numeric_limits<int8_t>::max();
|
||||
|
||||
xtensa::hifimini::DepthwiseConvPerChannel(
|
||||
op_params, data->per_channel_output_multiplier,
|
||||
data->per_channel_output_shift, GetTensorShape(input),
|
||||
GetTensorData<int8_t>(input), GetTensorShape(filter),
|
||||
GetTensorData<int8_t>(filter), GetTensorShape(bias),
|
||||
GetTensorData<int32_t>(bias), GetTensorShape(output),
|
||||
GetTensorData<int8_t>(output));
|
||||
data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
@ -443,11 +456,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
|
||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
||||
const TfLiteTensor* bias =
|
||||
(NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr;
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||
const TfLiteEvalTensor* filter =
|
||||
tflite::micro::GetEvalInput(context, node, kFilterTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
(NumInputs(node) == 3)
|
||||
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
|
||||
: nullptr;
|
||||
|
||||
// Handle special case for streaming model.
|
||||
int* input_dims = input->dims->data;
|
||||
@ -456,14 +474,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
input_dims[3] == 32 && filter_dims[0] == 1 && filter_dims[1] == 4 &&
|
||||
filter_dims[2] == 1 && filter_dims[3] == 32) {
|
||||
xtensa::hifimini::DepthwiseConv4x32MatchingInputAndFilter(
|
||||
-input->params.zero_point, output->params.zero_point,
|
||||
-op_data->input_zero_point, op_data->output_zero_point,
|
||||
std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max(),
|
||||
op_data->per_channel_output_multiplier,
|
||||
op_data->per_channel_output_shift, GetTensorShape(input),
|
||||
GetTensorData<int8_t>(input), GetTensorShape(filter),
|
||||
GetTensorData<int8_t>(filter), GetTensorShape(bias),
|
||||
GetTensorData<int32_t>(bias), GetTensorShape(output),
|
||||
GetTensorData<int8_t>(output));
|
||||
op_data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
switch (input->type) { // Already know in/out types are same.
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -128,6 +129,12 @@ struct OpData {
|
||||
// be represented as a fixed point multiplier plus a left shift.
|
||||
int32_t output_multiplier;
|
||||
int output_shift;
|
||||
|
||||
// Cached tensor zero point values for quantized operations.
|
||||
int32_t input_zero_point;
|
||||
int32_t filter_zero_point;
|
||||
int32_t output_zero_point;
|
||||
|
||||
// The range of the fused activation layer. For example for kNone and
|
||||
// uint8_t these would be 0 and 255.
|
||||
int32_t output_activation_min;
|
||||
@ -147,8 +154,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output,
|
||||
OpData* data) {
|
||||
TFLITE_DCHECK(data_type != kTfLiteFloat32);
|
||||
|
||||
double real_multiplier = 0.0;
|
||||
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
||||
context, input, filter, bias, output, &real_multiplier));
|
||||
@ -179,33 +184,49 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
if (input->type != kTfLiteInt8) {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
data->input_zero_point = input->params.zero_point;
|
||||
data->filter_zero_point = filter->params.zero_point;
|
||||
data->output_zero_point = output->params.zero_point;
|
||||
|
||||
return CalculateOpData(context, params->activation, input->type, input,
|
||||
filter, bias, output, data);
|
||||
}
|
||||
|
||||
TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
const OpData& data, const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output) {
|
||||
const OpData& data,
|
||||
const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* filter,
|
||||
const TfLiteEvalTensor* bias,
|
||||
TfLiteEvalTensor* output) {
|
||||
// TODO(b/154032858): Investigate removing extra copies, and also passing by
|
||||
// value. TODO(b/155656675): Consider passing OpData by value once it is also
|
||||
// passed to the FullyConnected function. Until it is copied to a local
|
||||
// op_param variable, we do not get any latency improvements from passing by
|
||||
// value.
|
||||
FullyConnectedParams op_params;
|
||||
op_params.input_offset = -input->params.zero_point;
|
||||
op_params.weights_offset = -filter->params.zero_point;
|
||||
op_params.output_offset = output->params.zero_point;
|
||||
op_params.input_offset = -data.input_zero_point;
|
||||
op_params.weights_offset = -data.filter_zero_point;
|
||||
op_params.output_offset = data.output_zero_point;
|
||||
op_params.output_multiplier = data.output_multiplier;
|
||||
op_params.output_shift = data.output_shift;
|
||||
op_params.quantized_activation_min = data.output_activation_min;
|
||||
op_params.quantized_activation_max = data.output_activation_max;
|
||||
|
||||
xtensa::hifimini::FullyConnected(
|
||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
GetTensorShape(filter), GetTensorData<int8_t>(filter),
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||
GetTensorShape(output), GetTensorData<int8_t>(output));
|
||||
op_params, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@ -213,12 +234,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||
const TfLiteEvalTensor* filter =
|
||||
tflite::micro::GetEvalInput(context, node, kWeightsTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
(NumInputs(node) == 3)
|
||||
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
|
||||
: nullptr;
|
||||
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
|
||||
TFLITE_DCHECK(filter->type == kTfLiteInt8);
|
||||
return EvalQuantizedInt8(context, node, data, input, filter, bias, output);
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -132,11 +133,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
auto* op_data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||
|
||||
tflite::QuantizationParams op_params;
|
||||
op_params.zero_point = output->params.zero_point;
|
||||
op_params.zero_point = op_data->zero_point;
|
||||
|
||||
if (input->type != kTfLiteInt16 && output->type != kTfLiteInt8) {
|
||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||
@ -146,9 +147,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
xtensa::hifimini::AffineQuantize(
|
||||
op_data->scale_multiplier, op_data->zero_point, GetTensorShape(input),
|
||||
GetTensorData<int16_t>(input), GetTensorShape(output),
|
||||
GetTensorData<int8_t>(output));
|
||||
op_data->scale_multiplier, op_data->zero_point,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int16_t>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@ -181,13 +182,14 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* op_data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||
|
||||
if (input->type == kTfLiteInt8 && output->type == kTfLiteInt16) {
|
||||
return Softmax(*op_data, GetTensorShape(input),
|
||||
GetTensorData<int8_t>(input), GetTensorShape(output),
|
||||
GetTensorData<int16_t>(output));
|
||||
return Softmax(*op_data, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int16_t>(output));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/activation_utils.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -41,6 +42,10 @@ struct OpData {
|
||||
int effective_scale_2_b;
|
||||
int scratch_tensor_index;
|
||||
int scratch_output_tensor_index;
|
||||
|
||||
// Cached tensor zero point values for quantized operations.
|
||||
int input_zero_point;
|
||||
int output_zero_point;
|
||||
};
|
||||
|
||||
// Input tensors.
|
||||
@ -62,14 +67,13 @@ constexpr int kOutputTensor = 0;
|
||||
* reduce the latency. See b/155656675 for more details.
|
||||
*/
|
||||
void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteTensor* input_tensor,
|
||||
const TfLiteTensor* weights_feature_tensor,
|
||||
const TfLiteTensor* weights_time_tensor,
|
||||
const TfLiteTensor* bias_tensor,
|
||||
const TfLiteEvalTensor* input_tensor,
|
||||
const TfLiteEvalTensor* weights_feature_tensor,
|
||||
const TfLiteEvalTensor* weights_time_tensor,
|
||||
const TfLiteEvalTensor* bias_tensor,
|
||||
const TfLiteSVDFParams* params,
|
||||
TfLiteTensor* activation_state_tensor,
|
||||
TfLiteTensor* output_tensor, OpData data, int32_t input_zp,
|
||||
int32_t output_zp) {
|
||||
TfLiteEvalTensor* activation_state_tensor,
|
||||
TfLiteEvalTensor* output_tensor, OpData data) {
|
||||
const int n_rank = params->rank;
|
||||
const int n_batch = input_tensor->dims->data[0];
|
||||
const int n_input = input_tensor->dims->data[1];
|
||||
@ -88,7 +92,8 @@ void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
|
||||
TFLITE_DCHECK(scratch_output_tensor != nullptr);
|
||||
|
||||
// Shift states.
|
||||
int16_t* const state_ptr = GetTensorData<int16_t>(activation_state_tensor);
|
||||
int16_t* const state_ptr =
|
||||
tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
|
||||
|
||||
// Left shift the activation_state.
|
||||
{
|
||||
@ -104,14 +109,14 @@ void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
// Feature matmul.
|
||||
{
|
||||
const int8_t* input = GetTensorData<int8_t>(input_tensor);
|
||||
const int8_t* input = tflite::micro::GetTensorData<int8_t>(input_tensor);
|
||||
const int8_t* weight_feature =
|
||||
GetTensorData<int8_t>(weights_feature_tensor);
|
||||
tflite::micro::GetTensorData<int8_t>(weights_feature_tensor);
|
||||
int16_t* result_in_batch = state_ptr + (n_memory - 1);
|
||||
|
||||
ae_q56s output_int16_max_56 = AE_CVTQ48A32S(INT16_MAX);
|
||||
ae_q56s output_int16_min_56 = AE_CVTQ48A32S(INT16_MIN);
|
||||
ae_p24x2s input_zp_24x2 = AE_MOVPA24(input_zp);
|
||||
ae_p24x2s input_zp_24x2 = AE_MOVPA24(data.input_zero_point);
|
||||
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
const int8_t* weight_feature_ptr = weight_feature - 2;
|
||||
@ -175,7 +180,8 @@ void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
|
||||
int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
|
||||
|
||||
// Perform batched vector dot product:
|
||||
const int16_t* vector1_ptr = GetTensorData<int16_t>(weights_time_tensor);
|
||||
const int16_t* vector1_ptr =
|
||||
tflite::micro::GetTensorData<int16_t>(weights_time_tensor);
|
||||
const int16_t* vector2_ptr = state_ptr + b * n_memory * n_filter;
|
||||
|
||||
const ae_p16x2s* offset_vector1 =
|
||||
@ -207,7 +213,8 @@ void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
|
||||
// Add bias.
|
||||
if (bias_tensor) {
|
||||
// Vector batch assign:
|
||||
const int32_t* bias_data = GetTensorData<int32_t>(bias_tensor);
|
||||
const int32_t* bias_data =
|
||||
tflite::micro::GetTensorData<int32_t>(bias_tensor);
|
||||
for (int i = 0; i < n_batch; ++i) {
|
||||
int32_t* output_ptr = scratch_output_tensor + i * n_unit;
|
||||
const int32_t* bias_ptr = bias_data;
|
||||
@ -238,7 +245,7 @@ void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
|
||||
// Rescale.
|
||||
ae_q56s output_int8_max_56 = AE_CVTQ48A32S(INT8_MAX);
|
||||
ae_q56s output_int8_min_56 = AE_CVTQ48A32S(INT8_MIN);
|
||||
ae_q56s output_zp_56 = AE_CVTQ48A32S(output_zp);
|
||||
ae_q56s output_zp_56 = AE_CVTQ48A32S(data.output_zero_point);
|
||||
for (int i = 0; i < n_batch * n_unit; ++i) {
|
||||
ae_q56s x_56 =
|
||||
tflite::ops::micro::xtensa::hifimini::MultiplyByQuantizedMultiplier(
|
||||
@ -249,7 +256,7 @@ void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
|
||||
// Cap min/max and convert to int32_t (already aligned to 32bit):
|
||||
x_56 = AE_MAXQ56S(x_56, output_int8_min_56);
|
||||
x_56 = AE_MINQ56S(x_56, output_int8_max_56);
|
||||
GetTensorData<int8_t>(output_tensor)[i] =
|
||||
tflite::micro::GetTensorData<int8_t>(output_tensor)[i] =
|
||||
static_cast<int8_t>(AE_TRUNCA32Q48(x_56));
|
||||
}
|
||||
}
|
||||
@ -365,6 +372,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
&data->effective_scale_2_a,
|
||||
&data->effective_scale_2_b);
|
||||
|
||||
data->input_zero_point = input->params.zero_point;
|
||||
data->output_zero_point = output->params.zero_point;
|
||||
|
||||
const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
|
||||
context, batch_size * num_filters * sizeof(int32_t),
|
||||
&(data->scratch_tensor_index));
|
||||
@ -381,22 +391,26 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = static_cast<TfLiteSVDFParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* weights_feature =
|
||||
GetInput(context, node, kWeightsFeatureTensor);
|
||||
const TfLiteTensor* weights_time =
|
||||
GetInput(context, node, kWeightsTimeTensor);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* activation_state =
|
||||
GetVariableInput(context, node, kInputActivationStateTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||
const TfLiteEvalTensor* weights_feature =
|
||||
tflite::micro::GetEvalInput(context, node, kWeightsFeatureTensor);
|
||||
const TfLiteEvalTensor* weights_time =
|
||||
tflite::micro::GetEvalInput(context, node, kWeightsTimeTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
(NumInputs(node) == 5)
|
||||
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
|
||||
: nullptr;
|
||||
TfLiteEvalTensor* activation_state = tflite::micro::GetMutableEvalInput(
|
||||
context, node, kInputActivationStateTensor);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
||||
|
||||
EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias,
|
||||
params, activation_state, output, data,
|
||||
input->params.zero_point, output->params.zero_point);
|
||||
params, activation_state, output, data);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user