Move GetTensorData calls directly to LstmStep call in the int8x8_8 version.
PiperOrigin-RevId: 317378494 Change-Id: I880de60ee5c963468b2eb47d8cca7a3b9765f57c
This commit is contained in:
parent
f60f6f0c1f
commit
030e65acd5
@ -2091,50 +2091,6 @@ TfLiteStatus EvalInteger8x8_8(
|
||||
const int n_cell = input_to_output_weights->dims->data[0];
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Weights and states.
|
||||
const int8_t* input_to_input_weight_ptr =
|
||||
GetTensorData<int8_t>(input_to_input_weights);
|
||||
const int8_t* recurrent_to_input_weight_ptr =
|
||||
GetTensorData<int8_t>(recurrent_to_input_weights);
|
||||
const int8_t* cell_to_input_weight_ptr =
|
||||
GetTensorData<int8_t>(cell_to_input_weights);
|
||||
const int8_t* input_to_forget_weight_ptr =
|
||||
GetTensorData<int8_t>(input_to_forget_weights);
|
||||
const int8_t* recurrent_to_forget_weight_ptr =
|
||||
GetTensorData<int8_t>(recurrent_to_forget_weights);
|
||||
const int8_t* cell_to_forget_weight_ptr =
|
||||
GetTensorData<int8_t>(cell_to_forget_weights);
|
||||
const int8_t* input_to_cell_weight_ptr =
|
||||
GetTensorData<int8_t>(input_to_cell_weights);
|
||||
const int8_t* recurrent_to_cell_weight_ptr =
|
||||
GetTensorData<int8_t>(recurrent_to_cell_weights);
|
||||
const int8_t* input_to_output_weight_ptr =
|
||||
GetTensorData<int8_t>(input_to_output_weights);
|
||||
const int8_t* recurrent_to_output_weight_ptr =
|
||||
GetTensorData<int8_t>(recurrent_to_output_weights);
|
||||
const int8_t* cell_to_output_weight_ptr =
|
||||
GetTensorData<int8_t>(cell_to_output_weights);
|
||||
const int8_t* projection_weight_ptr =
|
||||
GetTensorData<int8_t>(projection_weights);
|
||||
const int16_t* layer_norm_input_weight_ptr =
|
||||
GetTensorData<int16_t>(input_layer_norm_coefficients);
|
||||
const int16_t* layer_norm_forget_weight_ptr =
|
||||
GetTensorData<int16_t>(forget_layer_norm_coefficients);
|
||||
const int16_t* layer_norm_cell_weight_ptr =
|
||||
GetTensorData<int16_t>(cell_layer_norm_coefficients);
|
||||
const int16_t* layer_norm_output_weight_ptr =
|
||||
GetTensorData<int16_t>(output_layer_norm_coefficients);
|
||||
const int32_t* input_gate_bias_ptr = GetTensorData<int32_t>(input_gate_bias);
|
||||
const int32_t* forget_gate_bias_ptr =
|
||||
GetTensorData<int32_t>(forget_gate_bias);
|
||||
const int32_t* cell_gate_bias_ptr = GetTensorData<int32_t>(cell_gate_bias);
|
||||
const int32_t* output_gate_bias_ptr =
|
||||
GetTensorData<int32_t>(output_gate_bias);
|
||||
const int32_t* projection_bias_ptr = GetTensorData<int32_t>(projection_bias);
|
||||
int16_t* cell_ptr = GetTensorData<int16_t>(cell_state);
|
||||
int8_t* output_state_ptr = GetTensorData<int8_t>(output_state);
|
||||
int8_t* output_ptr = nullptr;
|
||||
|
||||
const int32_t input_zp = input->params.zero_point;
|
||||
const int32_t output_state_zp = output_state->params.zero_point;
|
||||
|
||||
@ -2146,89 +2102,93 @@ TfLiteStatus EvalInteger8x8_8(
|
||||
|
||||
for (int t = 0; t < max_time; t++) {
|
||||
const int t_rel = t;
|
||||
output_ptr = output->data.int8 + t_rel * output_step;
|
||||
|
||||
int8_t* output_ptr = GetTensorData<int8_t>(output) + t_rel * output_step;
|
||||
// Input can be int8 asymmetric or int16 symmetric.
|
||||
const int8_t* input_ptr = input->data.int8 + t_rel * input_step;
|
||||
const int8_t* input_ptr = GetTensorData<int8_t>(input) + t_rel * input_step;
|
||||
lstm_eval::LstmStepInteger(
|
||||
input_ptr, input_zp,
|
||||
|
||||
input_to_input_weight_ptr,
|
||||
GetTensorData<int8_t>(input_to_input_weights),
|
||||
integer_lstm_param->effective_input_to_input_scale_a,
|
||||
integer_lstm_param->effective_input_to_input_scale_b,
|
||||
|
||||
input_to_forget_weight_ptr,
|
||||
GetTensorData<int8_t>(input_to_forget_weights),
|
||||
integer_lstm_param->effective_input_to_forget_scale_a,
|
||||
integer_lstm_param->effective_input_to_forget_scale_b,
|
||||
|
||||
input_to_cell_weight_ptr,
|
||||
GetTensorData<int8_t>(input_to_cell_weights),
|
||||
integer_lstm_param->effective_input_to_cell_scale_a,
|
||||
integer_lstm_param->effective_input_to_cell_scale_b,
|
||||
|
||||
input_to_output_weight_ptr,
|
||||
GetTensorData<int8_t>(input_to_output_weights),
|
||||
integer_lstm_param->effective_input_to_output_scale_a,
|
||||
integer_lstm_param->effective_input_to_output_scale_b,
|
||||
|
||||
recurrent_to_input_weight_ptr,
|
||||
GetTensorData<int8_t>(recurrent_to_input_weights),
|
||||
integer_lstm_param->effective_recurrent_to_input_scale_a,
|
||||
integer_lstm_param->effective_recurrent_to_input_scale_b,
|
||||
|
||||
recurrent_to_forget_weight_ptr,
|
||||
GetTensorData<int8_t>(recurrent_to_forget_weights),
|
||||
integer_lstm_param->effective_recurrent_to_forget_scale_a,
|
||||
integer_lstm_param->effective_recurrent_to_forget_scale_b,
|
||||
|
||||
recurrent_to_cell_weight_ptr,
|
||||
GetTensorData<int8_t>(recurrent_to_cell_weights),
|
||||
integer_lstm_param->effective_recurrent_to_cell_scale_a,
|
||||
integer_lstm_param->effective_recurrent_to_cell_scale_b,
|
||||
|
||||
recurrent_to_output_weight_ptr,
|
||||
GetTensorData<int8_t>(recurrent_to_output_weights),
|
||||
integer_lstm_param->effective_recurrent_to_output_scale_a,
|
||||
integer_lstm_param->effective_recurrent_to_output_scale_b,
|
||||
|
||||
cell_to_input_weight_ptr,
|
||||
GetTensorData<int8_t>(cell_to_input_weights),
|
||||
integer_lstm_param->effective_cell_to_input_scale_a,
|
||||
integer_lstm_param->effective_cell_to_input_scale_b,
|
||||
|
||||
cell_to_forget_weight_ptr,
|
||||
GetTensorData<int8_t>(cell_to_forget_weights),
|
||||
integer_lstm_param->effective_cell_to_forget_scale_a,
|
||||
integer_lstm_param->effective_cell_to_forget_scale_b,
|
||||
|
||||
cell_to_output_weight_ptr,
|
||||
GetTensorData<int8_t>(cell_to_output_weights),
|
||||
integer_lstm_param->effective_cell_to_output_scale_a,
|
||||
integer_lstm_param->effective_cell_to_output_scale_b,
|
||||
|
||||
projection_weight_ptr, integer_lstm_param->effective_proj_scale_a,
|
||||
GetTensorData<int8_t>(projection_weights),
|
||||
integer_lstm_param->effective_proj_scale_a,
|
||||
integer_lstm_param->effective_proj_scale_b,
|
||||
|
||||
layer_norm_input_weight_ptr,
|
||||
GetTensorData<int16_t>(input_layer_norm_coefficients),
|
||||
integer_lstm_param->layer_norm_input_scale_a,
|
||||
integer_lstm_param->layer_norm_input_scale_b,
|
||||
|
||||
layer_norm_forget_weight_ptr,
|
||||
GetTensorData<int16_t>(forget_layer_norm_coefficients),
|
||||
integer_lstm_param->layer_norm_forget_scale_a,
|
||||
integer_lstm_param->layer_norm_forget_scale_b,
|
||||
|
||||
layer_norm_cell_weight_ptr, integer_lstm_param->layer_norm_cell_scale_a,
|
||||
GetTensorData<int16_t>(cell_layer_norm_coefficients),
|
||||
integer_lstm_param->layer_norm_cell_scale_a,
|
||||
integer_lstm_param->layer_norm_cell_scale_b,
|
||||
|
||||
layer_norm_output_weight_ptr,
|
||||
GetTensorData<int16_t>(output_layer_norm_coefficients),
|
||||
integer_lstm_param->layer_norm_output_scale_a,
|
||||
integer_lstm_param->layer_norm_output_scale_b,
|
||||
|
||||
input_gate_bias_ptr, forget_gate_bias_ptr, cell_gate_bias_ptr,
|
||||
output_gate_bias_ptr, projection_bias_ptr,
|
||||
GetTensorData<int32_t>(input_gate_bias),
|
||||
GetTensorData<int32_t>(forget_gate_bias),
|
||||
GetTensorData<int32_t>(cell_gate_bias),
|
||||
GetTensorData<int32_t>(output_gate_bias),
|
||||
GetTensorData<int32_t>(projection_bias),
|
||||
|
||||
params, integer_lstm_param->intermediate_scale_a,
|
||||
integer_lstm_param->intermediate_scale_b,
|
||||
integer_lstm_param->intermediate_zp,
|
||||
integer_lstm_param->quantized_cell_clip,
|
||||
integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
|
||||
n_output, output_batch_leading_dim, output_state_ptr, output_state_zp,
|
||||
cell_ptr, output_ptr, GetTensorData<int8_t>(scratch0),
|
||||
GetTensorData<int8_t>(scratch1), GetTensorData<int16_t>(scratch2),
|
||||
GetTensorData<int16_t>(scratch3), GetTensorData<int16_t>(scratch4),
|
||||
GetTensorData<int16_t>(scratch5), GetTensorData<int16_t>(scratch6),
|
||||
GetTensorData<int16_t>(scratch7));
|
||||
n_output, output_batch_leading_dim, GetTensorData<int8_t>(output_state),
|
||||
output_state_zp, GetTensorData<int16_t>(cell_state), output_ptr,
|
||||
GetTensorData<int8_t>(scratch0), GetTensorData<int8_t>(scratch1),
|
||||
GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3),
|
||||
GetTensorData<int16_t>(scratch4), GetTensorData<int16_t>(scratch5),
|
||||
GetTensorData<int16_t>(scratch6), GetTensorData<int16_t>(scratch7));
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
|
Loading…
Reference in New Issue
Block a user