Merge pull request #47098 from advaitjain:fusion-f1-svdf
PiperOrigin-RevId: 357813167 Change-Id: I20771b069117c8fdbe0644b3cbcb8cf2a0371abc
This commit is contained in:
commit
c8e9451ab4
@ -51,14 +51,14 @@ constexpr int kOutputTensor = 0;
|
||||
* Note: passing OpData by value might seem like an oversight but it helps
|
||||
* reduce the latency. See b/155656675 for more details.
|
||||
*/
|
||||
void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteEvalTensor* input_tensor,
|
||||
const TfLiteEvalTensor* weights_feature_tensor,
|
||||
const TfLiteEvalTensor* weights_time_tensor,
|
||||
const TfLiteEvalTensor* bias_tensor,
|
||||
const TfLiteSVDFParams* params,
|
||||
TfLiteEvalTensor* activation_state_tensor,
|
||||
TfLiteEvalTensor* output_tensor, OpData data) {
|
||||
void EvalIntegerSvdfHifimini(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteEvalTensor* input_tensor,
|
||||
const TfLiteEvalTensor* weights_feature_tensor,
|
||||
const TfLiteEvalTensor* weights_time_tensor,
|
||||
const TfLiteEvalTensor* bias_tensor,
|
||||
const TfLiteSVDFParams* params,
|
||||
TfLiteEvalTensor* activation_state_tensor,
|
||||
TfLiteEvalTensor* output_tensor, OpData data) {
|
||||
const int n_rank = params->rank;
|
||||
const int n_batch = input_tensor->dims->data[0];
|
||||
const int n_input = input_tensor->dims->data[1];
|
||||
@ -243,7 +243,76 @@ void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#elif defined(FUSION_F1)
|
||||
|
||||
TfLiteStatus EvalIntegerSvdfHifi4(
|
||||
TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteEvalTensor* input_tensor,
|
||||
const TfLiteEvalTensor* weights_feature_tensor,
|
||||
const TfLiteEvalTensor* weights_time_tensor,
|
||||
const TfLiteEvalTensor* bias_tensor, const TfLiteSVDFParams* params,
|
||||
TfLiteEvalTensor* activation_state_tensor, TfLiteEvalTensor* output_tensor,
|
||||
const OpData& data) {
|
||||
const int n_rank = params->rank;
|
||||
const int n_batch = input_tensor->dims->data[0];
|
||||
const int n_input = input_tensor->dims->data[1];
|
||||
const int n_filter = weights_feature_tensor->dims->data[0];
|
||||
const int n_unit = n_filter / n_rank;
|
||||
const int n_memory = weights_time_tensor->dims->data[1];
|
||||
|
||||
TFLITE_DCHECK(context != nullptr);
|
||||
TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
|
||||
|
||||
// Shift states.
|
||||
int16_t* const state_ptr =
|
||||
tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
|
||||
|
||||
// Left shift the activation_state.
|
||||
int num_bytes = sizeof(*state_ptr) * (n_batch * n_filter * n_memory - 1);
|
||||
xa_nn_memmove_16(state_ptr, state_ptr + 1, num_bytes);
|
||||
|
||||
// Note: no need to clear the latest activation, matmul is not accumulative.
|
||||
|
||||
// Feature matmul.
|
||||
const int8_t* input = tflite::micro::GetTensorData<int8_t>(input_tensor);
|
||||
const int8_t* weight_feature =
|
||||
tflite::micro::GetTensorData<int8_t>(weights_feature_tensor);
|
||||
int16_t* result_in_batch = state_ptr + (n_memory - 1);
|
||||
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
TF_LITE_ENSURE_EQ(context,
|
||||
xa_nn_matXvec_out_stride_sym8sxasym8s_16(
|
||||
&result_in_batch[b * n_filter * n_memory],
|
||||
weight_feature, &input[b * n_input], NULL, n_filter,
|
||||
n_input, n_input, n_memory, -data.input_zero_point,
|
||||
(data.effective_scale_1_a), data.effective_scale_1_b),
|
||||
0);
|
||||
}
|
||||
|
||||
// Time weights dot product + activation
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
const int16_t* vector1_ptr =
|
||||
tflite::micro::GetTensorData<int16_t>(weights_time_tensor);
|
||||
const int16_t* vector2_ptr =
|
||||
tflite::micro::GetTensorData<int16_t>(activation_state_tensor) +
|
||||
b * n_memory * n_filter;
|
||||
const int32_t* bias_ptr =
|
||||
tflite::micro::GetTensorData<int32_t>(bias_tensor);
|
||||
int8_t* output_ptr =
|
||||
tflite::micro::GetTensorData<int8_t>(output_tensor) + b * n_unit;
|
||||
|
||||
TF_LITE_ENSURE_EQ(
|
||||
context,
|
||||
xa_nn_dot_prod_16x16_asym8s(
|
||||
output_ptr, vector1_ptr, vector2_ptr, bias_ptr, n_memory * n_rank,
|
||||
(data.effective_scale_2_a), data.effective_scale_2_b,
|
||||
data.output_zero_point, n_unit),
|
||||
0);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
#endif // defined(FUSION_F1) || defined(HIFIMINI)
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context != nullptr);
|
||||
@ -274,11 +343,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int rank = params->rank;
|
||||
const int input_size = input->dims->data[1];
|
||||
const int batch_size = input->dims->data[0];
|
||||
|
||||
#if defined(HIFIMINI)
|
||||
// Ensure the input size is a multiple of two. This is necessary since
|
||||
// optimized kernels access the memory in chunks of two, and all accesses
|
||||
// must be aligned to 16 bits.
|
||||
// TODO(b/153202598): Remove when padding is allowed in TFLite tensors.
|
||||
TF_LITE_ENSURE_EQ(context, input_size % 2, 0);
|
||||
#endif // defined(HIFIMINI)
|
||||
|
||||
const int num_filters = weights_feature->dims->data[0];
|
||||
TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
|
||||
@ -339,9 +411,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
static_cast<double>(activation_state->params.scale *
|
||||
weights_time->params.scale / output->params.scale);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, static_cast<double>(bias->params.scale),
|
||||
static_cast<double>(activation_state->params.scale *
|
||||
weights_time->params.scale));
|
||||
TF_LITE_ENSURE_NEAR(context, static_cast<double>(bias->params.scale),
|
||||
static_cast<double>(activation_state->params.scale *
|
||||
weights_time->params.scale),
|
||||
1e-5);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
@ -396,13 +469,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
||||
|
||||
#if defined(HIFIMINI)
|
||||
EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias,
|
||||
params, activation_state, output, data);
|
||||
EvalIntegerSvdfHifimini(context, node, input, weights_feature, weights_time,
|
||||
bias, params, activation_state, output, data);
|
||||
return kTfLiteOk;
|
||||
#elif defined(FUSION_F1)
|
||||
return EvalIntegerSvdfHifi4(context, node, input, weights_feature,
|
||||
weights_time, bias, params, activation_state,
|
||||
output, data);
|
||||
#else
|
||||
EvalIntegerSvdfReference(context, node, input, weights_feature, weights_time,
|
||||
bias, params, activation_state, output, data);
|
||||
#endif
|
||||
return kTfLiteOk;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user