diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc index c2c2c86fe81..c8da67b5af8 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc @@ -33,13 +33,12 @@ namespace micro { namespace xtensa { namespace hifimini { -// Int8 optimized: -inline void FullyConnected( - const FullyConnectedParams& params, const RuntimeShape& input_shape, - const int8_t* input_data, const RuntimeShape& filter_shape, - const int8_t* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - int8_t* output_data) { +void FullyConnected(const FullyConnectedParams& params, + const RuntimeShape& input_shape, const int8_t* input_data, + const RuntimeShape& filter_shape, const int8_t* filter_data, + const RuntimeShape& bias_shape, const int32* bias_data, + const RuntimeShape& output_shape, int8_t* output_data) { + // TODO(b/154032858): Investigate removing extra copies. const int32 input_offset = params.input_offset; const int32 filter_offset = params.weights_offset; const int32 output_offset = params.output_offset; @@ -142,72 +141,69 @@ constexpr int kWeightsTensor = 1; constexpr int kBiasTensor = 2; constexpr int kOutputTensor = 0; -// This size will work for both the hotword (5) and ambient music (2): -constexpr int kMaxOpDataSize = 7; -static int op_data_counter = 0; -static OpData kStaticOpData[kMaxOpDataSize]; - TfLiteStatus CalculateOpData(TfLiteContext* context, - TfLiteFullyConnectedParams* params, + TfLiteFusedActivation activation, TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output, OpData* data) { - TfLiteStatus status = kTfLiteOk; - if (data_type != kTfLiteFloat32) { - double real_multiplier = 0.0; - TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( - context, input, filter, bias, output, &real_multiplier)); - int exponent; - xtensa::hifimini::QuantizeMultiplier(real_multiplier, - &data->output_multiplier, &exponent); - data->output_shift = -exponent; - TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( - context, params->activation, output, &data->output_activation_min, - &data->output_activation_max)); - } - return status; + TFLITE_DCHECK(data_type != kTfLiteFloat32); + + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + int exponent; + xtensa::hifimini::QuantizeMultiplier(real_multiplier, + &data->output_multiplier, &exponent); + data->output_shift = -exponent; + return CalculateActivationRangeQuantized(context, activation, output, + &data->output_activation_min, + &data->output_activation_max); } } // namespace -void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; } +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + void* data = nullptr; + if (context->AllocatePersistentBuffer(context, sizeof(OpData), &data) == + kTfLiteError) { + return nullptr; + } + return data; +} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - auto* params = + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(node->builtin_data != nullptr); + + OpData* data = static_cast<OpData*>(node->user_data); + const auto* params = reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteType data_type = input->type; - // TODO(b/132070898): Use statically slotted OpData structures until a - // scratch memory API is ready. - OpData* op_data = &kStaticOpData[op_data_counter++]; - node->user_data = op_data; - - TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, - filter, bias, output, op_data)); - - return kTfLiteOk; + return CalculateOpData(context, params->activation, input->type, input, + filter, bias, output, data); } TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, - TfLiteFullyConnectedParams* params, OpData* data, - const TfLiteTensor* input, + const OpData& data, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output) { + // TODO(b/154032858): Investigate removing extra copies. FullyConnectedParams op_params; op_params.input_offset = -input->params.zero_point; op_params.weights_offset = -filter->params.zero_point; op_params.output_offset = output->params.zero_point; - op_params.output_multiplier = data->output_multiplier; + op_params.output_multiplier = data.output_multiplier; // TODO(b/138810107): Figure out whether output shift should be inverted - op_params.output_shift = -data->output_shift; - op_params.quantized_activation_min = data->output_activation_min; - op_params.quantized_activation_max = data->output_activation_max; + op_params.output_shift = -data.output_shift; + op_params.quantized_activation_min = data.output_activation_min; + op_params.quantized_activation_max = data.output_activation_max; xtensa::hifimini::FullyConnected( op_params, GetTensorShape(input), GetTensorData<int8_t>(input), @@ -218,33 +214,23 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = - reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data); - auto* op_data = reinterpret_cast<OpData*>(node->user_data); + TFLITE_DCHECK(node->user_data != nullptr); + const OpData& data = *(static_cast<const OpData*>(node->user_data)); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - switch (filter->type) { // Already know in/out types are same. - case kTfLiteInt8: - return EvalQuantizedInt8(context, node, params, op_data, input, filter, - bias, output); - - default: - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(filter->type), filter->type); - return kTfLiteError; - } - return kTfLiteOk; + TFLITE_DCHECK(filter->type == kTfLiteInt8); + return EvalQuantizedInt8(context, node, data, input, filter, bias, output); } } // namespace fully_connected TfLiteRegistration* Register_FULLY_CONNECTED() { - static TfLiteRegistration r = {/*init=*/nullptr, - /*free=*/fully_connected::Free, + static TfLiteRegistration r = {/*init=*/fully_connected::Init, + /*free=*/nullptr, /*prepare=*/fully_connected::Prepare, /*invoke=*/fully_connected::Eval, /*profiling_string=*/nullptr,