Build for TARGET_ARCH=fusion_f1 via reference kernel fallbacks.

This change adds reference fallbacks to the optimized xtensa kernels for
the case when TARGET_ARCH is anything other than hifimini.

This sets the stage for a baseline from which we can incrementally
optimize for architectures other than hifimini.

The goal is to have a starting point where all the unit tests pass for
`TARGET_ARCH=hifimini` (which will use the optimized implementations) or
any other `TARGET_ARCH` (with reference fallback).

Tested for `TARGET_ARCH=fusion_f1` with:

```
make -f tensorflow/lite/micro/tools/make/Makefile -j8 TARGET=xtensa OPTIMIZED_KERNEL_DIR=xtensa TARGET_ARCH=fusion_f1 XTENSA_CORE=Google_F1 test
```

With the following profiling results:

```
make -f tensorflow/lite/micro/tools/make/Makefile -j8 TARGET=xtensa OPTIMIZED_KERNEL_DIR=xtensa TARGET_ARCH=fusion_f1 XTENSA_CORE=Google_F1 test_keyword_benchmark

InitializeKeywordRunner() took 239061 ticks (239 ms)
KeywordRunNIerations(1) took 168564 ticks (168 ms)
KeywordRunNIerations(10) took 1685111 ticks (1685 ms)
```

```
make -f tensorflow/lite/micro/tools/make/Makefile -j8 TARGET=xtensa OPTIMIZED_KERNEL_DIR=xtensa TARGET_ARCH=fusion_f1 XTENSA_CORE=Google_F1 keyword_benchmark BUILD_TYPE=release
xt-size tensorflow/lite/micro/tools/make/gen/xtensa_fusion_f1/bin/keyword_benchmark

   text	   data	    bss	    dec	    hex	filename
  48256	  40132	  24952	 113340	  1babc	tensorflow/lite/micro/tools/make/gen/xtensa_fusion_f1/bin/keyword_benchmark
```

After this change, we can:
 * add a continuous build for Hifi4
 * add optimizations for Hifi4 on a per kernelbasis and keep profiling
   the impact of these optimizations on the keyword benchmark cycles and
   binary size.

Also tested that `TARGET_ARCH=hifimini` is unaffected:

```
make -f tensorflow/lite/micro/tools/make/Makefile -j8 TARGET=xtensa OPTIMIZED_KERNEL_DIR=xtensa TARGET_ARCH=hifimini XTENSA_CORE=mini1m1m_RG test_keyword_benchmark

InitializeKeywordRunner() took 1392788 ticks (1392 ms)
KeywordRunNIerations(1) took 89195 ticks (89 ms)
KeywordRunNIerations(10) took 891509 ticks (891 ms)
```

```
make -f tensorflow/lite/micro/tools/make/Makefile -j8 TARGET=xtensa OPTIMIZED_KERNEL_DIR=xtensa TARGET_ARCH=hifimini XTENSA_CORE=mini1m1m_RG keyword_benchmark BUILD_TYPE=release
xt-size tensorflow/lite/micro/tools/make/gen/xtensa_hifimini/bin/keyword_benchmark
   text	   data	    bss	    dec	    hex	filename
  46080	  40204	  24952	 111236	  1b284	tensorflow/lite/micro/tools/make/gen/xtensa_hifimini/bin/keyword_benchmark
```
This commit is contained in:
Advait Jain 2020-12-07 15:31:32 -08:00
parent be7b8874d0
commit 00f5e3ce05
10 changed files with 308 additions and 182 deletions

View File

@ -126,6 +126,7 @@ cc_library(
"split_v.cc",
"strided_slice.cc",
"sub.cc",
"svdf_common.cc",
"tanh.cc",
"unpack.cc",
] + select({
@ -144,7 +145,10 @@ cc_library(
"xtensa/svdf.cc",
],
}),
hdrs = ["micro_ops.h"],
hdrs = [
"micro_ops.h",
"svdf.h",
],
copts = micro_copts(),
visibility = [
# Needed for micro:op_resolvers but visibility can not be finer-grained

View File

@ -24,26 +24,12 @@ limitations under the License.
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/activation_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/svdf.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
struct OpData {
int32_t effective_scale_1_a;
int32_t effective_scale_2_a;
// b versions of each scale are kept at int since the numbers are just the
// shift value - typically between [-32, 32].
int effective_scale_1_b;
int effective_scale_2_b;
int scratch_tensor_index;
int scratch_output_tensor_index;
// Cached tensor zero point values for quantized operations.
int input_zero_point;
int output_zero_point;
};
// Input tensors.
constexpr int kInputTensor = 0;
constexpr int kWeightsFeatureTensor = 1;
@ -200,150 +186,6 @@ inline void EvalFloatSVDF(
bias_ptr, params->activation, state_ptr, scratch_ptr, output_ptr);
}
void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
const TfLiteEvalTensor* input_tensor,
const TfLiteEvalTensor* weights_feature_tensor,
const TfLiteEvalTensor* weights_time_tensor,
const TfLiteEvalTensor* bias_tensor,
const TfLiteSVDFParams* params,
TfLiteEvalTensor* activation_state_tensor,
TfLiteEvalTensor* output_tensor, const OpData& data) {
const int n_rank = params->rank;
const int n_batch = input_tensor->dims->data[0];
const int n_input = input_tensor->dims->data[1];
const int n_filter = weights_feature_tensor->dims->data[0];
const int n_unit = n_filter / n_rank;
const int n_memory = weights_time_tensor->dims->data[1];
TFLITE_DCHECK(context != nullptr);
TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
int32_t* scratch_tensor = static_cast<int32_t*>(
context->GetScratchBuffer(context, data.scratch_tensor_index));
int32_t* scratch_output_tensor = static_cast<int32_t*>(
context->GetScratchBuffer(context, data.scratch_output_tensor_index));
// Shift states.
int16_t* const state_ptr =
tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
// Left shift the activation_state.
{
int16_t* new_state_start = state_ptr;
const int16_t* old_state_start = state_ptr + 1;
const int16_t* old_state_end = state_ptr + n_batch * n_filter * n_memory;
while (old_state_start != old_state_end) {
*new_state_start++ = *old_state_start++;
}
}
// Note: no need to clear the latest activation, matmul is not accumulative.
// Feature matmul.
{
int16_t* state =
tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
const int8_t* input = tflite::micro::GetTensorData<int8_t>(input_tensor);
const int8_t* weight_feature =
tflite::micro::GetTensorData<int8_t>(weights_feature_tensor);
const int32_t output_max = std::numeric_limits<int16_t>::max();
const int32_t output_min = std::numeric_limits<int16_t>::min();
int16_t* result_in_batch = state + (n_memory - 1);
for (int b = 0; b < n_batch; b++) {
const int8_t* matrix_ptr = weight_feature;
for (int r = 0; r < n_filter; r++) {
int32_t dot_prod = 0;
const int8_t* vector_in_batch = input + b * n_input;
for (int c = 0; c < n_input; c++) {
dot_prod +=
*matrix_ptr++ * (*vector_in_batch++ - data.input_zero_point);
}
dot_prod = MultiplyByQuantizedMultiplier(
dot_prod, data.effective_scale_1_a, data.effective_scale_1_b);
dot_prod = std::min(std::max(output_min, dot_prod), output_max);
// This assumes state is symmetrically quantized. Otherwise last bit of
// state should be initialized to its zero point and accumulate the
// dot_prod.
// Equivalent as the following:
// result_in_batch = zero point, which happens to be zero.
// result_in_batch += dot_prod_56.
*result_in_batch = dot_prod;
result_in_batch += n_memory;
}
}
}
// Time.
{
for (int b = 0; b < n_batch; ++b) {
int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
// Perform batched vector dot product:
const int16_t* vector1_ptr =
tflite::micro::GetTensorData<int16_t>(weights_time_tensor);
const int16_t* vector2_ptr =
tflite::micro::GetTensorData<int16_t>(activation_state_tensor) +
b * n_memory * n_filter;
for (int i = 0; i < n_filter; i++) {
*scratch_ptr_batch = 0;
for (int j = 0; j < n_memory; j++) {
*scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
}
scratch_ptr_batch++;
}
}
}
// Reduce, add bias, rescale, activation.
{
// Add bias.
if (bias_tensor) {
// Vector batch assign:
const int32_t* bias_data =
tflite::micro::GetTensorData<int32_t>(bias_tensor);
for (int i = 0; i < n_batch; ++i) {
int32_t* output_ptr = scratch_output_tensor + i * n_unit;
const int32_t* bias_ptr = bias_data;
for (int j = 0; j < n_unit; ++j) {
*output_ptr++ = *bias_ptr++;
}
}
} else {
int32_t* output_ptr = scratch_output_tensor;
for (int i = 0; i < n_batch * n_unit; ++i) {
*output_ptr++ = 0;
}
}
// Reduce.
for (int b = 0; b < n_batch; ++b) {
int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit;
int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
// Reduction sum vector
for (int i = 0; i < n_unit; ++i) {
for (int j = 0; j < n_rank; ++j) {
output_temp_ptr[i] += *scratch_ptr_batch++;
}
}
}
// Rescale.
const int32_t output_max = std::numeric_limits<int8_t>::max();
const int32_t output_min = std::numeric_limits<int8_t>::min();
for (int i = 0; i < n_batch * n_unit; ++i) {
int32_t x1 = scratch_output_tensor[i];
int32_t x2 = MultiplyByQuantizedMultiplier(x1, data.effective_scale_2_a,
data.effective_scale_2_b);
int32_t x3 = x2 + data.output_zero_point;
int32_t x4 = std::min(std::max(output_min, x3), output_max);
tflite::micro::GetTensorData<int8_t>(output_tensor)[i] =
static_cast<int8_t>(x4);
}
}
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
@ -517,8 +359,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
case kTfLiteInt8: {
EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias,
params, activation_state, output, data);
EvalIntegerSvdfReference(context, node, input, weights_feature,
weights_time, bias, params, activation_state,
output, data);
return kTfLiteOk;
break;
}

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_SVDF_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_SVDF_H_
#include "tensorflow/lite/c/common.h"
namespace tflite {
struct OpData {
int32_t effective_scale_1_a;
int32_t effective_scale_2_a;
// b versions of each scale are kept at int since the numbers are just the
// shift value - typically between [-32, 32].
int effective_scale_1_b;
int effective_scale_2_b;
int scratch_tensor_index;
int scratch_output_tensor_index;
// Cached tensor zero point values for quantized operations.
int input_zero_point;
int output_zero_point;
};
// TensorflowLite Micro-specific reference implementation for Integer SVDF.
void EvalIntegerSvdfReference(TfLiteContext* context, TfLiteNode* node,
const TfLiteEvalTensor* input_tensor,
const TfLiteEvalTensor* weights_feature_tensor,
const TfLiteEvalTensor* weights_time_tensor,
const TfLiteEvalTensor* bias_tensor,
const TfLiteSVDFParams* params,
TfLiteEvalTensor* activation_state_tensor,
TfLiteEvalTensor* output_tensor,
const OpData& data);
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_SVDF_H_

View File

@ -0,0 +1,163 @@
#include <math.h>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/activation_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/svdf.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
void EvalIntegerSvdfReference(TfLiteContext* context, TfLiteNode* node,
const TfLiteEvalTensor* input_tensor,
const TfLiteEvalTensor* weights_feature_tensor,
const TfLiteEvalTensor* weights_time_tensor,
const TfLiteEvalTensor* bias_tensor,
const TfLiteSVDFParams* params,
TfLiteEvalTensor* activation_state_tensor,
TfLiteEvalTensor* output_tensor,
const OpData& data) {
const int n_rank = params->rank;
const int n_batch = input_tensor->dims->data[0];
const int n_input = input_tensor->dims->data[1];
const int n_filter = weights_feature_tensor->dims->data[0];
const int n_unit = n_filter / n_rank;
const int n_memory = weights_time_tensor->dims->data[1];
TFLITE_DCHECK(context != nullptr);
TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
int32_t* scratch_tensor = static_cast<int32_t*>(
context->GetScratchBuffer(context, data.scratch_tensor_index));
int32_t* scratch_output_tensor = static_cast<int32_t*>(
context->GetScratchBuffer(context, data.scratch_output_tensor_index));
// Shift states.
int16_t* const state_ptr =
tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
// Left shift the activation_state.
{
int16_t* new_state_start = state_ptr;
const int16_t* old_state_start = state_ptr + 1;
const int16_t* old_state_end = state_ptr + n_batch * n_filter * n_memory;
while (old_state_start != old_state_end) {
*new_state_start++ = *old_state_start++;
}
}
// Note: no need to clear the latest activation, matmul is not accumulative.
// Feature matmul.
{
int16_t* state =
tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
const int8_t* input = tflite::micro::GetTensorData<int8_t>(input_tensor);
const int8_t* weight_feature =
tflite::micro::GetTensorData<int8_t>(weights_feature_tensor);
const int32_t output_max = std::numeric_limits<int16_t>::max();
const int32_t output_min = std::numeric_limits<int16_t>::min();
int16_t* result_in_batch = state + (n_memory - 1);
for (int b = 0; b < n_batch; b++) {
const int8_t* matrix_ptr = weight_feature;
for (int r = 0; r < n_filter; r++) {
int32_t dot_prod = 0;
const int8_t* vector_in_batch = input + b * n_input;
for (int c = 0; c < n_input; c++) {
dot_prod +=
*matrix_ptr++ * (*vector_in_batch++ - data.input_zero_point);
}
dot_prod = MultiplyByQuantizedMultiplier(
dot_prod, data.effective_scale_1_a, data.effective_scale_1_b);
dot_prod = std::min(std::max(output_min, dot_prod), output_max);
// This assumes state is symmetrically quantized. Otherwise last bit of
// state should be initialized to its zero point and accumulate the
// dot_prod.
// Equivalent as the following:
// result_in_batch = zero point, which happens to be zero.
// result_in_batch += dot_prod_56.
*result_in_batch = dot_prod;
result_in_batch += n_memory;
}
}
}
// Time.
{
for (int b = 0; b < n_batch; ++b) {
int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
// Perform batched vector dot product:
const int16_t* vector1_ptr =
tflite::micro::GetTensorData<int16_t>(weights_time_tensor);
const int16_t* vector2_ptr =
tflite::micro::GetTensorData<int16_t>(activation_state_tensor) +
b * n_memory * n_filter;
for (int i = 0; i < n_filter; i++) {
*scratch_ptr_batch = 0;
for (int j = 0; j < n_memory; j++) {
*scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
}
scratch_ptr_batch++;
}
}
}
// Reduce, add bias, rescale, activation.
{
// Add bias.
if (bias_tensor) {
// Vector batch assign:
const int32_t* bias_data =
tflite::micro::GetTensorData<int32_t>(bias_tensor);
for (int i = 0; i < n_batch; ++i) {
int32_t* output_ptr = scratch_output_tensor + i * n_unit;
const int32_t* bias_ptr = bias_data;
for (int j = 0; j < n_unit; ++j) {
*output_ptr++ = *bias_ptr++;
}
}
} else {
int32_t* output_ptr = scratch_output_tensor;
for (int i = 0; i < n_batch * n_unit; ++i) {
*output_ptr++ = 0;
}
}
// Reduce.
for (int b = 0; b < n_batch; ++b) {
int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit;
int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
// Reduction sum vector
for (int i = 0; i < n_unit; ++i) {
for (int j = 0; j < n_rank; ++j) {
output_temp_ptr[i] += *scratch_ptr_batch++;
}
}
}
// Rescale.
const int32_t output_max = std::numeric_limits<int8_t>::max();
const int32_t output_min = std::numeric_limits<int8_t>::min();
for (int i = 0; i < n_batch * n_unit; ++i) {
int32_t x1 = scratch_output_tensor[i];
int32_t x2 = MultiplyByQuantizedMultiplier(x1, data.effective_scale_2_a,
data.effective_scale_2_b);
int32_t x3 = x2 + data.output_zero_point;
int32_t x4 = std::min(std::max(output_min, x3), output_max);
tflite::micro::GetTensorData<int8_t>(output_tensor)[i] =
static_cast<int8_t>(x4);
}
}
}
} // namespace tflite

View File

@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include <xtensa/tie/xt_hifi2.h>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
@ -60,6 +60,7 @@ struct OpData {
int32_t output_activation_max;
};
#if defined(HIFIMINI)
void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier,
const int32_t* output_shift,
const RuntimeShape& input_shape, const int8_t* input_data,
@ -260,6 +261,7 @@ inline void Conv1x32Input32x32Filter(
output_data[ch] = static_cast<int8_t>(AE_TRUNCA32Q48(acc_56));
}
}
#endif
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, int width, int height,
@ -379,6 +381,7 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
#if defined(HIFIMINI)
ConvPerChannel(op_params, data->per_channel_output_multiplier,
data->per_channel_output_shift,
tflite::micro::GetTensorShape(input),
@ -389,6 +392,18 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
#else
reference_integer_ops::ConvPerChannel(
op_params, data->per_channel_output_multiplier,
data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
#endif
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@ -408,6 +423,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
: nullptr;
#if defined(HIFIMINI)
int* input_dims = input->dims->data;
int* filter_dims = filter->dims->data;
if (input_dims[0] == 1 && input_dims[1] == 1 && input_dims[2] == 1 &&
@ -427,6 +443,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(output));
return kTfLiteOk;
}
#endif
switch (input->type) {
case kTfLiteInt8:

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
@ -61,6 +62,7 @@ struct OpData {
int32_t output_activation_max;
};
#if defined(HIFIMINI)
inline void DepthwiseConvPerChannel(
const DepthwiseParams& params, const int32_t* output_multiplier,
const int32_t* output_shift, const RuntimeShape& input_shape,
@ -304,6 +306,7 @@ inline void DepthwiseConv4x32MatchingInputAndFilter(
output_data[ch_1] = static_cast<int8_t>(AE_TRUNCA32Q48(block_1_acc));
}
}
#endif
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
TfLiteDepthwiseConvParams* params, int width,
@ -331,7 +334,7 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
// TODO(b/148610881): Consider calculating quantized params at int24
// calculations:
// calculations for hifimini.
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
context, input, filter, bias, output, params->activation,
&data->output_multiplier, &data->output_shift,
@ -424,6 +427,7 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
op_params.quantized_activation_min = std::numeric_limits<int8_t>::min();
op_params.quantized_activation_max = std::numeric_limits<int8_t>::max();
#if defined(HIFIMINI)
DepthwiseConvPerChannel(op_params, data->per_channel_output_multiplier,
data->per_channel_output_shift,
tflite::micro::GetTensorShape(input),
@ -434,6 +438,18 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
#else
reference_integer_ops::DepthwiseConvPerChannel(
op_params, data->per_channel_output_multiplier,
data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
#endif
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@ -454,6 +470,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
: nullptr;
#if defined(HIFIMINI)
// Handle special case for streaming model.
int* input_dims = input->dims->data;
int* filter_dims = filter->dims->data;
@ -474,6 +491,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(output));
return kTfLiteOk;
}
#endif
switch (input->type) { // Already know in/out types are same.
case kTfLiteInt8:
EvalQuantizedPerChannel(context, node, params, op_data, input, filter,

View File

@ -54,6 +54,7 @@ constexpr int kWeightsTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kOutputTensor = 0;
#if defined(HIFIMINI)
void FullyConnected(const FullyConnectedParams& params,
const RuntimeShape& input_shape, const int8_t* input_data,
const RuntimeShape& filter_shape, const int8_t* filter_data,
@ -137,6 +138,7 @@ void FullyConnected(const FullyConnectedParams& params,
}
}
}
#endif
TfLiteStatus CalculateOpData(TfLiteContext* context,
TfLiteFusedActivation activation,
@ -147,8 +149,12 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
#if defined(HIFIMINI)
QuantizeMultiplierForInt24(real_multiplier, &data->output_multiplier,
&data->output_shift);
#else
QuantizeMultiplier(real_multiplier, &data->output_multiplier, &data->output_shift);
#endif
return CalculateActivationRangeQuantized(context, activation, output,
&data->output_activation_min,
&data->output_activation_max);
@ -206,6 +212,7 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
op_params.quantized_activation_min = data.output_activation_min;
op_params.quantized_activation_max = data.output_activation_max;
#if defined(HIFIMINI)
FullyConnected(op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
@ -214,6 +221,18 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
#else
reference_integer_ops::FullyConnected(
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
#endif
return kTfLiteOk;
}

View File

@ -25,26 +25,12 @@ limitations under the License.
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/activation_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/svdf.h"
#include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h"
namespace tflite {
namespace {
struct OpData {
int32_t effective_scale_1_a;
int32_t effective_scale_2_a;
// b versions of each scale are kept at int since the numbers are just the
// shift value - typically between [-32, 32].
int effective_scale_1_b;
int effective_scale_2_b;
int scratch_tensor_index;
int scratch_output_tensor_index;
// Cached tensor zero point values for quantized operations.
int input_zero_point;
int output_zero_point;
};
// Input tensors.
constexpr int kInputTensor = 0;
constexpr int kWeightsFeatureTensor = 1;
@ -56,6 +42,7 @@ constexpr int kInputActivationStateTensor = 4;
// Output tensor.
constexpr int kOutputTensor = 0;
#if defined(HIFIMINI)
/**
* This version of SVDF is specific to TFLite Micro. It contains only a full
* integer receipe with optimizations for the Xtensa HiFiMini platform.
@ -255,6 +242,7 @@ void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
}
}
}
#endif
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context != nullptr);
@ -357,10 +345,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
#if defined(HIFIMINI)
QuantizeMultiplierForInt24(effective_scale_1, &data->effective_scale_1_a,
&data->effective_scale_1_b);
QuantizeMultiplierForInt24(effective_scale_2, &data->effective_scale_2_a,
&data->effective_scale_2_b);
#else
QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
&(data->effective_scale_1_b));
QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
&(data->effective_scale_2_b));
#endif
data->input_zero_point = input->params.zero_point;
data->output_zero_point = output->params.zero_point;
@ -399,8 +394,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
#if defined(HIFIMINI)
EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias,
params, activation_state, output, data);
#else
EvalIntegerSvdfReference(context, node, input, weights_feature, weights_time, bias,
params, activation_state, output, data);
#endif
return kTfLiteOk;
}

View File

@ -336,6 +336,7 @@ tensorflow/lite/micro/kernels/split_v.cc \
tensorflow/lite/micro/kernels/strided_slice.cc \
tensorflow/lite/micro/kernels/sub.cc \
tensorflow/lite/micro/kernels/svdf.cc \
tensorflow/lite/micro/kernels/svdf_common.cc \
tensorflow/lite/micro/kernels/tanh.cc \
tensorflow/lite/micro/kernels/unpack.cc

View File

@ -22,14 +22,23 @@ ifndef XTENSA_CORE
$(error XTENSA_CORE is undefined)
endif
ifeq ($(TARGET_ARCH), )
$(error TARGET_ARCH must be specified on the command line)
endif
# Create a cflag based on the specified TARGET_ARCH. For example:
# TARGET_ARCH=hifimini --> -DHIFIMINI
# TARGET_ARCH=fusion_f1 --> -DFUSION_F1
TARGET_ARCH_DEFINES := -D$(shell echo $(TARGET_ARCH) | tr [a-z] [A-Z])
PLATFORM_FLAGS = \
-DTF_LITE_MCU_DEBUG_LOG \
-DTF_LITE_USE_CTIME \
--xtensa-core=$(XTENSA_CORE) \
-mcoproc \
-DXTENSA \
-DMAX_RFFT_PWR=9 \
-DMIN_RFFT_PWR=MAX_RFFT_PWR
-DMIN_RFFT_PWR=MAX_RFFT_PWR \
$(TARGET_ARCH_DEFINES)
ifeq ($(BUILD_TYPE), release)
PLATFORM_FLAGS += -Wno-unused-private-field