diff --git a/tensorflow/lite/micro/kernels/xtensa/conv.cc b/tensorflow/lite/micro/kernels/xtensa/conv.cc new file mode 100644 index 00000000000..de9820b82d9 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/conv.cc @@ -0,0 +1,456 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/conv.h" + +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" + +namespace tflite { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +// Conv is quantized along dimension 0: +// https://www.tensorflow.org/lite/performance/quantization_spec +constexpr int kConvQuantizedDimension = 0; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + + // Cached tensor zero point values for quantized operations. + int32_t input_zero_point; + int32_t output_zero_point; + + // Per channel output multiplier and shift. + int32_t* per_channel_output_multiplier; + int32_t* per_channel_output_shift; + + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, + const RuntimeShape& input_shape, const int8_t* input_data, + const RuntimeShape& filter_shape, const int8_t* filter_data, + const RuntimeShape& bias_shape, const int32_t* bias_data, + const RuntimeShape& output_shape, int8_t* output_data) { + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int dilation_width_factor = params.dilation_width_factor; + const int dilation_height_factor = params.dilation_height_factor; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + const int32_t input_offset = params.input_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + + const int batches = input_shape.Dims(0); + + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int input_depth = input_shape.Dims(3); + + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int filter_depth = filter_shape.Dims(3); + + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int output_depth = output_shape.Dims(3); + + ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset); + ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset); + ae_q56s output_activation_min_56 = AE_CVTQ48A32S(output_activation_min); + ae_q56s output_activation_max_56 = AE_CVTQ48A32S(output_activation_max); + + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + const int in_y_origin = (out_y * stride_height) - pad_height; + for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = (out_x * stride_width) - pad_width; + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + ae_q56s acc_56 = AE_ZEROQ56(); + + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; filter_x += 2) { + const int in_x = in_x_origin + dilation_width_factor * filter_x; + const int in_y = in_y_origin + dilation_height_factor * filter_y; + const bool is_point_inside_image = + (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height); + if (is_point_inside_image) { + // Find current input index, minus 2 for Xtensa load + // alignments: + // TODO(b/147322595): Consider doing these offset calculations + // with intrinsics: + int input_idx = + ((batch * input_height + in_y) * input_width + in_x) * + input_depth * 2 - + 2; + const int8_t* input_vals_offset_ptr = input_data + input_idx; + for (int i = 0; i < input_depth; i += 2) { + // Load signed 2x 8bit values and right shift into 24bit + // alignment: + ae_p24x2s input_vals_24x2; + AE_LP8X2F_IU(input_vals_24x2, input_vals_offset_ptr, 2); + input_vals_24x2 = AE_P24X2S_SRAI(input_vals_24x2, 16); + + // Add input offset (24bit aligned): + input_vals_24x2 = + AE_P24S_ADDS_P24X2S(input_vals_24x2, input_offset_24x2); + + // Find current filter index, minus 2 for Xtensa load + // alignments: + int filter_idx = + ((out_channel * filter_height + filter_y) * filter_width + + filter_x) * + filter_depth + + i - 2; + const int8_t* filter_vals_offset_ptr = + filter_data + filter_idx; + + // Load signed 2x 8bit values and right shift into 24bit + // alignment: + ae_p24x2s filter_vals_24x2; + AE_LP8X2F_IU(filter_vals_24x2, filter_vals_offset_ptr, 2); + filter_vals_24x2 = AE_P24X2S_SRAI(filter_vals_24x2, 16); + + // Multiply and accumulate into 48bit bit space: + AE_MULAAP24S_HH_LL(acc_56, filter_vals_24x2, input_vals_24x2); + } + } + } + } + + // Left shift from 48bit alignment to 32bit: + acc_56 = AE_Q56S_SLAI(acc_56, 16); + + if (bias_data) { + // Load and add bias at 32bit alignment: + ae_q56s bias_56 = AE_CVTQ48A32S(bias_data[out_channel]); + acc_56 = AE_ADDQ56(acc_56, bias_56); + } + + // Shift from 32bit alignment to 24bit alignment and place back on + // the PR register: + acc_56 = AE_Q56S_SLAI(acc_56, 8); + ae_p24x2s acc_24x2 = AE_TRUNCP24Q48(acc_56); + + // Apply quantized multiplier and accumulate result at 48bit + // alignment. Convert the (unsigned) 32-bit multiplier down to a + // 24-bit multiplier. + acc_56 = MultiplyByQuantizedMultiplier( + acc_24x2, output_multiplier[out_channel] >> 8, + output_shift[out_channel]); + + // Add output offset, cap activation, and assign to the output: + acc_56 = AE_ADDQ56(acc_56, output_offset_56); + acc_56 = AE_MINQ56S(acc_56, output_activation_max_56); + acc_56 = AE_MAXQ56S(acc_56, output_activation_min_56); + + int output_idx = + ((batch * output_height + out_y) * output_width + out_x) * + output_depth + + out_channel; + output_data[output_idx] = static_cast(AE_TRUNCA32Q48(acc_56)); + } + } + } + } +} + +// TODO(b/154240772): Move shared code into common methods. +inline void Conv1x32Input32x32Filter( + const int input_offset, const int output_offset, + const int quantized_activation_min, const int quantized_activation_max, + const int32_t* output_multiplier, const int32_t* output_shift, + const RuntimeShape& input_shape, const int8_t* input_data, + const RuntimeShape& filter_shape, const int8_t* filter_data, + const RuntimeShape& bias_shape, const int32_t* bias_data, + const RuntimeShape& output_shape, int8_t* output_data) { + ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset); + ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset); + ae_q56s output_activation_max_56 = AE_CVTQ48A32S(quantized_activation_max); + ae_q56s output_activation_min_56 = AE_CVTQ48A32S(quantized_activation_min); + + constexpr int kChannels = 32; + constexpr int kFilterDepth = 32; + for (int ch = 0; ch < kChannels; ch++) { + ae_q56s acc_56 = AE_ZEROQ56(); + const int8_t* input_vals_ptr = input_data - 2; + for (int i = 0; i < kFilterDepth; i += 2) { + // Load signed 2x 8bit values and right shift into 24bit + // alignment: + ae_p24x2s input_vals_24x2; + AE_LP8X2F_IU(input_vals_24x2, input_vals_ptr, 2); + input_vals_24x2 = AE_P24X2S_SRAI(input_vals_24x2, 16); + + // Add input offset (24bit aligned): + input_vals_24x2 = AE_P24S_ADDS_P24X2S(input_vals_24x2, input_offset_24x2); + // Find current filter index, minus 2 for Xtensa load + // alignments: + const int filter_idx = ch * kFilterDepth + i - 2; + const int8_t* filter_vals_offset_ptr = filter_data + filter_idx; + + // Load signed 2x 8bit values and right shift into 24bit + // alignment: + ae_p24x2s filter_vals_24x2; + AE_LP8X2F_IU(filter_vals_24x2, filter_vals_offset_ptr, 2); + filter_vals_24x2 = AE_P24X2S_SRAI(filter_vals_24x2, 16); + + // Multiply and accumulate into 48bit bit space: + AE_MULAAP24S_HH_LL(acc_56, filter_vals_24x2, input_vals_24x2); + } + // Left shift from 48bit alignment to 32bit: + acc_56 = AE_Q56S_SLAI(acc_56, 16); + if (bias_data) { + // Load and add bias at 32bit alignment: + ae_q56s bias_56 = AE_CVTQ48A32S(bias_data[ch]); + acc_56 = AE_ADDQ56(acc_56, bias_56); + } + + // Shift from 32bit alignment to 24bit alignment and place back on + // the PR register: + acc_56 = AE_Q56S_SLAI(acc_56, 8); + ae_p24x2s acc_24x2 = AE_TRUNCP24Q48(acc_56); + + // Apply quantized multiplier and accumulate result at 48bit alignment. + // Convert the (unsigned) 32-bit multiplier down to a 24-bit multiplier. + acc_56 = MultiplyByQuantizedMultiplier(acc_24x2, output_multiplier[ch] >> 8, + output_shift[ch]); + + // Add output offset, cap activation, and assign to the output: + acc_56 = AE_ADDQ56(acc_56, output_offset_56); + acc_56 = AE_MINQ56S(acc_56, output_activation_max_56); + acc_56 = AE_MAXQ56S(acc_56, output_activation_min_56); + + output_data[ch] = static_cast(AE_TRUNCA32Q48(acc_56)); + } +} + +TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, int width, int height, + int filter_width, int filter_height, int out_width, + int out_height, const TfLiteType data_type, + OpData* data) { + bool has_bias = node->inputs->size == 3; + // Check number of inputs/outputs + TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + data->padding = ComputePaddingHeightWidth( + params->stride_height, params->stride_width, + params->dilation_height_factor, params->dilation_width_factor, height, + width, filter_height, filter_width, padding, &out_height, &out_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + int output_channels = filter->dims->data[kConvQuantizedDimension]; + + return tflite::PopulateConvolutionQuantizationParams( + context, input, filter, bias, output, params->activation, + &data->output_multiplier, &data->output_shift, + &data->output_activation_min, &data->output_activation_max, + data->per_channel_output_multiplier, + reinterpret_cast(data->per_channel_output_shift), + output_channels); + } + return kTfLiteOk; +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpData)); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(node->builtin_data != nullptr); + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + + auto* op_data = reinterpret_cast(node->user_data); + + int input_width = input->dims->data[2]; + int input_height = input->dims->data[1]; + int filter_width = filter->dims->data[2]; + int filter_height = filter->dims->data[1]; + int output_width = output->dims->data[2]; + int output_height = output->dims->data[1]; + + // Per channel quantization is only needed for int8_t inference. For other + // quantized types, only a single scale and zero point is needed. + const int num_channels = filter->dims->data[kConvQuantizedDimension]; + // Dynamically allocate per-channel quantization parameters. + op_data->per_channel_output_multiplier = + reinterpret_cast(context->AllocatePersistentBuffer( + context, num_channels * sizeof(int32_t))); + op_data->per_channel_output_shift = + reinterpret_cast(context->AllocatePersistentBuffer( + context, num_channels * sizeof(int32_t))); + op_data->input_zero_point = input->params.zero_point; + op_data->output_zero_point = output->params.zero_point; + // All per-channel quantized tensors need valid zero point and scale arrays. + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, filter->quantization.type, + kTfLiteAffineQuantization); + + const auto* affine_quantization = + reinterpret_cast( + filter->quantization.params); + TF_LITE_ENSURE(context, affine_quantization); + TF_LITE_ENSURE(context, affine_quantization->scale); + TF_LITE_ENSURE(context, affine_quantization->zero_point); + + TF_LITE_ENSURE(context, + affine_quantization->scale->size == 1 || + affine_quantization->scale->size == + filter->dims->data[kConvQuantizedDimension]); + TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, + affine_quantization->zero_point->size); + } + + return CalculateOpData(context, node, params, input_width, input_height, + filter_width, filter_height, output_width, + output_height, input->type, op_data); +} + +void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, + const TfLiteEvalTensor* input, + const TfLiteEvalTensor* filter, + const TfLiteEvalTensor* bias, + TfLiteEvalTensor* output, + TfLiteEvalTensor* im2col) { + // TODO(b/154032858): Investigate removing extra copies. + ConvParams op_params; + op_params.input_offset = -data->input_zero_point; + op_params.output_offset = data->output_zero_point; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + + ConvPerChannel(op_params, data->per_channel_output_multiplier, + data->per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(node->builtin_data != nullptr); + auto* params = reinterpret_cast(node->builtin_data); + auto* op_data = reinterpret_cast(node->user_data); + + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + const TfLiteEvalTensor* filter = + tflite::micro::GetEvalInput(context, node, kFilterTensor); + const TfLiteEvalTensor* bias = + (NumInputs(node) == 3) + ? tflite::micro::GetEvalInput(context, node, kBiasTensor) + : nullptr; + + int* input_dims = input->dims->data; + int* filter_dims = filter->dims->data; + if (input_dims[0] == 1 && input_dims[1] == 1 && input_dims[2] == 1 && + input_dims[3] == 32 && filter_dims[0] == 32 && filter_dims[1] == 1 && + filter_dims[2] == 1 && filter_dims[3] == 32) { + Conv1x32Input32x32Filter( + -op_data->input_zero_point, op_data->output_zero_point, + op_data->output_activation_min, op_data->output_activation_max, + op_data->per_channel_output_multiplier, + op_data->per_channel_output_shift, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + return kTfLiteOk; + } + + switch (input->type) { + case kTfLiteInt8: + EvalQuantizedPerChannel(context, node, params, op_data, input, filter, + bias, output, nullptr); + break; + default: + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + return kTfLiteOk; +} +} // namespace + +TfLiteRegistration Register_CONV_2D() { + return {/*init=*/Init, + /*free=*/nullptr, + /*prepare=*/Prepare, + /*invoke=*/Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc new file mode 100644 index 00000000000..12410a94456 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc @@ -0,0 +1,503 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" + +namespace tflite { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +// Depthwise conv is quantized along dimension 3: +// https://www.tensorflow.org/lite/performance/quantization_spec +constexpr int kDepthwiseConvQuantizedDimension = 3; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + + // Cached tensor zero point values for quantized operations. + int32_t input_zero_point; + int32_t output_zero_point; + + // Per channel output multiplier and shift. + // TODO(b/141139247): Allocate these dynamically when possible. + int32_t* per_channel_output_multiplier; + int32_t* per_channel_output_shift; + + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +inline void DepthwiseConvPerChannel( + const DepthwiseParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + int8_t* output_data) { + // TODO(b/154032858): Investigate removing extra copies. + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int dilation_width_factor = params.dilation_width_factor; + const int dilation_height_factor = params.dilation_height_factor; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + const int depth_multiplier = params.depth_multiplier; + const int32_t input_offset = params.input_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + + const int batches = input_shape.Dims(0); + + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int input_depth = input_shape.Dims(3); + + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int filter_depth = filter_shape.Dims(3); + + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int output_depth = output_shape.Dims(3); + + ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset); + ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset); + ae_q56s output_activation_min_56 = AE_CVTQ48A32S(output_activation_min); + ae_q56s output_activation_max_56 = AE_CVTQ48A32S(output_activation_max); + + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + const int in_y_origin = (out_y * stride_height) - pad_height; + for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = (out_x * stride_width) - pad_width; + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + for (int m = 0; m < depth_multiplier; ++m) { + const int output_channel = m + in_channel * depth_multiplier; + ae_q56s acc_56 = AE_ZEROQ56(); + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int in_y = in_y_origin + dilation_height_factor * filter_y; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + dilation_width_factor * filter_x; + // Zero padding by omitting the areas outside the image. + const bool is_point_inside_image = + (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height); + + if (is_point_inside_image) { + // Find current input index, minus 2 for Xtensa load + // alignments: + // TODO(b/147322595): Consider doing these offset calculations + // with intrinsics: + int input_idx = + ((batch * input_height + in_y) * input_width + in_x) * + input_depth + + (in_channel); + int32_t input_val = input_data[input_idx]; + + // Find current filter index, minus 2 for Xtensa load + // alignments: + int filter_idx = + ((filter_y)*filter_width + filter_x) * filter_depth + + (output_channel); + int32_t filter_val = filter_data[filter_idx]; + + // Load 8bit value as int32_t into a 24x24 register and right + // shift into 24bit space. Note: value is duplicated in the HH + // and LL register - but all calculations are done on the HH + // side. + ae_p24x2s input_val_24x2 = AE_MOVPA24(input_val); + + // Add input offset (24bit aligned): + input_val_24x2 = + AE_P24S_ADDS_P24X2S(input_val_24x2, input_offset_24x2); + + // Load filter 8bit value into 24bit alignment: + ae_p24x2s filter_val_24x2 = AE_MOVPA24(filter_val); + + // Multiply and accumulate the HH side of each 24x24 PR + // register: + AE_MULAS56P24S_HH(acc_56, filter_val_24x2, input_val_24x2); + } + } + } + + // Left shift from 48bit alignment to 32bit: + acc_56 = AE_Q56S_SLAI(acc_56, 16); + + if (bias_data) { + // Load and add bias at 32bit alignment: + ae_q56s bias_56 = AE_CVTQ48A32S(bias_data[output_channel]); + acc_56 = AE_ADDQ56(acc_56, bias_56); + } + + // Shift from 32bit alignment to 24bit alignment and place back on + // the PR register: + acc_56 = AE_Q56S_SLAI(acc_56, 8); + ae_p24x2s acc_24x2 = AE_TRUNCP24Q48(acc_56); + + // Apply quantized multiplier and accumulate result at 48bit + // alignment: + acc_56 = MultiplyByQuantizedMultiplier( + acc_24x2, output_multiplier[output_channel], + output_shift[output_channel]); + + // Add output offset, cap activation, and assign to the output: + acc_56 = AE_ADDQ56(acc_56, output_offset_56); + acc_56 = AE_MINQ56S(acc_56, output_activation_max_56); + acc_56 = AE_MAXQ56S(acc_56, output_activation_min_56); + + int output_idx = + ((batch * output_height + out_y) * output_width + out_x) * + output_depth + + output_channel; + output_data[output_idx] = + static_cast(AE_TRUNCA32Q48(acc_56)); + } + } + } + } + } +} + +constexpr int kConvolutionalKernelWidth = 4; +constexpr int kConvolutionalKernelDepth = 32; +inline void DepthwiseConv4x32MatchingInputAndFilter( + const int input_offset, const int output_offset, + const int quantized_activation_min, const int quantized_activation_max, + const int32_t* output_multiplier, const int32_t* output_shift, + const RuntimeShape& input_shape, const int8_t* input_data, + const RuntimeShape& filter_shape, const int8_t* filter_data, + const RuntimeShape& bias_shape, const int32_t* bias_data, + const RuntimeShape& output_shape, int8_t* output_data) { + // Convert the (unsigned) 32-bit multiplier down to a 24-bit multiplier. + const int32_t mult = output_multiplier[0] >> 8; + const int32_t shift = output_shift[0]; + ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset); + ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset); + ae_q56s output_activation_min_56 = AE_CVTQ48A32S(quantized_activation_min); + ae_q56s output_activation_max_56 = AE_CVTQ48A32S(quantized_activation_max); + + const int num_blocks = + kConvolutionalKernelDepth / 2; // Based on the 24x2 register size. + const int stride_elements = + (kConvolutionalKernelDepth / kConvolutionalKernelWidth); + + const int8_t* input_0_ptr = (const int8_t*)(input_data - 2); + const int8_t* weight_0_ptr = (const int8_t*)(filter_data - 2); + // Apply the kernels in blocks of 4 for all the channels. + const int8_t* input_1_ptr = input_0_ptr + stride_elements * 4; + const int8_t* input_2_ptr = input_1_ptr + stride_elements * 4; + const int8_t* input_3_ptr = input_2_ptr + stride_elements * 4; + + const int8_t* weight_1_ptr = weight_0_ptr + stride_elements * 4; + const int8_t* weight_2_ptr = weight_1_ptr + stride_elements * 4; + const int8_t* weight_3_ptr = weight_2_ptr + stride_elements * 4; + + for (int i = 0; i < num_blocks; ++i) { + ae_q56s block_0_acc = AE_ZEROQ56(); + ae_q56s block_1_acc = AE_ZEROQ56(); + + // Load all the weights. + ae_p24x2s weight_0, weight_1, weight_2, weight_3; + AE_LP8X2F_IU(weight_0, weight_0_ptr, 2); + AE_LP8X2F_IU(weight_1, weight_1_ptr, 2); + AE_LP8X2F_IU(weight_2, weight_2_ptr, 2); + AE_LP8X2F_IU(weight_3, weight_3_ptr, 2); + + // Load all the inputs. + ae_p24x2s input_0, input_1, input_2, input_3; + AE_LP8X2F_IU(input_0, input_0_ptr, 2); + AE_LP8X2F_IU(input_1, input_1_ptr, 2); + AE_LP8X2F_IU(input_2, input_2_ptr, 2); + AE_LP8X2F_IU(input_3, input_3_ptr, 2); + + // Shift inputs to 8 bit alignment and add offsets. + input_0 = AE_P24X2S_SRAI(input_0, 16); + input_1 = AE_P24X2S_SRAI(input_1, 16); + input_2 = AE_P24X2S_SRAI(input_2, 16); + input_3 = AE_P24X2S_SRAI(input_3, 16); + + input_0 = AE_P24S_ADDS_P24X2S(input_0, input_offset_24x2); + input_1 = AE_P24S_ADDS_P24X2S(input_1, input_offset_24x2); + input_2 = AE_P24S_ADDS_P24X2S(input_2, input_offset_24x2); + input_3 = AE_P24S_ADDS_P24X2S(input_3, input_offset_24x2); + + // Do the multiplies across all channels. Resulting accumulators are 32bit + // aligned (24 bit aligned weights * 8 bit aligned inputs). + AE_MULAS56P24S_HH(block_0_acc, input_0, weight_0); + AE_MULAS56P24S_HH(block_0_acc, input_1, weight_1); + AE_MULAS56P24S_HH(block_0_acc, input_2, weight_2); + AE_MULAS56P24S_HH(block_0_acc, input_3, weight_3); + + AE_MULAS56P24S_LL(block_1_acc, input_0, weight_0); + AE_MULAS56P24S_LL(block_1_acc, input_1, weight_1); + AE_MULAS56P24S_LL(block_1_acc, input_2, weight_2); + AE_MULAS56P24S_LL(block_1_acc, input_3, weight_3); + + int ch_0 = i * 2; + int ch_1 = i * 2 + 1; + + // Load and add bias at 32bit alignment: + ae_q56s bias_56_0 = AE_CVTQ48A32S(bias_data[ch_0]); + ae_q56s bias_56_1 = AE_CVTQ48A32S(bias_data[ch_1]); + block_0_acc = AE_ADDQ56(block_0_acc, bias_56_0); + block_1_acc = AE_ADDQ56(block_1_acc, bias_56_1); + + // Shift from 32bit alignment to 24bit alignment and place back on + // the PR register: + block_0_acc = AE_Q56S_SLAI(block_0_acc, 8); + block_1_acc = AE_Q56S_SLAI(block_1_acc, 8); + ae_p24x2s acc_24x2_0 = AE_TRUNCP24Q48(block_0_acc); + ae_p24x2s acc_24x2_1 = AE_TRUNCP24Q48(block_1_acc); + + // Apply quantized multiplier and accumulate result at 48bit + // alignment: + block_0_acc = MultiplyByQuantizedMultiplier(acc_24x2_0, mult, shift); + // Apply quantized multiplier and accumulate result at 48bit + // alignment: + block_1_acc = MultiplyByQuantizedMultiplier(acc_24x2_1, mult, shift); + + // Add output offset, cap activation, and assign to the output: + block_0_acc = AE_ADDQ56(block_0_acc, output_offset_56); + block_1_acc = AE_ADDQ56(block_1_acc, output_offset_56); + block_0_acc = AE_MINQ56S(block_0_acc, output_activation_max_56); + block_1_acc = AE_MINQ56S(block_1_acc, output_activation_max_56); + block_0_acc = AE_MAXQ56S(block_0_acc, output_activation_min_56); + block_1_acc = AE_MAXQ56S(block_1_acc, output_activation_min_56); + + output_data[ch_0] = static_cast(AE_TRUNCA32Q48(block_0_acc)); + output_data[ch_1] = static_cast(AE_TRUNCA32Q48(block_1_acc)); + } +} + +TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, int width, + int height, int filter_width, int filter_height, + const TfLiteType data_type, OpData* data) { + bool has_bias = node->inputs->size == 3; + // Check number of inputs/outputs + TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + int unused_output_height, unused_output_width; + data->padding = ComputePaddingHeightWidth( + params->stride_height, params->stride_width, 1, 1, height, width, + filter_height, filter_width, params->padding, &unused_output_height, + &unused_output_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension]; + + // TODO(b/148610881): Consider calculating quantized params at int24 + // calculations: + TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( + context, input, filter, bias, output, params->activation, + &data->output_multiplier, &data->output_shift, + &data->output_activation_min, &data->output_activation_max, + data->per_channel_output_multiplier, + reinterpret_cast(data->per_channel_output_shift), num_channels)); + } + return kTfLiteOk; +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpData)); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(node->builtin_data != nullptr); + auto* params = + reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + auto* op_data = reinterpret_cast(node->user_data); + + const TfLiteType data_type = input->type; + int width = SizeOfDimension(input, 2); + int height = SizeOfDimension(input, 1); + int filter_width = SizeOfDimension(filter, 2); + int filter_height = SizeOfDimension(filter, 1); + + // Per channel quantization is only needed for int8_t inference. For other + // quantized types, only a single scale and zero point is needed. + const int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension]; + // Dynamically allocate per-channel quantization parameters. + op_data->per_channel_output_multiplier = + reinterpret_cast(context->AllocatePersistentBuffer( + context, num_channels * sizeof(int32_t))); + op_data->per_channel_output_shift = + reinterpret_cast(context->AllocatePersistentBuffer( + context, num_channels * sizeof(int32_t))); + + op_data->input_zero_point = input->params.zero_point; + op_data->output_zero_point = output->params.zero_point; + + // All per-channel quantized tensors need valid zero point and scale arrays. + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, filter->quantization.type, + kTfLiteAffineQuantization); + + const auto* affine_quantization = + reinterpret_cast( + filter->quantization.params); + TF_LITE_ENSURE(context, affine_quantization); + TF_LITE_ENSURE(context, affine_quantization->scale); + TF_LITE_ENSURE(context, affine_quantization->zero_point); + TF_LITE_ENSURE( + context, affine_quantization->scale->size == 1 || + affine_quantization->scale->size == + filter->dims->data[kDepthwiseConvQuantizedDimension]); + TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, + affine_quantization->zero_point->size); + } + + return CalculateOpData(context, node, params, width, height, filter_width, + filter_height, data_type, op_data); +} + +void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteEvalTensor* input, + const TfLiteEvalTensor* filter, + const TfLiteEvalTensor* bias, + TfLiteEvalTensor* output) { + DepthwiseParams op_params; + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.depth_multiplier = params->depth_multiplier; + op_params.input_offset = -data->input_zero_point; + op_params.weights_offset = 0; + op_params.output_offset = data->output_zero_point; + // TODO(b/130439627): Use calculated value for clamping. + op_params.quantized_activation_min = std::numeric_limits::min(); + op_params.quantized_activation_max = std::numeric_limits::max(); + + DepthwiseConvPerChannel(op_params, data->per_channel_output_multiplier, + data->per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(node->builtin_data != nullptr); + auto* params = + reinterpret_cast(node->builtin_data); + auto* op_data = reinterpret_cast(node->user_data); + + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + const TfLiteEvalTensor* filter = + tflite::micro::GetEvalInput(context, node, kFilterTensor); + const TfLiteEvalTensor* bias = + (NumInputs(node) == 3) + ? tflite::micro::GetEvalInput(context, node, kBiasTensor) + : nullptr; + + // Handle special case for streaming model. + int* input_dims = input->dims->data; + int* filter_dims = filter->dims->data; + if (input_dims[0] == 1 && input_dims[1] == 4 && input_dims[2] == 1 && + input_dims[3] == 32 && filter_dims[0] == 1 && filter_dims[1] == 4 && + filter_dims[2] == 1 && filter_dims[3] == 32) { + DepthwiseConv4x32MatchingInputAndFilter( + -op_data->input_zero_point, op_data->output_zero_point, + std::numeric_limits::min(), std::numeric_limits::max(), + op_data->per_channel_output_multiplier, + op_data->per_channel_output_shift, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + return kTfLiteOk; + } + switch (input->type) { // Already know in/out types are same. + case kTfLiteInt8: + EvalQuantizedPerChannel(context, node, params, op_data, input, filter, + bias, output); + break; + default: + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace + +TfLiteRegistration Register_DEPTHWISE_CONV_2D() { + return {/*init=*/Init, + /*free=*/nullptr, + /*prepare=*/Prepare, + /*invoke=*/Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h b/tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h new file mode 100644 index 00000000000..a1d14df1352 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h @@ -0,0 +1,137 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_ + +#include + +#include +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" + +namespace tflite { + +// INT24 MIN/MAX +#define INT24_MIN -8388608 +#define INT24_MAX 8388607 + +// Multiply 24bit value by a quantized multiplier (w/ shift) and returns a 48bit +// aligned value in the QR register. +inline ae_q56s MultiplyByQuantizedMultiplier(ae_p24x2s x_24x2, + int32_t quantized_multiplier, + int shift) { + // A value with 1 sign bit, N integer bits and M fractional bits is + // represented as QN+1.M since the sign bit is included in the integer bits. + // + // The Q notation in this method explains the values represented in each + // variable, along with an implicit division since the quantized_multiplier + // represents a value between 0.5 and 1.0 (Q1.X-1 where X is the bit precision + // of the type). + // + // Load the quantized multiplier into the PR register. + // NOTE: This method assumes that this param has been calculated for 24bit + // space - not 32bits. + // Q32.0 / 2^23 -> Q24.0 / 2^23 representing a Q1.23 multiplier. + ae_p24x2s quantized_multiplier_24x2 = AE_MOVPA24(quantized_multiplier); + // Shift right by 23 - 16 bits minus the specified shift. This is because we + // keep 16 fractional bits until the end to perform rounding. Subtract shift + // since shift is a left shift, and the 23-16 is a right shift. + int shift_amount = 7 - shift; + + // Find the product of x and the quantized_multiplier. + // Q24.0 / 2^23 * Q24.0 = Q48.0 / 2^23 + // Q48.0 / 2^23 >> 7 = Q48.0 / 2^16 + ae_q56s result_56 = AE_MULP24S_HH(x_24x2, quantized_multiplier_24x2); + + // Shift right if shift amount is positive, left if shift amount is negative. + if (shift_amount >= 0) { + result_56 = AE_Q56S_SRA(result_56, shift_amount); + } else { + result_56 = AE_Q56S_SLA(result_56, -shift_amount); + } + + // Round off the bottom 16 bits. + // Q48.0 / 2^16 -> Q32.0 aligned to 48 bits. + result_56 = AE_ROUNDSQ32SYM(result_56); + return result_56; +} + +// Multiply 32bit value by a quantized multiplier (w/ shift) and returns a 48bit +// aligned value in the QR register. +inline ae_q56s MultiplyByQuantizedMultiplierResult48Bit( + int32_t x, int32_t quantized_multiplier, int shift) { + // Convert x into a 2x24bit PR register file. If x is outside the numerical + // limits of a 24bit integer, the "fractional" or lower 8bits are discarded. + // If x is within the range of a 24 bit integer, the "signed" or upper 8bits + // are discarded. + ae_p24x2s x_24x2; + if (x > INT24_MIN && x < INT24_MAX) { + x_24x2 = AE_MOVPA24(x); + } else { + x_24x2 = static_cast(*reinterpret_cast(&x)); + shift += 8; + } + + return MultiplyByQuantizedMultiplier(x_24x2, quantized_multiplier, shift); +} + +// Calculate quantization params for 24bit runtimes. +inline void QuantizeMultiplierForInt24(float multiplier, + int32_t* quantized_multiplier, + int* shift) { + if (multiplier == 0.0f) { + *quantized_multiplier = 0; + *shift = 0; + return; + } + + // Special cased to 24bit: + const float q = std::frexp(multiplier, shift); + auto q_fixed = static_cast(std::round(q * (1 << 23))); + + TFLITE_CHECK(q_fixed <= (1 << 23)); + if (q_fixed == (1 << 23)) { + q_fixed /= 2; + ++*shift; + } + TFLITE_CHECK_LE(q_fixed, INT24_MAX); + + // Ensure shift does not exceed 24-bit range. + TFLITE_CHECK_LE(*shift, 23); + if (*shift < -23) { + *shift = 0; + q_fixed = 0; + } + *quantized_multiplier = static_cast(q_fixed); +} + +// Convert a floating point number to a Q representation for 24 bit integers. +inline int CreateQConstantForInt24(int integer_bits, float f) { + const float min_bounds = static_cast(INT24_MIN); + const float max_bounds = static_cast(INT24_MAX); + + int fractional_bits = 23 - integer_bits; + float raw = std::round(f * static_cast(1 << fractional_bits)); + raw = std::max(raw, min_bounds); + raw = std::min(raw, max_bounds); + return static_cast(raw); +} + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_ diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc new file mode 100644 index 00000000000..30a5b6a602a --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc @@ -0,0 +1,252 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" + +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" + +namespace tflite { +namespace { + +struct OpData { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + + // Cached tensor zero point values for quantized operations. + int32_t input_zero_point; + int32_t filter_zero_point; + int32_t output_zero_point; + + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; + // The index of the temporary tensor where the quantized inputs are cached. + int input_quantized_index; +}; + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +void FullyConnected(const FullyConnectedParams& params, + const RuntimeShape& input_shape, const int8_t* input_data, + const RuntimeShape& filter_shape, const int8_t* filter_data, + const RuntimeShape& bias_shape, const int32_t* bias_data, + const RuntimeShape& output_shape, int8_t* output_data) { + // TODO(b/154032858): Investigate removing extra copies. + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + + const int filter_dim_count = filter_shape.DimensionsCount(); + const int batches = output_shape.Dims(0); + const int output_depth = output_shape.Dims(1); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + const int accum_depth_iters = accum_depth / 2; + + ae_p24x2s offsets_input_24x2 = AE_MOVPA24(input_offset); + ae_p24x2s offsets_filter_24x2 = AE_MOVPA24(filter_offset); + ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset); + ae_q56s output_activation_max_56 = AE_CVTQ48A32S(output_activation_max); + ae_q56s output_activation_min_56 = AE_CVTQ48A32S(output_activation_min); + + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + // Load intrinsics advance pointer before loading so backoff data pointers + // by two before loading: + const int8_t* input_ptr = (input_data + b * accum_depth) - 2; + const int8_t* filter_ptr = (filter_data + out_c * accum_depth) - 2; + + // Main accumulator register entry for loop: + ae_q56s sum_56 = AE_ZEROQ56(); + + for (int d = 0; d < accum_depth_iters; d++) { + // Load the signed 8bit values into the PR register: + ae_p24x2s input_24x2; + ae_p24x2s filter_24x2; + AE_LP8X2F_IU(input_24x2, input_ptr, 2); + AE_LP8X2F_IU(filter_24x2, filter_ptr, 2); + + // Right shift the signed 8bit values to expand to signed 24bit values: + input_24x2 = AE_P24X2S_SRAI(input_24x2, 16); + filter_24x2 = AE_P24X2S_SRAI(filter_24x2, 16); + + // Add offsets to data values (24 bit aligned): + input_24x2 = AE_P24S_ADDS_P24X2S(offsets_input_24x2, input_24x2); + filter_24x2 = AE_P24S_ADDS_P24X2S(offsets_filter_24x2, filter_24x2); + + // 24x2 signed integer dual MAC w/ addition into 56bit accumulator (48 + // bit aligned): + AE_MULAAP24S_HH_LL(sum_56, input_24x2, filter_24x2); + } + + // Left shift to get back into 32bit space (right padded to 48bit): + sum_56 = AE_Q56S_SLAI(sum_56, 16); + + // Add bias data if needed: + if (bias_data) { + ae_q56s bias_56 = AE_CVTQ48A32S(bias_data[out_c]); + sum_56 = AE_ADDQ56(sum_56, bias_56); + } + + // Shift left into 24bit space and place back on PR register: + sum_56 = AE_Q56S_SLAI(sum_56, 8); + ae_p24x2s sum_24x2 = AE_TRUNCP24Q48(sum_56); + + // MultiplyByQuantizedMultiplier returns a 48bit aligned value + sum_56 = MultiplyByQuantizedMultiplier(sum_24x2, output_multiplier, + output_shift); + + // Add output_offset and cap min/max values: + sum_56 = AE_ADDQ56(sum_56, output_offset_56); + sum_56 = AE_MINQ56S(sum_56, output_activation_max_56); + sum_56 = AE_MAXQ56S(sum_56, output_activation_min_56); + + output_data[out_c + output_depth * b] = + static_cast(AE_TRUNCA32Q48(sum_56)); + } + } +} + +TfLiteStatus CalculateOpData(TfLiteContext* context, + TfLiteFusedActivation activation, + TfLiteType data_type, const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output, + OpData* data) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplierForInt24(real_multiplier, &data->output_multiplier, + &data->output_shift); + return CalculateActivationRangeQuantized(context, activation, output, + &data->output_activation_min, + &data->output_activation_max); +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpData)); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(node->builtin_data != nullptr); + + OpData* data = static_cast(node->user_data); + const auto* params = + reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (input->type != kTfLiteInt8) { + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + + data->input_zero_point = input->params.zero_point; + data->filter_zero_point = filter->params.zero_point; + data->output_zero_point = output->params.zero_point; + + return CalculateOpData(context, params->activation, input->type, input, + filter, bias, output, data); +} + +TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, + const OpData& data, + const TfLiteEvalTensor* input, + const TfLiteEvalTensor* filter, + const TfLiteEvalTensor* bias, + TfLiteEvalTensor* output) { + // TODO(b/154032858): Investigate removing extra copies, and also passing by + // value. TODO(b/155656675): Consider passing OpData by value once it is also + // passed to the FullyConnected function. Until it is copied to a local + // op_param variable, we do not get any latency improvements from passing by + // value. + FullyConnectedParams op_params; + op_params.input_offset = -data.input_zero_point; + op_params.weights_offset = -data.filter_zero_point; + op_params.output_offset = data.output_zero_point; + op_params.output_multiplier = data.output_multiplier; + op_params.output_shift = data.output_shift; + op_params.quantized_activation_min = data.output_activation_min; + op_params.quantized_activation_max = data.output_activation_max; + + FullyConnected(op_params, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + const OpData& data = *(static_cast(node->user_data)); + + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + const TfLiteEvalTensor* filter = + tflite::micro::GetEvalInput(context, node, kWeightsTensor); + const TfLiteEvalTensor* bias = + (NumInputs(node) == 3) + ? tflite::micro::GetEvalInput(context, node, kBiasTensor) + : nullptr; + + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + return EvalQuantizedInt8(context, node, data, input, filter, bias, output); +} + +} // namespace + +TfLiteRegistration Register_FULLY_CONNECTED() { + return {/*init=*/Init, + /*free=*/nullptr, + /*prepare=*/Prepare, + /*invoke=*/Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/quantize.cc b/tensorflow/lite/micro/kernels/xtensa/quantize.cc new file mode 100644 index 00000000000..b867e70d98b --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/quantize.cc @@ -0,0 +1,161 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/quantize.h" + +#include + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" + +namespace tflite { +namespace { + +struct OpData { + int32_t zero_point = 0; + int scale_multiplier = 0; +}; + +void AffineQuantize(int scale_multiplier, const int32_t zero_point, + const RuntimeShape& input_shape, const int16_t* input_data, + const RuntimeShape& output_shape, int8_t* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + ae_q56s min_val_56 = AE_CVTQ48A32S(INT16_MIN); + ae_q56s max_val_56 = AE_CVTQ48A32S(INT16_MAX); + ae_q56s zero_point_56 = AE_CVTQ48A32S(zero_point); + + const ae_p16x2s* input_data_ptr = (const ae_p16x2s*)(input_data - 2); + + ae_p24x2s scale_multiplier_24x2 = AE_MOVPA24(scale_multiplier); + + int iters = flat_size / 2; + for (int i = 0; i < iters; i++) { + // Load two 16bit pairs into the 2x24bit register PR: + // Values need to be right shifted 8 bits to align from upper 16bits to a + // 24bit value: + ae_p24x2s inputs_24x2; + AE_LP16X2F_IU(inputs_24x2, input_data_ptr, 4); + inputs_24x2 = AE_P24X2S_SRAI(inputs_24x2, 8); + + // Q0.23 * Q16.0 == Q16.23 + { + ae_q56s sum_56 = AE_MULP24S_HH(scale_multiplier_24x2, inputs_24x2); + + // Q16.23 -> Q16.0 + // Shift right only 7 bits (23 - 16). This truncated shift aligns the + // 16bit value at the truncation line for 32bit in the QR register. The + // lower 16 bits will be used for rounding in AE_ROUNDSQ32SYM. + sum_56 = AE_Q56S_SRAI(sum_56, 7); + + // Round and truncate 32 bits + sum_56 = AE_ROUNDSQ32SYM(sum_56); + + // Add offset (zero_point_56 is already aligned at 32bits. + sum_56 = AE_ADDQ56(sum_56, zero_point_56); + + // Saturate: + sum_56 = AE_MINQ56S(sum_56, max_val_56); + sum_56 = AE_MAXQ56S(sum_56, min_val_56); + + output_data[i * 2] = static_cast(AE_TRUNCA32Q48(sum_56)); + } + { + ae_q56s sum_56 = AE_MULP24S_LL(scale_multiplier_24x2, inputs_24x2); + + // Q16.23 -> Q16.0 + // Shift right only 7 bits (23 - 16). This truncated shift aligns the + // 16bit value at the truncation line for 32bit in the QR register. The + // lower 16 bits will be used for rounding in AE_ROUNDSQ32SYM. + sum_56 = AE_Q56S_SRAI(sum_56, 23 - 16); + + // Round and truncate 32 bits + sum_56 = AE_ROUNDSQ32SYM(sum_56); + + // Add offset (zero_point_56 is already aligned at 32bits. + sum_56 = AE_ADDQ56(sum_56, zero_point_56); + + // Saturate: + sum_56 = AE_MINQ56S(sum_56, max_val_56); + sum_56 = AE_MAXQ56S(sum_56, min_val_56); + + output_data[i * 2 + 1] = static_cast(AE_TRUNCA32Q48(sum_56)); + } + } +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpData)); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + auto* op_data = static_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); + + // TODO(b/155682734): Fix dangerous input/output scale ratio assumptions. + op_data->scale_multiplier = + CreateQConstantForInt24(0, input->params.scale / output->params.scale); + + op_data->zero_point = output->params.zero_point; + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + auto* op_data = static_cast(node->user_data); + + const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); + TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); + + tflite::QuantizationParams op_params; + op_params.zero_point = op_data->zero_point; + + if (input->type != kTfLiteInt16 && output->type != kTfLiteInt8) { + TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + + AffineQuantize(op_data->scale_multiplier, op_data->zero_point, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + return kTfLiteOk; +} + +} // namespace + +TfLiteRegistration Register_QUANTIZE() { + return {/*init=*/Init, + /*free=*/nullptr, + /*prepare=*/Prepare, + /*invoke=*/Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/softmax.cc b/tensorflow/lite/micro/kernels/xtensa/softmax.cc new file mode 100644 index 00000000000..75eb2838034 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/softmax.cc @@ -0,0 +1,208 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/softmax.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" + +namespace tflite { +namespace { + +struct OpData { + uint16_t* exp_lut; +}; + +// Number of unique int8_t and int16_t values. Used in exponent lookup table +// computation. +constexpr int kInt8Range = + std::numeric_limits::max() - std::numeric_limits::min() + 1; +constexpr int kInt16Range = std::numeric_limits::max() - + std::numeric_limits::min() + 1; +// Each 16-bit precalculated exponent is expressed as a Q0.16 fixedpoint +// value. We special-case e^0 since 1.0 requires 1 integer bit to +// express. +constexpr int kExpFractionalBits = 16; +// e^0 expressed as Q1.15 exceeds the int16_t range, so it must be handled +// specially. +constexpr int kMaxExponentValue = (1 << kExpFractionalBits); + +// Quantized softmax with int8_t input and int16_t output. +// Passing OpData by value does not have much savings in this op, but following +// that as a best practice, at least for the xtensa kernels. See b/155656675 for +// more details. +TfLiteStatus Softmax(OpData op_data, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& output_shape, + int16_t* output_data) { + // The last dimension is depth. Outer size is the total input size + // divided by depth. + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + + for (int i = 0; i < outer_size; ++i) { + int8_t max_in_row = std::numeric_limits::min(); + for (int c = 0; c < depth; ++c) { + max_in_row = std::max(max_in_row, input_data[i * depth + c]); + } + + uint32_t sum_of_exps = 0; + for (int c = 0; c < depth; ++c) { + TFLITE_DCHECK(max_in_row >= input_data[i * depth + c]); + uint8_t input_diff = max_in_row - input_data[i * depth + c]; + + sum_of_exps += + input_diff == 0 ? kMaxExponentValue : op_data.exp_lut[input_diff]; + } + + // Ensure we cannot overflow the full_range_output value. We need to + // guarantee that kInt16Range * max(input_data) / sum_of_exps < kInt16Range. + TFLITE_DCHECK(sum_of_exps >= kMaxExponentValue); + + for (int c = 0; c < depth; ++c) { + uint8_t input_diff = max_in_row - input_data[i * depth + c]; + // Special case for diff == 0 + uint32_t unscaled_output = + input_diff == 0 ? kMaxExponentValue : op_data.exp_lut[input_diff]; + int64_t scaled_output = static_cast(unscaled_output) * + static_cast(kInt16Range); + int32_t full_range_output = + scaled_output / sum_of_exps + std::numeric_limits::min(); + // Round up if remainder exceeds half of the divider value. + uint32_t remainder = scaled_output % sum_of_exps; + if (remainder * 2 >= sum_of_exps) { + full_range_output++; + } + output_data[i * depth + c] = static_cast(std::max( + std::min(full_range_output, + static_cast(std::numeric_limits::max())), + static_cast(std::numeric_limits::min()))); + } + } + return kTfLiteOk; +} + +TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, + const TfLiteTensor* input, + TfLiteTensor* output, + const TfLiteSoftmaxParams* params, + OpData* op_data) { + if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + } else { + if (output->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, + std::numeric_limits::min()); + // NOTE: Current int16_t softmax output does not require symmetric + // scaling + // - so no need to verify scale here. + } else { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, + std::numeric_limits::min()); + TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); + } + } + + // Precompute e^(-x * input_scale * beta) for every possible int8_t input. + // This computation is used for every iteration of Softmax. We must compute + // using pre-scaled inputs to avoid introducing additional error, while + // restricting our input range to the int8_t range. This is valid since beta + // and input scale are constant for a given op in the graph. Skip index 0 + // since that is a special case which requires 1 integer bit instead of 0. + for (int i = 1; i <= kInt8Range; i++) { + float scaled_input = i * input->params.scale; + float exp_value = + std::exp((-scaled_input) * static_cast(params->beta)); + + float exponent_scaled = + std::round(exp_value * static_cast(1 << kExpFractionalBits)); + op_data->exp_lut[i] = static_cast(exponent_scaled); + } + } + return kTfLiteOk; +} + +void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpData)); +} + +TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = static_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, NumDimensions(input) >= 1); + + TFLITE_DCHECK(node->user_data != nullptr); + OpData* op_data = static_cast(node->user_data); + + // Allocate an array to precompute exponents over all int8_t inputs, applying + // the scale and beta before calculating exp. It is mandatory to apply beta + // and scale here, since each softmax op may have different beta and scale + // values. Beta and scale will remain constant for a given softmax op. + op_data->exp_lut = static_cast(context->AllocatePersistentBuffer( + context, kInt8Range * sizeof(uint16_t))); + TF_LITE_ENSURE(context, op_data->exp_lut != nullptr); + + TF_LITE_ENSURE_STATUS( + CalculateSoftmaxOpData(context, input, output, params, op_data)); + + return kTfLiteOk; +} + +TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* op_data = static_cast(node->user_data); + + const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); + TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); + + if (input->type == kTfLiteInt8 && output->type == kTfLiteInt16) { + return Softmax(*op_data, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } else { + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } +} + +} // namespace + +TfLiteRegistration Register_SOFTMAX() { + return {/*init=*/SoftmaxInit, + /*free=*/nullptr, + /*prepare=*/SoftmaxPrepare, + /*invoke=*/SoftmaxEval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/svdf.cc b/tensorflow/lite/micro/kernels/xtensa/svdf.cc new file mode 100644 index 00000000000..28f8f1e1af0 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/svdf.cc @@ -0,0 +1,420 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/activation_utils.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" + +namespace tflite { +namespace { + +struct OpData { + int32_t effective_scale_1_a; + int32_t effective_scale_2_a; + // b versions of each scale are kept at int since the numbers are just the + // shift value - typically between [-32, 32]. + int effective_scale_1_b; + int effective_scale_2_b; + int scratch_tensor_index; + int scratch_output_tensor_index; + + // Cached tensor zero point values for quantized operations. + int input_zero_point; + int output_zero_point; +}; + +// Input tensors. +constexpr int kInputTensor = 0; +constexpr int kWeightsFeatureTensor = 1; +constexpr int kWeightsTimeTensor = 2; +constexpr int kBiasTensor = 3; +// This is a variable tensor, and will be modified by this op. +constexpr int kInputActivationStateTensor = 4; + +// Output tensor. +constexpr int kOutputTensor = 0; + +/** + * This version of SVDF is specific to TFLite Micro. It contains only a full + * integer receipe with optimizations for the Xtensa HiFiMini platform. + * + * Note: passing OpData by value might seem like an oversight but it helps + * reduce the latency. See b/155656675 for more details. + */ +void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node, + const TfLiteEvalTensor* input_tensor, + const TfLiteEvalTensor* weights_feature_tensor, + const TfLiteEvalTensor* weights_time_tensor, + const TfLiteEvalTensor* bias_tensor, + const TfLiteSVDFParams* params, + TfLiteEvalTensor* activation_state_tensor, + TfLiteEvalTensor* output_tensor, OpData data) { + const int n_rank = params->rank; + const int n_batch = input_tensor->dims->data[0]; + const int n_input = input_tensor->dims->data[1]; + const int n_filter = weights_feature_tensor->dims->data[0]; + const int n_unit = n_filter / n_rank; + const int n_memory = weights_time_tensor->dims->data[1]; + + TFLITE_DCHECK(context != nullptr); + TFLITE_DCHECK(context->GetScratchBuffer != nullptr); + + int32_t* scratch_tensor = static_cast( + context->GetScratchBuffer(context, data.scratch_tensor_index)); + TFLITE_DCHECK(scratch_tensor != nullptr); + int32_t* scratch_output_tensor = static_cast( + context->GetScratchBuffer(context, data.scratch_output_tensor_index)); + TFLITE_DCHECK(scratch_output_tensor != nullptr); + + // Shift states. + int16_t* const state_ptr = + tflite::micro::GetTensorData(activation_state_tensor); + + // Left shift the activation_state. + { + int16_t* new_state_start = state_ptr; + const int16_t* old_state_start = state_ptr + 1; + const int16_t* old_state_end = state_ptr + n_batch * n_filter * n_memory; + while (old_state_start != old_state_end) { + *new_state_start++ = *old_state_start++; + } + } + + // Note: no need to clear the latest activation, matmul is not accumulative. + + // Feature matmul. + { + const int8_t* input = tflite::micro::GetTensorData(input_tensor); + const int8_t* weight_feature = + tflite::micro::GetTensorData(weights_feature_tensor); + int16_t* result_in_batch = state_ptr + (n_memory - 1); + + ae_q56s output_int16_max_56 = AE_CVTQ48A32S(INT16_MAX); + ae_q56s output_int16_min_56 = AE_CVTQ48A32S(INT16_MIN); + ae_p24x2s input_zp_24x2 = AE_MOVPA24(data.input_zero_point); + + for (int b = 0; b < n_batch; b++) { + const int8_t* weight_feature_ptr = weight_feature - 2; + + for (int r = 0; r < n_filter; r++) { + ae_q56s dot_prod_56 = AE_ZEROQ56(); + + const int8_t* input_batch_ptr = input + b * n_input; + const int8_t* offset_input_batch_ptr = input_batch_ptr - 2; + + int num_iters = n_input / 2; + for (int c = 0; c < num_iters; c++) { + // Load 2 sets of values: + ae_p24x2s weight_feature_ptr_24x2; + ae_p24x2s input_batch_ptr_24x2; + AE_LP8X2F_IU(weight_feature_ptr_24x2, weight_feature_ptr, 2); + AE_LP8X2F_IU(input_batch_ptr_24x2, offset_input_batch_ptr, 2); + + // Right shift the signed 8bit values to expand to signed 24bit + // values: + weight_feature_ptr_24x2 = AE_P24X2S_SRAI(weight_feature_ptr_24x2, 16); + input_batch_ptr_24x2 = AE_P24X2S_SRAI(input_batch_ptr_24x2, 16); + + // First subtract input_zp from input_batch_ptr_24x2: + input_batch_ptr_24x2 = + AE_SUBSP24S(input_batch_ptr_24x2, input_zp_24x2); + + // Multiply accum: + AE_MULAAP24S_HH_LL(dot_prod_56, weight_feature_ptr_24x2, + input_batch_ptr_24x2); + } + + // Left shift 48bit value into 24bit space and place on the PR register: + dot_prod_56 = AE_Q56S_SLAI(dot_prod_56, 24); + ae_p24x2s dot_prod_24x2 = AE_TRUNCP24Q48(dot_prod_56); + + dot_prod_56 = MultiplyByQuantizedMultiplier( + dot_prod_24x2, data.effective_scale_1_a, data.effective_scale_1_b); + + // Cap min/max and convert to int32_t: + dot_prod_56 = AE_MAXQ56S(dot_prod_56, output_int16_min_56); + dot_prod_56 = AE_MINQ56S(dot_prod_56, output_int16_max_56); + // Truncate immediately since the QR register is already 32 bit aligned: + // This assumes state is symmetrically quantized. Otherwise last bit of + // state should be initialized to its zero point and accumulate the + // dot_prod. + // Equivalent as the following: + // result_in_batch = zero point, which happens to be zero. + // result_in_batch += dot_prod_56. + *result_in_batch = AE_TRUNCA32Q48(dot_prod_56); + result_in_batch += n_memory; + } + } + } + + // Time. + { + for (int b = 0; b < n_batch; ++b) { + int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter; + + // Perform batched vector dot product: + const int16_t* vector1_ptr = + tflite::micro::GetTensorData(weights_time_tensor); + const int16_t* vector2_ptr = state_ptr + b * n_memory * n_filter; + + const ae_p16x2s* offset_vector1 = + reinterpret_cast(vector1_ptr - 2); + const ae_p16x2s* offset_vector2 = + reinterpret_cast(vector2_ptr - 2); + + for (int i = 0; i < n_filter; i++) { + *scratch_ptr_batch = 0; + + ae_q56s sum_56 = AE_ZEROQ56(); + int num_iters = n_memory / 2; + for (int j = 0; j < num_iters; j++) { + ae_p24x2s vector1_24x2; + ae_p24x2s vector2_24x2; + AE_LP16X2F_IU(vector1_24x2, offset_vector1, 4); + AE_LP16X2F_IU(vector2_24x2, offset_vector2, 4); + AE_MULAAP24S_HH_LL(sum_56, vector1_24x2, vector2_24x2); + } + // Truncate directly since values are already 32bit aligned: + *scratch_ptr_batch = AE_TRUNCA32Q48(sum_56); + scratch_ptr_batch++; + } + } + } + + // Reduce, add bias, rescale, activation. + { + // Add bias. + if (bias_tensor) { + // Vector batch assign: + const int32_t* bias_data = + tflite::micro::GetTensorData(bias_tensor); + for (int i = 0; i < n_batch; ++i) { + int32_t* output_ptr = scratch_output_tensor + i * n_unit; + const int32_t* bias_ptr = bias_data; + for (int j = 0; j < n_unit; ++j) { + *output_ptr++ = *bias_ptr++; + } + } + } else { + int32_t* output_ptr = scratch_output_tensor; + for (int i = 0; i < n_batch * n_unit; ++i) { + *output_ptr++ = 0; + } + } + + // Reduce. + for (int b = 0; b < n_batch; ++b) { + int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit; + int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter; + + // Reduction sum vector + for (int i = 0; i < n_unit; ++i) { + for (int j = 0; j < n_rank; ++j) { + output_temp_ptr[i] += *scratch_ptr_batch++; + } + } + } + + // Rescale. + ae_q56s output_int8_max_56 = AE_CVTQ48A32S(INT8_MAX); + ae_q56s output_int8_min_56 = AE_CVTQ48A32S(INT8_MIN); + ae_q56s output_zp_56 = AE_CVTQ48A32S(data.output_zero_point); + for (int i = 0; i < n_batch * n_unit; ++i) { + ae_q56s x_56 = MultiplyByQuantizedMultiplierResult48Bit( + scratch_output_tensor[i], data.effective_scale_2_a, + data.effective_scale_2_b); + // Add output adjustment: + x_56 = AE_ADDQ56(x_56, output_zp_56); + // Cap min/max and convert to int32_t (already aligned to 32bit): + x_56 = AE_MAXQ56S(x_56, output_int8_min_56); + x_56 = AE_MINQ56S(x_56, output_int8_max_56); + tflite::micro::GetTensorData(output_tensor)[i] = + static_cast(AE_TRUNCA32Q48(x_56)); + } + } +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpData)); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->builtin_data != nullptr); + const auto* params = static_cast(node->builtin_data); + + // Validate Tensor Inputs (dtype depends on quantization): + // [0] = Input, {2, batch_size, input_size} + // [1] = Weights Feature, {2, num_filters, input_size} + // [2] = Weights Time, {2, num_filters, memory_size} + // [3] = Bias (optional), {1, num_units} + // [4] = Activation State (variable), + // {2, batch_size, memory_size * num_filters} + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* weights_feature = + GetInput(context, node, kWeightsFeatureTensor); + const TfLiteTensor* weights_time = + GetInput(context, node, kWeightsTimeTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + const TfLiteTensor* activation_state = + GetInput(context, node, kInputActivationStateTensor); + + // Define input constants based on input tensor definition above: + const int rank = params->rank; + const int input_size = input->dims->data[1]; + const int batch_size = input->dims->data[0]; + // Ensure the input size is a multiple of two. This is necessary since + // optimized kernels access the memory in chunks of two, and all accesses + // must be aligned to 16 bits. + // TODO(b/153202598): Remove when padding is allowed in TFLite tensors. + TF_LITE_ENSURE_EQ(context, input_size % 2, 0); + + const int num_filters = weights_feature->dims->data[0]; + TF_LITE_ENSURE_EQ(context, num_filters % rank, 0); + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + + if (input->type != kTfLiteInt8) { + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + + // Validate Input Tensor: + TF_LITE_ENSURE(context, input->type == kTfLiteInt8); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2); + + // Validate Tensor Output: + // [0] = float/int8_t, {2, batch_size, num_units} + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2); + TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size); + TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units); + + // Validate Weights Feature Input Tensor: + TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2); + TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size); + + // Validate Weights Time Input Tensor: + TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2); + TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters); + TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size); + + // Validate Optional Bias Input Tensor: + if (bias != nullptr) { + TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units); + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); + } + + // Validate Activation State Input Tensor: + TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2); + TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size); + TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1], + memory_size * num_filters); + + TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); + TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8); + TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16); + TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16); + + // Validate output tensor: + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8); + + const double effective_scale_1 = + static_cast(input->params.scale * weights_feature->params.scale / + activation_state->params.scale); + const double effective_scale_2 = + static_cast(activation_state->params.scale * + weights_time->params.scale / output->params.scale); + + TF_LITE_ENSURE_EQ(context, static_cast(bias->params.scale), + static_cast(activation_state->params.scale * + weights_time->params.scale)); + + TFLITE_DCHECK(node->user_data != nullptr); + OpData* data = static_cast(node->user_data); + + QuantizeMultiplierForInt24(effective_scale_1, &data->effective_scale_1_a, + &data->effective_scale_1_b); + QuantizeMultiplierForInt24(effective_scale_2, &data->effective_scale_2_a, + &data->effective_scale_2_b); + + data->input_zero_point = input->params.zero_point; + data->output_zero_point = output->params.zero_point; + + const TfLiteStatus scratch_status = context->RequestScratchBufferInArena( + context, batch_size * num_filters * sizeof(int32_t), + &(data->scratch_tensor_index)); + TF_LITE_ENSURE_OK(context, scratch_status); + const TfLiteStatus scratch_output_status = + context->RequestScratchBufferInArena( + context, batch_size * num_units * sizeof(int32_t), + &(data->scratch_output_tensor_index)); + TF_LITE_ENSURE_OK(context, scratch_output_status); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = static_cast(node->builtin_data); + + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + const TfLiteEvalTensor* weights_feature = + tflite::micro::GetEvalInput(context, node, kWeightsFeatureTensor); + const TfLiteEvalTensor* weights_time = + tflite::micro::GetEvalInput(context, node, kWeightsTimeTensor); + const TfLiteEvalTensor* bias = + (NumInputs(node) == 5) + ? tflite::micro::GetEvalInput(context, node, kBiasTensor) + : nullptr; + TfLiteEvalTensor* activation_state = tflite::micro::GetMutableEvalInput( + context, node, kInputActivationStateTensor); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + TFLITE_DCHECK(node->user_data != nullptr); + const OpData& data = *(static_cast(node->user_data)); + + EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias, + params, activation_state, output, data); + return kTfLiteOk; +} + +} // namespace + +TfLiteRegistration Register_SVDF() { + return {/*init=*/Init, + /*free=*/nullptr, + /*prepare=*/Prepare, + /*invoke=*/Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/testing/test_xtensa_binary.sh b/tensorflow/lite/micro/testing/test_xtensa_binary.sh new file mode 100755 index 00000000000..fb9ca9cd48d --- /dev/null +++ b/tensorflow/lite/micro/testing/test_xtensa_binary.sh @@ -0,0 +1,39 @@ +#!/bin/bash -e +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Tests an Xtensa binary by parsing the log output. +# +# First argument is the binary location. +# +# Second argument is a regular expression that's required to be in the output +# logs for the test to pass. + +declare -r ROOT_DIR=`pwd` +declare -r TEST_TMPDIR=/tmp/test_xtensa_binary/ +declare -r MICRO_LOG_PATH=${TEST_TMPDIR}/$1 +declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt +mkdir -p ${MICRO_LOG_PATH} + +xt-run $1 2>&1 | tee ${MICRO_LOG_FILENAME} + +if grep -q "$2" ${MICRO_LOG_FILENAME} +then + echo "$1: PASS" + exit 0 +else + echo "$1: FAIL - '$2' not found in logs." + exit 1 +fi diff --git a/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc b/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc new file mode 100644 index 00000000000..47a41d71ad4 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/ext_libs/xtensa.inc @@ -0,0 +1,3 @@ +# Every optimized kernel implementation directory (i.e. +# micro/kernels// must have a corresponding +# micro/tools/make/ext_libs/.inc diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc new file mode 100644 index 00000000000..41d2dd47b3b --- /dev/null +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc @@ -0,0 +1,58 @@ +# Settings for Xtensa toolchain for the hifimini kernels. +# REQUIRED: +# Environment variables: +# - XTENSA_BASE must be set to location of +# the Xtensa developer tools installation directory. +# Command line arguments: +# - XTENSA_TOOLS_VERSION: For example: RI-2019.2-linux +# - XTENSA_CORE: The name of the Xtensa core to use +# For example: hifimini + +TARGET_ARCH := + +ifndef XTENSA_BASE + $(error XTENSA_BASE is undefined) +endif + +ifndef XTENSA_TOOLS_VERSION + $(error XTENSA_TOOLS_VERSION is undefined) +endif + +ifndef XTENSA_CORE + $(error XTENSA_CORE is undefined) +endif + +PLATFORM_FLAGS = \ + -DTF_LITE_MCU_DEBUG_LOG \ + -DTF_LITE_USE_CTIME \ + --xtensa-core=$(XTENSA_CORE) \ + -mcoproc \ + -DXTENSA \ + -DMAX_RFFT_PWR=9 \ + -DMIN_RFFT_PWR=MAX_RFFT_PWR + + +export PATH := $(XTENSA_BASE)/tools/$(XTENSA_TOOLS_VERSION)/XtensaTools/bin:$(PATH) +TARGET_TOOLCHAIN_PREFIX := xt- +CXX_TOOL := clang++ +CC_TOOL := clang + +CXXFLAGS += $(PLATFORM_FLAGS) +CCFLAGS += $(PLATFORM_FLAGS) + +# TODO(b/150240249): Do not remove -fno-rtti once that works for the Xtensa toolchain. +CXXFLAGS := $(filter-out -fno-rtti, $(CXXFLAGS)) + +TEST_SCRIPT := tensorflow/lite/micro/testing/test_xtensa_binary.sh + +# TODO(b/156962140): This manually maintained list of excluded examples is +# quite error prone. +EXCLUDED_EXAMPLE_TESTS := \ + tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc \ + tensorflow/lite/micro/examples/magic_wand/Makefile.inc \ + tensorflow/lite/micro/examples/micro_speech/Makefile.inc \ + tensorflow/lite/micro/examples/network_tester/Makefile.inc \ + tensorflow/lite/micro/examples/person_detection/Makefile.inc \ + tensorflow/lite/micro/examples/person_detection_experimental/Makefile.inc +MICRO_LITE_EXAMPLE_TESTS := $(filter-out $(EXCLUDED_EXAMPLE_TESTS), $(MICRO_LITE_EXAMPLE_TESTS)) +