Merge pull request from nyadla-sys:tfulite_upload

PiperOrigin-RevId: 302928087
Change-Id: I2efd158b43a490189b765ba735baa0d02ffe990b
This commit is contained in:
TensorFlower Gardener 2020-03-25 11:12:27 -07:00
commit 26a3d1c92d
16 changed files with 3601 additions and 0 deletions

View File

@ -0,0 +1,241 @@
/******************************************************************************
* Copyright (C) 2019 Cadence Design Systems, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to use this Software with Cadence processor cores only and
* not with any other processors and platforms, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
******************************************************************************/
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/micro_utils.h"
#include "xtensa_tf_micro_common.h"
namespace tflite {
namespace ops {
namespace micro {
namespace activations {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
template <typename Q>
inline void ReluQuantized(int32_t lower, const RuntimeShape& input_shape,
const Q* input_data, const RuntimeShape& output_shape,
Q* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const Q val = input_data[i];
const Q clamped = val < lower ? lower : val;
output_data[i] = clamped;
}
}
inline void ReluFloat(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const float val = input_data[i];
const float lower = 0.0f;
const float clamped = val < lower ? lower : val;
output_data[i] = clamped;
}
}
inline void Relu6Float(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const float val = input_data[i];
const float upper = 6.0f;
const float lower = 0.0f;
const float clamped = val > upper ? upper : val < lower ? lower : val;
output_data[i] = clamped;
}
}
template <typename Q>
inline void Relu6Quantized(Q lower, Q upper, const RuntimeShape& input_shape,
const Q* input_data,
const RuntimeShape& output_shape, Q* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const Q val = input_data[i];
const Q clamped = val > upper ? upper : val < lower ? lower : val;
output_data[i] = clamped;
}
}
TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input->type) {
case kTfLiteFloat32: {
int err;
const float* inp_data_ptr;
float* out_data_ptr;
const RuntimeShape& input_shape = GetTensorShape(input);
const RuntimeShape& output_shape = GetTensorShape(output);
const int flat_size = MatchingFlatSize(input_shape, output_shape);
inp_data_ptr = GetTensorData<float>(input);
out_data_ptr = GetTensorData<float>(output);
const float f32_pos_inf = 0x7F800000;
err = xa_nn_vec_relu_f32_f32(out_data_ptr, inp_data_ptr, f32_pos_inf,
flat_size);
CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_relu1_f32_f32 failed");
return kTfLiteOk;
}
case kTfLiteInt8: {
ReluQuantized<int8_t>(input->params.zero_point, GetTensorShape(input),
GetTensorData<int8_t>(input),
GetTensorShape(output),
GetTensorData<int8_t>(output));
return kTfLiteOk;
}
case kTfLiteUInt8: {
int err;
const uint8_t* inp_data_ptr;
uint8_t* out_data_ptr;
const RuntimeShape& input_shape = GetTensorShape(input);
const RuntimeShape& output_shape = GetTensorShape(output);
const int flat_size = MatchingFlatSize(input_shape, output_shape);
inp_data_ptr = GetTensorData<uint8_t>(input);
out_data_ptr = GetTensorData<uint8_t>(output);
err = xa_nn_vec_activation_min_max_asym8_asym8(
out_data_ptr, inp_data_ptr, 0, 255, flat_size); // Is 255 right?
CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_activation_min_max_8_8 failed");
return kTfLiteOk;
}
default: {
TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
}
TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input->type) {
case kTfLiteFloat32: {
int err;
const float* inp_data_ptr;
float* out_data_ptr;
const RuntimeShape& input_shape = GetTensorShape(input);
const RuntimeShape& output_shape = GetTensorShape(output);
const int flat_size = MatchingFlatSize(input_shape, output_shape);
inp_data_ptr = GetTensorData<float>(input);
out_data_ptr = GetTensorData<float>(output);
err = xa_nn_vec_relu6_f32_f32(out_data_ptr, inp_data_ptr, flat_size);
CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_relu1_f32_f32 failed");
return kTfLiteOk;
}
case kTfLiteInt8: {
const int8_t six = FloatToAsymmetricQuantizedInt8(
6.0f, input->params.scale, input->params.zero_point);
const int8_t zero = input->params.zero_point;
Relu6Quantized<int8_t>(
zero, six, GetTensorShape(input), GetTensorData<int8_t>(input),
GetTensorShape(output), GetTensorData<int8_t>(output));
return kTfLiteOk;
}
case kTfLiteUInt8: {
const uint8_t six = FloatToAsymmetricQuantizedUInt8(
6.0f, input->params.scale, input->params.zero_point);
const uint8_t zero = input->params.zero_point;
int err;
const uint8_t* inp_data_ptr;
uint8_t* out_data_ptr;
const RuntimeShape& input_shape = GetTensorShape(input);
const RuntimeShape& output_shape = GetTensorShape(output);
const int flat_size = MatchingFlatSize(input_shape, output_shape);
inp_data_ptr = GetTensorData<uint8_t>(input);
out_data_ptr = GetTensorData<uint8_t>(output);
err = xa_nn_vec_activation_min_max_asym8_asym8(out_data_ptr, inp_data_ptr,
zero, six, flat_size);
CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_activation_min_max_8_8 failed");
return kTfLiteOk;
}
default: {
TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
}
} // namespace activations
TfLiteRegistration* Register_RELU() {
static TfLiteRegistration r = {};
r.prepare = activations::ReluPrepare;
r.invoke = activations::ReluEval;
return &r;
}
TfLiteRegistration* Register_RELU6() {
static TfLiteRegistration r = {};
r.prepare = activations::Relu6Prepare;
r.invoke = activations::Relu6Eval;
return &r;
}
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,549 @@
/******************************************************************************
* Copyright (C) 2019 Cadence Design Systems, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to use this Software with Cadence processor cores only and
* not with any other processors and platforms, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
******************************************************************************/
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
#include "xtensa_tf_micro_common.h"
namespace tflite {
namespace ops {
namespace micro {
namespace conv {
constexpr int kInputTensor = 0;
constexpr int kFilterTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kOutputTensor = 0;
constexpr int kMaxChannels = 256;
// Conv is quantized along dimension 0:
// https://www.tensorflow.org/lite/performance/quantization_spec
constexpr int kConvQuantizedDimension = 0;
// This file has 2 implementation of Conv.
struct OpData {
TfLitePaddingValues padding;
// The scaling factor from input to output (aka the 'real multiplier') can
// be represented as a fixed point multiplier plus a left shift.
int32_t output_multiplier;
int output_shift;
// Per channel output multiplier and shift.
// (b/141139247): Allocate these dynamically when possible.
int32_t per_channel_output_multiplier[kMaxChannels];
int32_t per_channel_output_shift[kMaxChannels];
// The range of the fused activation layer. For example for kNone and
// uint8_t these would be 0 and 255.
int32_t output_activation_min;
int32_t output_activation_max;
};
inline PaddingType RuntimePaddingType(TfLitePadding padding) {
switch (padding) {
case TfLitePadding::kTfLitePaddingSame:
return PaddingType::kSame;
case TfLitePadding::kTfLitePaddingValid:
return PaddingType::kValid;
case TfLitePadding::kTfLitePaddingUnknown:
default:
return PaddingType::kNone;
}
}
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, int width, int height,
int filter_width, int filter_height, int out_width,
int out_height, const TfLiteType data_type,
OpData* data) {
bool has_bias = node->inputs->size == 3;
// Check number of inputs/outputs
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
data->padding = ComputePaddingHeightWidth(
params->stride_height, params->stride_width,
params->dilation_height_factor, params->dilation_width_factor, height,
width, filter_height, filter_width, padding, &out_height, &out_width);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
if (data_type != kTfLiteFloat32) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
const TfLiteTensor* bias =
GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
int output_channels = filter->dims->data[kConvQuantizedDimension];
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
context, input, filter, bias, output, params->activation,
&data->output_multiplier, &data->output_shift,
&data->output_activation_min, &data->output_activation_max,
data->per_channel_output_multiplier,
reinterpret_cast<int*>(data->per_channel_output_shift),
output_channels));
}
return kTfLiteOk;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return nullptr;
}
void Free(TfLiteContext* context, void* buffer) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data,
const TfLiteTensor* input,
const TfLiteTensor* filter, const TfLiteTensor* bias,
TfLiteTensor* im2col, TfLiteTensor* hwcn_weights,
TfLiteTensor* output) {
const int32_t input_offset = -input->params.zero_point;
const int32_t filter_offset = -filter->params.zero_point;
const int32_t output_offset = output->params.zero_point;
if ((params->dilation_width_factor == 1) &&
(params->dilation_height_factor == 1)) {
const uint8 *input_data, *filter_data;
const int32_t* bias_data;
uint8* output_data;
const RuntimeShape& input_shape = GetTensorShape(input);
const RuntimeShape& filter_shape = GetTensorShape(filter);
const RuntimeShape& output_shape = GetTensorShape(output);
const RuntimeShape& bias_shape = GetTensorShape(bias);
input_data = GetTensorData<uint8_t>(input);
filter_data = GetTensorData<uint8_t>(filter);
bias_data = GetTensorData<int32_t>(bias);
output_data = GetTensorData<uint8_t>(output);
const int stride_width = params->stride_width;
const int stride_height = params->stride_height;
const int dilation_width_factor = 1;
const int dilation_height_factor = 1;
const int pad_width = data->padding.width;
const int pad_height = data->padding.height;
const int32 output_activation_min = data->output_activation_min;
const int32 output_activation_max = data->output_activation_max;
const int32 output_multiplier = data->output_multiplier;
const int output_shift = -data->output_shift;
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (bias_data) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
}
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int filter_depth = filter_shape.Dims(3);
int err, output_data_format = 0;
void* p_scratch;
uint8 *p_filter, *p_out_scratch;
// Calculate filter_depth_padded as next near multiple of 4
int filter_depth_padded = (filter_depth + 3) & (~3);
int out_length = output_height * output_width * output_depth;
int required_scratch, input_precision = PREC_ASYM8;
int h, w, c;
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width, stride_height,
pad_height, output_height, input_precision);
if (required_scratch <= 0) {
TF_LITE_KERNEL_LOG(context,
"conv2d_std_asym8: xa_nn_conv2d_std_getsize failed");
return kTfLiteError;
}
ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
p_scratch = xtensa_nnlib_scratch_buf;
p_filter = (uint8*)p_scratch;
p_out_scratch =
(p_filter +
ALIGNED_SIZE((sizeof(uint8_t) * filter_height * filter_width *
filter_depth_padded * output_depth),
8));
required_scratch +=
ALIGNED_SIZE((sizeof(uint8_t) * filter_height * filter_width *
filter_depth_padded * output_depth),
8);
p_scratch =
(uint8*)(p_out_scratch + ALIGNED_SIZE(sizeof(uint8_t) * out_length, 8));
required_scratch += ALIGNED_SIZE(sizeof(uint8_t) * out_length, 8);
if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
TF_LITE_KERNEL_LOG(context,
"conv2d_std_asym8: insufficient scratch memory");
return kTfLiteError;
}
// Padding filter coefficients depthwise
for (h = 0; h < filter_height * filter_width * output_depth; h++) {
for (c = 0; c < filter_depth; c++) {
p_filter[h * filter_depth_padded + c] =
filter_data[h * filter_depth + c];
}
for (c = input_depth; c < filter_depth_padded; c++) {
p_filter[h * filter_depth_padded + c] =
-filter_offset; // filter_depth[h*input_depth + c];
}
}
for (int batch = 0; batch < batches; ++batch) {
uint8* p_out_temp;
p_out_temp = (uint8*)&p_out_scratch[0];
p_out_temp = (uint8*)ALIGN_PTR(p_out_temp, 8);
err = xa_nn_conv2d_std_asym8xasym8(
p_out_temp,
&input_data[batch * input_height * input_width * input_depth],
p_filter, // filter_data,
bias_data, input_height, input_width, input_depth, filter_height,
filter_width, output_depth, stride_width, stride_height, pad_width,
pad_height, output_height, output_width, input_offset, filter_offset,
output_multiplier, output_shift, output_offset, output_data_format,
p_scratch);
CHECK_ERR_HIFI_NNLIB_KER(
err, "conv2d_std_asym8: xa_nn_conv2d_std_asym8xasym8 failed");
for (int i = 0; i < out_length; i++) {
uint8* p_temp;
p_temp = &output_data[batch * out_length];
ACTIVATION_MIN_MAX_ASYM8(p_temp[i], p_out_temp[i],
output_activation_min, output_activation_max)
}
}
} else {
ConvParams op_params;
op_params.padding_type = RuntimePaddingType(params->padding);
op_params.padding_values.width = data->padding.width;
op_params.padding_values.height = data->padding.height;
op_params.stride_width = params->stride_width;
op_params.stride_height = params->stride_height;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.input_offset = input_offset;
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = data->output_multiplier;
op_params.output_shift = -data->output_shift;
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
reference_ops::Conv(op_params, GetTensorShape(input),
GetTensorData<uint8_t>(input), GetTensorShape(filter),
GetTensorData<uint8_t>(filter), GetTensorShape(bias),
GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorData<uint8_t>(output), GetTensorShape(im2col),
GetTensorData<uint8_t>(im2col), nullptr);
}
return kTfLiteOk;
}
void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data,
const TfLiteTensor* input,
const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output,
TfLiteTensor* im2col) {
ConvParams op_params;
op_params.input_offset = -input->params.zero_point;
op_params.output_offset = output->params.zero_point;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.padding_values.height = data->padding.height;
op_params.padding_values.width = data->padding.width;
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
reference_integer_ops::ConvPerChannel(
op_params, data->per_channel_output_multiplier,
data->per_channel_output_shift, GetTensorShape(input),
GetTensorData<int8>(input), GetTensorShape(filter),
GetTensorData<int8>(filter), GetTensorShape(bias),
GetTensorData<int32>(bias), GetTensorShape(output),
GetTensorData<int8>(output));
}
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* im2col,
TfLiteTensor* hwcn_weights, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
if ((params->dilation_width_factor == 1) &&
(params->dilation_height_factor == 1)) {
const float *input_data, *filter_data;
const float* bias_data;
float* output_data;
const RuntimeShape& input_shape = GetTensorShape(input);
const RuntimeShape& filter_shape = GetTensorShape(filter);
const RuntimeShape& output_shape = GetTensorShape(output);
const RuntimeShape& bias_shape = GetTensorShape(bias);
input_data = GetTensorData<float>(input);
filter_data = GetTensorData<float>(filter);
bias_data = GetTensorData<float>(bias);
output_data = GetTensorData<float>(output);
const int stride_width = params->stride_width;
const int stride_height = params->stride_height;
const int dilation_width_factor = 1;
const int dilation_height_factor = 1;
const int pad_width = data->padding.width;
const int pad_height = data->padding.height;
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (bias_data) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
}
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int filter_depth = filter_shape.Dims(3);
int err, output_data_format = 0;
void* p_scratch;
float *p_filter, *p_out_scratch;
// Calculate filter_depth_padded as next near multiple of 2
int filter_depth_padded = (filter_depth + 1) & (~1);
int out_length = output_height * output_width * output_depth;
int required_scratch, input_precision = PREC_F32;
int h, w, c;
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width, stride_height,
pad_height, output_height, input_precision);
if (required_scratch <= 0) {
TF_LITE_KERNEL_LOG(context,
"conv2d_std_f32: xa_nn_conv2d_std_getsize failed");
return kTfLiteError;
}
ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
p_scratch = xtensa_nnlib_scratch_buf;
p_filter = (float*)p_scratch;
p_out_scratch =
(float*)((uint8_t*)p_filter +
ALIGNED_SIZE((sizeof(float) * filter_height * filter_width *
filter_depth_padded * output_depth),
8));
required_scratch +=
ALIGNED_SIZE((sizeof(float) * filter_height * filter_width *
filter_depth_padded * output_depth),
8);
p_scratch = (float*)((uint8_t*)p_out_scratch +
ALIGNED_SIZE(sizeof(float) * out_length, 8));
required_scratch += ALIGNED_SIZE(sizeof(float) * out_length, 8);
if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
TF_LITE_KERNEL_LOG(context,
"conv2d_std_f32: insufficient scratch memory");
return kTfLiteError;
}
// Padding filter coefficients depthwise
for (h = 0; h < filter_height * filter_width * output_depth; h++) {
for (c = 0; c < filter_depth; c++) {
p_filter[h * filter_depth_padded + c] =
filter_data[h * filter_depth + c];
}
for (c = input_depth; c < filter_depth_padded; c++) {
p_filter[h * filter_depth_padded + c] = 0;
}
}
for (int batch = 0; batch < batches; ++batch) {
float* p_out_temp;
p_out_temp = (float*)&p_out_scratch[0];
p_out_temp = (float*)ALIGN_PTR(p_out_temp, 8);
err = xa_nn_conv2d_std_f32(
p_out_temp,
&input_data[batch * input_height * input_width * input_depth],
p_filter, bias_data, input_height, input_width, input_depth,
filter_height, filter_width, output_depth, stride_width,
stride_height, pad_width, pad_height, output_height, output_width,
output_data_format, p_scratch);
CHECK_ERR_HIFI_NNLIB_KER(
err, "conv2d_std_f32: xa_nn_conv2d_std_f32xf32 failed");
for (int i = 0; i < out_length; i++) {
float* p_temp;
p_temp = &output_data[batch * out_length];
ACTIVATION_MIN_MAX(float, p_temp[i], p_out_temp[i],
output_activation_min, output_activation_max)
}
}
} else {
ConvParams op_params;
op_params.padding_type = RuntimePaddingType(params->padding);
op_params.padding_values.width = data->padding.width;
op_params.padding_values.height = data->padding.height;
op_params.stride_width = params->stride_width;
op_params.stride_height = params->stride_height;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.float_activation_min = output_activation_min;
op_params.float_activation_max = output_activation_max;
reference_ops::Conv(op_params, GetTensorShape(input),
GetTensorData<float>(input), GetTensorShape(filter),
GetTensorData<float>(filter), GetTensorShape(bias),
GetTensorData<float>(bias), GetTensorShape(output),
GetTensorData<float>(output), GetTensorShape(im2col),
GetTensorData<float>(im2col));
}
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
int input_width = input->dims->data[2];
int input_height = input->dims->data[1];
int filter_width = filter->dims->data[2];
int filter_height = filter->dims->data[1];
int output_width = output->dims->data[2];
int output_height = output->dims->data[1];
OpData data;
// All per-channel quantized tensors need valid zero point and scale arrays.
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
kTfLiteAffineQuantization);
const auto* affine_quantization =
reinterpret_cast<TfLiteAffineQuantization*>(
filter->quantization.params);
TF_LITE_ENSURE(context, affine_quantization);
TF_LITE_ENSURE(context, affine_quantization->scale);
TF_LITE_ENSURE(context, affine_quantization->zero_point);
TF_LITE_ENSURE(context,
affine_quantization->scale->size == 1 ||
affine_quantization->scale->size ==
filter->dims->data[kConvQuantizedDimension]);
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
affine_quantization->zero_point->size);
}
TF_LITE_ENSURE_STATUS(CalculateOpData(
context, node, params, input_width, input_height, filter_width,
filter_height, output_width, output_height, input->type, &data));
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
EvalFloat(context, node, params, &data, input, filter, bias, nullptr,
nullptr, output);
break;
case kTfLiteInt8:
EvalQuantizedPerChannel(context, node, params, &data, input, filter, bias,
output, nullptr);
break;
case kTfLiteUInt8:
EvalQuantized(context, node, params, &data, input, filter, bias, nullptr,
nullptr, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace conv
TfLiteRegistration* Register_CONV_2D() {
static TfLiteRegistration r = {};
r.prepare = conv::Prepare;
r.invoke = conv::Eval;
return &r;
}
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,560 @@
/******************************************************************************
* Copyright (C) 2019 Cadence Design Systems, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to use this Software with Cadence processor cores only and
* not with any other processors and platforms, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
******************************************************************************/
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
#include "xtensa_tf_micro_common.h"
namespace tflite {
namespace ops {
namespace micro {
namespace depthwise_conv {
namespace {
constexpr int kInputTensor = 0;
constexpr int kFilterTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kOutputTensor = 0;
constexpr int kMaxChannels = 256;
// Depthwise conv is quantized along dimension 3:
// https://www.tensorflow.org/lite/performance/quantization_spec
constexpr int kDepthwiseConvQuantizedDimension = 3;
struct OpData {
TfLitePaddingValues padding;
// The scaling factor from input to output (aka the 'real multiplier') can
// be represented as a fixed point multiplier plus a left shift.
int32_t output_multiplier;
int output_shift;
// Per channel output multiplier and shift.
// (b/141139247): Allocate these dynamically when possible.
int32_t per_channel_output_multiplier[kMaxChannels];
int32_t per_channel_output_shift[kMaxChannels];
// The range of the fused activation layer. For example for kNone and
// uint8_t these would be 0 and 255.
int32_t output_activation_min;
int32_t output_activation_max;
};
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
TfLiteDepthwiseConvParams* params, int width,
int height, int filter_width, int filter_height,
const TfLiteType data_type, OpData* data) {
bool has_bias = node->inputs->size == 3;
// Check number of inputs/outputs
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
int unused_output_height, unused_output_width;
data->padding = ComputePaddingHeightWidth(
params->stride_height, params->stride_width, 1, 1, height, width,
filter_height, filter_width, params->padding, &unused_output_height,
&unused_output_width);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
if (data_type != kTfLiteFloat32) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
const TfLiteTensor* bias =
GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
context, input, filter, bias, output, params->activation,
&data->output_multiplier, &data->output_shift,
&data->output_activation_min, &data->output_activation_max,
data->per_channel_output_multiplier,
reinterpret_cast<int*>(data->per_channel_output_shift), num_channels));
}
return kTfLiteOk;
}
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return nullptr;
}
void Free(TfLiteContext* context, void* buffer) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteDepthwiseConvParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
if ((params->dilation_width_factor == 1) &&
(params->dilation_height_factor == 1)) {
const float *input_data, *filter_data, *bias_data;
float* output_data;
const RuntimeShape& input_shape = GetTensorShape(input);
const RuntimeShape& filter_shape = GetTensorShape(filter);
const RuntimeShape& output_shape = GetTensorShape(output);
const RuntimeShape& bias_shape = GetTensorShape(bias);
input_data = GetTensorData<float>(input);
filter_data = GetTensorData<float>(filter);
bias_data = GetTensorData<float>(bias);
output_data = GetTensorData<float>(output);
const int stride_width = params->stride_width;
const int stride_height = params->stride_height;
const int dilation_width_factor = 1;
const int dilation_height_factor = 1;
// const int dilation_width_factor = params->dilation_width_factor;;
// const int dilation_height_factor = params->dilation_height_factor;
const int pad_width = data->padding.width;
const int pad_height = data->padding.height;
const int depth_multiplier = params->depth_multiplier;
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int input_depth = input_shape.Dims(3);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int filter_depth = filter_shape.Dims(3);
TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
int32_t err, input_data_format = 0, output_data_format = 0;
void* p_scratch;
float* p_filter;
int filter_depth_padded, filter_size_padded, required_scratch;
int input_precision = PREC_F32;
int h, c, i;
ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
p_scratch = xtensa_nnlib_scratch_buf;
filter_depth_padded = (filter_depth + 1) & (~1);
filter_size_padded = filter_height * filter_width * filter_depth_padded;
required_scratch = xa_nn_conv2d_depthwise_getsize(
input_height, input_width, input_depth, filter_height, filter_width,
depth_multiplier, stride_width, stride_height, pad_width, pad_height,
output_height, output_width, input_precision, input_data_format);
if (required_scratch <= 0) {
TF_LITE_KERNEL_LOG(
context, "DepthwiseConvFloat: xa_nn_conv2d_depthwise_getsize failed");
return kTfLiteError;
}
required_scratch += ALIGNED_SIZE(sizeof(float) * filter_size_padded, 8);
if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
TF_LITE_KERNEL_LOG(context,
"DepthwiseConvFloat: insufficient scratch memory");
return kTfLiteError;
}
p_filter = (float*)p_scratch;
p_scratch = (void*)((uint8_t*)p_filter +
ALIGNED_SIZE(sizeof(float) * filter_size_padded, 8));
for (h = 0; h < filter_height * filter_width; h++) {
for (c = 0; c < filter_depth; c++) {
p_filter[h * filter_depth_padded + c] =
filter_data[h * filter_depth + c];
}
for (c = filter_depth; c < filter_depth_padded; c++) {
p_filter[h * filter_depth_padded + c] = 0;
}
}
for (i = 0; i < batches; i++) {
err = xa_nn_conv2d_depthwise_f32(
&output_data[i * output_height * output_width * output_depth],
p_filter, // filter_data,
&input_data[i * input_height * input_width * input_depth], bias_data,
input_height, input_width, input_depth, filter_height, filter_width,
depth_multiplier, stride_width, stride_height, pad_width, pad_height,
output_height, output_width, input_data_format, output_data_format,
p_scratch);
CHECK_ERR_HIFI_NNLIB_KER(
err, "DepthwiseConvFloat: xa_nn_conv2d_depthwise_f32 failed");
}
// pre loop for activation_min_max to handle alignment
int out_length = batches * output_height * output_width * output_depth;
uint32 p_unalign_val = (uint32)output_data, p_align_val;
p_align_val = (p_unalign_val + 7) & (~7);
int pre_loop_count = p_align_val - p_unalign_val;
pre_loop_count = MIN(pre_loop_count, out_length);
for (i = 0; i < pre_loop_count; i++) {
ACTIVATION_MIN_MAX(float, output_data[i], output_data[i],
output_activation_min, output_activation_max)
}
out_length = out_length - pre_loop_count;
if (out_length) {
err = xa_nn_vec_activation_min_max_f32_f32(
&output_data[i], &output_data[i], output_activation_min,
output_activation_max, out_length);
CHECK_ERR_HIFI_NNLIB_KER(
err,
"DepthwiseConvFloat: xa_nn_vec_activation_min_max_f32_f32 failed");
}
} else {
tflite::DepthwiseParams op_params;
// Padding type is ignored, but still set.
op_params.padding_type = PaddingType::kSame;
op_params.padding_values.width = data->padding.width;
op_params.padding_values.height = data->padding.height;
op_params.stride_width = params->stride_width;
op_params.stride_height = params->stride_height;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.depth_multiplier = params->depth_multiplier;
op_params.float_activation_min = output_activation_min;
op_params.float_activation_max = output_activation_max;
tflite::reference_ops::DepthwiseConv(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(filter), GetTensorData<float>(filter),
GetTensorShape(bias), GetTensorData<float>(bias),
GetTensorShape(output), GetTensorData<float>(output));
}
return kTfLiteOk;
}
void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
TfLiteDepthwiseConvParams* params, OpData* data,
const TfLiteTensor* input,
const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
DepthwiseParams op_params;
op_params.padding_type = PaddingType::kSame;
op_params.padding_values.width = data->padding.width;
op_params.padding_values.height = data->padding.height;
op_params.stride_width = params->stride_width;
op_params.stride_height = params->stride_height;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.depth_multiplier = params->depth_multiplier;
op_params.input_offset = -input->params.zero_point;
op_params.weights_offset = 0;
op_params.output_offset = output->params.zero_point;
// (b/130439627): Use calculated value for clamping.
op_params.quantized_activation_min = std::numeric_limits<int8_t>::min();
op_params.quantized_activation_max = std::numeric_limits<int8_t>::max();
reference_integer_ops::DepthwiseConvPerChannel(
op_params, data->per_channel_output_multiplier,
data->per_channel_output_shift, GetTensorShape(input),
GetTensorData<int8>(input), GetTensorShape(filter),
GetTensorData<int8>(filter), GetTensorShape(bias),
GetTensorData<int32>(bias), GetTensorShape(output),
GetTensorData<int8>(output));
}
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteDepthwiseConvParams* params, OpData* data,
const TfLiteTensor* input,
const TfLiteTensor* filter, const TfLiteTensor* bias,
TfLiteTensor* output) {
const int32_t input_offset = -input->params.zero_point;
const int32_t filter_offset = -filter->params.zero_point;
const int32_t output_offset = output->params.zero_point;
if ((params->dilation_width_factor == 1) &&
(params->dilation_height_factor == 1)) {
const uint8 *input_data, *filter_data;
const int32_t* bias_data;
uint8* output_data;
const RuntimeShape& input_shape = GetTensorShape(input);
const RuntimeShape& filter_shape = GetTensorShape(filter);
const RuntimeShape& output_shape = GetTensorShape(output);
const RuntimeShape& bias_shape = GetTensorShape(bias);
input_data = GetTensorData<uint8_t>(input);
filter_data = GetTensorData<uint8_t>(filter);
bias_data = GetTensorData<int32_t>(bias);
output_data = GetTensorData<uint8_t>(output);
const int stride_width = params->stride_width;
const int stride_height = params->stride_height;
const int dilation_width_factor = 1;
const int dilation_height_factor = 1;
// const int dilation_width_factor = params->dilation_width_factor;
// const int dilation_height_factor = params->dilation_height_factor;
const int pad_width = data->padding.width;
const int pad_height = data->padding.height;
const int depth_multiplier = params->depth_multiplier;
const int32 output_activation_min = data->output_activation_min;
const int32 output_activation_max = data->output_activation_max;
const int32 output_multiplier = data->output_multiplier;
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
const int output_shift = -data->output_shift;
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int input_depth = input_shape.Dims(3);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int filter_depth = filter_shape.Dims(3);
TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
int32_t err, i, input_data_format = 0, output_data_format = 0;
void* p_scratch;
uint8* p_filter;
int filter_depth_padded, filter_size_padded, required_scratch;
int input_precision = PREC_ASYM8;
int h, c;
ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
p_scratch = xtensa_nnlib_scratch_buf;
required_scratch = xa_nn_conv2d_depthwise_getsize(
input_height, input_width, input_depth, filter_height, filter_width,
depth_multiplier, stride_width, stride_height, pad_width, pad_height,
output_height, output_width, input_precision, input_data_format);
if (required_scratch <= 0) {
TF_LITE_KERNEL_LOG(
context, "DepthwiseConvAsym8: xa_nn_conv2d_depthwise_getsize failed");
return kTfLiteError;
}
filter_depth_padded = (filter_depth + 3) & (~3);
filter_size_padded = filter_height * filter_width * filter_depth_padded;
required_scratch += ALIGNED_SIZE(sizeof(uint8_t) * filter_size_padded, 8);
if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
TF_LITE_KERNEL_LOG(context,
"DepthwiseConvAsym8: insufficient scratch memory");
return kTfLiteError;
}
p_filter = (uint8*)p_scratch;
p_scratch = (void*)(p_filter +
ALIGNED_SIZE(sizeof(uint8_t) * filter_size_padded, 8));
for (h = 0; h < filter_height * filter_width; h++) {
for (c = 0; c < filter_depth; c++) {
p_filter[h * filter_depth_padded + c] =
filter_data[h * filter_depth + c];
}
for (c = filter_depth; c < filter_depth_padded; c++) {
p_filter[h * filter_depth_padded + c] = -filter_offset;
}
}
for (i = 0; i < batches; i++) {
err = xa_nn_conv2d_depthwise_asym8xasym8(
&output_data[i * output_height * output_width * output_depth],
p_filter, // filter_data,
&input_data[i * input_height * input_width * input_depth], bias_data,
input_height, input_width, input_depth, filter_height, filter_width,
depth_multiplier, stride_width, stride_height, pad_width, pad_height,
output_height, output_width, input_offset, filter_offset,
output_multiplier, output_shift, output_offset, input_data_format,
output_data_format, p_scratch);
CHECK_ERR_HIFI_NNLIB_KER(
err, "DepthwiseConvAsym8: xa_nn_conv2d_depthwise_asym8xasym8 failed");
}
// pre loop for activation_min_max to handle alignment
int out_length = batches * output_height * output_width * output_depth;
uint32 p_unalign_val = (uint32)output_data, p_align_val;
p_align_val = (p_unalign_val + 7) & (~7);
int pre_loop_count = p_align_val - p_unalign_val;
pre_loop_count = MIN(pre_loop_count, out_length);
for (i = 0; i < pre_loop_count; i++) {
ACTIVATION_MIN_MAX_ASYM8(output_data[i], output_data[i],
output_activation_min, output_activation_max)
}
out_length = out_length - pre_loop_count;
if (out_length > 0) {
err = xa_nn_vec_activation_min_max_asym8_asym8(
&output_data[i], &output_data[i], output_activation_min,
output_activation_max, out_length);
CHECK_ERR_HIFI_NNLIB_KER(
err,
"DepthwiseConvAsym8: xa_nn_vec_activation_min_max_asym8_asym8 "
"failed");
}
} else {
tflite::DepthwiseParams op_params;
// Padding type is ignored, but still set.
op_params.padding_type = PaddingType::kSame;
op_params.padding_values.width = data->padding.width;
op_params.padding_values.height = data->padding.height;
op_params.stride_width = params->stride_width;
op_params.stride_height = params->stride_height;
op_params.dilation_width_factor = params->dilation_width_factor;
op_params.dilation_height_factor = params->dilation_height_factor;
op_params.depth_multiplier = params->depth_multiplier;
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
op_params.input_offset = input_offset;
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = data->output_multiplier;
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
op_params.output_shift = -data->output_shift;
tflite::reference_ops::DepthwiseConv(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<uint8_t>(output));
}
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
const TfLiteTensor* bias =
(NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr;
const TfLiteType data_type = input->type;
int width = SizeOfDimension(input, 2);
int height = SizeOfDimension(input, 1);
int filter_width = SizeOfDimension(filter, 2);
int filter_height = SizeOfDimension(filter, 1);
OpData data;
// All per-channel quantized tensors need valid zero point and scale arrays.
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
kTfLiteAffineQuantization);
const auto* affine_quantization =
reinterpret_cast<TfLiteAffineQuantization*>(
filter->quantization.params);
TF_LITE_ENSURE(context, affine_quantization);
TF_LITE_ENSURE(context, affine_quantization->scale);
TF_LITE_ENSURE(context, affine_quantization->zero_point);
TF_LITE_ENSURE(
context, affine_quantization->scale->size == 1 ||
affine_quantization->scale->size ==
filter->dims->data[kDepthwiseConvQuantizedDimension]);
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
affine_quantization->zero_point->size);
}
TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height,
filter_width, filter_height, data_type,
&data));
// (aselle): Consider whether float conv and quantized conv should be
// separate ops to avoid dispatch overhead here.
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
EvalFloat(context, node, params, &data, input, filter, bias, output);
break;
case kTfLiteInt8:
EvalQuantizedPerChannel(context, node, params, &data, input, filter, bias,
output);
break;
case kTfLiteUInt8:
EvalQuantized(context, node, params, &data, input, filter, bias, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace depthwise_conv
TfLiteRegistration* Register_DEPTHWISE_CONV_2D() {
static TfLiteRegistration r = {};
r.init = depthwise_conv::Init;
r.free = depthwise_conv::Free;
r.prepare = depthwise_conv::Prepare;
r.invoke = depthwise_conv::Eval;
return &r;
}
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,81 @@
/******************************************************************************
* Copyright (C) 2019 Cadence Design Systems, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to use this Software with Cadence processor cores only and
* not with any other processors and platforms, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
******************************************************************************/
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/floor.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "xtensa_tf_micro_common.h"
namespace tflite {
namespace ops {
namespace micro {
namespace floor {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
int err;
const float* inp_data_ptr;
float* out_data_ptr;
const RuntimeShape& input_shape = GetTensorShape(input);
const RuntimeShape& output_shape = GetTensorShape(output);
const int flat_size = MatchingFlatSize(input_shape, output_shape);
inp_data_ptr = GetTensorData<float>(input);
out_data_ptr = GetTensorData<float>(output);
err = xa_nn_elm_floor_f32_f32(out_data_ptr, inp_data_ptr, flat_size);
CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_elm_floor_f32_f32 failed");
return kTfLiteOk;
}
} // namespace floor
TfLiteRegistration* Register_FLOOR() {
static TfLiteRegistration r = {};
r.invoke = floor::Eval;
return &r;
}
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,277 @@
/******************************************************************************
* Copyright (C) 2019 Cadence Design Systems, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to use this Software with Cadence processor cores only and
* not with any other processors and platforms, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
******************************************************************************/
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "xtensa_tf_micro_common.h"
namespace tflite {
namespace ops {
namespace micro {
namespace fully_connected {
namespace {
struct OpData {
// The scaling factor from input to output (aka the 'real multiplier') can
// be represented as a fixed point multiplier plus a left shift.
int32_t output_multiplier;
int output_shift;
// The range of the fused activation layer. For example for kNone and
// uint8_t these would be 0 and 255.
int32_t output_activation_min;
int32_t output_activation_max;
// The index of the temporary tensor where the quantized inputs are cached.
int input_quantized_index;
};
constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kOutputTensor = 0;
TfLiteStatus CalculateOpData(TfLiteContext* context,
TfLiteFullyConnectedParams* params,
TfLiteType data_type, const TfLiteTensor* input,
const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output,
OpData* data) {
TfLiteStatus status = kTfLiteOk;
if (data_type != kTfLiteFloat32) {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
int exponent;
QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
data->output_shift = -exponent;
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
context, params->activation, output, &data->output_activation_min,
&data->output_activation_max));
}
return status;
}
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return nullptr;
}
void Free(TfLiteContext* context, void* buffer) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
TfLiteFullyConnectedParams* params, OpData* data,
const TfLiteTensor* input,
const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
FullyConnectedParams op_params;
op_params.input_offset = -input->params.zero_point;
op_params.weights_offset = -filter->params.zero_point;
op_params.output_offset = output->params.zero_point;
op_params.output_multiplier = data->output_multiplier;
// (b/138810107): Figure out whether output shift should be inverted
op_params.output_shift = -data->output_shift;
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
reference_integer_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
GetTensorShape(filter), GetTensorData<int8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<int8_t>(output));
return kTfLiteOk;
}
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteFullyConnectedParams* params, OpData* data,
const TfLiteTensor* input,
const TfLiteTensor* filter, const TfLiteTensor* bias,
TfLiteTensor* output) {
const int32_t input_offset = -input->params.zero_point;
const int32_t filter_offset = -filter->params.zero_point;
const int32_t output_offset = output->params.zero_point;
tflite::FullyConnectedParams op_params;
op_params.input_offset = input_offset;
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = data->output_multiplier;
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
op_params.output_shift = -data->output_shift;
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
#define TF_LITE_FULLY_CONNECTED(output_data_type) \
reference_ops::FullyConnected( \
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
GetTensorShape(bias), GetTensorData<int32_t>(bias), \
GetTensorShape(output), GetTensorData<output_data_type>(output))
switch (output->type) {
case kTfLiteUInt8: {
int ret, b, weight_depth, out_depth, batches;
uint8_t* p_out = GetTensorData<uint8_t>(output);
weight_depth = GetTensorShape(filter).Dims(
GetTensorShape(filter).DimensionsCount() - 1);
out_depth = GetTensorShape(output).Dims(
GetTensorShape(output).DimensionsCount() - 1);
batches = FlatSizeSkipDim(GetTensorShape(output),
GetTensorShape(output).DimensionsCount() - 1);
for (b = 0; b < batches; b++) {
ret = xa_nn_fully_connected_asym8xasym8_asym8(
(GetTensorData<uint8_t>(output) + b * out_depth),
GetTensorData<uint8_t>(filter),
(GetTensorData<uint8_t>(input) + b * weight_depth),
GetTensorData<int32_t>(bias), weight_depth, out_depth,
op_params.input_offset, op_params.weights_offset,
op_params.output_multiplier, op_params.output_shift,
op_params.output_offset);
CHECK_ERR_HIFI_NNLIB_KER(
ret, "xa_nn_fully_connected_asym8xasym8_asym8 failed");
}
for (int i = 0; i < batches * out_depth; i++) {
ACTIVATION_MIN_MAX_ASYM8(p_out[i], p_out[i],
data->output_activation_min,
data->output_activation_max)
}
break;
}
case kTfLiteInt16:
TF_LITE_FULLY_CONNECTED(int16_t);
break;
default:
TF_LITE_KERNEL_LOG(
context,
"Quantized FullyConnected expects output data type uint8 or int16");
return kTfLiteError;
}
return kTfLiteOk;
}
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteFullyConnectedParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
tflite::FullyConnectedParams op_params;
op_params.float_activation_min = output_activation_min;
op_params.float_activation_max = output_activation_max;
int ret, b, weight_depth, out_depth, batches;
weight_depth =
GetTensorShape(filter).Dims(GetTensorShape(filter).DimensionsCount() - 1);
out_depth =
GetTensorShape(output).Dims(GetTensorShape(output).DimensionsCount() - 1);
batches = FlatSizeSkipDim(GetTensorShape(output),
GetTensorShape(output).DimensionsCount() - 1);
for (b = 0; b < batches; b++) {
ret = xa_nn_fully_connected_f32(
(GetTensorData<float>(output) + b * out_depth),
GetTensorData<float>(filter),
(GetTensorData<float>(input) + b * weight_depth),
GetTensorData<float>(bias), weight_depth, out_depth);
CHECK_ERR_HIFI_NNLIB_KER(ret, "xa_nn_fully_connected_f32 failed.");
}
float* p_out = GetTensorData<float>(output);
for (int i = 0; i < batches * out_depth; i++) {
ACTIVATION_MIN_MAX(float, p_out[i], p_out[i], output_activation_min,
output_activation_max)
}
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteType data_type = input->type;
OpData local_data_object;
OpData* data = &local_data_object;
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input,
filter, bias, output, data));
switch (filter->type) { // Already know in/out types are same.
case kTfLiteFloat32:
return EvalFloat(context, node, params, data, input, filter, bias,
output);
case kTfLiteInt8:
return EvalQuantizedInt8(context, node, params, data, input, filter, bias,
output);
case kTfLiteUInt8:
return EvalQuantized(context, node, params, data, input, filter, bias,
output);
default:
TF_LITE_KERNEL_LOG(context, "Type %d not currently supported.",
filter->type);
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace fully_connected
TfLiteRegistration* Register_FULLY_CONNECTED() {
static TfLiteRegistration r = {};
r.init = fully_connected::Init;
r.free = fully_connected::Free;
r.prepare = fully_connected::Prepare;
r.invoke = fully_connected::Eval;
return &r;
}
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,125 @@
/******************************************************************************
* Copyright (C) 2019 Cadence Design Systems, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to use this Software with Cadence processor cores only and
* not with any other processors and platforms, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
******************************************************************************/
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/logistic.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "xtensa_tf_micro_common.h"
namespace tflite {
namespace ops {
namespace micro {
namespace activations {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (input->type == kTfLiteFloat32) {
switch (output->type) {
case kTfLiteFloat32: {
int err;
const float* inp_data_ptr;
float* out_data_ptr;
const RuntimeShape& input_shape = GetTensorShape(input);
const RuntimeShape& output_shape = GetTensorShape(output);
const int flat_size = MatchingFlatSize(input_shape, output_shape);
inp_data_ptr = GetTensorData<float>(input);
out_data_ptr = GetTensorData<float>(output);
err = xa_nn_vec_sigmoid_f32_f32(out_data_ptr, inp_data_ptr, flat_size);
CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_sigmoid_f32_f32 failed");
return kTfLiteOk;
}
default:
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
} else if (input->type == kTfLiteInt8) {
switch (output->type) {
case kTfLiteInt8: {
reference_ops::Logistic(
GetTensorShape(input), GetTensorData<int8_t>(input),
input->params.scale, input->params.zero_point,
GetTensorShape(output), GetTensorData<int8_t>(output),
output->params.scale, output->params.zero_point);
return kTfLiteOk;
}
default:
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
} else {
// (b/141211002): Also support other data types once we have supported
// temporary tensors in TFLM.
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace activations
TfLiteRegistration* Register_LOGISTIC() {
static TfLiteRegistration r = {};
r.prepare = activations::Prepare;
r.invoke = activations::Eval;
return &r;
}
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,580 @@
/******************************************************************************
* Copyright (C) 2019 Cadence Design Systems, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to use this Software with Cadence processor cores only and
* not with any other processors and platforms, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
******************************************************************************/
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/pooling.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
#include "xtensa_tf_micro_common.h"
namespace tflite {
namespace ops {
namespace micro {
namespace pooling {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
struct OpData {
TfLitePaddingValues padding;
};
TfLiteStatus CalculateOpData(const TfLiteContext* context,
const TfLitePoolParams* params,
const TfLiteTensor* input,
const TfLiteTensor* output, OpData* data) {
// input: batch, height, width, channel
int height = SizeOfDimension(input, 1);
int width = SizeOfDimension(input, 2);
int out_height, out_width;
data->padding = ComputePaddingHeightWidth(
params->stride_height, params->stride_width,
/*dilation_rate_height=*/1,
/*dilation_rate_width=*/1, height, width, params->filter_height,
params->filter_width, params->padding, &out_height, &out_width);
return kTfLiteOk;
}
TfLiteStatus AverageEvalFloat(TfLiteContext* context, const TfLiteNode* node,
const TfLitePoolParams* params,
const OpData* data, const TfLiteTensor* input,
TfLiteTensor* output) {
float activation_min, activation_max;
CalculateActivationRange(params->activation, &activation_min,
&activation_max);
const int stride_height = params->stride_height;
const int stride_width = params->stride_width;
const int pad_width = data->padding.width;
const int pad_height = data->padding.height;
const int kernel_height = params->filter_height;
const int kernel_width = params->filter_width;
const RuntimeShape& input_shape = GetTensorShape(input);
const RuntimeShape& output_shape = GetTensorShape(output);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int depth = MatchingDim(input_shape, 3, output_shape, 3);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const float* inp_data_ptr;
float* out_data_ptr;
int inp_data_format = 0, out_data_format = 0, out_length;
int inp_precision = PREC_F32, out_precision = PREC_F32;
void* p_scratch;
int err, required_scratch = 0;
ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
p_scratch = (void*)xtensa_nnlib_scratch_buf;
required_scratch = xa_nn_avgpool_getsize(
depth, inp_precision, out_precision, input_height, input_width,
kernel_height, kernel_width,
stride_width, // x_stride,
stride_height, // y_stride,
pad_width, // x_padding,
pad_height, // y_padding,
output_height, output_width, inp_data_format, out_data_format);
if (required_scratch <= 0) {
TF_LITE_KERNEL_LOG(context,
"AveragepoolFloat: xa_nn_avgpool_getsize failed");
return kTfLiteError;
}
if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
TF_LITE_KERNEL_LOG(context,
"AveragepoolFloat: insufficient scratch memory");
return kTfLiteError;
}
inp_data_ptr = GetTensorData<float>(input);
out_data_ptr = GetTensorData<float>(output);
for (int batch = 0; batch < batches; ++batch) {
err = xa_nn_avgpool_f32(
&out_data_ptr[output_height * output_width * depth * batch],
&inp_data_ptr[output_height * output_width * depth * batch],
input_height, input_width, depth, kernel_height, kernel_width,
stride_width, stride_height, pad_width, pad_height, output_height,
output_width, inp_data_format, out_data_format, p_scratch);
CHECK_ERR_HIFI_NNLIB_KER(err, "AveragepoolFloat: xa_nn_avgpool_f32 failed");
}
out_length = batches * output_height * output_width * depth;
uint32 p_unalign_val = (uint32)out_data_ptr, p_align_val;
p_align_val = (p_unalign_val + 7) & (~7);
// pre loop for activation_min_max
int pre_loop_count = p_align_val - p_unalign_val;
pre_loop_count = MIN(pre_loop_count, out_length);
for (int i = 0; i < pre_loop_count; i++) {
ACTIVATION_MIN_MAX(float, out_data_ptr[i], out_data_ptr[i], activation_min,
activation_max)
}
out_length = out_length - pre_loop_count;
if (out_length) {
err = xa_nn_vec_activation_min_max_f32_f32(
out_data_ptr, out_data_ptr, activation_min, activation_max, out_length);
CHECK_ERR_HIFI_NNLIB_KER(
err, "AveragepoolFloat: xa_nn_vec_activation_min_max_f32_f32 failed");
}
return kTfLiteOk;
}
TfLiteStatus AverageEvalQuantized(TfLiteContext* context,
const TfLiteNode* node,
const TfLitePoolParams* params,
const OpData* data, const TfLiteTensor* input,
TfLiteTensor* output) {
TFLITE_DCHECK(input->type == kTfLiteUInt8 || input->type == kTfLiteInt8);
int32_t activation_min, activation_max;
(void)CalculateActivationRangeQuantized(context, params->activation, output,
&activation_min, &activation_max);
if (input->type == kTfLiteUInt8) {
const int stride_height = params->stride_height;
const int stride_width = params->stride_width;
const int pad_width = data->padding.width;
const int pad_height = data->padding.height;
const int kernel_height = params->filter_height;
const int kernel_width = params->filter_width;
const RuntimeShape& input_shape = GetTensorShape(input);
const RuntimeShape& output_shape = GetTensorShape(output);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int depth = MatchingDim(input_shape, 3, output_shape, 3);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const uint8* inp_data_ptr;
uint8* out_data_ptr;
int inp_data_format = 0, out_data_format = 0, out_length;
int inp_precision = PREC_ASYM8, out_precision = PREC_ASYM8;
void* p_scratch;
int err, required_scratch = 0;
ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
p_scratch = (void*)xtensa_nnlib_scratch_buf;
required_scratch = xa_nn_avgpool_getsize(
depth, inp_precision, out_precision, input_height, input_width,
kernel_height, kernel_width,
stride_width, // x_stride,
stride_height, // y_stride,
pad_width, // x_padding,
pad_height, // y_padding,
output_height, output_width, inp_data_format, out_data_format);
if (required_scratch <= 0) {
TF_LITE_KERNEL_LOG(context,
"AveragepoolAsym8: xa_nn_avgpool_getsize failed");
return kTfLiteError;
}
if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
TF_LITE_KERNEL_LOG(context,
"AveragepoolAsym8: insufficient scratch memory");
return kTfLiteError;
}
inp_data_ptr = GetTensorData<uint8_t>(input);
out_data_ptr = GetTensorData<uint8_t>(output);
for (int batch = 0; batch < batches; ++batch) {
err = xa_nn_avgpool_asym8(
&out_data_ptr[output_height * output_width * depth * batch],
&inp_data_ptr[output_height * output_width * depth * batch],
input_height, input_width, depth, kernel_height, kernel_width,
stride_width, stride_height, pad_width, pad_height, output_height,
output_width, inp_data_format, out_data_format, p_scratch);
CHECK_ERR_HIFI_NNLIB_KER(err,
"AveragepoolAsym8: xa_nn_avgpool_asym8 failed");
}
out_length = batches * output_height * output_width * depth;
uint32 p_unalign_val = (uint32)out_data_ptr, p_align_val;
p_align_val = (p_unalign_val + 7) & (~7);
// pre loop for activation_min_max
int pre_loop_count = p_align_val - p_unalign_val;
pre_loop_count = MIN(pre_loop_count, out_length);
for (int i = 0; i < pre_loop_count; i++) {
ACTIVATION_MIN_MAX_ASYM8(out_data_ptr[i], out_data_ptr[i], activation_min,
activation_max)
}
out_length = out_length - pre_loop_count;
if (out_length > 0) {
err = xa_nn_vec_activation_min_max_asym8_asym8(
out_data_ptr, out_data_ptr, activation_min, activation_max,
out_length);
CHECK_ERR_HIFI_NNLIB_KER(
err,
"AveragepoolAsym8: xa_nn_vec_activation_min_max_asym8_asym8 failed");
}
} else {
PoolParams op_params;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.filter_height = params->filter_height;
op_params.filter_width = params->filter_width;
op_params.padding_values.height = data->padding.height;
op_params.padding_values.width = data->padding.width;
op_params.quantized_activation_min = activation_min;
op_params.quantized_activation_max = activation_max;
reference_integer_ops::AveragePool(
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
GetTensorShape(output), GetTensorData<int8_t>(output));
}
return kTfLiteOk;
}
TfLiteStatus MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, OpData* data,
const TfLiteTensor* input, TfLiteTensor* output) {
float activation_min, activation_max;
CalculateActivationRange(params->activation, &activation_min,
&activation_max);
const int stride_height = params->stride_height;
const int stride_width = params->stride_width;
const int pad_width = data->padding.width;
const int pad_height = data->padding.height;
const int kernel_height = params->filter_height;
const int kernel_width = params->filter_width;
const RuntimeShape& input_shape = GetTensorShape(input);
const RuntimeShape& output_shape = GetTensorShape(output);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int depth = MatchingDim(input_shape, 3, output_shape, 3);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const float* inp_data_ptr;
float* out_data_ptr;
int inp_data_format = 0, out_data_format = 0, out_length;
int inp_precision = PREC_F32, out_precision = PREC_F32;
void* p_scratch;
int err, required_scratch = 0;
ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
p_scratch = (void*)xtensa_nnlib_scratch_buf;
required_scratch = xa_nn_maxpool_getsize(
depth, inp_precision, out_precision, input_height, input_width,
kernel_height, kernel_width,
stride_width, // x_stride,
stride_height, // y_stride,
pad_width, // x_padding,
pad_height, // y_padding,
output_height, output_width, inp_data_format, out_data_format);
if (required_scratch <= 0) {
TF_LITE_KERNEL_LOG(context, "MaxpoolFloat: xa_nn_maxpool_getsize failed");
return kTfLiteError;
}
if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
TF_LITE_KERNEL_LOG(context, "MaxpoolFloat: insufficient scratch memory");
return kTfLiteError;
}
inp_data_ptr = GetTensorData<float>(input);
out_data_ptr = GetTensorData<float>(output);
for (int batch = 0; batch < batches; ++batch) {
err = xa_nn_maxpool_f32(
&out_data_ptr[output_height * output_width * depth * batch],
&inp_data_ptr[output_height * output_width * depth * batch],
input_height, input_width, depth, kernel_height, kernel_width,
stride_width, stride_height, pad_width, pad_height, output_height,
output_width, inp_data_format, out_data_format, p_scratch);
CHECK_ERR_HIFI_NNLIB_KER(err, "MaxpoolFloat: xa_nn_maxpool_f32 failed");
}
out_length = batches * output_height * output_width * depth;
uint32 p_unalign_val = (uint32)out_data_ptr, p_align_val;
p_align_val = (p_unalign_val + 7) & (~7);
// pre loop for activation_min_max
int pre_loop_count = p_align_val - p_unalign_val;
pre_loop_count = MIN(pre_loop_count, out_length);
for (int i = 0; i < pre_loop_count; i++) {
ACTIVATION_MIN_MAX(float, out_data_ptr[i], out_data_ptr[i], activation_min,
activation_max)
}
out_length = out_length - pre_loop_count;
if (out_length > 0) {
err = xa_nn_vec_activation_min_max_f32_f32(
out_data_ptr, out_data_ptr, activation_min, activation_max, out_length);
CHECK_ERR_HIFI_NNLIB_KER(
err, "MaxpoolFloat: xa_nn_vec_activation_min_max_f32_f32 failed");
}
return kTfLiteOk;
}
TfLiteStatus MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, OpData* data,
const TfLiteTensor* input, TfLiteTensor* output) {
TFLITE_DCHECK(input->type == kTfLiteUInt8 || input->type == kTfLiteInt8);
int32_t activation_min, activation_max;
(void)CalculateActivationRangeQuantized(context, params->activation, output,
&activation_min, &activation_max);
if (input->type == kTfLiteUInt8) {
const int stride_height = params->stride_height;
const int stride_width = params->stride_width;
const int pad_width = data->padding.width;
const int pad_height = data->padding.height;
const int kernel_height = params->filter_height;
const int kernel_width = params->filter_width;
const RuntimeShape& input_shape = GetTensorShape(input);
const RuntimeShape& output_shape = GetTensorShape(output);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int depth = MatchingDim(input_shape, 3, output_shape, 3);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const uint8* inp_data_ptr;
uint8* out_data_ptr;
int inp_data_format = 0, out_data_format = 0, out_length;
int inp_precision = PREC_ASYM8, out_precision = PREC_ASYM8;
void* p_scratch;
int err, required_scratch = 0;
ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
p_scratch = (void*)xtensa_nnlib_scratch_buf;
required_scratch = xa_nn_maxpool_getsize(
depth, inp_precision, out_precision, input_height, input_width,
kernel_height, kernel_width,
stride_width, // x_stride,
stride_height, // y_stride,
pad_width, // x_padding,
pad_height, // y_padding,
output_height, output_width, inp_data_format, out_data_format);
if (required_scratch <= 0) {
TF_LITE_KERNEL_LOG(context, "MaxpoolAsym8: xa_nn_maxpool_getsize failed");
return kTfLiteError;
}
if (required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
TF_LITE_KERNEL_LOG(context, "MaxpoolAsym8: insufficient scratch memory");
return kTfLiteError;
}
inp_data_ptr = GetTensorData<uint8_t>(input);
out_data_ptr = GetTensorData<uint8_t>(output);
for (int batch = 0; batch < batches; ++batch) {
err = xa_nn_maxpool_asym8(
&out_data_ptr[output_height * output_width * depth * batch],
&inp_data_ptr[output_height * output_width * depth * batch],
input_height, input_width, depth, kernel_height, kernel_width,
stride_width, stride_height, pad_width, pad_height, output_height,
output_width, inp_data_format, out_data_format, p_scratch);
CHECK_ERR_HIFI_NNLIB_KER(err, "MaxpoolAsym8: xa_nn_maxpool_asym8 failed");
}
out_length = batches * output_height * output_width * depth;
uint32 p_unalign_val = (uint32)out_data_ptr, p_align_val;
p_align_val = (p_unalign_val + 7) & (~7);
// pre loop for activation_min_max
int pre_loop_count = p_align_val - p_unalign_val;
pre_loop_count = MIN(pre_loop_count, out_length);
for (int i = 0; i < pre_loop_count; i++) {
ACTIVATION_MIN_MAX_ASYM8(out_data_ptr[i], out_data_ptr[i], activation_min,
activation_max)
}
out_length = out_length - pre_loop_count;
if (out_length > 0) {
err = xa_nn_vec_activation_min_max_asym8_asym8(
out_data_ptr, out_data_ptr, activation_min, activation_max,
out_length);
CHECK_ERR_HIFI_NNLIB_KER(
err, "MaxpoolAsym8: xa_nn_vec_activation_min_max_asym8_asym8 failed");
}
} else {
tflite::PoolParams op_params;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.filter_height = params->filter_height;
op_params.filter_width = params->filter_width;
op_params.padding_values.height = data->padding.height;
op_params.padding_values.width = data->padding.width;
op_params.quantized_activation_min = activation_min;
op_params.quantized_activation_max = activation_max;
reference_integer_ops::MaxPool(
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
GetTensorShape(output), GetTensorData<int8_t>(output));
}
return kTfLiteOk;
}
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return nullptr;
}
void Free(TfLiteContext* context, void* buffer) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
OpData data;
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data));
// Inputs and outputs share the same type, guarenteed by the converter.
switch (input->type) {
case kTfLiteFloat32:
AverageEvalFloat(context, node, params, &data, input, output);
break;
case kTfLiteUInt8:
case kTfLiteInt8:
AverageEvalQuantized(context, node, params, &data, input, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;
}
TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
OpData data;
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data));
switch (input->type) {
case kTfLiteFloat32:
MaxEvalFloat(context, node, params, &data, input, output);
break;
case kTfLiteUInt8:
case kTfLiteInt8:
MaxEvalQuantized(context, node, params, &data, input, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace pooling
TfLiteRegistration* Register_AVERAGE_POOL_2D() {
static TfLiteRegistration r = {};
r.init = pooling::Init;
r.free = pooling::Free;
r.prepare = pooling::Prepare;
r.invoke = pooling::AverageEval;
return &r;
}
TfLiteRegistration* Register_MAX_POOL_2D() {
static TfLiteRegistration r = {};
r.init = pooling::Init;
r.free = pooling::Free;
r.prepare = pooling::Prepare;
r.invoke = pooling::MaxEval;
return &r;
}
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,320 @@
/******************************************************************************
* Copyright (C) 2019 Cadence Design Systems, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to use this Software with Cadence processor cores only and
* not with any other processors and platforms, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
******************************************************************************/
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "xtensa_tf_micro_common.h"
namespace tflite {
namespace ops {
namespace micro {
namespace activations {
namespace {
struct OpData {
int32_t input_multiplier = 0;
int input_left_shift = 0;
int32_t input_range_radius = 0;
int diff_min = 0;
};
TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
const TfLiteTensor* input,
TfLiteTensor* output,
const TfLiteSoftmaxParams* params,
OpData* data) {
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
} else {
if (output->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768);
// NOTE: Current int16 softmax output does not require symmetric scaling
// - so no need to verify scale here.
} else {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
}
}
static const int kScaledDiffIntegerBits = 5;
tflite::PreprocessSoftmaxScaling(
static_cast<double>(params->beta),
static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
&data->input_multiplier, &data->input_left_shift);
data->diff_min = -1.0 * tflite::CalculateInputRadius(
kScaledDiffIntegerBits, data->input_left_shift);
}
return kTfLiteOk;
}
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return nullptr;
}
void Free(TfLiteContext* context, void* buffer) {}
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
// Takes a 1D tensor and performs softmax along it.
void Softmax1DFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) {
const int input_size = input->dims->data[0];
tflite::reference_ops::Softmax(input->data.f, input_size, 1, params->beta,
output->data.f);
}
// Takes a 2D tensor and perform softmax along the last dimension.
TfLiteStatus Softmax2DFloat(TfLiteContext* context, const TfLiteTensor* input,
TfLiteTensor* output, TfLiteSoftmaxParams* params) {
const int batch_size = input->dims->data[0];
const int input_size = input->dims->data[1];
ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
float* p_scratch = (float*)xtensa_nnlib_scratch_buf;
if (input->dims->data[1] * sizeof(float) > XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
TF_LITE_KERNEL_LOG(context, "Softmax: insufficient scratch memory");
return kTfLiteError;
}
for (int i = 0; i < batch_size * input_size; ++i) {
p_scratch[i] = input->data.f[i] * params->beta;
}
for (int i = 0; i < batch_size; ++i) {
int err = xa_nn_vec_softmax_f32_f32(&output->data.f[i * input_size],
&p_scratch[i * input_size], input_size);
CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_softmax_f32_f32 failed");
}
return kTfLiteOk;
}
void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
// (ahentz): this is arguably a dirty trick. Since the implementation
// always traverses the last dimension of a 4D tensor, we will pretend our 1D
// tensor is 4D in a special way. We will convert a (Y) shape into a (1,
// 1, 1, Y) shape.
const int input_size = input->dims->data[0];
const int32_t shape_data[4] = {1, 1, 1, input_size};
RuntimeShape shape(4, shape_data);
SoftmaxParams op_params;
op_params.input_multiplier = data->input_multiplier;
op_params.input_left_shift = data->input_left_shift;
op_params.diff_min = data->diff_min;
if (input->type == kTfLiteUInt8) {
tflite::reference_ops::Softmax(op_params, shape,
GetTensorData<uint8_t>(input), shape,
GetTensorData<uint8_t>(output));
} else {
if (output->type == kTfLiteInt16) {
tflite::reference_integer_ops::Softmax(
op_params, shape, GetTensorData<int8_t>(input), shape,
GetTensorData<int16_t>(output));
} else {
tflite::reference_integer_ops::Softmax(
op_params, shape, GetTensorData<int8_t>(input), shape,
GetTensorData<int8_t>(output));
}
}
}
TfLiteStatus Softmax2DQuantized(TfLiteContext* context,
const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
// (ahentz): this is arguably a dirty trick. Since the implementation
// always traverses the last dimension of a 4D tensor, we will pretend our 2D
// tensor is 4D in a special way. We will convert a (X, Y) shape into a (X,
// 1, 1, Y) shape.
const int batch_size = input->dims->data[0];
const int input_size = input->dims->data[1];
const int32_t shape_data[4] = {batch_size, 1, 1, input_size};
RuntimeShape shape(4, shape_data);
SoftmaxParams op_params;
op_params.input_multiplier = data->input_multiplier;
op_params.input_left_shift = data->input_left_shift;
op_params.diff_min = data->diff_min;
if (input->type == kTfLiteUInt8) {
ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM;
void* p_scratch = (void*)xtensa_nnlib_scratch_buf;
if (get_softmax_scratch_size(PREC_ASYM8, PREC_ASYM8, input_size) >
XTENSA_NNLIB_MAX_SCRATCH_SIZE) {
TF_LITE_KERNEL_LOG(context, "Softmax: insufficient scratch memory");
return kTfLiteError;
}
for (int i = 0; i < batch_size; ++i) {
int err = xa_nn_vec_softmax_asym8_asym8(
&output->data.uint8[i * input_size],
&input->data.uint8[i * input_size], op_params.diff_min,
op_params.input_left_shift, op_params.input_multiplier, input_size,
p_scratch);
CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_softmax_asym8_asym8 failed");
}
} else {
if (output->type == kTfLiteInt16) {
tflite::reference_integer_ops::Softmax(
op_params, shape, GetTensorData<int8_t>(input), shape,
GetTensorData<int16_t>(output));
} else {
tflite::reference_integer_ops::Softmax(
op_params, shape, GetTensorData<int8_t>(input), shape,
GetTensorData<int8_t>(output));
}
}
return kTfLiteOk;
}
// Takes a 4D tensor and perform softmax along the forth dimension.
void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) {
SoftmaxParams op_params;
op_params.beta = static_cast<double>(params->beta);
tflite::reference_ops::Softmax(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output));
}
void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
SoftmaxParams op_params;
op_params.input_multiplier = data->input_multiplier;
op_params.input_left_shift = data->input_left_shift;
op_params.diff_min = data->diff_min;
if (input->type == kTfLiteUInt8) {
tflite::reference_ops::Softmax(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(output), GetTensorData<uint8_t>(output));
} else {
if (output->type == kTfLiteInt16) {
tflite::reference_integer_ops::Softmax(
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
GetTensorShape(output), GetTensorData<int16_t>(output));
} else {
tflite::reference_integer_ops::Softmax(
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
GetTensorShape(output), GetTensorData<int8_t>(output));
}
}
}
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
OpData local_data_object;
OpData* data = &local_data_object;
TF_LITE_ENSURE_STATUS(
CalculateSoftmaxOpData(context, input, output, params, data));
// (ahentz): consider an implementation that works for many (all?)
// dimensions.
switch (input->type) {
case kTfLiteFloat32: {
if (NumDimensions(input) == 1) {
Softmax1DFloat(input, output, params);
return kTfLiteOk;
}
if (NumDimensions(input) == 2) {
return Softmax2DFloat(context, input, output, params);
}
if (NumDimensions(input) == 4) {
Softmax4DFloat(input, output, params);
return kTfLiteOk;
}
TF_LITE_KERNEL_LOG(
context, "Only 1D, 2D and 4D tensors supported currently, got %dD.",
NumDimensions(input));
return kTfLiteError;
}
case kTfLiteInt8:
case kTfLiteUInt8: {
if (NumDimensions(input) == 1) {
Softmax1DQuantized(input, output, params, data);
return kTfLiteOk;
}
if (NumDimensions(input) == 2) {
return Softmax2DQuantized(context, input, output, params, data);
}
if (NumDimensions(input) == 4) {
Softmax4DQuantized(input, output, params, data);
return kTfLiteOk;
}
TF_LITE_KERNEL_LOG(context,
"Only 2D and 4D tensors supported currently, got %dD.",
NumDimensions(input));
return kTfLiteError;
}
default:
TF_LITE_KERNEL_LOG(
context,
"Only float32, uint8_t and int8_t supported currently, got %d.",
input->type);
return kTfLiteError;
}
}
} // namespace activations
TfLiteRegistration* Register_SOFTMAX() {
static TfLiteRegistration r = {};
r.init = activations::Init;
r.free = activations::Free;
r.prepare = activations::SoftmaxPrepare;
r.invoke = activations::SoftmaxEval;
return &r;
}
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,579 @@
/******************************************************************************
* Copyright (C) 2019 Cadence Design Systems, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to use this Software with Cadence processor cores only and
* not with any other processors and platforms, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
******************************************************************************/
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <math.h>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/activation_utils.h"
#include "tensorflow/lite/micro/micro_utils.h"
#include "xtensa_tf_micro_common.h"
namespace tflite {
namespace ops {
namespace micro {
namespace svdf {
namespace {
// These constants represent constants specific to the hotword "OK G" model.
// They exist until (b/132070898) is fixed.
constexpr int kScratchTensorMaxSize = 64;
struct OpData {
int32 effective_scale_1_a;
int32 effective_scale_2_a;
// b versions of each scale are kept at int since the numbers are just the
// shift value - typically between [-32, 32].
int effective_scale_1_b;
int effective_scale_2_b;
};
/**
* This version of SVDF is specific to TFLite Micro. It contains the following
* differences between the TFLite version:
*
* 1.) Scratch tensor allocation - scratch tensors must be known ahead of time
* for the Micro interpreter.
* 2.) Output dimensions - the TFLite version determines output size and runtime
* and resizes the output tensor. Micro runtime does not support tensor
* resizing.
*/
static inline TfLiteStatus ApplyTimeWeightsBiasAndActivation(
TfLiteContext* context, int batch_size, int memory_size, int num_filters,
int num_units, int rank, const TfLiteTensor* weights_time,
const TfLiteTensor* bias, TfLiteFusedActivation activation,
TfLiteTensor* activation_state, TfLiteTensor* scratch,
TfLiteTensor* output) {
float* scratch_bias = GetTensorData<float>(scratch);
if (bias) {
const float* bias_data = GetTensorData<float>(bias);
for (int j = 0; j < num_units; ++j) {
scratch_bias[j] = *bias_data++;
}
} else {
for (int j = 0; j < num_units; ++j) {
scratch_bias[j] = 0.0f;
}
}
int err = 0;
for (int b = 0; b < batch_size; ++b) {
const float* weights_time_vec = GetTensorData<float>(weights_time);
const float* mat_ptr =
GetTensorData<float>(activation_state) + b * memory_size * num_filters;
float* output_ptr_batch = GetTensorData<float>(output) + b * num_units;
for (int j = 0; j < num_units; j++) {
err = xa_nn_matXvec_f32xf32_f32(
output_ptr_batch, mat_ptr, NULL, weights_time_vec, NULL, scratch_bias,
1, memory_size * rank, 0, memory_size * rank, 0);
CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_matXvec_f32xf32_f32 failed");
output_ptr_batch++;
mat_ptr += memory_size * rank;
weights_time_vec += memory_size * rank;
}
}
// Apply activation.
for (int b = 0; b < batch_size; ++b) {
float* output_ptr_batch = GetTensorData<float>(output) + b * num_units;
for (int i = 0; i < num_units; ++i) {
*output_ptr_batch = ActivationValFloat(activation, *output_ptr_batch);
++output_ptr_batch;
}
}
// Left shift the activation_state to make room for next cycle's activation.
// (alanchiao): explore collapsing this into a single loop.
for (int b = 0; b < batch_size; ++b) {
float* state_ptr_batch =
GetTensorData<float>(activation_state) + b * memory_size * num_filters;
for (int f = 0; f < num_filters; ++f) {
// Shift the vector left:
float* batch_ptr = state_ptr_batch;
float* batch_start = state_ptr_batch + 1;
float* batch_end = state_ptr_batch + memory_size;
while (batch_start != batch_end) {
*batch_ptr++ = *batch_start++;
}
state_ptr_batch[memory_size - 1] = 0.0f;
state_ptr_batch += memory_size;
}
}
return kTfLiteOk;
}
inline TfLiteStatus EvalFloatSVDF(
TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input,
const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time,
const TfLiteTensor* bias, const TfLiteSVDFParams* params,
TfLiteTensor* scratch, TfLiteTensor* activation_state,
TfLiteTensor* output) {
const int rank = params->rank;
const int batch_size = input->dims->data[0];
const int input_size = input->dims->data[1];
const int num_filters = weights_feature->dims->data[0];
const int num_units = num_filters / rank;
const int memory_size = weights_time->dims->data[1];
// Clear the activation (activation_state's leftmost column).
// (ghodrat): Add a test which initialize activation_state with invalid
// values in leftmost column and make sure it passes.
for (int b = 0; b < batch_size; ++b) {
float* state_ptr_batch =
GetTensorData<float>(activation_state) + b * memory_size * num_filters;
}
// Compute conv1d(inputs, weights_feature).
// The activation_state's rightmost column is used to save current cycle
// activation. This is achieved by starting at
// GetTensorData<float>(activation_state)[memory_size - 1] and having the
// stride equal to memory_size.
const float* matrix = GetTensorData<float>(weights_feature);
const float* vector = GetTensorData<float>(input);
float* out_scratch = GetTensorData<float>(scratch);
/* NNLib matXvec needs a bias buffer, so using output buffer to
avoid need for extra memory, output buffer size is batch * num_units,
batch is at least 1 so we use size num_units of it */
float* bias_scratch = GetTensorData<float>(output);
float* result = &GetTensorData<float>(activation_state)[memory_size - 1];
float* result_in_batch = result;
for (int i = 0; i < num_units; i++) bias_scratch[i] = 0.0f;
int err = 0;
for (int i = 0; i < batch_size; i++) {
/* We are using output buffer for bias (it is needed by NNLib kernel,
so only num_units size is guaranteed, so introduced rank loop and
calling matXvec for num_units rows */
for (int j = 0; j < rank; j++) {
err = xa_nn_matXvec_f32xf32_f32(
&out_scratch[j * num_units], &matrix[j * input_size * num_units],
NULL, &vector[i * input_size], NULL, bias_scratch, num_units,
input_size, 0, input_size, 0);
CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_matXvec_f32xf32_f32 failed");
}
for (int j = 0; j < num_filters; ++j) {
*result_in_batch = out_scratch[j];
result_in_batch += memory_size;
}
}
return ApplyTimeWeightsBiasAndActivation(
context, batch_size, memory_size, num_filters, num_units, rank,
weights_time, bias, params->activation, activation_state, scratch,
output);
}
void EvalIntegerSVDF(
TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input_tensor,
const TfLiteTensor* weights_feature_tensor,
const TfLiteTensor* weights_time_tensor, const TfLiteTensor* bias_tensor,
const TfLiteSVDFParams* params, TfLiteTensor* activation_state_tensor,
TfLiteTensor* output_tensor, int32_t scale_1_a, int scale_1_b,
int32_t scale_2_a, int scale_2_b, int32_t input_zp, int32_t output_zp) {
const int n_rank = params->rank;
const int n_batch = input_tensor->dims->data[0];
const int n_input = input_tensor->dims->data[1];
const int n_filter = weights_feature_tensor->dims->data[0];
const int n_unit = n_filter / n_rank;
const int n_memory = weights_time_tensor->dims->data[1];
// (b/132070898): Move these temp variables to the new scratch buffer API
// when ready.
int32_t scratch_tensor[kScratchTensorMaxSize];
int32_t scratch_output_tensor[kScratchTensorMaxSize];
// Rewrite last bit of state.
{
for (int b = 0; b < n_batch; ++b) {
int16_t* state_ptr_batch =
GetTensorData<int16_t>(activation_state_tensor) +
b * n_memory * n_filter;
for (int c = 0; c < n_filter; ++c) {
int16_t* state_ptr = state_ptr_batch + c * n_memory;
state_ptr[n_memory - 1] = 0;
}
}
}
// Feature matmul.
{
int16_t* state = GetTensorData<int16_t>(activation_state_tensor);
const int8_t* input = GetTensorData<int8_t>(input_tensor);
const int8_t* weight_feature =
GetTensorData<int8_t>(weights_feature_tensor);
const int32_t output_max = std::numeric_limits<int16_t>::max();
const int32_t output_min = std::numeric_limits<int16_t>::min();
int16_t* result_in_batch = state + (n_memory - 1);
for (int b = 0; b < n_batch; b++) {
const int8_t* matrix_ptr = weight_feature;
for (int r = 0; r < n_filter; r++) {
int32_t dot_prod = 0;
const int8_t* vector_in_batch = input + b * n_input;
for (int c = 0; c < n_input; c++) {
dot_prod += *matrix_ptr++ * (*vector_in_batch++ - input_zp);
}
dot_prod =
MultiplyByQuantizedMultiplier(dot_prod, scale_1_a, scale_1_b);
dot_prod = std::min(std::max(output_min, dot_prod), output_max);
*result_in_batch = dot_prod;
result_in_batch += n_memory;
}
}
}
// Time.
{
for (int b = 0; b < n_batch; ++b) {
int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
// Perform batched vector dot product:
const int16_t* vector1_ptr = GetTensorData<int16_t>(weights_time_tensor);
const int16_t* vector2_ptr =
GetTensorData<int16_t>(activation_state_tensor) +
b * n_memory * n_filter;
for (int i = 0; i < n_filter; i++) {
*scratch_ptr_batch = 0;
for (int j = 0; j < n_memory; j++) {
*scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
}
scratch_ptr_batch++;
}
}
}
// Reduce, add bias, rescale, activation.
{
// Add bias.
if (bias_tensor) {
// Vector batch assign:
const int32_t* bias_data = GetTensorData<int32_t>(bias_tensor);
for (int i = 0; i < n_batch; ++i) {
int32_t* output_ptr = scratch_output_tensor + i * n_unit;
const int32_t* bias_ptr = bias_data;
for (int j = 0; j < n_unit; ++j) {
*output_ptr++ = *bias_ptr++;
}
}
} else {
int32_t* output_ptr = scratch_output_tensor;
for (int i = 0; i < n_batch * n_unit; ++i) {
*output_ptr++ = 0;
}
}
// Reduce.
for (int b = 0; b < n_batch; ++b) {
int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit;
int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
// Reduction sum vector
for (int i = 0; i < n_unit; ++i) {
for (int j = 0; j < n_rank; ++j) {
output_temp_ptr[i] += *scratch_ptr_batch++;
}
}
}
// Rescale.
const int32_t output_max = std::numeric_limits<int8_t>::max();
const int32_t output_min = std::numeric_limits<int8_t>::min();
for (int i = 0; i < n_batch * n_unit; ++i) {
int32_t x1 = scratch_output_tensor[i];
int32_t x2 = MultiplyByQuantizedMultiplier(x1, scale_2_a, scale_2_b);
int32_t x3 = x2 + output_zp;
int32_t x4 = std::min(std::max(output_min, x3), output_max);
GetTensorData<int8_t>(output_tensor)[i] = static_cast<int8_t>(x4);
}
}
// Shift state.
{
for (int b = 0; b < n_batch; ++b) {
int16_t* state_ptr_batch =
GetTensorData<int16_t>(activation_state_tensor) +
b * n_memory * n_filter;
for (int f = 0; f < n_filter; ++f) {
// Shift the vector left:
int16_t* batch_ptr = state_ptr_batch;
int16_t* batch_start = state_ptr_batch + 1;
int16_t* batch_end = state_ptr_batch + n_memory;
while (batch_start != batch_end) {
*batch_ptr++ = *batch_start++;
}
state_ptr_batch[n_memory - 1] = 0;
state_ptr_batch += n_memory;
}
}
}
}
} // namespace
// Input tensors.
constexpr int kInputTensor = 0;
constexpr int kWeightsFeatureTensor = 1;
constexpr int kWeightsTimeTensor = 2;
constexpr int kBiasTensor = 3;
// This is a variable tensor, and will be modified by this op.
constexpr int kInputActivationStateTensor = 4;
// Output tensor.
constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return nullptr;
}
void Free(TfLiteContext* context, void* buffer) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
// Validate Tensor Inputs (dtype depends on quantization):
// [0] = Input, {2, batch_size, input_size}
// [1] = Weights Feature, {2, num_filters, input_size}
// [2] = Weights Time, {2, num_filters, memory_size}
// [3] = Bias (optional), {1, num_units}
// [4] = Activation State (variable),
// {2, batch_size, memory_size * num_filters}
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
GetInput(context, node, kWeightsFeatureTensor);
const TfLiteTensor* weights_time =
GetInput(context, node, kWeightsTimeTensor);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
const TfLiteTensor* activation_state =
GetInput(context, node, kInputActivationStateTensor);
// Define input constants based on input tensor definition above:
const int rank = params->rank;
const int input_size = input->dims->data[1];
const int batch_size = input->dims->data[0];
const int num_filters = weights_feature->dims->data[0];
TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
const int num_units = num_filters / rank;
const int memory_size = weights_time->dims->data[1];
const bool is_full_integer = input->type == kTfLiteInt8;
// Validate Input Tensor:
TF_LITE_ENSURE(context,
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
// Validate Tensor Output:
// [0] = float/int8, {2, batch_size, num_units}
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
// Validate Weights Feature Input Tensor:
TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2);
TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size);
// Validate Weights Time Input Tensor:
TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2);
TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters);
TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size);
// Validate Optional Bias Input Tensor:
if (bias) {
TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
}
// Validate Activation State Input Tensor:
TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
memory_size * num_filters);
if (is_full_integer) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
if (bias) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
}
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
// Validate Scratch Tensors:
// [0] = (shared - see float block below for usage)
// [1] = Output Temp, int8_t, {2, num_units, batch_size}
// (b/132070898): Scratch values are used as stack variables in
// EvalIntegerSVDF().
// Validate output tensor:
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt8);
} else {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 6);
// Validate Input Tensor dtypes:
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
if (bias) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
}
// Validate shared Scratch Tensor:
// [0] = Holds dot-product of time-forward calculations in
// ApplyTimeWeightsBiasAndActivation():
// float/int32, {2, batch_size, num_filters}
// (b/132070898): Use input tensor as variable until scratch tensor
// allocation has been implemented (b/132070898) TfLiteTensor*
// scratch_tensor = GetTemporary(context, node, 0);
TfLiteTensor* scratch_tensor = &context->tensors[node->inputs->data[5]];
TF_LITE_ENSURE_EQ(context, scratch_tensor->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(scratch_tensor), 2);
TF_LITE_ENSURE_EQ(context, scratch_tensor->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, scratch_tensor->dims->data[1], num_filters);
// Full-float SVDF only uses the one shared scratch tensor (see above for
// usage).
// (b/132070898): Use input tensor as variable until scratch tensor
// allocation has been implemented.
// TF_LITE_ENSURE_EQ(context, node->temporaries->size, 1);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
}
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
GetInput(context, node, kWeightsFeatureTensor);
const TfLiteTensor* weights_time =
GetInput(context, node, kWeightsTimeTensor);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* activation_state =
GetVariableInput(context, node, kInputActivationStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const bool is_full_integer = input->type == kTfLiteInt8;
switch (weights_feature->type) {
case kTfLiteFloat32: {
// (b/132070898): Use input tensor as variable until scratch tensor
// allocation has been implemented. TfLiteTensor* scratch =
// GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scratch = &context->tensors[node->inputs->data[5]];
return EvalFloatSVDF(context, node, input, weights_feature, weights_time,
bias, params, scratch, activation_state, output);
break;
}
case kTfLiteInt8: {
if (is_full_integer) {
// (b/132070898): Store these values in ::Prepare() instead of
// ::Eval():
// Calculate effective scales.
OpData op_data;
auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>(
input->quantization.params);
auto* weights_feature_params =
reinterpret_cast<TfLiteAffineQuantization*>(
weights_feature->quantization.params);
auto* state_params = reinterpret_cast<TfLiteAffineQuantization*>(
activation_state->quantization.params);
auto* weight_time_params = reinterpret_cast<TfLiteAffineQuantization*>(
weights_time->quantization.params);
auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
output->quantization.params);
const double effective_scale_1 =
static_cast<double>(input_params->scale->data[0] *
weights_feature_params->scale->data[0] /
state_params->scale->data[0]);
const double effective_scale_2 = static_cast<double>(
state_params->scale->data[0] * weight_time_params->scale->data[0] /
output_params->scale->data[0]);
QuantizeMultiplier(effective_scale_1, &op_data.effective_scale_1_a,
&op_data.effective_scale_1_b);
QuantizeMultiplier(effective_scale_2, &op_data.effective_scale_2_a,
&op_data.effective_scale_2_b);
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu);
EvalIntegerSVDF(
context, node, input, weights_feature, weights_time, bias, params,
activation_state, output, op_data.effective_scale_1_a,
op_data.effective_scale_1_b, op_data.effective_scale_2_a,
op_data.effective_scale_2_b, input->params.zero_point,
output->params.zero_point);
return kTfLiteOk;
}
break;
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
TfLiteTypeGetName(weights_feature->type));
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace svdf
TfLiteRegistration* Register_SVDF() {
static TfLiteRegistration r = {};
r.init = svdf::Init;
r.free = svdf::Free;
r.prepare = svdf::Prepare;
r.invoke = svdf::Eval;
return &r;
}
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,80 @@
/******************************************************************************
* Copyright (C) 2019 Cadence Design Systems, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to use this Software with Cadence processor cores only and
* not with any other processors and platforms, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
******************************************************************************/
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef __XTENSA_TF_MICRO_COMMON__
#define __XTENSA_TF_MICRO_COMMON__
#include "xa_nnlib_api.h"
#include "xa_nnlib_standards.h"
#define CHECK_ERR_HIFI_NNLIB_KER(ret, err_msg) \
if (ret != 0) { \
TF_LITE_KERNEL_LOG(context, err_msg); \
return kTfLiteError; \
}
#ifndef XTENSA_NNLIB_MAX_SCRATCH_SIZE
#define XTENSA_NNLIB_MAX_SCRATCH_SIZE (70 * 1024)
#endif
#define ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM \
uint8_t xtensa_nnlib_scratch_buf[XTENSA_NNLIB_MAX_SCRATCH_SIZE];
#define MIN(a, b) (a) < (b) ? (a) : (b);
#define MAX(a, b) (a) > (b) ? (a) : (b);
#define ACTIVATION_MIN_MAX(data_type, out, inp, min, max) \
{ \
data_type temp = MAX(inp, min); \
out = MIN(temp, max); \
}
#define ACTIVATION_MIN_MAX_F32(out, inp, min, max) \
{ \
float temp = MAX(inp, min); \
out = MIN(temp, max); \
}
#define ACTIVATION_MIN_MAX_ASYM8(out, inp, min, max) \
{ \
int32_t temp = MAX((int32_t)inp, min); \
out = (uint8_t)MIN(temp, max); \
}
#define ALIGNED_SIZE(x, bytes) (((x) + (bytes - 1)) & (~(bytes - 1)))
#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1)))
#endif /* __XTENSA_TF_MICRO_COMMON__ */

View File

@ -0,0 +1,59 @@
#!/bin/bash -e
# ==============================================================================
# Copyright (C) 2019 Cadence Design Systems, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to use this Software with Cadence processor cores only and
# not with any other processors and platforms, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==============================================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Tests an Xtensa binary by parsing the log output.
#
# First argument is the binary location.
# Second argument is a regular expression that's required to be in the output
# logs for the test to pass.
declare -r ROOT_DIR=`pwd`
declare -r TEST_TMPDIR=/tmp/test_xtensa_hifi_binary/
declare -r MICRO_LOG_PATH=${TEST_TMPDIR}/$1
declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt
mkdir -p ${MICRO_LOG_PATH}
xt-run $1 2>&1 | tee ${MICRO_LOG_FILENAME}
if grep -q "$2" ${MICRO_LOG_FILENAME}
then
echo "$1: PASS"
exit 0
else
echo "$1: FAIL - '$2' not found in logs."
exit 1
fi

View File

@ -0,0 +1,67 @@
ifneq ($(filter xtensa_hifi, $(ALL_TAGS)),)
XTENSA_PATH = $(MAKEFILE_DIR)/../../kernels/xtensa_hifi
ifneq (,$(filter hifi4%, $(TARGET_ARCH)))
CCFLAGS += -DNNLIB_V2 \
-DXTENSA_NNLIB_MAX_SCRATCH_SIZE=70*1024
CXXFLAGS += -DNNLIB_V2 \
-DXTENSA_NNLIB_MAX_SCRATCH_SIZE=70*1024
MICROLITE_CC_SRCS += \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_activations_f32_f32.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_activations_asym8_asym8.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_activations_32_16.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_activations_32_8.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_softmax_asym8_asym8.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/basic/hifi4/xa_nn_floor_f32.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_circ_buf.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_asym8xasym8.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_f32.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_matXvec_asym8xasym8_asym8_circ.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_matXvec_f32_circ.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_depthwise.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_depthwise_f32.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_depthwise_asym8xasym8.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_circ_buf.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/fc/hifi4/xa_nn_fully_connected.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_f32.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_16x16.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_8x16.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_8x8.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_asym8xasym8.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool_f32.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool_asym8.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool_f32.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool_asym8.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool_f32_nhwc.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool_asym8_nhwc.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool_f32_nhwc.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool_asym8_nhwc.c \
$(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_inv_256_tbl.c \
$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_sigmoidf_hifi4.c \
$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_tanhf_hifi4.c \
$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_reluf_hifi4.c \
$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_softmaxf_hifi4.c \
$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_alognf_hifi4.c \
$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/scl_sigmoidf_hifi4.c \
$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/scl_tanhf_hifi4.c \
$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/expf_tbl.c \
$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/pow2f_tbl.c \
$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/inff_tbl.c \
$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/tanhf_tbl.c \
$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/nanf_tbl.c \
INCLUDES += -I$(XTENSA_PATH)/xa_nnlib/algo/kernels/ \
-I$(XTENSA_PATH)/xa_nnlib/include/nnlib/ \
-I$(XTENSA_PATH)/xa_nnlib/include/ \
-I$(XTENSA_PATH)/xa_nnlib/algo/common/include/ \
-I$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/include/ \
endif
endif

View File

@ -0,0 +1,35 @@
# Building TensorFlow Lite for Microcontrollers for Cadence Tensilica HiFi DSPs
This document describes the steps to build and run the Tensorflow Lite Micro on
the Cadence HiFi DSPs.
## Pre-requisites
The Xtensa development tools and the target processor configurations should be
installed on the system. Please check [https://tensilicatools.com] for more
information about downloading and installing the required tools.
The PATH variable should be set to include the <xtensa_tools_root>/bin
directory. The XTENSA_SYSTEM and XTENSA_CORE environment variables should be set
to the required tools version and the required processor configuration.
## Building for HiFi Processors
To build the code using Xtensa tools for the processor configuration selected by
XTENSA_CORE , set TARGET=xtensa_hifi. Additionally TARGET_ARCH can be used to
select optimized HiFi NN kernels specific to the processor configuration.
Currently the HiFi4 NN kernels are provided which can be enabled as follows:
make -f tensorflow/lite/micro/tools/make/Makefile test_micro_speech_test
TARGET=xtensa_hifi TARGET_ARCH=hifi4
Xtensa specific TF Lite Micro kernels are implemented in this folder:
tensorflow/lite/micro/kernels/xtensa_hifi/
A scratch memory allocation is needed for the HiFi optimized kernels. This
allocation is currently done on stack and it's size can be controlled by
defining 'XTENSA_NNLIB_MAX_SCRATCH_SIZE' approproately in the file
'tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifi_nn_library.inc
The files containing the HiFi optimized NN kernels are present in this folder:
tensorflow/lite/micro/kernels/xtensa_hifi/xa_nnlib/

View File

@ -0,0 +1,42 @@
# Settings for Xtensa toolchain.
# Derived from xtensa_xpg_makefile.inc
# The Xtensa environment variables should be configured externally (XTENSA_CORE, XTENSA_SYSTEM)
ifeq ($(TARGET), xtensa_hifi)
TARGET_ARCH := hifi3_bd5
PLATFORM_ARGS = \
-mno-mul16 \
-mno-mul32 \
-mno-div32 \
-fsigned-char \
-fno-exceptions \
-mlongcalls \
-INLINE:requested \
-mcoproc \
-fno-zero-initialized-in-bss \
-mtext-section-literals \
-fno-unsafe-math-optimizations \
TF_LITE_MICRO_FLAGS = \
-DTF_LITE_STATIC_MEMORY\
TARGET_TOOLCHAIN_PREFIX := xt-
CXX_TOOL := clang++
CC_TOOL := clang
CXXFLAGS = -O0 $(PLATFORM_ARGS) -std=c++11 $(TF_LITE_MICRO_FLAGS)
#TODO: Use -std=c11 ?
CCFLAGS = -O3 $(PLATFORM_ARGS) $(TF_LITE_MICRO_FLAGS)
TEST_SCRIPT := tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh
# These are microcontroller-specific rules for converting the ELF output
# of the linker into a binary image that can be loaded directly.
OBJCOPY := $(TARGET_TOOLCHAIN_PREFIX)objcopy
$(BINDIR)/%.bin: $(BINDIR)/%
echo "here"
@mkdir -p $(dir $@)
$(OBJCOPY) $< $@ -O binary
endif

View File

@ -7,6 +7,8 @@
ifeq ($(TARGET), xtensa-xpg)
TARGET_ARCH := xtensa-xpg
$(eval $(call add_third_party_download,$(XTENSA_HIFI4_URL),$(XTENSA_HIFI4_MD5),xa_nnlib,))
PLATFORM_ARGS = \
-DTF_LITE_MCU_DEBUG_LOG \
--xtensa-core=$(XTENSA_CORE) \

View File

@ -59,3 +59,7 @@ EMBARC_OSP_MD5 := "9eaf7b3a1ed05872a03da9796672a776"
EMBARC_MLI_URL := "https://github.com/foss-for-synopsys-dwc-arc-processors/embarc_mli/archive/6316034d421cbbb59756239908d7c9a99075a3bb.zip"
EMBARC_MLI_MD5 := "db0910cf0e07e43f74ae7a31de485d56"
XTENSA_HIFI4_URL :="https://github.com/foss-xtensa/nnlib-hifi4/blob/master/archive/xa_nnlib.zip"
XTENSA_HIFI4_MD5 :="a517b653a75b96d0271e1b99ee2a8c14"