Move GetTensorData calls directly to LstmStep call in the int8x8_8 version.
PiperOrigin-RevId: 317378494 Change-Id: I880de60ee5c963468b2eb47d8cca7a3b9765f57c
This commit is contained in:
parent
f60f6f0c1f
commit
030e65acd5
@ -2091,50 +2091,6 @@ TfLiteStatus EvalInteger8x8_8(
|
|||||||
const int n_cell = input_to_output_weights->dims->data[0];
|
const int n_cell = input_to_output_weights->dims->data[0];
|
||||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||||
|
|
||||||
// Weights and states.
|
|
||||||
const int8_t* input_to_input_weight_ptr =
|
|
||||||
GetTensorData<int8_t>(input_to_input_weights);
|
|
||||||
const int8_t* recurrent_to_input_weight_ptr =
|
|
||||||
GetTensorData<int8_t>(recurrent_to_input_weights);
|
|
||||||
const int8_t* cell_to_input_weight_ptr =
|
|
||||||
GetTensorData<int8_t>(cell_to_input_weights);
|
|
||||||
const int8_t* input_to_forget_weight_ptr =
|
|
||||||
GetTensorData<int8_t>(input_to_forget_weights);
|
|
||||||
const int8_t* recurrent_to_forget_weight_ptr =
|
|
||||||
GetTensorData<int8_t>(recurrent_to_forget_weights);
|
|
||||||
const int8_t* cell_to_forget_weight_ptr =
|
|
||||||
GetTensorData<int8_t>(cell_to_forget_weights);
|
|
||||||
const int8_t* input_to_cell_weight_ptr =
|
|
||||||
GetTensorData<int8_t>(input_to_cell_weights);
|
|
||||||
const int8_t* recurrent_to_cell_weight_ptr =
|
|
||||||
GetTensorData<int8_t>(recurrent_to_cell_weights);
|
|
||||||
const int8_t* input_to_output_weight_ptr =
|
|
||||||
GetTensorData<int8_t>(input_to_output_weights);
|
|
||||||
const int8_t* recurrent_to_output_weight_ptr =
|
|
||||||
GetTensorData<int8_t>(recurrent_to_output_weights);
|
|
||||||
const int8_t* cell_to_output_weight_ptr =
|
|
||||||
GetTensorData<int8_t>(cell_to_output_weights);
|
|
||||||
const int8_t* projection_weight_ptr =
|
|
||||||
GetTensorData<int8_t>(projection_weights);
|
|
||||||
const int16_t* layer_norm_input_weight_ptr =
|
|
||||||
GetTensorData<int16_t>(input_layer_norm_coefficients);
|
|
||||||
const int16_t* layer_norm_forget_weight_ptr =
|
|
||||||
GetTensorData<int16_t>(forget_layer_norm_coefficients);
|
|
||||||
const int16_t* layer_norm_cell_weight_ptr =
|
|
||||||
GetTensorData<int16_t>(cell_layer_norm_coefficients);
|
|
||||||
const int16_t* layer_norm_output_weight_ptr =
|
|
||||||
GetTensorData<int16_t>(output_layer_norm_coefficients);
|
|
||||||
const int32_t* input_gate_bias_ptr = GetTensorData<int32_t>(input_gate_bias);
|
|
||||||
const int32_t* forget_gate_bias_ptr =
|
|
||||||
GetTensorData<int32_t>(forget_gate_bias);
|
|
||||||
const int32_t* cell_gate_bias_ptr = GetTensorData<int32_t>(cell_gate_bias);
|
|
||||||
const int32_t* output_gate_bias_ptr =
|
|
||||||
GetTensorData<int32_t>(output_gate_bias);
|
|
||||||
const int32_t* projection_bias_ptr = GetTensorData<int32_t>(projection_bias);
|
|
||||||
int16_t* cell_ptr = GetTensorData<int16_t>(cell_state);
|
|
||||||
int8_t* output_state_ptr = GetTensorData<int8_t>(output_state);
|
|
||||||
int8_t* output_ptr = nullptr;
|
|
||||||
|
|
||||||
const int32_t input_zp = input->params.zero_point;
|
const int32_t input_zp = input->params.zero_point;
|
||||||
const int32_t output_state_zp = output_state->params.zero_point;
|
const int32_t output_state_zp = output_state->params.zero_point;
|
||||||
|
|
||||||
@ -2146,89 +2102,93 @@ TfLiteStatus EvalInteger8x8_8(
|
|||||||
|
|
||||||
for (int t = 0; t < max_time; t++) {
|
for (int t = 0; t < max_time; t++) {
|
||||||
const int t_rel = t;
|
const int t_rel = t;
|
||||||
output_ptr = output->data.int8 + t_rel * output_step;
|
int8_t* output_ptr = GetTensorData<int8_t>(output) + t_rel * output_step;
|
||||||
|
|
||||||
// Input can be int8 asymmetric or int16 symmetric.
|
// Input can be int8 asymmetric or int16 symmetric.
|
||||||
const int8_t* input_ptr = input->data.int8 + t_rel * input_step;
|
const int8_t* input_ptr = GetTensorData<int8_t>(input) + t_rel * input_step;
|
||||||
lstm_eval::LstmStepInteger(
|
lstm_eval::LstmStepInteger(
|
||||||
input_ptr, input_zp,
|
input_ptr, input_zp,
|
||||||
|
|
||||||
input_to_input_weight_ptr,
|
GetTensorData<int8_t>(input_to_input_weights),
|
||||||
integer_lstm_param->effective_input_to_input_scale_a,
|
integer_lstm_param->effective_input_to_input_scale_a,
|
||||||
integer_lstm_param->effective_input_to_input_scale_b,
|
integer_lstm_param->effective_input_to_input_scale_b,
|
||||||
|
|
||||||
input_to_forget_weight_ptr,
|
GetTensorData<int8_t>(input_to_forget_weights),
|
||||||
integer_lstm_param->effective_input_to_forget_scale_a,
|
integer_lstm_param->effective_input_to_forget_scale_a,
|
||||||
integer_lstm_param->effective_input_to_forget_scale_b,
|
integer_lstm_param->effective_input_to_forget_scale_b,
|
||||||
|
|
||||||
input_to_cell_weight_ptr,
|
GetTensorData<int8_t>(input_to_cell_weights),
|
||||||
integer_lstm_param->effective_input_to_cell_scale_a,
|
integer_lstm_param->effective_input_to_cell_scale_a,
|
||||||
integer_lstm_param->effective_input_to_cell_scale_b,
|
integer_lstm_param->effective_input_to_cell_scale_b,
|
||||||
|
|
||||||
input_to_output_weight_ptr,
|
GetTensorData<int8_t>(input_to_output_weights),
|
||||||
integer_lstm_param->effective_input_to_output_scale_a,
|
integer_lstm_param->effective_input_to_output_scale_a,
|
||||||
integer_lstm_param->effective_input_to_output_scale_b,
|
integer_lstm_param->effective_input_to_output_scale_b,
|
||||||
|
|
||||||
recurrent_to_input_weight_ptr,
|
GetTensorData<int8_t>(recurrent_to_input_weights),
|
||||||
integer_lstm_param->effective_recurrent_to_input_scale_a,
|
integer_lstm_param->effective_recurrent_to_input_scale_a,
|
||||||
integer_lstm_param->effective_recurrent_to_input_scale_b,
|
integer_lstm_param->effective_recurrent_to_input_scale_b,
|
||||||
|
|
||||||
recurrent_to_forget_weight_ptr,
|
GetTensorData<int8_t>(recurrent_to_forget_weights),
|
||||||
integer_lstm_param->effective_recurrent_to_forget_scale_a,
|
integer_lstm_param->effective_recurrent_to_forget_scale_a,
|
||||||
integer_lstm_param->effective_recurrent_to_forget_scale_b,
|
integer_lstm_param->effective_recurrent_to_forget_scale_b,
|
||||||
|
|
||||||
recurrent_to_cell_weight_ptr,
|
GetTensorData<int8_t>(recurrent_to_cell_weights),
|
||||||
integer_lstm_param->effective_recurrent_to_cell_scale_a,
|
integer_lstm_param->effective_recurrent_to_cell_scale_a,
|
||||||
integer_lstm_param->effective_recurrent_to_cell_scale_b,
|
integer_lstm_param->effective_recurrent_to_cell_scale_b,
|
||||||
|
|
||||||
recurrent_to_output_weight_ptr,
|
GetTensorData<int8_t>(recurrent_to_output_weights),
|
||||||
integer_lstm_param->effective_recurrent_to_output_scale_a,
|
integer_lstm_param->effective_recurrent_to_output_scale_a,
|
||||||
integer_lstm_param->effective_recurrent_to_output_scale_b,
|
integer_lstm_param->effective_recurrent_to_output_scale_b,
|
||||||
|
|
||||||
cell_to_input_weight_ptr,
|
GetTensorData<int8_t>(cell_to_input_weights),
|
||||||
integer_lstm_param->effective_cell_to_input_scale_a,
|
integer_lstm_param->effective_cell_to_input_scale_a,
|
||||||
integer_lstm_param->effective_cell_to_input_scale_b,
|
integer_lstm_param->effective_cell_to_input_scale_b,
|
||||||
|
|
||||||
cell_to_forget_weight_ptr,
|
GetTensorData<int8_t>(cell_to_forget_weights),
|
||||||
integer_lstm_param->effective_cell_to_forget_scale_a,
|
integer_lstm_param->effective_cell_to_forget_scale_a,
|
||||||
integer_lstm_param->effective_cell_to_forget_scale_b,
|
integer_lstm_param->effective_cell_to_forget_scale_b,
|
||||||
|
|
||||||
cell_to_output_weight_ptr,
|
GetTensorData<int8_t>(cell_to_output_weights),
|
||||||
integer_lstm_param->effective_cell_to_output_scale_a,
|
integer_lstm_param->effective_cell_to_output_scale_a,
|
||||||
integer_lstm_param->effective_cell_to_output_scale_b,
|
integer_lstm_param->effective_cell_to_output_scale_b,
|
||||||
|
|
||||||
projection_weight_ptr, integer_lstm_param->effective_proj_scale_a,
|
GetTensorData<int8_t>(projection_weights),
|
||||||
|
integer_lstm_param->effective_proj_scale_a,
|
||||||
integer_lstm_param->effective_proj_scale_b,
|
integer_lstm_param->effective_proj_scale_b,
|
||||||
|
|
||||||
layer_norm_input_weight_ptr,
|
GetTensorData<int16_t>(input_layer_norm_coefficients),
|
||||||
integer_lstm_param->layer_norm_input_scale_a,
|
integer_lstm_param->layer_norm_input_scale_a,
|
||||||
integer_lstm_param->layer_norm_input_scale_b,
|
integer_lstm_param->layer_norm_input_scale_b,
|
||||||
|
|
||||||
layer_norm_forget_weight_ptr,
|
GetTensorData<int16_t>(forget_layer_norm_coefficients),
|
||||||
integer_lstm_param->layer_norm_forget_scale_a,
|
integer_lstm_param->layer_norm_forget_scale_a,
|
||||||
integer_lstm_param->layer_norm_forget_scale_b,
|
integer_lstm_param->layer_norm_forget_scale_b,
|
||||||
|
|
||||||
layer_norm_cell_weight_ptr, integer_lstm_param->layer_norm_cell_scale_a,
|
GetTensorData<int16_t>(cell_layer_norm_coefficients),
|
||||||
|
integer_lstm_param->layer_norm_cell_scale_a,
|
||||||
integer_lstm_param->layer_norm_cell_scale_b,
|
integer_lstm_param->layer_norm_cell_scale_b,
|
||||||
|
|
||||||
layer_norm_output_weight_ptr,
|
GetTensorData<int16_t>(output_layer_norm_coefficients),
|
||||||
integer_lstm_param->layer_norm_output_scale_a,
|
integer_lstm_param->layer_norm_output_scale_a,
|
||||||
integer_lstm_param->layer_norm_output_scale_b,
|
integer_lstm_param->layer_norm_output_scale_b,
|
||||||
|
|
||||||
input_gate_bias_ptr, forget_gate_bias_ptr, cell_gate_bias_ptr,
|
GetTensorData<int32_t>(input_gate_bias),
|
||||||
output_gate_bias_ptr, projection_bias_ptr,
|
GetTensorData<int32_t>(forget_gate_bias),
|
||||||
|
GetTensorData<int32_t>(cell_gate_bias),
|
||||||
|
GetTensorData<int32_t>(output_gate_bias),
|
||||||
|
GetTensorData<int32_t>(projection_bias),
|
||||||
|
|
||||||
params, integer_lstm_param->intermediate_scale_a,
|
params, integer_lstm_param->intermediate_scale_a,
|
||||||
integer_lstm_param->intermediate_scale_b,
|
integer_lstm_param->intermediate_scale_b,
|
||||||
integer_lstm_param->intermediate_zp,
|
integer_lstm_param->intermediate_zp,
|
||||||
integer_lstm_param->quantized_cell_clip,
|
integer_lstm_param->quantized_cell_clip,
|
||||||
integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
|
integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
|
||||||
n_output, output_batch_leading_dim, output_state_ptr, output_state_zp,
|
n_output, output_batch_leading_dim, GetTensorData<int8_t>(output_state),
|
||||||
cell_ptr, output_ptr, GetTensorData<int8_t>(scratch0),
|
output_state_zp, GetTensorData<int16_t>(cell_state), output_ptr,
|
||||||
GetTensorData<int8_t>(scratch1), GetTensorData<int16_t>(scratch2),
|
GetTensorData<int8_t>(scratch0), GetTensorData<int8_t>(scratch1),
|
||||||
GetTensorData<int16_t>(scratch3), GetTensorData<int16_t>(scratch4),
|
GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3),
|
||||||
GetTensorData<int16_t>(scratch5), GetTensorData<int16_t>(scratch6),
|
GetTensorData<int16_t>(scratch4), GetTensorData<int16_t>(scratch5),
|
||||||
GetTensorData<int16_t>(scratch7));
|
GetTensorData<int16_t>(scratch6), GetTensorData<int16_t>(scratch7));
|
||||||
}
|
}
|
||||||
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
|
Loading…
Reference in New Issue
Block a user