Use persistent buffer in xtensa_hifimini/fully_connected.
PiperOrigin-RevId: 309264335 Change-Id: If15693b8cd590a947497c429d476c12109885cd3
This commit is contained in:
parent
7148ced951
commit
f81dadf228
@ -33,13 +33,12 @@ namespace micro {
|
||||
namespace xtensa {
|
||||
namespace hifimini {
|
||||
|
||||
// Int8 optimized:
|
||||
inline void FullyConnected(
|
||||
const FullyConnectedParams& params, const RuntimeShape& input_shape,
|
||||
const int8_t* input_data, const RuntimeShape& filter_shape,
|
||||
const int8_t* filter_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data, const RuntimeShape& output_shape,
|
||||
int8_t* output_data) {
|
||||
void FullyConnected(const FullyConnectedParams& params,
|
||||
const RuntimeShape& input_shape, const int8_t* input_data,
|
||||
const RuntimeShape& filter_shape, const int8_t* filter_data,
|
||||
const RuntimeShape& bias_shape, const int32* bias_data,
|
||||
const RuntimeShape& output_shape, int8_t* output_data) {
|
||||
// TODO(b/154032858): Investigate removing extra copies.
|
||||
const int32 input_offset = params.input_offset;
|
||||
const int32 filter_offset = params.weights_offset;
|
||||
const int32 output_offset = params.output_offset;
|
||||
@ -142,72 +141,69 @@ constexpr int kWeightsTensor = 1;
|
||||
constexpr int kBiasTensor = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
// This size will work for both the hotword (5) and ambient music (2):
|
||||
constexpr int kMaxOpDataSize = 7;
|
||||
static int op_data_counter = 0;
|
||||
static OpData kStaticOpData[kMaxOpDataSize];
|
||||
|
||||
TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
TfLiteFullyConnectedParams* params,
|
||||
TfLiteFusedActivation activation,
|
||||
TfLiteType data_type, const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output,
|
||||
OpData* data) {
|
||||
TfLiteStatus status = kTfLiteOk;
|
||||
if (data_type != kTfLiteFloat32) {
|
||||
double real_multiplier = 0.0;
|
||||
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
||||
context, input, filter, bias, output, &real_multiplier));
|
||||
int exponent;
|
||||
xtensa::hifimini::QuantizeMultiplier(real_multiplier,
|
||||
&data->output_multiplier, &exponent);
|
||||
data->output_shift = -exponent;
|
||||
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
|
||||
context, params->activation, output, &data->output_activation_min,
|
||||
&data->output_activation_max));
|
||||
}
|
||||
return status;
|
||||
TFLITE_DCHECK(data_type != kTfLiteFloat32);
|
||||
|
||||
double real_multiplier = 0.0;
|
||||
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
||||
context, input, filter, bias, output, &real_multiplier));
|
||||
int exponent;
|
||||
xtensa::hifimini::QuantizeMultiplier(real_multiplier,
|
||||
&data->output_multiplier, &exponent);
|
||||
data->output_shift = -exponent;
|
||||
return CalculateActivationRangeQuantized(context, activation, output,
|
||||
&data->output_activation_min,
|
||||
&data->output_activation_max);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; }
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
void* data = nullptr;
|
||||
if (context->AllocatePersistentBuffer(context, sizeof(OpData), &data) ==
|
||||
kTfLiteError) {
|
||||
return nullptr;
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
const auto* params =
|
||||
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteType data_type = input->type;
|
||||
|
||||
// TODO(b/132070898): Use statically slotted OpData structures until a
|
||||
// scratch memory API is ready.
|
||||
OpData* op_data = &kStaticOpData[op_data_counter++];
|
||||
node->user_data = op_data;
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input,
|
||||
filter, bias, output, op_data));
|
||||
|
||||
return kTfLiteOk;
|
||||
return CalculateOpData(context, params->activation, input->type, input,
|
||||
filter, bias, output, data);
|
||||
}
|
||||
|
||||
TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteFullyConnectedParams* params, OpData* data,
|
||||
const TfLiteTensor* input,
|
||||
const OpData& data, const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output) {
|
||||
// TODO(b/154032858): Investigate removing extra copies.
|
||||
FullyConnectedParams op_params;
|
||||
op_params.input_offset = -input->params.zero_point;
|
||||
op_params.weights_offset = -filter->params.zero_point;
|
||||
op_params.output_offset = output->params.zero_point;
|
||||
op_params.output_multiplier = data->output_multiplier;
|
||||
op_params.output_multiplier = data.output_multiplier;
|
||||
// TODO(b/138810107): Figure out whether output shift should be inverted
|
||||
op_params.output_shift = -data->output_shift;
|
||||
op_params.quantized_activation_min = data->output_activation_min;
|
||||
op_params.quantized_activation_max = data->output_activation_max;
|
||||
op_params.output_shift = -data.output_shift;
|
||||
op_params.quantized_activation_min = data.output_activation_min;
|
||||
op_params.quantized_activation_max = data.output_activation_max;
|
||||
|
||||
xtensa::hifimini::FullyConnected(
|
||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
@ -218,33 +214,23 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
switch (filter->type) { // Already know in/out types are same.
|
||||
case kTfLiteInt8:
|
||||
return EvalQuantizedInt8(context, node, params, op_data, input, filter,
|
||||
bias, output);
|
||||
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(filter->type), filter->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
TFLITE_DCHECK(filter->type == kTfLiteInt8);
|
||||
return EvalQuantizedInt8(context, node, data, input, filter, bias, output);
|
||||
}
|
||||
|
||||
} // namespace fully_connected
|
||||
|
||||
TfLiteRegistration* Register_FULLY_CONNECTED() {
|
||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
||||
/*free=*/fully_connected::Free,
|
||||
static TfLiteRegistration r = {/*init=*/fully_connected::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/fully_connected::Prepare,
|
||||
/*invoke=*/fully_connected::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
|
Loading…
Reference in New Issue
Block a user