Move GetTensorData calls directly to LstmStep call in the int8x8_8 version.

PiperOrigin-RevId: 317378494
Change-Id: I880de60ee5c963468b2eb47d8cca7a3b9765f57c
This commit is contained in:
Robert David 2020-06-19 14:17:36 -07:00 committed by TensorFlower Gardener
parent f60f6f0c1f
commit 030e65acd5

View File

@ -2091,50 +2091,6 @@ TfLiteStatus EvalInteger8x8_8(
const int n_cell = input_to_output_weights->dims->data[0];
const int n_output = recurrent_to_output_weights->dims->data[1];
// Weights and states.
const int8_t* input_to_input_weight_ptr =
GetTensorData<int8_t>(input_to_input_weights);
const int8_t* recurrent_to_input_weight_ptr =
GetTensorData<int8_t>(recurrent_to_input_weights);
const int8_t* cell_to_input_weight_ptr =
GetTensorData<int8_t>(cell_to_input_weights);
const int8_t* input_to_forget_weight_ptr =
GetTensorData<int8_t>(input_to_forget_weights);
const int8_t* recurrent_to_forget_weight_ptr =
GetTensorData<int8_t>(recurrent_to_forget_weights);
const int8_t* cell_to_forget_weight_ptr =
GetTensorData<int8_t>(cell_to_forget_weights);
const int8_t* input_to_cell_weight_ptr =
GetTensorData<int8_t>(input_to_cell_weights);
const int8_t* recurrent_to_cell_weight_ptr =
GetTensorData<int8_t>(recurrent_to_cell_weights);
const int8_t* input_to_output_weight_ptr =
GetTensorData<int8_t>(input_to_output_weights);
const int8_t* recurrent_to_output_weight_ptr =
GetTensorData<int8_t>(recurrent_to_output_weights);
const int8_t* cell_to_output_weight_ptr =
GetTensorData<int8_t>(cell_to_output_weights);
const int8_t* projection_weight_ptr =
GetTensorData<int8_t>(projection_weights);
const int16_t* layer_norm_input_weight_ptr =
GetTensorData<int16_t>(input_layer_norm_coefficients);
const int16_t* layer_norm_forget_weight_ptr =
GetTensorData<int16_t>(forget_layer_norm_coefficients);
const int16_t* layer_norm_cell_weight_ptr =
GetTensorData<int16_t>(cell_layer_norm_coefficients);
const int16_t* layer_norm_output_weight_ptr =
GetTensorData<int16_t>(output_layer_norm_coefficients);
const int32_t* input_gate_bias_ptr = GetTensorData<int32_t>(input_gate_bias);
const int32_t* forget_gate_bias_ptr =
GetTensorData<int32_t>(forget_gate_bias);
const int32_t* cell_gate_bias_ptr = GetTensorData<int32_t>(cell_gate_bias);
const int32_t* output_gate_bias_ptr =
GetTensorData<int32_t>(output_gate_bias);
const int32_t* projection_bias_ptr = GetTensorData<int32_t>(projection_bias);
int16_t* cell_ptr = GetTensorData<int16_t>(cell_state);
int8_t* output_state_ptr = GetTensorData<int8_t>(output_state);
int8_t* output_ptr = nullptr;
const int32_t input_zp = input->params.zero_point;
const int32_t output_state_zp = output_state->params.zero_point;
@ -2146,89 +2102,93 @@ TfLiteStatus EvalInteger8x8_8(
for (int t = 0; t < max_time; t++) {
const int t_rel = t;
output_ptr = output->data.int8 + t_rel * output_step;
int8_t* output_ptr = GetTensorData<int8_t>(output) + t_rel * output_step;
// Input can be int8 asymmetric or int16 symmetric.
const int8_t* input_ptr = input->data.int8 + t_rel * input_step;
const int8_t* input_ptr = GetTensorData<int8_t>(input) + t_rel * input_step;
lstm_eval::LstmStepInteger(
input_ptr, input_zp,
input_to_input_weight_ptr,
GetTensorData<int8_t>(input_to_input_weights),
integer_lstm_param->effective_input_to_input_scale_a,
integer_lstm_param->effective_input_to_input_scale_b,
input_to_forget_weight_ptr,
GetTensorData<int8_t>(input_to_forget_weights),
integer_lstm_param->effective_input_to_forget_scale_a,
integer_lstm_param->effective_input_to_forget_scale_b,
input_to_cell_weight_ptr,
GetTensorData<int8_t>(input_to_cell_weights),
integer_lstm_param->effective_input_to_cell_scale_a,
integer_lstm_param->effective_input_to_cell_scale_b,
input_to_output_weight_ptr,
GetTensorData<int8_t>(input_to_output_weights),
integer_lstm_param->effective_input_to_output_scale_a,
integer_lstm_param->effective_input_to_output_scale_b,
recurrent_to_input_weight_ptr,
GetTensorData<int8_t>(recurrent_to_input_weights),
integer_lstm_param->effective_recurrent_to_input_scale_a,
integer_lstm_param->effective_recurrent_to_input_scale_b,
recurrent_to_forget_weight_ptr,
GetTensorData<int8_t>(recurrent_to_forget_weights),
integer_lstm_param->effective_recurrent_to_forget_scale_a,
integer_lstm_param->effective_recurrent_to_forget_scale_b,
recurrent_to_cell_weight_ptr,
GetTensorData<int8_t>(recurrent_to_cell_weights),
integer_lstm_param->effective_recurrent_to_cell_scale_a,
integer_lstm_param->effective_recurrent_to_cell_scale_b,
recurrent_to_output_weight_ptr,
GetTensorData<int8_t>(recurrent_to_output_weights),
integer_lstm_param->effective_recurrent_to_output_scale_a,
integer_lstm_param->effective_recurrent_to_output_scale_b,
cell_to_input_weight_ptr,
GetTensorData<int8_t>(cell_to_input_weights),
integer_lstm_param->effective_cell_to_input_scale_a,
integer_lstm_param->effective_cell_to_input_scale_b,
cell_to_forget_weight_ptr,
GetTensorData<int8_t>(cell_to_forget_weights),
integer_lstm_param->effective_cell_to_forget_scale_a,
integer_lstm_param->effective_cell_to_forget_scale_b,
cell_to_output_weight_ptr,
GetTensorData<int8_t>(cell_to_output_weights),
integer_lstm_param->effective_cell_to_output_scale_a,
integer_lstm_param->effective_cell_to_output_scale_b,
projection_weight_ptr, integer_lstm_param->effective_proj_scale_a,
GetTensorData<int8_t>(projection_weights),
integer_lstm_param->effective_proj_scale_a,
integer_lstm_param->effective_proj_scale_b,
layer_norm_input_weight_ptr,
GetTensorData<int16_t>(input_layer_norm_coefficients),
integer_lstm_param->layer_norm_input_scale_a,
integer_lstm_param->layer_norm_input_scale_b,
layer_norm_forget_weight_ptr,
GetTensorData<int16_t>(forget_layer_norm_coefficients),
integer_lstm_param->layer_norm_forget_scale_a,
integer_lstm_param->layer_norm_forget_scale_b,
layer_norm_cell_weight_ptr, integer_lstm_param->layer_norm_cell_scale_a,
GetTensorData<int16_t>(cell_layer_norm_coefficients),
integer_lstm_param->layer_norm_cell_scale_a,
integer_lstm_param->layer_norm_cell_scale_b,
layer_norm_output_weight_ptr,
GetTensorData<int16_t>(output_layer_norm_coefficients),
integer_lstm_param->layer_norm_output_scale_a,
integer_lstm_param->layer_norm_output_scale_b,
input_gate_bias_ptr, forget_gate_bias_ptr, cell_gate_bias_ptr,
output_gate_bias_ptr, projection_bias_ptr,
GetTensorData<int32_t>(input_gate_bias),
GetTensorData<int32_t>(forget_gate_bias),
GetTensorData<int32_t>(cell_gate_bias),
GetTensorData<int32_t>(output_gate_bias),
GetTensorData<int32_t>(projection_bias),
params, integer_lstm_param->intermediate_scale_a,
integer_lstm_param->intermediate_scale_b,
integer_lstm_param->intermediate_zp,
integer_lstm_param->quantized_cell_clip,
integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
n_output, output_batch_leading_dim, output_state_ptr, output_state_zp,
cell_ptr, output_ptr, GetTensorData<int8_t>(scratch0),
GetTensorData<int8_t>(scratch1), GetTensorData<int16_t>(scratch2),
GetTensorData<int16_t>(scratch3), GetTensorData<int16_t>(scratch4),
GetTensorData<int16_t>(scratch5), GetTensorData<int16_t>(scratch6),
GetTensorData<int16_t>(scratch7));
n_output, output_batch_leading_dim, GetTensorData<int8_t>(output_state),
output_state_zp, GetTensorData<int16_t>(cell_state), output_ptr,
GetTensorData<int8_t>(scratch0), GetTensorData<int8_t>(scratch1),
GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3),
GetTensorData<int16_t>(scratch4), GetTensorData<int16_t>(scratch5),
GetTensorData<int16_t>(scratch6), GetTensorData<int16_t>(scratch7));
}
return kTfLiteOk;