diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index e5bdf3e9a1e..5b4e8a8d479 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -976,7 +976,7 @@ inline void LstmStepHybrid( // Scalar values: // quantized_cell_clip: quantized clip value for cell. // quantized_proj_clip: quantized clip value for projection. -// cell_scale: the power of two scale for cell state. +// cell_state_scale: the power of two scale for cell state. // // Zero points: // output_state_zp: zero point of output state @@ -1043,9 +1043,10 @@ inline void LstmStepInteger( int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b, const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr, const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr, - int16_t quantized_cell_clip, int8_t quantized_proj_clip, int32_t cell_scale, - int32_t input_variance_guard, int32_t forget_variance_guard, - int32_t cell_variance_guard, int32_t output_variance_guard, + int16_t quantized_cell_clip, int8_t quantized_proj_clip, + int32_t cell_state_scale, int32_t input_variance_guard, + int32_t forget_variance_guard, int32_t cell_variance_guard, + int32_t output_variance_guard, const int32_t* input_to_forget_effective_bias, const int32_t* recurrent_to_forget_effective_bias, const int32_t* input_to_cell_effective_bias, @@ -1056,7 +1057,7 @@ inline void LstmStepInteger( const int32_t* recurrent_to_input_effective_bias, const int32_t* projection_effective_bias, int n_batch, int n_cell, int n_input, int n_output, int8_t* output_state_ptr, - int32_t output_state_zp, int16_t* cell_ptr, int8_t* output_ptr, + int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr, int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3, int8_t* scratch4, int32_t* scratch5, CpuBackendContext* context) { ruy::profiler::ScopeLabel label("LstmStepInteger"); @@ -1106,7 +1107,7 @@ inline void LstmStepInteger( scratch5, forget_gate_scratch, context); if (use_peephole) { tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_forget_weight_ptr, n_output, cell_ptr, n_batch, + cell_to_forget_weight_ptr, n_output, cell_state_ptr, n_batch, effective_cell_to_forget_scale_a, effective_cell_to_forget_scale_b, forget_gate_scratch); } @@ -1160,7 +1161,7 @@ inline void LstmStepInteger( scratch5, input_gate_scratch, context); if (use_peephole) { tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_input_weight_ptr, n_output, cell_ptr, n_batch, + cell_to_input_weight_ptr, n_output, cell_state_ptr, n_batch, effective_cell_to_input_scale_a, effective_cell_to_input_scale_b, input_gate_scratch); } @@ -1175,18 +1176,19 @@ inline void LstmStepInteger( input_gate_scratch); } - // New cell. - tensor_utils::CwiseMul(forget_gate_scratch, cell_ptr, n_batch, n_cell, 15, - forget_gate_scratch); + // New cell state. + tensor_utils::CwiseMul(forget_gate_scratch, cell_state_ptr, n_batch, n_cell, + 15, forget_gate_scratch); tensor_utils::CwiseMul(input_gate_scratch, cell_gate_scratch, n_batch, n_cell, - 30 + cell_scale, cell_gate_scratch); + 30 + cell_state_scale, cell_gate_scratch); tensor_utils::CwiseAdd(forget_gate_scratch, cell_gate_scratch, n_batch, - n_cell, cell_ptr); + n_cell, cell_state_ptr); if (quantized_cell_clip > 0) { - tensor_utils::CwiseClipping(cell_ptr, quantized_cell_clip, n_batch, n_cell); + tensor_utils::CwiseClipping(cell_state_ptr, quantized_cell_clip, n_batch, + n_cell); } // Ouptut gate. @@ -1202,7 +1204,7 @@ inline void LstmStepInteger( scratch5, output_gate_scratch, context); if (use_peephole) { tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_output_weight_ptr, n_output, cell_ptr, n_batch, + cell_to_output_weight_ptr, n_output, cell_state_ptr, n_batch, effective_cell_to_output_scale_a, effective_cell_to_output_scale_b, output_gate_scratch); } @@ -1218,8 +1220,8 @@ inline void LstmStepInteger( output_gate_scratch); // Hidden. - tensor_utils::ApplyTanh(15 + cell_scale, cell_ptr, n_batch, n_cell, - input_gate_scratch); + tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state_ptr, n_batch, + n_cell, input_gate_scratch); tensor_utils::CwiseMul(output_gate_scratch, input_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b, @@ -1312,7 +1314,7 @@ inline void LstmStepInteger( // Scalar values: // quantized_cell_clip: quantized clip value for cell. // quantized_proj_clip: quantized clip value for projection. -// cell_scale: the power of two scale for cell state. +// cell_state_scale: the power of two scale for cell state. // // Zero points: // output_state_zp: zero point of output state. @@ -1385,7 +1387,7 @@ void LstmStepInteger( const int32_t* intermediate_zp, int16_t quantized_cell_clip, int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input, int n_output, int output_batch_leading_dim, int8_t* output_state_ptr, - int32_t output_state_zp, int16_t* cell_ptr, int8_t* output_ptr, + int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr, int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3, int16_t* scratch4, int16_t* scratch5, int16_t* scratch6, int16_t* scratch7) { @@ -1486,20 +1488,21 @@ void LstmStepInteger( input_gate_scratch); // New cell. - tensor_utils::CwiseMul(forget_gate_scratch, cell_ptr, n_batch, n_cell, + tensor_utils::CwiseMul(forget_gate_scratch, cell_state_ptr, n_batch, n_cell, 15 + 15 - 15, scratch6); tensor_utils::CwiseMul(input_gate_scratch, cell_gate_scratch, n_batch, n_cell, 15 + 15 - 15, scratch7); - tensor_utils::CwiseAdd(scratch6, scratch7, n_batch, n_cell, cell_ptr); + tensor_utils::CwiseAdd(scratch6, scratch7, n_batch, n_cell, cell_state_ptr); if (quantized_cell_clip > 0) { - tensor_utils::CwiseClipping(cell_ptr, quantized_cell_clip, n_batch, n_cell); + tensor_utils::CwiseClipping(cell_state_ptr, quantized_cell_clip, n_batch, + n_cell); } // Cell to hidden. - tensor_utils::ApplyTanhFloat(cell_ptr, n_batch, n_cell, -15, + tensor_utils::ApplyTanhFloat(cell_state_ptr, n_batch, n_cell, -15, forget_gate_scratch); std::vector hidden(n_batch * n_cell);