From f5f08e60c10d8f92b99371edfa63b9714e149c75 Mon Sep 17 00:00:00 2001 From: Niranjan Yadla Date: Fri, 14 Feb 2020 10:12:11 -0800 Subject: [PATCH] Cadence HiFi4 Neural Network (NN) Library Optimized HiFi4 NN library to run ops for TensorFlowLite for Micro Signed-off-by: Prasad Nikam Signed-off-by: Niranjan Yadla Signed-off-by: Pramod Kumar Surana Signed-off-by: Ranjit Kumar Voruganti --- .../micro/kernels/xtensa_hifi/activations.cc | 239 +++++++ .../lite/micro/kernels/xtensa_hifi/conv.cc | 561 +++++++++++++++ .../kernels/xtensa_hifi/depthwise_conv.cc | 597 ++++++++++++++++ .../lite/micro/kernels/xtensa_hifi/floor.cc | 81 +++ .../kernels/xtensa_hifi/fully_connected.cc | 280 ++++++++ .../micro/kernels/xtensa_hifi/logistic.cc | 125 ++++ .../lite/micro/kernels/xtensa_hifi/pooling.cc | 662 ++++++++++++++++++ .../lite/micro/kernels/xtensa_hifi/softmax.cc | 328 +++++++++ .../lite/micro/kernels/xtensa_hifi/svdf.cc | 599 ++++++++++++++++ .../xtensa_hifi/xtensa_tf_micro_common.h | 79 +++ .../micro/testing/test_xtensa_hifi_binary.sh | 59 ++ .../make/ext_libs/xtensa_hifi_nn_library.inc | 67 ++ .../tools/make/targets/xtensa_hifi/README.md | 31 + .../make/targets/xtensa_hifi_makefile.inc | 42 ++ .../make/targets/xtensa_xpg_makefile.inc | 2 + .../tools/make/third_party_downloads.inc | 4 + 16 files changed, 3756 insertions(+) create mode 100644 tensorflow/lite/micro/kernels/xtensa_hifi/activations.cc create mode 100755 tensorflow/lite/micro/kernels/xtensa_hifi/conv.cc create mode 100755 tensorflow/lite/micro/kernels/xtensa_hifi/depthwise_conv.cc create mode 100644 tensorflow/lite/micro/kernels/xtensa_hifi/floor.cc create mode 100644 tensorflow/lite/micro/kernels/xtensa_hifi/fully_connected.cc create mode 100644 tensorflow/lite/micro/kernels/xtensa_hifi/logistic.cc create mode 100755 tensorflow/lite/micro/kernels/xtensa_hifi/pooling.cc create mode 100755 tensorflow/lite/micro/kernels/xtensa_hifi/softmax.cc create mode 100644 tensorflow/lite/micro/kernels/xtensa_hifi/svdf.cc create mode 100755 tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h create mode 100755 tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh create mode 100644 tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifi_nn_library.inc create mode 100644 tensorflow/lite/micro/tools/make/targets/xtensa_hifi/README.md create mode 100644 tensorflow/lite/micro/tools/make/targets/xtensa_hifi_makefile.inc diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/activations.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/activations.cc new file mode 100644 index 00000000000..67f47102438 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/activations.cc @@ -0,0 +1,239 @@ +/****************************************************************************** +* Copyright (C) 2019 Cadence Design Systems, Inc. +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files (the +* "Software"), to use this Software with Cadence processor cores only and +* not with any other processors and platforms, subject to +* the following conditions: +* +* The above copyright notice and this permission notice shall be included +* in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +******************************************************************************/ + +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/micro_utils.h" +#include "xtensa_tf_micro_common.h" + + +namespace tflite { +namespace ops { +namespace micro { +namespace activations { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +template +inline void ReluQuantized(int32_t lower, const RuntimeShape& input_shape, + const Q* input_data, const RuntimeShape& output_shape, + Q* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const Q val = input_data[i]; + const Q clamped = val < lower ? lower : val; + output_data[i] = clamped; + } +} + +inline void ReluFloat(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const float val = input_data[i]; + const float lower = 0.0f; + const float clamped = val < lower ? lower : val; + output_data[i] = clamped; + } +} + +inline void Relu6Float(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const float val = input_data[i]; + const float upper = 6.0f; + const float lower = 0.0f; + const float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[i] = clamped; + } +} + +template +inline void Relu6Quantized(Q lower, Q upper, const RuntimeShape& input_shape, + const Q* input_data, + const RuntimeShape& output_shape, Q* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const Q val = input_data[i]; + const Q clamped = val > upper ? upper : val < lower ? lower : val; + output_data[i] = clamped; + } +} + +TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input->type) { + case kTfLiteFloat32: { + int err; + const float *inp_data_ptr; + float *out_data_ptr; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + const float f32_pos_inf = 0x7F800000; + err = xa_nn_vec_relu_f32_f32(out_data_ptr, inp_data_ptr, f32_pos_inf,flat_size); + + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_relu1_f32_f32 failed"); + return kTfLiteOk; + } + case kTfLiteInt8: { + ReluQuantized(input->params.zero_point, GetTensorShape(input), + GetTensorData(input), + GetTensorShape(output), + GetTensorData(output)); + return kTfLiteOk; + } + case kTfLiteUInt8: { + int err; + const uint8_t *inp_data_ptr; + uint8_t *out_data_ptr; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + err = xa_nn_vec_activation_min_max_asym8_asym8(out_data_ptr, inp_data_ptr,0,255, flat_size); //Is 255 right? + + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_activation_min_max_8_8 failed"); + return kTfLiteOk; + } + default: { + TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + } +} + +TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input->type) { + case kTfLiteFloat32: { + int err; + const float *inp_data_ptr; + float *out_data_ptr; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + err = xa_nn_vec_relu6_f32_f32(out_data_ptr, inp_data_ptr, flat_size); + + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_relu1_f32_f32 failed"); + return kTfLiteOk; + } + case kTfLiteInt8: { + const int8_t six = FloatToAsymmetricQuantizedInt8( + 6.0f, input->params.scale, input->params.zero_point); + const int8_t zero = input->params.zero_point; + Relu6Quantized( + zero, six, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + return kTfLiteOk; + } + case kTfLiteUInt8: { + const uint8_t six = FloatToAsymmetricQuantizedUInt8( + 6.0f, input->params.scale, input->params.zero_point); + const uint8_t zero = input->params.zero_point; + int err; + const uint8_t *inp_data_ptr; + uint8_t *out_data_ptr; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + err = xa_nn_vec_activation_min_max_asym8_asym8(out_data_ptr, inp_data_ptr,zero,six, flat_size); + + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_activation_min_max_8_8 failed"); + return kTfLiteOk; + } + default: { + TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + } +} + +} // namespace activations + +TfLiteRegistration* Register_RELU() { + static TfLiteRegistration r = {}; + r.prepare = activations::ReluPrepare; + r.invoke = activations::ReluEval; + return &r; +} + +TfLiteRegistration* Register_RELU6() { + static TfLiteRegistration r = {}; + r.prepare = activations::Relu6Prepare; + r.invoke = activations::Relu6Eval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/conv.cc new file mode 100755 index 00000000000..ba1c7516722 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/conv.cc @@ -0,0 +1,561 @@ +/****************************************************************************** +* Copyright (C) 2019 Cadence Design Systems, Inc. +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files (the +* "Software"), to use this Software with Cadence processor cores only and +* not with any other processors and platforms, subject to +* the following conditions: +* +* The above copyright notice and this permission notice shall be included +* in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +******************************************************************************/ + +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/conv.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" + +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace conv { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; +constexpr int kMaxChannels = 256; + +// Conv is quantized along dimension 0: +// https://www.tensorflow.org/lite/performance/quantization_spec +constexpr int kConvQuantizedDimension = 0; + +// This file has 2 implementation of Conv. + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + + // Per channel output multiplier and shift. + // (b/141139247): Allocate these dynamically when possible. + int32_t per_channel_output_multiplier[kMaxChannels]; + int32_t per_channel_output_shift[kMaxChannels]; + + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +inline PaddingType RuntimePaddingType(TfLitePadding padding) { + switch (padding) { + case TfLitePadding::kTfLitePaddingSame: + return PaddingType::kSame; + case TfLitePadding::kTfLitePaddingValid: + return PaddingType::kValid; + case TfLitePadding::kTfLitePaddingUnknown: + default: + return PaddingType::kNone; + } +} + +TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, int width, int height, + int filter_width, int filter_height, int out_width, + int out_height, const TfLiteType data_type, + OpData* data) { + bool has_bias = node->inputs->size == 3; + // Check number of inputs/outputs + TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + data->padding = ComputePaddingHeightWidth( + params->stride_height, params->stride_width, + params->dilation_height_factor, params->dilation_width_factor, height, + width, filter_height, filter_width, padding, &out_height, &out_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + int output_channels = filter->dims->data[kConvQuantizedDimension]; + + TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( + context, input, filter, bias, output, params->activation, + &data->output_multiplier, &data->output_shift, + &data->output_activation_min, &data->output_activation_max, + data->per_channel_output_multiplier, + reinterpret_cast(data->per_channel_output_shift), + output_channels)); + } + return kTfLiteOk; +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* im2col, + TfLiteTensor* hwcn_weights, TfLiteTensor* output) { + const int32_t input_offset = -input->params.zero_point; + const int32_t filter_offset = -filter->params.zero_point; + const int32_t output_offset = output->params.zero_point; + + if((params->dilation_width_factor == 1) && (params->dilation_height_factor == 1)) { + const uint8 *input_data, *filter_data; + const int32_t *bias_data; + uint8 *output_data; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& filter_shape = GetTensorShape(filter); + const RuntimeShape& output_shape = GetTensorShape(output); + const RuntimeShape& bias_shape = GetTensorShape(bias); + + input_data = GetTensorData(input); + filter_data = GetTensorData(filter); + bias_data = GetTensorData(bias); + output_data = GetTensorData(output); + + const int stride_width = params->stride_width; + const int stride_height = params->stride_height; + const int dilation_width_factor = 1; + const int dilation_height_factor = 1; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + const int32 output_activation_min = data->output_activation_min; + const int32 output_activation_max = data->output_activation_max; + const int32 output_multiplier = data->output_multiplier; + const int output_shift = -data->output_shift; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); + const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); + if (bias_data) { + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + } + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int filter_depth = filter_shape.Dims(3); + + int err, output_data_format=0; + void *p_scratch; + uint8 *p_filter, *p_out_scratch; + // Calculate filter_depth_padded as next near multiple of 4 + int filter_depth_padded = (filter_depth + 3) & (~3); + int out_length = output_height * output_width * output_depth; + int required_scratch, input_precision = PREC_ASYM8; + int h,w,c; + + required_scratch = xa_nn_conv2d_std_getsize( + input_height, + input_depth, + filter_height, + filter_width, + stride_height, + pad_height, + output_height, + input_precision); + + if(required_scratch <= 0){ + TF_LITE_KERNEL_LOG( + context, "conv2d_std_asym8: xa_nn_conv2d_std_getsize failed"); + return kTfLiteError; + } + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = xtensa_nnlib_scratch_buf; + + p_filter = (uint8 *)p_scratch; + p_out_scratch = (p_filter + ALIGNED_SIZE((sizeof(uint8_t)*filter_height*filter_width*filter_depth_padded*output_depth), 8)); + required_scratch += ALIGNED_SIZE((sizeof(uint8_t)*filter_height*filter_width*filter_depth_padded*output_depth), 8); + p_scratch = (uint8 *)(p_out_scratch + ALIGNED_SIZE(sizeof(uint8_t)*out_length, 8)); + required_scratch += ALIGNED_SIZE(sizeof(uint8_t)*out_length, 8); + + if(required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE){ + TF_LITE_KERNEL_LOG(context, "conv2d_std_asym8: insufficient scratch memory"); + return kTfLiteError; + } + + // Padding filter coefficients depthwise + for(h=0; hpadding); + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + op_params.output_shift = -data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + reference_ops::Conv(op_params, GetTensorShape(input), + GetTensorData(input), GetTensorShape(filter), + GetTensorData(filter), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output), GetTensorShape(im2col), + GetTensorData(im2col), nullptr); + } + return kTfLiteOk; +} + +void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output, + TfLiteTensor* im2col) { + ConvParams op_params; + op_params.input_offset = -input->params.zero_point; + op_params.output_offset = output->params.zero_point; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + + reference_integer_ops::ConvPerChannel( + op_params, data->per_channel_output_multiplier, + data->per_channel_output_shift, GetTensorShape(input), + GetTensorData(input), GetTensorShape(filter), + GetTensorData(filter), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); +} + +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* im2col, + TfLiteTensor* hwcn_weights, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + + if((params->dilation_width_factor == 1) && (params->dilation_height_factor == 1)) { + const float *input_data, *filter_data; + const float *bias_data; + float *output_data; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& filter_shape = GetTensorShape(filter); + const RuntimeShape& output_shape = GetTensorShape(output); + const RuntimeShape& bias_shape = GetTensorShape(bias); + + input_data = GetTensorData(input); + filter_data = GetTensorData(filter); + bias_data = GetTensorData(bias); + output_data = GetTensorData(output); + + const int stride_width = params->stride_width; + const int stride_height = params->stride_height; + const int dilation_width_factor = 1; + const int dilation_height_factor = 1; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); + const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); + if (bias_data) { + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + } + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int filter_depth = filter_shape.Dims(3); + int err, output_data_format=0; + void *p_scratch; + float *p_filter, *p_out_scratch; + // Calculate filter_depth_padded as next near multiple of 2 + int filter_depth_padded = (filter_depth + 1) & (~1); + int out_length = output_height * output_width * output_depth; + int required_scratch, input_precision = PREC_F32; + int h,w,c; + + required_scratch = xa_nn_conv2d_std_getsize( + input_height, + input_depth, + filter_height, + filter_width, + stride_height, + pad_height, + output_height, + input_precision); + + if(required_scratch <= 0){ + TF_LITE_KERNEL_LOG( + context, "conv2d_std_f32: xa_nn_conv2d_std_getsize failed"); + return kTfLiteError; + } + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = xtensa_nnlib_scratch_buf; + + p_filter = (float *)p_scratch; + p_out_scratch = (float *)((uint8_t *)p_filter + ALIGNED_SIZE((sizeof(float)*filter_height*filter_width*filter_depth_padded*output_depth), 8)); + required_scratch += ALIGNED_SIZE((sizeof(float)*filter_height*filter_width*filter_depth_padded*output_depth), 8); + p_scratch = (float *)((uint8_t *)p_out_scratch + ALIGNED_SIZE(sizeof(float)*out_length, 8)); + required_scratch += ALIGNED_SIZE(sizeof(float)*out_length, 8); + + if(required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE){ + TF_LITE_KERNEL_LOG(context, "conv2d_std_f32: insufficient scratch memory"); + return kTfLiteError; + } + + // Padding filter coefficients depthwise + for(h=0; hpadding); + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + reference_ops::Conv(op_params, GetTensorShape(input), + GetTensorData(input), GetTensorShape(filter), + GetTensorData(filter), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output), GetTensorShape(im2col), + GetTensorData(im2col)); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + + int input_width = input->dims->data[2]; + int input_height = input->dims->data[1]; + int filter_width = filter->dims->data[2]; + int filter_height = filter->dims->data[1]; + int output_width = output->dims->data[2]; + int output_height = output->dims->data[1]; + + OpData data; + + // All per-channel quantized tensors need valid zero point and scale arrays. + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, filter->quantization.type, + kTfLiteAffineQuantization); + + const auto* affine_quantization = + reinterpret_cast( + filter->quantization.params); + TF_LITE_ENSURE(context, affine_quantization); + TF_LITE_ENSURE(context, affine_quantization->scale); + TF_LITE_ENSURE(context, affine_quantization->zero_point); + + TF_LITE_ENSURE(context, + affine_quantization->scale->size == 1 || + affine_quantization->scale->size == + filter->dims->data[kConvQuantizedDimension]); + TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, + affine_quantization->zero_point->size); + } + + TF_LITE_ENSURE_STATUS(CalculateOpData( + context, node, params, input_width, input_height, filter_width, + filter_height, output_width, output_height, input->type, &data)); + + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, &data, input, filter, bias, nullptr, + nullptr, output); + break; + case kTfLiteInt8: + EvalQuantizedPerChannel(context, node, params, &data, input, filter, bias, + output, nullptr); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, &data, input, filter, bias, nullptr, + nullptr, output); + break; + default: + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace conv + +TfLiteRegistration* Register_CONV_2D() { + static TfLiteRegistration r = {}; + r.prepare = conv::Prepare; + r.invoke = conv::Eval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/depthwise_conv.cc new file mode 100755 index 00000000000..4209b4c8a0b --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/depthwise_conv.cc @@ -0,0 +1,597 @@ +/****************************************************************************** +* Copyright (C) 2019 Cadence Design Systems, Inc. +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files (the +* "Software"), to use this Software with Cadence processor cores only and +* not with any other processors and platforms, subject to +* the following conditions: +* +* The above copyright notice and this permission notice shall be included +* in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +******************************************************************************/ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" + +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace depthwise_conv { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; +constexpr int kMaxChannels = 256; + +// Depthwise conv is quantized along dimension 3: +// https://www.tensorflow.org/lite/performance/quantization_spec +constexpr int kDepthwiseConvQuantizedDimension = 3; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + + // Per channel output multiplier and shift. + // (b/141139247): Allocate these dynamically when possible. + int32_t per_channel_output_multiplier[kMaxChannels]; + int32_t per_channel_output_shift[kMaxChannels]; + + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, int width, + int height, int filter_width, int filter_height, + const TfLiteType data_type, OpData* data) { + bool has_bias = node->inputs->size == 3; + // Check number of inputs/outputs + TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + int unused_output_height, unused_output_width; + data->padding = ComputePaddingHeightWidth( + params->stride_height, params->stride_width, 1, 1, height, width, + filter_height, filter_width, params->padding, &unused_output_height, + &unused_output_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension]; + + TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( + context, input, filter, bias, output, params->activation, + &data->output_multiplier, &data->output_shift, + &data->output_activation_min, &data->output_activation_max, + data->per_channel_output_multiplier, + reinterpret_cast(data->per_channel_output_shift), num_channels)); + } + return kTfLiteOk; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + + if((params->dilation_width_factor == 1) && (params->dilation_height_factor == 1)) { + const float *input_data, *filter_data, *bias_data; + float *output_data; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& filter_shape = GetTensorShape(filter); + const RuntimeShape& output_shape = GetTensorShape(output); + const RuntimeShape& bias_shape = GetTensorShape(bias); + + input_data = GetTensorData(input); + filter_data = GetTensorData(filter); + bias_data = GetTensorData(bias); + output_data = GetTensorData(output); + + const int stride_width = params->stride_width; + const int stride_height = params->stride_height; + const int dilation_width_factor = 1; + const int dilation_height_factor = 1; + //const int dilation_width_factor = params->dilation_width_factor;; + //const int dilation_height_factor = params->dilation_height_factor; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + const int depth_multiplier = params->depth_multiplier; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int input_depth = input_shape.Dims(3); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int filter_depth = filter_shape.Dims(3); + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + + int32_t err, input_data_format = 0, output_data_format = 0; + void *p_scratch; + float *p_filter; + int filter_depth_padded, filter_size_padded, required_scratch; + int input_precision = PREC_F32; + int h, c, i; + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = xtensa_nnlib_scratch_buf; + + filter_depth_padded = (filter_depth + 1) & (~1); + filter_size_padded = filter_height * filter_width * filter_depth_padded; + + required_scratch = xa_nn_conv2d_depthwise_getsize( + input_height, + input_width, + input_depth, + filter_height, + filter_width, + depth_multiplier, + stride_width, + stride_height, + pad_width, + pad_height, + output_height, + output_width, + input_precision, + input_data_format); + + if(required_scratch <= 0) { + TF_LITE_KERNEL_LOG(context, "DepthwiseConvFloat: xa_nn_conv2d_depthwise_getsize failed"); + return kTfLiteError; + } + + required_scratch += ALIGNED_SIZE(sizeof(float)* filter_size_padded, 8); + if(required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) { + TF_LITE_KERNEL_LOG(context, "DepthwiseConvFloat: insufficient scratch memory"); + return kTfLiteError; + } + + p_filter = (float *)p_scratch; + p_scratch = (void *)((uint8_t *)p_filter + ALIGNED_SIZE(sizeof(float)* filter_size_padded, 8)); + + for(h = 0; h < filter_height * filter_width; h++) { + for(c = 0; c < filter_depth; c++) { + p_filter[h * filter_depth_padded + c] = filter_data[h * filter_depth + c]; + } + for(c = filter_depth; c < filter_depth_padded; c++) { + p_filter[h * filter_depth_padded + c] = 0; + } + } + + for(i = 0; i < batches; i++) { + err = xa_nn_conv2d_depthwise_f32(&output_data[i*output_height*output_width*output_depth], + p_filter, //filter_data, + &input_data[i*input_height*input_width*input_depth], + bias_data, + input_height, + input_width, + input_depth, + filter_height, + filter_width, + depth_multiplier, + stride_width, + stride_height, + pad_width, + pad_height, + output_height, + output_width, + input_data_format, + output_data_format, + p_scratch); + + CHECK_ERR_HIFI_NNLIB_KER(err, "DepthwiseConvFloat: xa_nn_conv2d_depthwise_f32 failed"); + } + + //pre loop for activation_min_max to handle alignment + int out_length = batches * output_height * output_width * output_depth; + uint32 p_unalign_val = (uint32)output_data, p_align_val; + p_align_val = (p_unalign_val + 7) & (~7); + + int pre_loop_count = p_align_val - p_unalign_val; + pre_loop_count = MIN(pre_loop_count, out_length); + + for(i = 0; i < pre_loop_count; i++) { + ACTIVATION_MIN_MAX(float, output_data[i], + output_data[i], output_activation_min, output_activation_max) + } + + out_length = out_length - pre_loop_count; + + if(out_length) { + err = xa_nn_vec_activation_min_max_f32_f32( + &output_data[i], + &output_data[i], + output_activation_min, + output_activation_max, + out_length); + + CHECK_ERR_HIFI_NNLIB_KER(err, "DepthwiseConvFloat: xa_nn_vec_activation_min_max_f32_f32 failed"); + } + } else { + tflite::DepthwiseParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.depth_multiplier = params->depth_multiplier; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + tflite::reference_ops::DepthwiseConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); + } + return kTfLiteOk; +} + +void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + DepthwiseParams op_params; + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.depth_multiplier = params->depth_multiplier; + op_params.input_offset = -input->params.zero_point; + op_params.weights_offset = 0; + op_params.output_offset = output->params.zero_point; + // (b/130439627): Use calculated value for clamping. + op_params.quantized_activation_min = std::numeric_limits::min(); + op_params.quantized_activation_max = std::numeric_limits::max(); + + reference_integer_ops::DepthwiseConvPerChannel( + op_params, data->per_channel_output_multiplier, + data->per_channel_output_shift, GetTensorShape(input), + GetTensorData(input), GetTensorShape(filter), + GetTensorData(filter), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); +} + +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + const int32_t input_offset = -input->params.zero_point; + const int32_t filter_offset = -filter->params.zero_point; + const int32_t output_offset = output->params.zero_point; + + if((params->dilation_width_factor == 1) && (params->dilation_height_factor == 1)) { + const uint8 *input_data, *filter_data; + const int32_t *bias_data; + uint8 *output_data; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& filter_shape = GetTensorShape(filter); + const RuntimeShape& output_shape = GetTensorShape(output); + const RuntimeShape& bias_shape = GetTensorShape(bias); + + input_data = GetTensorData(input); + filter_data = GetTensorData(filter); + bias_data = GetTensorData(bias); + output_data = GetTensorData(output); + + const int stride_width = params->stride_width; + const int stride_height = params->stride_height; + const int dilation_width_factor = 1 ; + const int dilation_height_factor = 1; + //const int dilation_width_factor = params->dilation_width_factor; + //const int dilation_height_factor = params->dilation_height_factor; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + const int depth_multiplier = params->depth_multiplier; + const int32 output_activation_min = data->output_activation_min; + const int32 output_activation_max = data->output_activation_max; + const int32 output_multiplier = data->output_multiplier; + // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. + const int output_shift = -data->output_shift; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int input_depth = input_shape.Dims(3); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int filter_depth = filter_shape.Dims(3); + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + + int32_t err, i, input_data_format = 0, output_data_format = 0; + void *p_scratch; + uint8 *p_filter; + int filter_depth_padded, filter_size_padded, required_scratch; + int input_precision = PREC_ASYM8; + int h,c; + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = xtensa_nnlib_scratch_buf; + + required_scratch = xa_nn_conv2d_depthwise_getsize( + input_height, + input_width, + input_depth, + filter_height, + filter_width, + depth_multiplier, + stride_width, + stride_height, + pad_width, + pad_height, + output_height, + output_width, + input_precision, + input_data_format); + + if(required_scratch <= 0) { + TF_LITE_KERNEL_LOG(context, "DepthwiseConvAsym8: xa_nn_conv2d_depthwise_getsize failed"); + return kTfLiteError; + } + + filter_depth_padded = (filter_depth + 3) & (~3); + filter_size_padded = filter_height * filter_width * filter_depth_padded; + required_scratch += ALIGNED_SIZE(sizeof(uint8_t)* filter_size_padded, 8); + + if(required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE) { + TF_LITE_KERNEL_LOG(context, "DepthwiseConvAsym8: insufficient scratch memory"); + return kTfLiteError; + } + + p_filter = (uint8 *)p_scratch; + p_scratch = (void *)(p_filter + ALIGNED_SIZE(sizeof(uint8_t) * filter_size_padded, 8)); + + for(h = 0; h < filter_height * filter_width; h++) { + for(c = 0; c < filter_depth; c++) { + p_filter[h * filter_depth_padded + c] = filter_data[h * filter_depth + c]; + } + for(c = filter_depth; c < filter_depth_padded; c++) { + p_filter[h * filter_depth_padded + c] = -filter_offset; + } + } + + for(i = 0; i < batches; i++) { + err = xa_nn_conv2d_depthwise_asym8xasym8(&output_data[i*output_height*output_width*output_depth], + p_filter, //filter_data, + &input_data[i*input_height*input_width*input_depth], + bias_data, + input_height, + input_width, + input_depth, + filter_height, + filter_width, + depth_multiplier, + stride_width, + stride_height, + pad_width, + pad_height, + output_height, + output_width, + input_offset, + filter_offset, + output_multiplier, + output_shift, + output_offset, + input_data_format, + output_data_format, + p_scratch); + + CHECK_ERR_HIFI_NNLIB_KER(err, "DepthwiseConvAsym8: xa_nn_conv2d_depthwise_asym8xasym8 failed"); + } + + //pre loop for activation_min_max to handle alignment + int out_length = batches*output_height*output_width*output_depth; + uint32 p_unalign_val = (uint32)output_data, p_align_val; + p_align_val = (p_unalign_val + 7) & (~7); + + int pre_loop_count = p_align_val - p_unalign_val; + pre_loop_count = MIN(pre_loop_count, out_length); + + for(i = 0; i < pre_loop_count; i++) { + ACTIVATION_MIN_MAX_ASYM8(output_data[i], + output_data[i], output_activation_min, output_activation_max) + } + + out_length = out_length - pre_loop_count; + + if(out_length > 0){ + err = xa_nn_vec_activation_min_max_asym8_asym8( + &output_data[i], + &output_data[i], + output_activation_min, + output_activation_max, + out_length); + + CHECK_ERR_HIFI_NNLIB_KER(err, "DepthwiseConvAsym8: xa_nn_vec_activation_min_max_asym8_asym8 failed"); + } + } else { + tflite::DepthwiseParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.depth_multiplier = params->depth_multiplier; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. + op_params.output_shift = -data->output_shift; + + tflite::reference_ops::DepthwiseConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; + + const TfLiteType data_type = input->type; + int width = SizeOfDimension(input, 2); + int height = SizeOfDimension(input, 1); + int filter_width = SizeOfDimension(filter, 2); + int filter_height = SizeOfDimension(filter, 1); + + OpData data; + + // All per-channel quantized tensors need valid zero point and scale arrays. + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, filter->quantization.type, + kTfLiteAffineQuantization); + + const auto* affine_quantization = + reinterpret_cast( + filter->quantization.params); + TF_LITE_ENSURE(context, affine_quantization); + TF_LITE_ENSURE(context, affine_quantization->scale); + TF_LITE_ENSURE(context, affine_quantization->zero_point); + TF_LITE_ENSURE( + context, affine_quantization->scale->size == 1 || + affine_quantization->scale->size == + filter->dims->data[kDepthwiseConvQuantizedDimension]); + TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, + affine_quantization->zero_point->size); + } + + TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height, + filter_width, filter_height, data_type, + &data)); + + // (aselle): Consider whether float conv and quantized conv should be + // separate ops to avoid dispatch overhead here. + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, &data, input, filter, bias, output); + break; + case kTfLiteInt8: + EvalQuantizedPerChannel(context, node, params, &data, input, filter, bias, + output); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, &data, input, filter, bias, output); + break; + default: + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace depthwise_conv + +TfLiteRegistration* Register_DEPTHWISE_CONV_2D() { + static TfLiteRegistration r = {}; + r.init = depthwise_conv::Init; + r.free = depthwise_conv::Free; + r.prepare = depthwise_conv::Prepare; + r.invoke = depthwise_conv::Eval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/floor.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/floor.cc new file mode 100644 index 00000000000..da77785ab25 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/floor.cc @@ -0,0 +1,81 @@ +/****************************************************************************** +* Copyright (C) 2019 Cadence Design Systems, Inc. +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files (the +* "Software"), to use this Software with Cadence processor cores only and +* not with any other processors and platforms, subject to +* the following conditions: +* +* The above copyright notice and this permission notice shall be included +* in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +******************************************************************************/ + +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/reference/floor.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace floor { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + int err; + const float *inp_data_ptr; + float *out_data_ptr; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + err = xa_nn_elm_floor_f32_f32(out_data_ptr, inp_data_ptr, flat_size); + + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_elm_floor_f32_f32 failed"); + return kTfLiteOk; +} +} // namespace floor + +TfLiteRegistration* Register_FLOOR() { + static TfLiteRegistration r = {}; + r.invoke = floor::Eval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/fully_connected.cc new file mode 100644 index 00000000000..56fe5d4abdd --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/fully_connected.cc @@ -0,0 +1,280 @@ +/****************************************************************************** +* Copyright (C) 2019 Cadence Design Systems, Inc. +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files (the +* "Software"), to use this Software with Cadence processor cores only and +* not with any other processors and platforms, subject to +* the following conditions: +* +* The above copyright notice and this permission notice shall be included +* in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +******************************************************************************/ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace fully_connected { +namespace { + +struct OpData { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; + // The index of the temporary tensor where the quantized inputs are cached. + int input_quantized_index; +}; + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +TfLiteStatus CalculateOpData(TfLiteContext* context, + TfLiteFullyConnectedParams* params, + TfLiteType data_type, const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output, + OpData* data) { + TfLiteStatus status = kTfLiteOk; + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; + TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( + context, params->activation, output, &data->output_activation_min, + &data->output_activation_max)); + } + return status; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + FullyConnectedParams op_params; + op_params.input_offset = -input->params.zero_point; + op_params.weights_offset = -filter->params.zero_point; + op_params.output_offset = output->params.zero_point; + op_params.output_multiplier = data->output_multiplier; + // (b/138810107): Figure out whether output shift should be inverted + op_params.output_shift = -data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + + reference_integer_ops::FullyConnected( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); + return kTfLiteOk; +} + +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + TfLiteTensor* output) { + const int32_t input_offset = -input->params.zero_point; + const int32_t filter_offset = -filter->params.zero_point; + const int32_t output_offset = output->params.zero_point; + + tflite::FullyConnectedParams op_params; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. + op_params.output_shift = -data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + +#define TF_LITE_FULLY_CONNECTED(output_data_type) \ + reference_ops::FullyConnected( \ + op_params, GetTensorShape(input), GetTensorData(input), \ + GetTensorShape(filter), GetTensorData(filter), \ + GetTensorShape(bias), GetTensorData(bias), \ + GetTensorShape(output), GetTensorData(output)) + switch (output->type) { + case kTfLiteUInt8: + { + int ret, b, weight_depth, out_depth, batches; + uint8_t * p_out = GetTensorData(output); + weight_depth = GetTensorShape(filter).Dims(GetTensorShape(filter).DimensionsCount()-1); + out_depth = GetTensorShape(output).Dims(GetTensorShape(output).DimensionsCount()-1); + batches = FlatSizeSkipDim(GetTensorShape(output), GetTensorShape(output).DimensionsCount()-1); + for(b = 0; b < batches; b++) { + ret = xa_nn_fully_connected_asym8xasym8_asym8( + (GetTensorData(output) + b*out_depth), + GetTensorData(filter), + (GetTensorData(input) + b*weight_depth), + GetTensorData(bias), + weight_depth, + out_depth, + op_params.input_offset, + op_params.weights_offset, + op_params.output_multiplier, + op_params.output_shift, + op_params.output_offset); + CHECK_ERR_HIFI_NNLIB_KER(ret, "xa_nn_fully_connected_asym8xasym8_asym8 failed"); + } + for(int i =0; i < batches * out_depth; i++) + { + ACTIVATION_MIN_MAX_ASYM8(p_out[i], + p_out[i], data->output_activation_min, data->output_activation_max) + } + break; + } + case kTfLiteInt16: + TF_LITE_FULLY_CONNECTED(int16_t); + break; + default: + TF_LITE_KERNEL_LOG( + context, + "Quantized FullyConnected expects output data type uint8 or int16"); + return kTfLiteError; + } + + return kTfLiteOk; +} + +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + tflite::FullyConnectedParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + int ret, b, weight_depth, out_depth, batches; + weight_depth = GetTensorShape(filter).Dims(GetTensorShape(filter).DimensionsCount()-1); + out_depth = GetTensorShape(output).Dims(GetTensorShape(output).DimensionsCount()-1); + batches = FlatSizeSkipDim(GetTensorShape(output), GetTensorShape(output).DimensionsCount()-1); + + for(b = 0; b < batches; b++) { + ret = xa_nn_fully_connected_f32( + (GetTensorData(output) + b*out_depth), + GetTensorData(filter), + (GetTensorData(input) + b*weight_depth), + GetTensorData(bias), + weight_depth, + out_depth + ); + CHECK_ERR_HIFI_NNLIB_KER(ret, "xa_nn_fully_connected_f32 failed."); + } + float * p_out = GetTensorData(output); + for(int i =0; i < batches * out_depth; i++) + { + ACTIVATION_MIN_MAX(float, p_out[i], + p_out[i], output_activation_min, output_activation_max) + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TfLiteType data_type = input->type; + OpData local_data_object; + OpData* data = &local_data_object; + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, + filter, bias, output, data)); + + switch (filter->type) { // Already know in/out types are same. + case kTfLiteFloat32: + return EvalFloat(context, node, params, data, input, filter, bias, + output); + case kTfLiteInt8: + return EvalQuantizedInt8(context, node, params, data, input, filter, bias, + output); + + case kTfLiteUInt8: + return EvalQuantized(context, node, params, data, input, filter, bias, + output); + + default: + TF_LITE_KERNEL_LOG(context, "Type %d not currently supported.", + filter->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace fully_connected + +TfLiteRegistration* Register_FULLY_CONNECTED() { + static TfLiteRegistration r = {}; + r.init = fully_connected::Init; + r.free = fully_connected::Free; + r.prepare = fully_connected::Prepare; + r.invoke = fully_connected::Eval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/logistic.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/logistic.cc new file mode 100644 index 00000000000..75bb6414ee5 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/logistic.cc @@ -0,0 +1,125 @@ +/****************************************************************************** +* Copyright (C) 2019 Cadence Design Systems, Inc. +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files (the +* "Software"), to use this Software with Cadence processor cores only and +* not with any other processors and platforms, subject to +* the following conditions: +* +* The above copyright notice and this permission notice shall be included +* in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +******************************************************************************/ + +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/logistic.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace activations { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (input->type == kTfLiteFloat32) { + switch (output->type) { + case kTfLiteFloat32: { + int err; + const float *inp_data_ptr; + float *out_data_ptr; + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + err = xa_nn_vec_sigmoid_f32_f32(out_data_ptr, inp_data_ptr,flat_size); + + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_sigmoid_f32_f32 failed"); + return kTfLiteOk; + } + default: + TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + } else if (input->type == kTfLiteInt8) { + switch (output->type) { + case kTfLiteInt8: { + reference_ops::Logistic( + GetTensorShape(input), GetTensorData(input), + input->params.scale, input->params.zero_point, + GetTensorShape(output), GetTensorData(output), + output->params.scale, output->params.zero_point); + return kTfLiteOk; + } + default: + TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + } else { + // (b/141211002): Also support other data types once we have supported + // temporary tensors in TFLM. + TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace activations + +TfLiteRegistration* Register_LOGISTIC() { + static TfLiteRegistration r = {}; + r.prepare = activations::Prepare; + r.invoke = activations::Eval; + return &r; +} +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/pooling.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/pooling.cc new file mode 100755 index 00000000000..6288cce1618 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/pooling.cc @@ -0,0 +1,662 @@ +/****************************************************************************** +* Copyright (C) 2019 Cadence Design Systems, Inc. +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files (the +* "Software"), to use this Software with Cadence processor cores only and +* not with any other processors and platforms, subject to +* the following conditions: +* +* The above copyright notice and this permission notice shall be included +* in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +******************************************************************************/ + +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/kernels/internal/reference/pooling.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" + +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace pooling { + +namespace { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +struct OpData { + TfLitePaddingValues padding; +}; + +TfLiteStatus CalculateOpData(const TfLiteContext* context, + const TfLitePoolParams* params, + const TfLiteTensor* input, + const TfLiteTensor* output, OpData* data) { + // input: batch, height, width, channel + int height = SizeOfDimension(input, 1); + int width = SizeOfDimension(input, 2); + + int out_height, out_width; + + data->padding = ComputePaddingHeightWidth( + params->stride_height, params->stride_width, + /*dilation_rate_height=*/1, + /*dilation_rate_width=*/1, height, width, params->filter_height, + params->filter_width, params->padding, &out_height, &out_width); + + return kTfLiteOk; +} + +TfLiteStatus AverageEvalFloat(TfLiteContext* context, const TfLiteNode* node, + const TfLitePoolParams* params, const OpData* data, + const TfLiteTensor* input, TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRange(params->activation, &activation_min, + &activation_max); + + const int stride_height = params->stride_height; + const int stride_width = params->stride_width; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + const int kernel_height = params->filter_height; + const int kernel_width = params->filter_width; + + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + + const float *inp_data_ptr; + float *out_data_ptr; + int inp_data_format = 0, out_data_format = 0, out_length; + int inp_precision = PREC_F32, out_precision = PREC_F32; + void *p_scratch; + int err, required_scratch=0; + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = (void *)xtensa_nnlib_scratch_buf; + + required_scratch = xa_nn_avgpool_getsize( + depth, + inp_precision, + out_precision, + input_height, + input_width, + kernel_height, + kernel_width, + stride_width, //x_stride, + stride_height, //y_stride, + pad_width, //x_padding, + pad_height, //y_padding, + output_height, + output_width, + inp_data_format, + out_data_format); + + if(required_scratch <= 0){ + TF_LITE_KERNEL_LOG( + context, "AveragepoolFloat: xa_nn_avgpool_getsize failed"); + return kTfLiteError; + } + + if(required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE){ + TF_LITE_KERNEL_LOG( + context, "AveragepoolFloat: insufficient scratch memory"); + return kTfLiteError; + } + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + for (int batch = 0; batch < batches; ++batch) { + err = xa_nn_avgpool_f32(&out_data_ptr[output_height*output_width*depth*batch], + &inp_data_ptr[output_height*output_width*depth*batch], + input_height, + input_width, + depth, + kernel_height, + kernel_width, + stride_width, + stride_height, + pad_width, + pad_height, + output_height, + output_width, + inp_data_format, + out_data_format, + p_scratch); + + CHECK_ERR_HIFI_NNLIB_KER(err, "AveragepoolFloat: xa_nn_avgpool_f32 failed"); + } + + out_length = batches*output_height*output_width*depth; + uint32 p_unalign_val = (uint32)out_data_ptr, p_align_val; + p_align_val = (p_unalign_val + 7) & (~7); + + //pre loop for activation_min_max + int pre_loop_count = p_align_val - p_unalign_val; + pre_loop_count = MIN(pre_loop_count, out_length); + + for(int i = 0; i < pre_loop_count; i++) { + ACTIVATION_MIN_MAX(float, out_data_ptr[i], + out_data_ptr[i], activation_min, activation_max) + } + + out_length = out_length - pre_loop_count; + + if(out_length){ + err = xa_nn_vec_activation_min_max_f32_f32( + out_data_ptr, + out_data_ptr, + activation_min, + activation_max, + out_length); + + CHECK_ERR_HIFI_NNLIB_KER(err, "AveragepoolFloat: xa_nn_vec_activation_min_max_f32_f32 failed"); + } + return kTfLiteOk; +} + +TfLiteStatus AverageEvalQuantized(TfLiteContext* context, const TfLiteNode* node, + const TfLitePoolParams* params, const OpData* data, + const TfLiteTensor* input, TfLiteTensor* output) { + TFLITE_DCHECK(input->type == kTfLiteUInt8 || input->type == kTfLiteInt8); + + int32_t activation_min, activation_max; + (void)CalculateActivationRangeQuantized(context, params->activation, output, + &activation_min, &activation_max); + + if (input->type == kTfLiteUInt8) { + const int stride_height = params->stride_height; + const int stride_width = params->stride_width; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + const int kernel_height = params->filter_height; + const int kernel_width = params->filter_width; + + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + + const uint8 *inp_data_ptr; + uint8 *out_data_ptr; + int inp_data_format = 0, out_data_format = 0, out_length; + int inp_precision = PREC_ASYM8, out_precision = PREC_ASYM8; + void *p_scratch; + int err, required_scratch=0; + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = (void *)xtensa_nnlib_scratch_buf; + + required_scratch = xa_nn_avgpool_getsize( + depth, + inp_precision, + out_precision, + input_height, + input_width, + kernel_height, + kernel_width, + stride_width, //x_stride, + stride_height, //y_stride, + pad_width, //x_padding, + pad_height, //y_padding, + output_height, + output_width, + inp_data_format, + out_data_format); + + if(required_scratch <= 0){ + TF_LITE_KERNEL_LOG( + context, "AveragepoolAsym8: xa_nn_avgpool_getsize failed"); + return kTfLiteError; + } + + if(required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE){ + TF_LITE_KERNEL_LOG( + context, "AveragepoolAsym8: insufficient scratch memory"); + return kTfLiteError; + } + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + for (int batch = 0; batch < batches; ++batch) { + err = xa_nn_avgpool_asym8(&out_data_ptr[output_height*output_width*depth*batch], + &inp_data_ptr[output_height*output_width*depth*batch], + input_height, + input_width, + depth, + kernel_height, + kernel_width, + stride_width, + stride_height, + pad_width, + pad_height, + output_height, + output_width, + inp_data_format, + out_data_format, + p_scratch); + + CHECK_ERR_HIFI_NNLIB_KER(err, "AveragepoolAsym8: xa_nn_avgpool_asym8 failed"); + } + + out_length = batches*output_height*output_width*depth; + uint32 p_unalign_val = (uint32)out_data_ptr, p_align_val; + p_align_val = (p_unalign_val + 7) & (~7); + + //pre loop for activation_min_max + int pre_loop_count = p_align_val - p_unalign_val; + pre_loop_count = MIN(pre_loop_count, out_length); + + for(int i = 0; i < pre_loop_count; i++) { + ACTIVATION_MIN_MAX_ASYM8(out_data_ptr[i], + out_data_ptr[i], activation_min, activation_max) + } + + out_length = out_length - pre_loop_count; + + if(out_length > 0){ + err = xa_nn_vec_activation_min_max_asym8_asym8( + out_data_ptr, + out_data_ptr, + activation_min, + activation_max, + out_length); + + CHECK_ERR_HIFI_NNLIB_KER(err, "AveragepoolAsym8: xa_nn_vec_activation_min_max_asym8_asym8 failed"); + } + } else { + PoolParams op_params; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.filter_height = params->filter_height; + op_params.filter_width = params->filter_width; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.quantized_activation_min = activation_min; + op_params.quantized_activation_max = activation_max; + reference_integer_ops::AveragePool( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } + return kTfLiteOk; +} + +TfLiteStatus MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + const TfLiteTensor* input, TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRange(params->activation, &activation_min, + &activation_max); + + const int stride_height = params->stride_height; + const int stride_width = params->stride_width; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + const int kernel_height = params->filter_height; + const int kernel_width = params->filter_width; + + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + + const float *inp_data_ptr; + float *out_data_ptr; + int inp_data_format = 0, out_data_format = 0, out_length; + int inp_precision = PREC_F32, out_precision = PREC_F32; + void *p_scratch; + int err, required_scratch=0; + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = (void *)xtensa_nnlib_scratch_buf; + + required_scratch = xa_nn_maxpool_getsize( + depth, + inp_precision, + out_precision, + input_height, + input_width, + kernel_height, + kernel_width, + stride_width, //x_stride, + stride_height, //y_stride, + pad_width, //x_padding, + pad_height, //y_padding, + output_height, + output_width, + inp_data_format, + out_data_format); + + if(required_scratch <= 0){ + TF_LITE_KERNEL_LOG( + context, "MaxpoolFloat: xa_nn_maxpool_getsize failed"); + return kTfLiteError; + } + + if(required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE){ + TF_LITE_KERNEL_LOG( + context, "MaxpoolFloat: insufficient scratch memory"); + return kTfLiteError; + } + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + for (int batch = 0; batch < batches; ++batch) { + err = xa_nn_maxpool_f32(&out_data_ptr[output_height*output_width*depth*batch], + &inp_data_ptr[output_height*output_width*depth*batch], + input_height, + input_width, + depth, + kernel_height, + kernel_width, + stride_width, + stride_height, + pad_width, + pad_height, + output_height, + output_width, + inp_data_format, + out_data_format, + p_scratch); + + CHECK_ERR_HIFI_NNLIB_KER(err, "MaxpoolFloat: xa_nn_maxpool_f32 failed"); + } + + out_length = batches*output_height*output_width*depth; + uint32 p_unalign_val = (uint32)out_data_ptr, p_align_val; + p_align_val = (p_unalign_val + 7) & (~7); + + //pre loop for activation_min_max + int pre_loop_count = p_align_val - p_unalign_val; + pre_loop_count = MIN(pre_loop_count, out_length); + + for(int i = 0; i < pre_loop_count; i++) { + ACTIVATION_MIN_MAX(float, out_data_ptr[i], + out_data_ptr[i], activation_min, activation_max) + } + + out_length = out_length - pre_loop_count; + + if(out_length > 0){ + err = xa_nn_vec_activation_min_max_f32_f32( + out_data_ptr, + out_data_ptr, + activation_min, + activation_max, + out_length); + + CHECK_ERR_HIFI_NNLIB_KER(err, "MaxpoolFloat: xa_nn_vec_activation_min_max_f32_f32 failed"); + } + return kTfLiteOk; +} + +TfLiteStatus MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + const TfLiteTensor* input, TfLiteTensor* output) { + TFLITE_DCHECK(input->type == kTfLiteUInt8 || input->type == kTfLiteInt8); + + int32_t activation_min, activation_max; + (void)CalculateActivationRangeQuantized(context, params->activation, output, + &activation_min, &activation_max); + + if (input->type == kTfLiteUInt8) { + const int stride_height = params->stride_height; + const int stride_width = params->stride_width; + const int pad_width = data->padding.width; + const int pad_height = data->padding.height; + const int kernel_height = params->filter_height; + const int kernel_width = params->filter_width; + + const RuntimeShape& input_shape = GetTensorShape(input); + const RuntimeShape& output_shape = GetTensorShape(output); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + + const uint8 *inp_data_ptr; + uint8 *out_data_ptr; + int inp_data_format = 0, out_data_format = 0, out_length; + int inp_precision = PREC_ASYM8, out_precision = PREC_ASYM8; + void *p_scratch; + int err, required_scratch=0; + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + p_scratch = (void *)xtensa_nnlib_scratch_buf; + + required_scratch = xa_nn_maxpool_getsize( + depth, + inp_precision, + out_precision, + input_height, + input_width, + kernel_height, + kernel_width, + stride_width, //x_stride, + stride_height, //y_stride, + pad_width, //x_padding, + pad_height, //y_padding, + output_height, + output_width, + inp_data_format, + out_data_format); + + if(required_scratch <= 0){ + TF_LITE_KERNEL_LOG( + context, "MaxpoolAsym8: xa_nn_maxpool_getsize failed"); + return kTfLiteError; + } + + if(required_scratch > (int)XTENSA_NNLIB_MAX_SCRATCH_SIZE){ + TF_LITE_KERNEL_LOG( + context, "MaxpoolAsym8: insufficient scratch memory"); + return kTfLiteError; + } + + inp_data_ptr = GetTensorData(input); + out_data_ptr = GetTensorData(output); + + for (int batch = 0; batch < batches; ++batch) { + err = xa_nn_maxpool_asym8(&out_data_ptr[output_height*output_width*depth*batch], + &inp_data_ptr[output_height*output_width*depth*batch], + input_height, + input_width, + depth, + kernel_height, + kernel_width, + stride_width, + stride_height, + pad_width, + pad_height, + output_height, + output_width, + inp_data_format, + out_data_format, + p_scratch); + + CHECK_ERR_HIFI_NNLIB_KER(err, "MaxpoolAsym8: xa_nn_maxpool_asym8 failed"); + } + + out_length = batches*output_height*output_width*depth; + uint32 p_unalign_val = (uint32)out_data_ptr, p_align_val; + p_align_val = (p_unalign_val + 7) & (~7); + + //pre loop for activation_min_max + int pre_loop_count = p_align_val - p_unalign_val; + pre_loop_count = MIN(pre_loop_count, out_length); + + for(int i = 0; i < pre_loop_count; i++) { + ACTIVATION_MIN_MAX_ASYM8(out_data_ptr[i], + out_data_ptr[i], activation_min, activation_max) + } + + out_length = out_length - pre_loop_count; + + if(out_length > 0){ + err = xa_nn_vec_activation_min_max_asym8_asym8( + out_data_ptr, + out_data_ptr, + activation_min, + activation_max, + out_length); + + CHECK_ERR_HIFI_NNLIB_KER(err, "MaxpoolAsym8: xa_nn_vec_activation_min_max_asym8_asym8 failed"); + } + } else { + tflite::PoolParams op_params; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.filter_height = params->filter_height; + op_params.filter_width = params->filter_width; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.quantized_activation_min = activation_min; + op_params.quantized_activation_max = activation_max; + reference_integer_ops::MaxPool( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } + return kTfLiteOk; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData data; + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data)); + + // Inputs and outputs share the same type, guarenteed by the converter. + switch (input->type) { + case kTfLiteFloat32: + AverageEvalFloat(context, node, params, &data, input, output); + break; + case kTfLiteUInt8: + case kTfLiteInt8: + AverageEvalQuantized(context, node, params, &data, input, output); + break; + default: + TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData data; + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data)); + + switch (input->type) { + case kTfLiteFloat32: + MaxEvalFloat(context, node, params, &data, input, output); + break; + case kTfLiteUInt8: + case kTfLiteInt8: + MaxEvalQuantized(context, node, params, &data, input, output); + break; + default: + TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace pooling + +TfLiteRegistration* Register_AVERAGE_POOL_2D() { + static TfLiteRegistration r = {}; + r.init = pooling::Init; + r.free = pooling::Free; + r.prepare = pooling::Prepare; + r.invoke = pooling::AverageEval; + return &r; +} + +TfLiteRegistration* Register_MAX_POOL_2D() { + static TfLiteRegistration r = {}; + r.init = pooling::Init; + r.free = pooling::Free; + r.prepare = pooling::Prepare; + r.invoke = pooling::MaxEval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/softmax.cc new file mode 100755 index 00000000000..0d5d4022e02 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/softmax.cc @@ -0,0 +1,328 @@ +/****************************************************************************** +* Copyright (C) 2019 Cadence Design Systems, Inc. +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files (the +* "Software"), to use this Software with Cadence processor cores only and +* not with any other processors and platforms, subject to +* the following conditions: +* +* The above copyright notice and this permission notice shall be included +* in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +******************************************************************************/ + +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/softmax.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" + +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace activations { +namespace { + +struct OpData { + int32_t input_multiplier = 0; + int input_left_shift = 0; + int32_t input_range_radius = 0; + int diff_min = 0; +}; + +TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, + const TfLiteTensor* input, + TfLiteTensor* output, + const TfLiteSoftmaxParams* params, + OpData* data) { + if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + } else { + if (output->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768); + // NOTE: Current int16 softmax output does not require symmetric scaling + // - so no need to verify scale here. + } else { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); + TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); + } + } + + static const int kScaledDiffIntegerBits = 5; + + tflite::PreprocessSoftmaxScaling( + static_cast(params->beta), + static_cast(input->params.scale), kScaledDiffIntegerBits, + &data->input_multiplier, &data->input_left_shift); + data->diff_min = -1.0 * tflite::CalculateInputRadius( + kScaledDiffIntegerBits, data->input_left_shift); + } + return kTfLiteOk; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +// Takes a 1D tensor and performs softmax along it. +void Softmax1DFloat(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int input_size = input->dims->data[0]; + tflite::reference_ops::Softmax(input->data.f, input_size, 1, params->beta, + output->data.f); +} + +// Takes a 2D tensor and perform softmax along the last dimension. +TfLiteStatus Softmax2DFloat(TfLiteContext* context, const TfLiteTensor* input, + TfLiteTensor* output, TfLiteSoftmaxParams* params) { + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + float *p_scratch = (float *)xtensa_nnlib_scratch_buf; + + if(input->dims->data[1] * sizeof(float) > XTENSA_NNLIB_MAX_SCRATCH_SIZE) + { + TF_LITE_KERNEL_LOG( + context, "Softmax: insufficient scratch memory"); + return kTfLiteError; + } + + for (int i = 0; i < batch_size * input_size; ++i) { + p_scratch[i] = input->data.f[i] * params->beta; + } + + for (int i = 0; i < batch_size; ++i) { + int err = xa_nn_vec_softmax_f32_f32(&output->data.f[i * input_size], + &p_scratch[i * input_size], + input_size); + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_softmax_f32_f32 failed"); \ + } + return kTfLiteOk; +} + +void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + // (ahentz): this is arguably a dirty trick. Since the implementation + // always traverses the last dimension of a 4D tensor, we will pretend our 1D + // tensor is 4D in a special way. We will convert a (Y) shape into a (1, + // 1, 1, Y) shape. + const int input_size = input->dims->data[0]; + const int32_t shape_data[4] = {1, 1, 1, input_size}; + RuntimeShape shape(4, shape_data); + SoftmaxParams op_params; + op_params.input_multiplier = data->input_multiplier; + op_params.input_left_shift = data->input_left_shift; + op_params.diff_min = data->diff_min; + if (input->type == kTfLiteUInt8) { + tflite::reference_ops::Softmax(op_params, shape, + GetTensorData(input), shape, + GetTensorData(output)); + } else { + if (output->type == kTfLiteInt16) { + tflite::reference_integer_ops::Softmax( + op_params, shape, GetTensorData(input), shape, + GetTensorData(output)); + } else { + tflite::reference_integer_ops::Softmax( + op_params, shape, GetTensorData(input), shape, + GetTensorData(output)); + } + } +} + +TfLiteStatus Softmax2DQuantized(TfLiteContext* context, const TfLiteTensor* input, + TfLiteTensor* output, TfLiteSoftmaxParams* params, OpData* data) { + // (ahentz): this is arguably a dirty trick. Since the implementation + // always traverses the last dimension of a 4D tensor, we will pretend our 2D + // tensor is 4D in a special way. We will convert a (X, Y) shape into a (X, + // 1, 1, Y) shape. + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + const int32_t shape_data[4] = {batch_size, 1, 1, input_size}; + RuntimeShape shape(4, shape_data); + SoftmaxParams op_params; + op_params.input_multiplier = data->input_multiplier; + op_params.input_left_shift = data->input_left_shift; + op_params.diff_min = data->diff_min; + + if (input->type == kTfLiteUInt8) { + + ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM; + void *p_scratch = (void *)xtensa_nnlib_scratch_buf; + + if(get_softmax_scratch_size(PREC_ASYM8, PREC_ASYM8, input_size) > XTENSA_NNLIB_MAX_SCRATCH_SIZE) + { + TF_LITE_KERNEL_LOG( + context, "Softmax: insufficient scratch memory"); + return kTfLiteError; + } + + for (int i = 0; i < batch_size; ++i) { + int err = xa_nn_vec_softmax_asym8_asym8(&output->data.uint8[i * input_size], + &input->data.uint8[i * input_size], + op_params.diff_min, + op_params.input_left_shift, + op_params.input_multiplier, + input_size, + p_scratch + ); + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_softmax_asym8_asym8 failed"); \ + } + } else { + if (output->type == kTfLiteInt16) { + tflite::reference_integer_ops::Softmax( + op_params, shape, GetTensorData(input), shape, + GetTensorData(output)); + } else { + tflite::reference_integer_ops::Softmax( + op_params, shape, GetTensorData(input), shape, + GetTensorData(output)); + } + } + return kTfLiteOk; +} + +// Takes a 4D tensor and perform softmax along the forth dimension. +void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + SoftmaxParams op_params; + op_params.beta = static_cast(params->beta); + tflite::reference_ops::Softmax( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); +} + +void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + SoftmaxParams op_params; + op_params.input_multiplier = data->input_multiplier; + op_params.input_left_shift = data->input_left_shift; + op_params.diff_min = data->diff_min; + if (input->type == kTfLiteUInt8) { + tflite::reference_ops::Softmax( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } else { + if (output->type == kTfLiteInt16) { + tflite::reference_integer_ops::Softmax( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } else { + tflite::reference_integer_ops::Softmax( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + } + } +} + +TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + + OpData local_data_object; + OpData* data = &local_data_object; + TF_LITE_ENSURE_STATUS( + CalculateSoftmaxOpData(context, input, output, params, data)); + + // (ahentz): consider an implementation that works for many (all?) + // dimensions. + switch (input->type) { + case kTfLiteFloat32: { + if (NumDimensions(input) == 1) { + Softmax1DFloat(input, output, params); + return kTfLiteOk; + } + if (NumDimensions(input) == 2) { + return Softmax2DFloat(context, input, output, params); + } + if (NumDimensions(input) == 4) { + Softmax4DFloat(input, output, params); + return kTfLiteOk; + } + TF_LITE_KERNEL_LOG( + context, "Only 1D, 2D and 4D tensors supported currently, got %dD.", + NumDimensions(input)); + return kTfLiteError; + } + case kTfLiteInt8: + case kTfLiteUInt8: { + if (NumDimensions(input) == 1) { + Softmax1DQuantized(input, output, params, data); + return kTfLiteOk; + } + if (NumDimensions(input) == 2) { + return Softmax2DQuantized(context, input, output, params, data); + } + if (NumDimensions(input) == 4) { + Softmax4DQuantized(input, output, params, data); + return kTfLiteOk; + } + TF_LITE_KERNEL_LOG(context, + "Only 2D and 4D tensors supported currently, got %dD.", + NumDimensions(input)); + return kTfLiteError; + } + default: + TF_LITE_KERNEL_LOG( + context, + "Only float32, uint8_t and int8_t supported currently, got %d.", + input->type); + return kTfLiteError; + } +} +} // namespace activations + +TfLiteRegistration* Register_SOFTMAX() { + static TfLiteRegistration r = {}; + r.init = activations::Init; + r.free = activations::Free; + r.prepare = activations::SoftmaxPrepare; + r.invoke = activations::SoftmaxEval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/svdf.cc b/tensorflow/lite/micro/kernels/xtensa_hifi/svdf.cc new file mode 100644 index 00000000000..f6f3dcd329b --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/svdf.cc @@ -0,0 +1,599 @@ +/****************************************************************************** +* Copyright (C) 2019 Cadence Design Systems, Inc. +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files (the +* "Software"), to use this Software with Cadence processor cores only and +* not with any other processors and platforms, subject to +* the following conditions: +* +* The above copyright notice and this permission notice shall be included +* in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +******************************************************************************/ + +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/activation_utils.h" +#include "tensorflow/lite/micro/micro_utils.h" + +#include "xtensa_tf_micro_common.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace svdf { +namespace { + +// These constants represent constants specific to the hotword "OK G" model. +// They exist until (b/132070898) is fixed. +constexpr int kScratchTensorMaxSize = 64; + +struct OpData { + int32 effective_scale_1_a; + int32 effective_scale_2_a; + // b versions of each scale are kept at int since the numbers are just the + // shift value - typically between [-32, 32]. + int effective_scale_1_b; + int effective_scale_2_b; +}; + +/** + * This version of SVDF is specific to TFLite Micro. It contains the following + * differences between the TFLite version: + * + * 1.) Scratch tensor allocation - scratch tensors must be known ahead of time + * for the Micro interpreter. + * 2.) Output dimensions - the TFLite version determines output size and runtime + * and resizes the output tensor. Micro runtime does not support tensor + * resizing. + */ + +static inline TfLiteStatus ApplyTimeWeightsBiasAndActivation( + TfLiteContext* context, int batch_size, int memory_size, int num_filters, + int num_units, int rank, const TfLiteTensor* weights_time, + const TfLiteTensor* bias, TfLiteFusedActivation activation, + TfLiteTensor* activation_state, TfLiteTensor* scratch, + TfLiteTensor* output) { + + float* scratch_bias = GetTensorData(scratch); + if(bias) { + const float* bias_data = GetTensorData(bias); + for (int j = 0; j < num_units; ++j) { + scratch_bias[j] = *bias_data++; + } + } + else { + for (int j = 0; j < num_units; ++j) { + scratch_bias[j] = 0.0f; + } + } + int err = 0; + for (int b = 0; b < batch_size; ++b) { + const float* weights_time_vec = GetTensorData(weights_time); + const float* mat_ptr = + GetTensorData(activation_state) + b * memory_size * num_filters; + float* output_ptr_batch = GetTensorData(output) + b * num_units; + for (int j = 0; j < num_units; j++) { + err = xa_nn_matXvec_f32xf32_f32(output_ptr_batch, + mat_ptr, + NULL, + weights_time_vec, + NULL, + scratch_bias, + 1, + memory_size * rank, + 0, + memory_size * rank, + 0); + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_matXvec_f32xf32_f32 failed"); + output_ptr_batch++; + mat_ptr += memory_size * rank; + weights_time_vec += memory_size * rank; + } + } + + // Apply activation. + for (int b = 0; b < batch_size; ++b) { + float* output_ptr_batch = GetTensorData(output) + b * num_units; + for (int i = 0; i < num_units; ++i) { + *output_ptr_batch = ActivationValFloat(activation, *output_ptr_batch); + ++output_ptr_batch; + } + } + + // Left shift the activation_state to make room for next cycle's activation. + // (alanchiao): explore collapsing this into a single loop. + for (int b = 0; b < batch_size; ++b) { + float* state_ptr_batch = + GetTensorData(activation_state) + b * memory_size * num_filters; + for (int f = 0; f < num_filters; ++f) { + // Shift the vector left: + float* batch_ptr = state_ptr_batch; + float* batch_start = state_ptr_batch + 1; + float* batch_end = state_ptr_batch + memory_size; + while (batch_start != batch_end) { + *batch_ptr++ = *batch_start++; + } + state_ptr_batch[memory_size - 1] = 0.0f; + state_ptr_batch += memory_size; + } + } + return kTfLiteOk; +} + +inline TfLiteStatus EvalFloatSVDF(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* input, + const TfLiteTensor* weights_feature, + const TfLiteTensor* weights_time, + const TfLiteTensor* bias, + const TfLiteSVDFParams* params, TfLiteTensor* scratch, + TfLiteTensor* activation_state, + TfLiteTensor* output) { + const int rank = params->rank; + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + const int num_filters = weights_feature->dims->data[0]; + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + + // Clear the activation (activation_state's leftmost column). + // (ghodrat): Add a test which initialize activation_state with invalid + // values in leftmost column and make sure it passes. + for (int b = 0; b < batch_size; ++b) { + float* state_ptr_batch = + GetTensorData(activation_state) + b * memory_size * num_filters; + } + + // Compute conv1d(inputs, weights_feature). + // The activation_state's rightmost column is used to save current cycle + // activation. This is achieved by starting at + // GetTensorData(activation_state)[memory_size - 1] and having the + // stride equal to memory_size. + + const float* matrix = GetTensorData(weights_feature); + const float* vector = GetTensorData(input); + float* out_scratch = GetTensorData(scratch); + /* NNLib matXvec needs a bias buffer, so using output buffer to + avoid need for extra memory, output buffer size is batch * num_units, + batch is at least 1 so we use size num_units of it */ + float* bias_scratch = GetTensorData(output); + float* result = &GetTensorData(activation_state)[memory_size - 1]; + float* result_in_batch = result; + + for (int i = 0; i < num_units; i++) + bias_scratch[i] = 0.0f; + + int err = 0; + for (int i = 0; i < batch_size; i++) { + /* We are using output buffer for bias (it is needed by NNLib kernel, + so only num_units size is guaranteed, so introduced rank loop and + calling matXvec for num_units rows */ + for (int j = 0; j < rank; j++) { + err = xa_nn_matXvec_f32xf32_f32(&out_scratch[j*num_units], + &matrix[j*input_size*num_units], + NULL, + &vector[i*input_size], + NULL, + bias_scratch, + num_units, + input_size, + 0, + input_size, + 0); + CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_matXvec_f32xf32_f32 failed"); + } + for (int j = 0; j < num_filters; ++j) { + *result_in_batch = out_scratch[j]; + result_in_batch += memory_size; + } + } + + return ApplyTimeWeightsBiasAndActivation( + context, batch_size, memory_size, num_filters, num_units, rank, weights_time, + bias, params->activation, activation_state, scratch, output); +} + +void EvalIntegerSVDF( + TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input_tensor, + const TfLiteTensor* weights_feature_tensor, + const TfLiteTensor* weights_time_tensor, const TfLiteTensor* bias_tensor, + const TfLiteSVDFParams* params, TfLiteTensor* activation_state_tensor, + TfLiteTensor* output_tensor, int32_t scale_1_a, int scale_1_b, + int32_t scale_2_a, int scale_2_b, int32_t input_zp, int32_t output_zp) { + const int n_rank = params->rank; + const int n_batch = input_tensor->dims->data[0]; + const int n_input = input_tensor->dims->data[1]; + const int n_filter = weights_feature_tensor->dims->data[0]; + const int n_unit = n_filter / n_rank; + const int n_memory = weights_time_tensor->dims->data[1]; + + // (b/132070898): Move these temp variables to the new scratch buffer API + // when ready. + int32_t scratch_tensor[kScratchTensorMaxSize]; + int32_t scratch_output_tensor[kScratchTensorMaxSize]; + + // Rewrite last bit of state. + { + for (int b = 0; b < n_batch; ++b) { + int16_t* state_ptr_batch = + GetTensorData(activation_state_tensor) + + b * n_memory * n_filter; + for (int c = 0; c < n_filter; ++c) { + int16_t* state_ptr = state_ptr_batch + c * n_memory; + state_ptr[n_memory - 1] = 0; + } + } + } + + // Feature matmul. + { + int16_t* state = GetTensorData(activation_state_tensor); + const int8_t* input = GetTensorData(input_tensor); + const int8_t* weight_feature = + GetTensorData(weights_feature_tensor); + const int32_t output_max = std::numeric_limits::max(); + const int32_t output_min = std::numeric_limits::min(); + int16_t* result_in_batch = state + (n_memory - 1); + for (int b = 0; b < n_batch; b++) { + const int8_t* matrix_ptr = weight_feature; + for (int r = 0; r < n_filter; r++) { + int32_t dot_prod = 0; + const int8_t* vector_in_batch = input + b * n_input; + for (int c = 0; c < n_input; c++) { + dot_prod += *matrix_ptr++ * (*vector_in_batch++ - input_zp); + } + dot_prod = + MultiplyByQuantizedMultiplier(dot_prod, scale_1_a, scale_1_b); + dot_prod = std::min(std::max(output_min, dot_prod), output_max); + *result_in_batch = dot_prod; + result_in_batch += n_memory; + } + } + } + + // Time. + { + for (int b = 0; b < n_batch; ++b) { + int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter; + + // Perform batched vector dot product: + const int16_t* vector1_ptr = GetTensorData(weights_time_tensor); + const int16_t* vector2_ptr = + GetTensorData(activation_state_tensor) + + b * n_memory * n_filter; + + for (int i = 0; i < n_filter; i++) { + *scratch_ptr_batch = 0; + for (int j = 0; j < n_memory; j++) { + *scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++; + } + scratch_ptr_batch++; + } + } + } + + // Reduce, add bias, rescale, activation. + { + // Add bias. + if (bias_tensor) { + // Vector batch assign: + const int32_t* bias_data = GetTensorData(bias_tensor); + for (int i = 0; i < n_batch; ++i) { + int32_t* output_ptr = scratch_output_tensor + i * n_unit; + const int32_t* bias_ptr = bias_data; + for (int j = 0; j < n_unit; ++j) { + *output_ptr++ = *bias_ptr++; + } + } + } else { + int32_t* output_ptr = scratch_output_tensor; + for (int i = 0; i < n_batch * n_unit; ++i) { + *output_ptr++ = 0; + } + } + + // Reduce. + for (int b = 0; b < n_batch; ++b) { + int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit; + int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter; + + // Reduction sum vector + for (int i = 0; i < n_unit; ++i) { + for (int j = 0; j < n_rank; ++j) { + output_temp_ptr[i] += *scratch_ptr_batch++; + } + } + } + + // Rescale. + const int32_t output_max = std::numeric_limits::max(); + const int32_t output_min = std::numeric_limits::min(); + for (int i = 0; i < n_batch * n_unit; ++i) { + int32_t x1 = scratch_output_tensor[i]; + int32_t x2 = MultiplyByQuantizedMultiplier(x1, scale_2_a, scale_2_b); + int32_t x3 = x2 + output_zp; + int32_t x4 = std::min(std::max(output_min, x3), output_max); + GetTensorData(output_tensor)[i] = static_cast(x4); + } + } + + // Shift state. + { + for (int b = 0; b < n_batch; ++b) { + int16_t* state_ptr_batch = + GetTensorData(activation_state_tensor) + + b * n_memory * n_filter; + for (int f = 0; f < n_filter; ++f) { + // Shift the vector left: + int16_t* batch_ptr = state_ptr_batch; + int16_t* batch_start = state_ptr_batch + 1; + int16_t* batch_end = state_ptr_batch + n_memory; + while (batch_start != batch_end) { + *batch_ptr++ = *batch_start++; + } + state_ptr_batch[n_memory - 1] = 0; + state_ptr_batch += n_memory; + } + } + } +} + +} // namespace + +// Input tensors. +constexpr int kInputTensor = 0; +constexpr int kWeightsFeatureTensor = 1; +constexpr int kWeightsTimeTensor = 2; +constexpr int kBiasTensor = 3; +// This is a variable tensor, and will be modified by this op. +constexpr int kInputActivationStateTensor = 4; + +// Output tensor. +constexpr int kOutputTensor = 0; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const auto* params = reinterpret_cast(node->builtin_data); + + // Validate Tensor Inputs (dtype depends on quantization): + // [0] = Input, {2, batch_size, input_size} + // [1] = Weights Feature, {2, num_filters, input_size} + // [2] = Weights Time, {2, num_filters, memory_size} + // [3] = Bias (optional), {1, num_units} + // [4] = Activation State (variable), + // {2, batch_size, memory_size * num_filters} + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* weights_feature = + GetInput(context, node, kWeightsFeatureTensor); + const TfLiteTensor* weights_time = + GetInput(context, node, kWeightsTimeTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + const TfLiteTensor* activation_state = + GetInput(context, node, kInputActivationStateTensor); + + // Define input constants based on input tensor definition above: + const int rank = params->rank; + const int input_size = input->dims->data[1]; + const int batch_size = input->dims->data[0]; + const int num_filters = weights_feature->dims->data[0]; + TF_LITE_ENSURE_EQ(context, num_filters % rank, 0); + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + + const bool is_full_integer = input->type == kTfLiteInt8; + + // Validate Input Tensor: + TF_LITE_ENSURE(context, + input->type == kTfLiteFloat32 || input->type == kTfLiteInt8); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2); + + // Validate Tensor Output: + // [0] = float/int8, {2, batch_size, num_units} + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2); + TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size); + TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units); + + // Validate Weights Feature Input Tensor: + TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2); + TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size); + + // Validate Weights Time Input Tensor: + TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2); + TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters); + TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size); + + // Validate Optional Bias Input Tensor: + if (bias) { + TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units); + } + + // Validate Activation State Input Tensor: + TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2); + TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size); + TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1], + memory_size * num_filters); + + if (is_full_integer) { + TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); + + TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8); + TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16); + + if (bias) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); + } + + TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16); + + // Validate Scratch Tensors: + // [0] = (shared - see float block below for usage) + // [1] = Output Temp, int8_t, {2, num_units, batch_size} + // (b/132070898): Scratch values are used as stack variables in + // EvalIntegerSVDF(). + + // Validate output tensor: + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt8); + } else { + TF_LITE_ENSURE_EQ(context, node->inputs->size, 6); + + // Validate Input Tensor dtypes: + TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32); + + if (bias) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32); + } + + // Validate shared Scratch Tensor: + // [0] = Holds dot-product of time-forward calculations in + // ApplyTimeWeightsBiasAndActivation(): + // float/int32, {2, batch_size, num_filters} + // (b/132070898): Use input tensor as variable until scratch tensor + // allocation has been implemented (b/132070898) TfLiteTensor* + // scratch_tensor = GetTemporary(context, node, 0); + TfLiteTensor* scratch_tensor = &context->tensors[node->inputs->data[5]]; + TF_LITE_ENSURE_EQ(context, scratch_tensor->type, kTfLiteFloat32); + + TF_LITE_ENSURE_EQ(context, NumDimensions(scratch_tensor), 2); + TF_LITE_ENSURE_EQ(context, scratch_tensor->dims->data[0], batch_size); + TF_LITE_ENSURE_EQ(context, scratch_tensor->dims->data[1], num_filters); + + // Full-float SVDF only uses the one shared scratch tensor (see above for + // usage). + // (b/132070898): Use input tensor as variable until scratch tensor + // allocation has been implemented. + // TF_LITE_ENSURE_EQ(context, node->temporaries->size, 1); + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* weights_feature = + GetInput(context, node, kWeightsFeatureTensor); + const TfLiteTensor* weights_time = + GetInput(context, node, kWeightsTimeTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* activation_state = + GetVariableInput(context, node, kInputActivationStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + const bool is_full_integer = input->type == kTfLiteInt8; + + switch (weights_feature->type) { + case kTfLiteFloat32: { + // (b/132070898): Use input tensor as variable until scratch tensor + // allocation has been implemented. TfLiteTensor* scratch = + // GetTemporary(context, node, /*index=*/0); + TfLiteTensor* scratch = &context->tensors[node->inputs->data[5]]; + return EvalFloatSVDF(context, node, input, weights_feature, weights_time, + bias, params, scratch, activation_state, output); + break; + } + + case kTfLiteInt8: { + if (is_full_integer) { + // (b/132070898): Store these values in ::Prepare() instead of + // ::Eval(): + // Calculate effective scales. + OpData op_data; + auto* input_params = reinterpret_cast( + input->quantization.params); + auto* weights_feature_params = + reinterpret_cast( + weights_feature->quantization.params); + auto* state_params = reinterpret_cast( + activation_state->quantization.params); + auto* weight_time_params = reinterpret_cast( + weights_time->quantization.params); + auto* output_params = reinterpret_cast( + output->quantization.params); + const double effective_scale_1 = + static_cast(input_params->scale->data[0] * + weights_feature_params->scale->data[0] / + state_params->scale->data[0]); + const double effective_scale_2 = static_cast( + state_params->scale->data[0] * weight_time_params->scale->data[0] / + output_params->scale->data[0]); + QuantizeMultiplier(effective_scale_1, &op_data.effective_scale_1_a, + &op_data.effective_scale_1_b); + QuantizeMultiplier(effective_scale_2, &op_data.effective_scale_2_a, + &op_data.effective_scale_2_b); + + TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu); + EvalIntegerSVDF( + context, node, input, weights_feature, weights_time, bias, params, + activation_state, output, op_data.effective_scale_1_a, + op_data.effective_scale_1_b, op_data.effective_scale_2_a, + op_data.effective_scale_2_b, input->params.zero_point, + output->params.zero_point); + return kTfLiteOk; + } + break; + } + + default: + TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.", + TfLiteTypeGetName(weights_feature->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace svdf + +TfLiteRegistration* Register_SVDF() { + static TfLiteRegistration r = {}; + r.init = svdf::Init; + r.free = svdf::Free; + r.prepare = svdf::Prepare; + r.invoke = svdf::Eval; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h b/tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h new file mode 100755 index 00000000000..1cb80c5f46c --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa_hifi/xtensa_tf_micro_common.h @@ -0,0 +1,79 @@ +/****************************************************************************** +* Copyright (C) 2019 Cadence Design Systems, Inc. +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files (the +* "Software"), to use this Software with Cadence processor cores only and +* not with any other processors and platforms, subject to +* the following conditions: +* +* The above copyright notice and this permission notice shall be included +* in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +******************************************************************************/ + +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef __XTENSA_TF_MICRO_COMMON__ +#define __XTENSA_TF_MICRO_COMMON__ + +#include "xa_nnlib_api.h" +#include "xa_nnlib_standards.h" + +#define CHECK_ERR_HIFI_NNLIB_KER(ret, err_msg) \ + if(ret != 0) { \ + TF_LITE_KERNEL_LOG( \ + context, \ + err_msg); \ + return kTfLiteError; \ + } + +#ifndef XTENSA_NNLIB_MAX_SCRATCH_SIZE +#define XTENSA_NNLIB_MAX_SCRATCH_SIZE (70 * 1024) +#endif + +#define ALLOCATE_XTENSA_NNLIB_SCRATCH_MEM \ + uint8_t xtensa_nnlib_scratch_buf[XTENSA_NNLIB_MAX_SCRATCH_SIZE]; + +#define MIN(a, b) (a)<(b)?(a):(b); +#define MAX(a, b) (a)>(b)?(a):(b); + +#define ACTIVATION_MIN_MAX(data_type, out, inp, min, max) {\ + data_type temp = MAX(inp, min);\ + out = MIN(temp, max);\ +} + +#define ACTIVATION_MIN_MAX_F32(out, inp, min, max) {\ + float temp = MAX(inp, min);\ + out = MIN(temp, max);\ +} + +#define ACTIVATION_MIN_MAX_ASYM8(out, inp, min, max) {\ + int32_t temp = MAX((int32_t)inp, min);\ + out = (uint8_t)MIN(temp, max);\ +} + +#define ALIGNED_SIZE(x, bytes) (((x)+(bytes-1))&(~(bytes-1))) +#define ALIGN_PTR(x, bytes) ((((unsigned)(x))+(bytes-1))&(~(bytes-1))) + +#endif /* __XTENSA_TF_MICRO_COMMON__ */ diff --git a/tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh b/tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh new file mode 100755 index 00000000000..50415e7cf11 --- /dev/null +++ b/tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh @@ -0,0 +1,59 @@ +#!/bin/bash -e +# ============================================================================== +# Copyright (C) 2019 Cadence Design Systems, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to use this Software with Cadence processor cores only and +# not with any other processors and platforms, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# ============================================================================== + +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Tests an Xtensa binary by parsing the log output. +# +# First argument is the binary location. +# Second argument is a regular expression that's required to be in the output +# logs for the test to pass. + +declare -r ROOT_DIR=`pwd` +declare -r TEST_TMPDIR=/tmp/test_xtensa_hifi_binary/ +declare -r MICRO_LOG_PATH=${TEST_TMPDIR}/$1 +declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt +mkdir -p ${MICRO_LOG_PATH} + +xt-run $1 2>&1 | tee ${MICRO_LOG_FILENAME} + +if grep -q "$2" ${MICRO_LOG_FILENAME} +then + echo "$1: PASS" + exit 0 +else + echo "$1: FAIL - '$2' not found in logs." + exit 1 +fi diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifi_nn_library.inc b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifi_nn_library.inc new file mode 100644 index 00000000000..7dd2f4fc4e9 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifi_nn_library.inc @@ -0,0 +1,67 @@ +ifneq ($(filter xtensa_hifi, $(ALL_TAGS)),) + + XTENSA_PATH = $(MAKEFILE_DIR)/../../kernels/xtensa_hifi + + ifneq (,$(filter hifi4%, $(TARGET_ARCH))) + + CCFLAGS += -DNNLIB_V2 \ + -DXTENSA_NNLIB_MAX_SCRATCH_SIZE=70*1024 + + CXXFLAGS += -DNNLIB_V2 \ + -DXTENSA_NNLIB_MAX_SCRATCH_SIZE=70*1024 + + MICROLITE_CC_SRCS += \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_activations_f32_f32.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_activations_asym8_asym8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_activations_32_16.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_activations_32_8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/activations/hifi4/xa_nn_softmax_asym8_asym8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/basic/hifi4/xa_nn_floor_f32.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_circ_buf.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_asym8xasym8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_std_f32.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_matXvec_asym8xasym8_asym8_circ.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_matXvec_f32_circ.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_depthwise.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_depthwise_f32.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_conv2d_depthwise_asym8xasym8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/cnn/hifi4/xa_nn_circ_buf.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/fc/hifi4/xa_nn_fully_connected.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_f32.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_16x16.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_8x16.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_8x8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/matXvec/hifi4/xa_nn_matXvec_asym8xasym8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool_f32.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool_asym8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool_f32.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool_asym8.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool_f32_nhwc.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_avgpool_asym8_nhwc.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool_f32_nhwc.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_maxpool_asym8_nhwc.c \ + $(XTENSA_PATH)/xa_nnlib/algo/kernels/pool/hifi4/xa_nn_inv_256_tbl.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_sigmoidf_hifi4.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_tanhf_hifi4.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_reluf_hifi4.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_softmaxf_hifi4.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/vec_alognf_hifi4.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/scl_sigmoidf_hifi4.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/scl_tanhf_hifi4.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/expf_tbl.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/pow2f_tbl.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/inff_tbl.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/tanhf_tbl.c \ + $(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/src/nanf_tbl.c \ + + INCLUDES += -I$(XTENSA_PATH)/xa_nnlib/algo/kernels/ \ + -I$(XTENSA_PATH)/xa_nnlib/include/nnlib/ \ + -I$(XTENSA_PATH)/xa_nnlib/include/ \ + -I$(XTENSA_PATH)/xa_nnlib/algo/common/include/ \ + -I$(XTENSA_PATH)/xa_nnlib/algo/ndsp/hifi4/include/ \ + + endif + +endif diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_hifi/README.md b/tensorflow/lite/micro/tools/make/targets/xtensa_hifi/README.md new file mode 100644 index 00000000000..1dc75412575 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_hifi/README.md @@ -0,0 +1,31 @@ + +# Building TensorFlow Lite for Microcontrollers for Cadence Tensilica HiFi DSPs +This document describes the steps to build and run the Tensorflow Lite Micro on the Cadence HiFi DSPs. + +## Pre-requisites + +The Xtensa development tools and the target processor configurations should be installed on the system. +Please check [https://tensilicatools.com] for more information about downloading and installing the required tools. + +The PATH variable should be set to include the /bin directory. +The XTENSA_SYSTEM and XTENSA_CORE environment variables should be set to the required tools version +and the required processor configuration. + +## Building for HiFi Processors + +To build the code using Xtensa tools for the processor configuration selected by XTENSA_CORE , set TARGET=xtensa_hifi. +Additionally TARGET_ARCH can be used to select optimized HiFi NN kernels specific to the processor configuration. +Currently the HiFi4 NN kernels are provided which can be enabled as follows: + +make -f tensorflow/lite/micro/tools/make/Makefile test_micro_speech_test TARGET=xtensa_hifi TARGET_ARCH=hifi4 + +Xtensa specific TF Lite Micro kernels are implemented in this folder: +tensorflow/lite/micro/kernels/xtensa_hifi/ + +A scratch memory allocation is needed for the HiFi optimized kernels. +This allocation is currently done on stack and it's size can be controlled by defining 'XTENSA_NNLIB_MAX_SCRATCH_SIZE' approproately +in the file 'tensorflow/lite/micro/tools/make/ext_libs/xtensa_hifi_nn_library.inc + +The files containing the HiFi optimized NN kernels are present in this folder: +tensorflow/lite/micro/kernels/xtensa_hifi/xa_nnlib/ + diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_hifi_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_hifi_makefile.inc new file mode 100644 index 00000000000..e2db712a433 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_hifi_makefile.inc @@ -0,0 +1,42 @@ +# Settings for Xtensa toolchain. +# Derived from xtensa_xpg_makefile.inc +# The Xtensa environment variables should be configured externally (XTENSA_CORE, XTENSA_SYSTEM) + +ifeq ($(TARGET), xtensa_hifi) + TARGET_ARCH := hifi3_bd5 + + PLATFORM_ARGS = \ + -mno-mul16 \ + -mno-mul32 \ + -mno-div32 \ + -fsigned-char \ + -fno-exceptions \ + -mlongcalls \ + -INLINE:requested \ + -mcoproc \ + -fno-zero-initialized-in-bss \ + -mtext-section-literals \ + -fno-unsafe-math-optimizations \ + + TF_LITE_MICRO_FLAGS = \ + -DTF_LITE_STATIC_MEMORY\ + + TARGET_TOOLCHAIN_PREFIX := xt- + CXX_TOOL := clang++ + CC_TOOL := clang + + CXXFLAGS = -O0 $(PLATFORM_ARGS) -std=c++11 $(TF_LITE_MICRO_FLAGS) + #TODO: Use -std=c11 ? + CCFLAGS = -O3 $(PLATFORM_ARGS) $(TF_LITE_MICRO_FLAGS) + + TEST_SCRIPT := tensorflow/lite/micro/testing/test_xtensa_hifi_binary.sh + + # These are microcontroller-specific rules for converting the ELF output + # of the linker into a binary image that can be loaded directly. + OBJCOPY := $(TARGET_TOOLCHAIN_PREFIX)objcopy + + $(BINDIR)/%.bin: $(BINDIR)/% + echo "here" + @mkdir -p $(dir $@) + $(OBJCOPY) $< $@ -O binary +endif diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc index d9545fc2116..db50d98732d 100644 --- a/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc @@ -6,6 +6,8 @@ ifeq ($(TARGET), xtensa-xpg) TARGET_ARCH := xtensa-xpg +$(eval $(call add_third_party_download,$(XTENSA_HIFI4_URL),$(XTENSA_HIFI4_MD5),xa_nnlib,)) + PLATFORM_ARGS = \ -DTF_LITE_MICRO_TENSORS_PREPARED \ -DTF_LITE_STATIC_MEMORY \ diff --git a/tensorflow/lite/micro/tools/make/third_party_downloads.inc b/tensorflow/lite/micro/tools/make/third_party_downloads.inc index 4f5eecfce04..8ebaedcd402 100644 --- a/tensorflow/lite/micro/tools/make/third_party_downloads.inc +++ b/tensorflow/lite/micro/tools/make/third_party_downloads.inc @@ -59,3 +59,7 @@ EMBARC_OSP_MD5 := "9eaf7b3a1ed05872a03da9796672a776" EMBARC_MLI_URL := "https://github.com/foss-for-synopsys-dwc-arc-processors/embarc_mli/archive/6316034d421cbbb59756239908d7c9a99075a3bb.zip" EMBARC_MLI_MD5 := "db0910cf0e07e43f74ae7a31de485d56" + +XTENSA_HIFI4_URL :="https://github.com/foss-xtensa/nnlib-hifi4/blob/master/archive/xa_nnlib.zip" +XTENSA_HIFI4_MD5 :="a517b653a75b96d0271e1b99ee2a8c14" +