Build for TARGET_ARCH=fusion_f1 via reference kernel fallbacks.
This change adds reference fallbacks to the optimized xtensa kernels for the case when TARGET_ARCH is anything other than hifimini. This sets the stage for a baseline from which we can incrementally optimize for architectures other than hifimini. The goal is to have a starting point where all the unit tests pass for `TARGET_ARCH=hifimini` (which will use the optimized implementations) or any other `TARGET_ARCH` (with reference fallback). Tested for `TARGET_ARCH=fusion_f1` with: ``` make -f tensorflow/lite/micro/tools/make/Makefile -j8 TARGET=xtensa OPTIMIZED_KERNEL_DIR=xtensa TARGET_ARCH=fusion_f1 XTENSA_CORE=Google_F1 test ``` With the following profiling results: ``` make -f tensorflow/lite/micro/tools/make/Makefile -j8 TARGET=xtensa OPTIMIZED_KERNEL_DIR=xtensa TARGET_ARCH=fusion_f1 XTENSA_CORE=Google_F1 test_keyword_benchmark InitializeKeywordRunner() took 239061 ticks (239 ms) KeywordRunNIerations(1) took 168564 ticks (168 ms) KeywordRunNIerations(10) took 1685111 ticks (1685 ms) ``` ``` make -f tensorflow/lite/micro/tools/make/Makefile -j8 TARGET=xtensa OPTIMIZED_KERNEL_DIR=xtensa TARGET_ARCH=fusion_f1 XTENSA_CORE=Google_F1 keyword_benchmark BUILD_TYPE=release xt-size tensorflow/lite/micro/tools/make/gen/xtensa_fusion_f1/bin/keyword_benchmark text data bss dec hex filename 48256 40132 24952 113340 1babc tensorflow/lite/micro/tools/make/gen/xtensa_fusion_f1/bin/keyword_benchmark ``` After this change, we can: * add a continuous build for Hifi4 * add optimizations for Hifi4 on a per kernelbasis and keep profiling the impact of these optimizations on the keyword benchmark cycles and binary size. Also tested that `TARGET_ARCH=hifimini` is unaffected: ``` make -f tensorflow/lite/micro/tools/make/Makefile -j8 TARGET=xtensa OPTIMIZED_KERNEL_DIR=xtensa TARGET_ARCH=hifimini XTENSA_CORE=mini1m1m_RG test_keyword_benchmark InitializeKeywordRunner() took 1392788 ticks (1392 ms) KeywordRunNIerations(1) took 89195 ticks (89 ms) KeywordRunNIerations(10) took 891509 ticks (891 ms) ``` ``` make -f tensorflow/lite/micro/tools/make/Makefile -j8 TARGET=xtensa OPTIMIZED_KERNEL_DIR=xtensa TARGET_ARCH=hifimini XTENSA_CORE=mini1m1m_RG keyword_benchmark BUILD_TYPE=release xt-size tensorflow/lite/micro/tools/make/gen/xtensa_hifimini/bin/keyword_benchmark text data bss dec hex filename 46080 40204 24952 111236 1b284 tensorflow/lite/micro/tools/make/gen/xtensa_hifimini/bin/keyword_benchmark ```
This commit is contained in:
parent
be7b8874d0
commit
00f5e3ce05
@ -126,6 +126,7 @@ cc_library(
|
||||
"split_v.cc",
|
||||
"strided_slice.cc",
|
||||
"sub.cc",
|
||||
"svdf_common.cc",
|
||||
"tanh.cc",
|
||||
"unpack.cc",
|
||||
] + select({
|
||||
@ -144,7 +145,10 @@ cc_library(
|
||||
"xtensa/svdf.cc",
|
||||
],
|
||||
}),
|
||||
hdrs = ["micro_ops.h"],
|
||||
hdrs = [
|
||||
"micro_ops.h",
|
||||
"svdf.h",
|
||||
],
|
||||
copts = micro_copts(),
|
||||
visibility = [
|
||||
# Needed for micro:op_resolvers but visibility can not be finer-grained
|
||||
|
@ -24,26 +24,12 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/activation_utils.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/svdf.h"
|
||||
#include "tensorflow/lite/micro/micro_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
struct OpData {
|
||||
int32_t effective_scale_1_a;
|
||||
int32_t effective_scale_2_a;
|
||||
// b versions of each scale are kept at int since the numbers are just the
|
||||
// shift value - typically between [-32, 32].
|
||||
int effective_scale_1_b;
|
||||
int effective_scale_2_b;
|
||||
int scratch_tensor_index;
|
||||
int scratch_output_tensor_index;
|
||||
|
||||
// Cached tensor zero point values for quantized operations.
|
||||
int input_zero_point;
|
||||
int output_zero_point;
|
||||
};
|
||||
|
||||
// Input tensors.
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kWeightsFeatureTensor = 1;
|
||||
@ -200,150 +186,6 @@ inline void EvalFloatSVDF(
|
||||
bias_ptr, params->activation, state_ptr, scratch_ptr, output_ptr);
|
||||
}
|
||||
|
||||
void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteEvalTensor* input_tensor,
|
||||
const TfLiteEvalTensor* weights_feature_tensor,
|
||||
const TfLiteEvalTensor* weights_time_tensor,
|
||||
const TfLiteEvalTensor* bias_tensor,
|
||||
const TfLiteSVDFParams* params,
|
||||
TfLiteEvalTensor* activation_state_tensor,
|
||||
TfLiteEvalTensor* output_tensor, const OpData& data) {
|
||||
const int n_rank = params->rank;
|
||||
const int n_batch = input_tensor->dims->data[0];
|
||||
const int n_input = input_tensor->dims->data[1];
|
||||
const int n_filter = weights_feature_tensor->dims->data[0];
|
||||
const int n_unit = n_filter / n_rank;
|
||||
const int n_memory = weights_time_tensor->dims->data[1];
|
||||
|
||||
TFLITE_DCHECK(context != nullptr);
|
||||
TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
|
||||
|
||||
int32_t* scratch_tensor = static_cast<int32_t*>(
|
||||
context->GetScratchBuffer(context, data.scratch_tensor_index));
|
||||
int32_t* scratch_output_tensor = static_cast<int32_t*>(
|
||||
context->GetScratchBuffer(context, data.scratch_output_tensor_index));
|
||||
|
||||
// Shift states.
|
||||
int16_t* const state_ptr =
|
||||
tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
|
||||
|
||||
// Left shift the activation_state.
|
||||
{
|
||||
int16_t* new_state_start = state_ptr;
|
||||
const int16_t* old_state_start = state_ptr + 1;
|
||||
const int16_t* old_state_end = state_ptr + n_batch * n_filter * n_memory;
|
||||
while (old_state_start != old_state_end) {
|
||||
*new_state_start++ = *old_state_start++;
|
||||
}
|
||||
}
|
||||
|
||||
// Note: no need to clear the latest activation, matmul is not accumulative.
|
||||
|
||||
// Feature matmul.
|
||||
{
|
||||
int16_t* state =
|
||||
tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
|
||||
const int8_t* input = tflite::micro::GetTensorData<int8_t>(input_tensor);
|
||||
const int8_t* weight_feature =
|
||||
tflite::micro::GetTensorData<int8_t>(weights_feature_tensor);
|
||||
const int32_t output_max = std::numeric_limits<int16_t>::max();
|
||||
const int32_t output_min = std::numeric_limits<int16_t>::min();
|
||||
int16_t* result_in_batch = state + (n_memory - 1);
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
const int8_t* matrix_ptr = weight_feature;
|
||||
for (int r = 0; r < n_filter; r++) {
|
||||
int32_t dot_prod = 0;
|
||||
const int8_t* vector_in_batch = input + b * n_input;
|
||||
for (int c = 0; c < n_input; c++) {
|
||||
dot_prod +=
|
||||
*matrix_ptr++ * (*vector_in_batch++ - data.input_zero_point);
|
||||
}
|
||||
dot_prod = MultiplyByQuantizedMultiplier(
|
||||
dot_prod, data.effective_scale_1_a, data.effective_scale_1_b);
|
||||
dot_prod = std::min(std::max(output_min, dot_prod), output_max);
|
||||
// This assumes state is symmetrically quantized. Otherwise last bit of
|
||||
// state should be initialized to its zero point and accumulate the
|
||||
// dot_prod.
|
||||
// Equivalent as the following:
|
||||
// result_in_batch = zero point, which happens to be zero.
|
||||
// result_in_batch += dot_prod_56.
|
||||
*result_in_batch = dot_prod;
|
||||
result_in_batch += n_memory;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Time.
|
||||
{
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
|
||||
|
||||
// Perform batched vector dot product:
|
||||
const int16_t* vector1_ptr =
|
||||
tflite::micro::GetTensorData<int16_t>(weights_time_tensor);
|
||||
const int16_t* vector2_ptr =
|
||||
tflite::micro::GetTensorData<int16_t>(activation_state_tensor) +
|
||||
b * n_memory * n_filter;
|
||||
|
||||
for (int i = 0; i < n_filter; i++) {
|
||||
*scratch_ptr_batch = 0;
|
||||
for (int j = 0; j < n_memory; j++) {
|
||||
*scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
|
||||
}
|
||||
scratch_ptr_batch++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce, add bias, rescale, activation.
|
||||
{
|
||||
// Add bias.
|
||||
if (bias_tensor) {
|
||||
// Vector batch assign:
|
||||
const int32_t* bias_data =
|
||||
tflite::micro::GetTensorData<int32_t>(bias_tensor);
|
||||
for (int i = 0; i < n_batch; ++i) {
|
||||
int32_t* output_ptr = scratch_output_tensor + i * n_unit;
|
||||
const int32_t* bias_ptr = bias_data;
|
||||
for (int j = 0; j < n_unit; ++j) {
|
||||
*output_ptr++ = *bias_ptr++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int32_t* output_ptr = scratch_output_tensor;
|
||||
for (int i = 0; i < n_batch * n_unit; ++i) {
|
||||
*output_ptr++ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce.
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit;
|
||||
int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
|
||||
|
||||
// Reduction sum vector
|
||||
for (int i = 0; i < n_unit; ++i) {
|
||||
for (int j = 0; j < n_rank; ++j) {
|
||||
output_temp_ptr[i] += *scratch_ptr_batch++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rescale.
|
||||
const int32_t output_max = std::numeric_limits<int8_t>::max();
|
||||
const int32_t output_min = std::numeric_limits<int8_t>::min();
|
||||
for (int i = 0; i < n_batch * n_unit; ++i) {
|
||||
int32_t x1 = scratch_output_tensor[i];
|
||||
int32_t x2 = MultiplyByQuantizedMultiplier(x1, data.effective_scale_2_a,
|
||||
data.effective_scale_2_b);
|
||||
int32_t x3 = x2 + data.output_zero_point;
|
||||
int32_t x4 = std::min(std::max(output_min, x3), output_max);
|
||||
tflite::micro::GetTensorData<int8_t>(output_tensor)[i] =
|
||||
static_cast<int8_t>(x4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||
@ -517,8 +359,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
case kTfLiteInt8: {
|
||||
EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias,
|
||||
params, activation_state, output, data);
|
||||
EvalIntegerSvdfReference(context, node, input, weights_feature,
|
||||
weights_time, bias, params, activation_state,
|
||||
output, data);
|
||||
return kTfLiteOk;
|
||||
break;
|
||||
}
|
||||
|
51
tensorflow/lite/micro/kernels/svdf.h
Normal file
51
tensorflow/lite/micro/kernels/svdf.h
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_SVDF_H_
|
||||
#define TENSORFLOW_LITE_MICRO_KERNELS_SVDF_H_
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
struct OpData {
|
||||
int32_t effective_scale_1_a;
|
||||
int32_t effective_scale_2_a;
|
||||
// b versions of each scale are kept at int since the numbers are just the
|
||||
// shift value - typically between [-32, 32].
|
||||
int effective_scale_1_b;
|
||||
int effective_scale_2_b;
|
||||
int scratch_tensor_index;
|
||||
int scratch_output_tensor_index;
|
||||
|
||||
// Cached tensor zero point values for quantized operations.
|
||||
int input_zero_point;
|
||||
int output_zero_point;
|
||||
};
|
||||
|
||||
|
||||
// TensorflowLite Micro-specific reference implementation for Integer SVDF.
|
||||
void EvalIntegerSvdfReference(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteEvalTensor* input_tensor,
|
||||
const TfLiteEvalTensor* weights_feature_tensor,
|
||||
const TfLiteEvalTensor* weights_time_tensor,
|
||||
const TfLiteEvalTensor* bias_tensor,
|
||||
const TfLiteSVDFParams* params,
|
||||
TfLiteEvalTensor* activation_state_tensor,
|
||||
TfLiteEvalTensor* output_tensor,
|
||||
const OpData& data);
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_SVDF_H_
|
163
tensorflow/lite/micro/kernels/svdf_common.cc
Normal file
163
tensorflow/lite/micro/kernels/svdf_common.cc
Normal file
@ -0,0 +1,163 @@
|
||||
#include <math.h>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/activation_utils.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/svdf.h"
|
||||
#include "tensorflow/lite/micro/micro_utils.h"
|
||||
|
||||
|
||||
namespace tflite {
|
||||
|
||||
void EvalIntegerSvdfReference(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteEvalTensor* input_tensor,
|
||||
const TfLiteEvalTensor* weights_feature_tensor,
|
||||
const TfLiteEvalTensor* weights_time_tensor,
|
||||
const TfLiteEvalTensor* bias_tensor,
|
||||
const TfLiteSVDFParams* params,
|
||||
TfLiteEvalTensor* activation_state_tensor,
|
||||
TfLiteEvalTensor* output_tensor,
|
||||
const OpData& data) {
|
||||
const int n_rank = params->rank;
|
||||
const int n_batch = input_tensor->dims->data[0];
|
||||
const int n_input = input_tensor->dims->data[1];
|
||||
const int n_filter = weights_feature_tensor->dims->data[0];
|
||||
const int n_unit = n_filter / n_rank;
|
||||
const int n_memory = weights_time_tensor->dims->data[1];
|
||||
|
||||
TFLITE_DCHECK(context != nullptr);
|
||||
TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
|
||||
|
||||
int32_t* scratch_tensor = static_cast<int32_t*>(
|
||||
context->GetScratchBuffer(context, data.scratch_tensor_index));
|
||||
int32_t* scratch_output_tensor = static_cast<int32_t*>(
|
||||
context->GetScratchBuffer(context, data.scratch_output_tensor_index));
|
||||
|
||||
// Shift states.
|
||||
int16_t* const state_ptr =
|
||||
tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
|
||||
|
||||
// Left shift the activation_state.
|
||||
{
|
||||
int16_t* new_state_start = state_ptr;
|
||||
const int16_t* old_state_start = state_ptr + 1;
|
||||
const int16_t* old_state_end = state_ptr + n_batch * n_filter * n_memory;
|
||||
while (old_state_start != old_state_end) {
|
||||
*new_state_start++ = *old_state_start++;
|
||||
}
|
||||
}
|
||||
|
||||
// Note: no need to clear the latest activation, matmul is not accumulative.
|
||||
|
||||
// Feature matmul.
|
||||
{
|
||||
int16_t* state =
|
||||
tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
|
||||
const int8_t* input = tflite::micro::GetTensorData<int8_t>(input_tensor);
|
||||
const int8_t* weight_feature =
|
||||
tflite::micro::GetTensorData<int8_t>(weights_feature_tensor);
|
||||
const int32_t output_max = std::numeric_limits<int16_t>::max();
|
||||
const int32_t output_min = std::numeric_limits<int16_t>::min();
|
||||
int16_t* result_in_batch = state + (n_memory - 1);
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
const int8_t* matrix_ptr = weight_feature;
|
||||
for (int r = 0; r < n_filter; r++) {
|
||||
int32_t dot_prod = 0;
|
||||
const int8_t* vector_in_batch = input + b * n_input;
|
||||
for (int c = 0; c < n_input; c++) {
|
||||
dot_prod +=
|
||||
*matrix_ptr++ * (*vector_in_batch++ - data.input_zero_point);
|
||||
}
|
||||
dot_prod = MultiplyByQuantizedMultiplier(
|
||||
dot_prod, data.effective_scale_1_a, data.effective_scale_1_b);
|
||||
dot_prod = std::min(std::max(output_min, dot_prod), output_max);
|
||||
// This assumes state is symmetrically quantized. Otherwise last bit of
|
||||
// state should be initialized to its zero point and accumulate the
|
||||
// dot_prod.
|
||||
// Equivalent as the following:
|
||||
// result_in_batch = zero point, which happens to be zero.
|
||||
// result_in_batch += dot_prod_56.
|
||||
*result_in_batch = dot_prod;
|
||||
result_in_batch += n_memory;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Time.
|
||||
{
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
|
||||
|
||||
// Perform batched vector dot product:
|
||||
const int16_t* vector1_ptr =
|
||||
tflite::micro::GetTensorData<int16_t>(weights_time_tensor);
|
||||
const int16_t* vector2_ptr =
|
||||
tflite::micro::GetTensorData<int16_t>(activation_state_tensor) +
|
||||
b * n_memory * n_filter;
|
||||
|
||||
for (int i = 0; i < n_filter; i++) {
|
||||
*scratch_ptr_batch = 0;
|
||||
for (int j = 0; j < n_memory; j++) {
|
||||
*scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
|
||||
}
|
||||
scratch_ptr_batch++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce, add bias, rescale, activation.
|
||||
{
|
||||
// Add bias.
|
||||
if (bias_tensor) {
|
||||
// Vector batch assign:
|
||||
const int32_t* bias_data =
|
||||
tflite::micro::GetTensorData<int32_t>(bias_tensor);
|
||||
for (int i = 0; i < n_batch; ++i) {
|
||||
int32_t* output_ptr = scratch_output_tensor + i * n_unit;
|
||||
const int32_t* bias_ptr = bias_data;
|
||||
for (int j = 0; j < n_unit; ++j) {
|
||||
*output_ptr++ = *bias_ptr++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int32_t* output_ptr = scratch_output_tensor;
|
||||
for (int i = 0; i < n_batch * n_unit; ++i) {
|
||||
*output_ptr++ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce.
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit;
|
||||
int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
|
||||
|
||||
// Reduction sum vector
|
||||
for (int i = 0; i < n_unit; ++i) {
|
||||
for (int j = 0; j < n_rank; ++j) {
|
||||
output_temp_ptr[i] += *scratch_ptr_batch++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rescale.
|
||||
const int32_t output_max = std::numeric_limits<int8_t>::max();
|
||||
const int32_t output_min = std::numeric_limits<int8_t>::min();
|
||||
for (int i = 0; i < n_batch * n_unit; ++i) {
|
||||
int32_t x1 = scratch_output_tensor[i];
|
||||
int32_t x2 = MultiplyByQuantizedMultiplier(x1, data.effective_scale_2_a,
|
||||
data.effective_scale_2_b);
|
||||
int32_t x3 = x2 + data.output_zero_point;
|
||||
int32_t x4 = std::min(std::max(output_min, x3), output_max);
|
||||
tflite::micro::GetTensorData<int8_t>(output_tensor)[i] =
|
||||
static_cast<int8_t>(x4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tflite
|
@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/reference/conv.h"
|
||||
|
||||
#include <xtensa/tie/xt_hifi2.h>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/conv.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
@ -60,6 +60,7 @@ struct OpData {
|
||||
int32_t output_activation_max;
|
||||
};
|
||||
|
||||
#if defined(HIFIMINI)
|
||||
void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier,
|
||||
const int32_t* output_shift,
|
||||
const RuntimeShape& input_shape, const int8_t* input_data,
|
||||
@ -260,6 +261,7 @@ inline void Conv1x32Input32x32Filter(
|
||||
output_data[ch] = static_cast<int8_t>(AE_TRUNCA32Q48(acc_56));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, int width, int height,
|
||||
@ -379,6 +381,7 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
op_params.quantized_activation_min = data->output_activation_min;
|
||||
op_params.quantized_activation_max = data->output_activation_max;
|
||||
|
||||
#if defined(HIFIMINI)
|
||||
ConvPerChannel(op_params, data->per_channel_output_multiplier,
|
||||
data->per_channel_output_shift,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
@ -389,6 +392,18 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
#else
|
||||
reference_integer_ops::ConvPerChannel(
|
||||
op_params, data->per_channel_output_multiplier,
|
||||
data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
#endif
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
@ -408,6 +423,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
|
||||
: nullptr;
|
||||
|
||||
#if defined(HIFIMINI)
|
||||
int* input_dims = input->dims->data;
|
||||
int* filter_dims = filter->dims->data;
|
||||
if (input_dims[0] == 1 && input_dims[1] == 1 && input_dims[2] == 1 &&
|
||||
@ -427,6 +443,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
#endif
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteInt8:
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
@ -61,6 +62,7 @@ struct OpData {
|
||||
int32_t output_activation_max;
|
||||
};
|
||||
|
||||
#if defined(HIFIMINI)
|
||||
inline void DepthwiseConvPerChannel(
|
||||
const DepthwiseParams& params, const int32_t* output_multiplier,
|
||||
const int32_t* output_shift, const RuntimeShape& input_shape,
|
||||
@ -304,6 +306,7 @@ inline void DepthwiseConv4x32MatchingInputAndFilter(
|
||||
output_data[ch_1] = static_cast<int8_t>(AE_TRUNCA32Q48(block_1_acc));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteDepthwiseConvParams* params, int width,
|
||||
@ -331,7 +334,7 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
|
||||
int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
|
||||
|
||||
// TODO(b/148610881): Consider calculating quantized params at int24
|
||||
// calculations:
|
||||
// calculations for hifimini.
|
||||
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
|
||||
context, input, filter, bias, output, params->activation,
|
||||
&data->output_multiplier, &data->output_shift,
|
||||
@ -424,6 +427,7 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
op_params.quantized_activation_min = std::numeric_limits<int8_t>::min();
|
||||
op_params.quantized_activation_max = std::numeric_limits<int8_t>::max();
|
||||
|
||||
#if defined(HIFIMINI)
|
||||
DepthwiseConvPerChannel(op_params, data->per_channel_output_multiplier,
|
||||
data->per_channel_output_shift,
|
||||
tflite::micro::GetTensorShape(input),
|
||||
@ -434,6 +438,18 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
#else
|
||||
reference_integer_ops::DepthwiseConvPerChannel(
|
||||
op_params, data->per_channel_output_multiplier,
|
||||
data->per_channel_output_shift, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
#endif
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
@ -454,6 +470,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
|
||||
: nullptr;
|
||||
|
||||
#if defined(HIFIMINI)
|
||||
// Handle special case for streaming model.
|
||||
int* input_dims = input->dims->data;
|
||||
int* filter_dims = filter->dims->data;
|
||||
@ -474,6 +491,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
#endif
|
||||
|
||||
switch (input->type) { // Already know in/out types are same.
|
||||
case kTfLiteInt8:
|
||||
EvalQuantizedPerChannel(context, node, params, op_data, input, filter,
|
||||
|
@ -54,6 +54,7 @@ constexpr int kWeightsTensor = 1;
|
||||
constexpr int kBiasTensor = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
#if defined(HIFIMINI)
|
||||
void FullyConnected(const FullyConnectedParams& params,
|
||||
const RuntimeShape& input_shape, const int8_t* input_data,
|
||||
const RuntimeShape& filter_shape, const int8_t* filter_data,
|
||||
@ -137,6 +138,7 @@ void FullyConnected(const FullyConnectedParams& params,
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
TfLiteFusedActivation activation,
|
||||
@ -147,8 +149,12 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
double real_multiplier = 0.0;
|
||||
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
||||
context, input, filter, bias, output, &real_multiplier));
|
||||
#if defined(HIFIMINI)
|
||||
QuantizeMultiplierForInt24(real_multiplier, &data->output_multiplier,
|
||||
&data->output_shift);
|
||||
#else
|
||||
QuantizeMultiplier(real_multiplier, &data->output_multiplier, &data->output_shift);
|
||||
#endif
|
||||
return CalculateActivationRangeQuantized(context, activation, output,
|
||||
&data->output_activation_min,
|
||||
&data->output_activation_max);
|
||||
@ -206,6 +212,7 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
op_params.quantized_activation_min = data.output_activation_min;
|
||||
op_params.quantized_activation_max = data.output_activation_max;
|
||||
|
||||
#if defined(HIFIMINI)
|
||||
FullyConnected(op_params, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
@ -214,6 +221,18 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
#else
|
||||
reference_integer_ops::FullyConnected(
|
||||
op_params, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
#endif
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
@ -25,26 +25,12 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/activation_utils.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/svdf.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
struct OpData {
|
||||
int32_t effective_scale_1_a;
|
||||
int32_t effective_scale_2_a;
|
||||
// b versions of each scale are kept at int since the numbers are just the
|
||||
// shift value - typically between [-32, 32].
|
||||
int effective_scale_1_b;
|
||||
int effective_scale_2_b;
|
||||
int scratch_tensor_index;
|
||||
int scratch_output_tensor_index;
|
||||
|
||||
// Cached tensor zero point values for quantized operations.
|
||||
int input_zero_point;
|
||||
int output_zero_point;
|
||||
};
|
||||
|
||||
// Input tensors.
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kWeightsFeatureTensor = 1;
|
||||
@ -56,6 +42,7 @@ constexpr int kInputActivationStateTensor = 4;
|
||||
// Output tensor.
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
#if defined(HIFIMINI)
|
||||
/**
|
||||
* This version of SVDF is specific to TFLite Micro. It contains only a full
|
||||
* integer receipe with optimizations for the Xtensa HiFiMini platform.
|
||||
@ -255,6 +242,7 @@ void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context != nullptr);
|
||||
@ -357,10 +345,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
#if defined(HIFIMINI)
|
||||
QuantizeMultiplierForInt24(effective_scale_1, &data->effective_scale_1_a,
|
||||
&data->effective_scale_1_b);
|
||||
QuantizeMultiplierForInt24(effective_scale_2, &data->effective_scale_2_a,
|
||||
&data->effective_scale_2_b);
|
||||
#else
|
||||
QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
|
||||
&(data->effective_scale_1_b));
|
||||
QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
|
||||
&(data->effective_scale_2_b));
|
||||
#endif
|
||||
|
||||
data->input_zero_point = input->params.zero_point;
|
||||
data->output_zero_point = output->params.zero_point;
|
||||
@ -399,8 +394,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
||||
|
||||
#if defined(HIFIMINI)
|
||||
EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias,
|
||||
params, activation_state, output, data);
|
||||
#else
|
||||
EvalIntegerSvdfReference(context, node, input, weights_feature, weights_time, bias,
|
||||
params, activation_state, output, data);
|
||||
#endif
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
@ -336,6 +336,7 @@ tensorflow/lite/micro/kernels/split_v.cc \
|
||||
tensorflow/lite/micro/kernels/strided_slice.cc \
|
||||
tensorflow/lite/micro/kernels/sub.cc \
|
||||
tensorflow/lite/micro/kernels/svdf.cc \
|
||||
tensorflow/lite/micro/kernels/svdf_common.cc \
|
||||
tensorflow/lite/micro/kernels/tanh.cc \
|
||||
tensorflow/lite/micro/kernels/unpack.cc
|
||||
|
||||
|
@ -22,14 +22,23 @@ ifndef XTENSA_CORE
|
||||
$(error XTENSA_CORE is undefined)
|
||||
endif
|
||||
|
||||
ifeq ($(TARGET_ARCH), )
|
||||
$(error TARGET_ARCH must be specified on the command line)
|
||||
endif
|
||||
|
||||
# Create a cflag based on the specified TARGET_ARCH. For example:
|
||||
# TARGET_ARCH=hifimini --> -DHIFIMINI
|
||||
# TARGET_ARCH=fusion_f1 --> -DFUSION_F1
|
||||
TARGET_ARCH_DEFINES := -D$(shell echo $(TARGET_ARCH) | tr [a-z] [A-Z])
|
||||
|
||||
PLATFORM_FLAGS = \
|
||||
-DTF_LITE_MCU_DEBUG_LOG \
|
||||
-DTF_LITE_USE_CTIME \
|
||||
--xtensa-core=$(XTENSA_CORE) \
|
||||
-mcoproc \
|
||||
-DXTENSA \
|
||||
-DMAX_RFFT_PWR=9 \
|
||||
-DMIN_RFFT_PWR=MAX_RFFT_PWR
|
||||
-DMIN_RFFT_PWR=MAX_RFFT_PWR \
|
||||
$(TARGET_ARCH_DEFINES)
|
||||
|
||||
ifeq ($(BUILD_TYPE), release)
|
||||
PLATFORM_FLAGS += -Wno-unused-private-field
|
||||
|
Loading…
x
Reference in New Issue
Block a user