diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index b2f3d77912b..e5bdf3e9a1e 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -216,9 +216,8 @@ inline void LstmStepFloat( const float* projection_weights_ptr, const float* projection_bias_ptr, const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, int output_batch_leading_dim, - float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_gate_scratch, - float* output_gate_scratch, float* output_ptr) { + float* output_state_ptr, float* cell_state_ptr, float* scratch0, + float* scratch1, float* scratch2, float* scratch3, float* output_ptr) { ruy::profiler::ScopeLabel label("LstmStepFloat"); // Since we have already checked that weights are all there or none, we can // check the existence of only one to the get the condition. @@ -226,6 +225,12 @@ inline void LstmStepFloat( const bool use_peephole = (cell_to_output_weights_ptr != nullptr); const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr); + // Make named scratch buffers for the different gates. + float* input_gate_scratch = scratch0; + float* forget_gate_scratch = scratch1; + float* cell_gate_scratch = scratch2; + float* output_gate_scratch = scratch3; + // Initialize scratch buffers with bias for regular lstm or initialize with // zero for layer norm lstm. if (use_layer_norm) { @@ -531,9 +536,8 @@ inline void LstmStepHybrid( const int8_t* projection_weights_ptr, float projection_weights_scale, const float* projection_bias_ptr, const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, - int output_batch_leading_dim, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_gate_scratch, - float* output_gate_scratch, float* scaling_factors, + int output_batch_leading_dim, float* scratch0, float* scratch1, + float* scratch2, float* scratch3, float* scaling_factors, float* scaling_factors_scratch, float* recovered_cell_weights, int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr, @@ -548,6 +552,12 @@ inline void LstmStepHybrid( const bool use_peephole = (cell_to_output_weights_ptr != nullptr); const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr); + // Make named scratch buffers for the different gates. + float* input_gate_scratch = scratch0; + float* forget_gate_scratch = scratch1; + float* cell_gate_scratch = scratch2; + float* output_gate_scratch = scratch3; + // Initialize scratch buffers with bias for regular lstm or initialize with // zero for layer norm lstm. if (use_layer_norm) { @@ -974,12 +984,12 @@ inline void LstmStepHybrid( // // Temporary pre-allocated storage for the calculation. Each is of size n_cell * // n_batch. -// scratch_0 -// scratch_1 -// scratch_2 -// scratch_3 -// scratch_4 -// scratch_5: this scratch buffer is created purely for optimizing the +// scratch0 +// scratch1 +// scratch2 +// scratch3 +// scratch4 +// scratch5: this scratch buffer is created purely for optimizing the // MatrixBatchVectorMultiplyAccumulate. // // Outputs: @@ -1047,10 +1057,15 @@ inline void LstmStepInteger( const int32_t* projection_effective_bias, int n_batch, int n_cell, int n_input, int n_output, int8_t* output_state_ptr, int32_t output_state_zp, int16_t* cell_ptr, int8_t* output_ptr, - int16_t* scratch_0_ptr, int16_t* scratch_1_ptr, int16_t* scratch_2_ptr, - int16_t* scratch_3_ptr, int8_t* scratch_4_ptr, int32_t* scratch_5_ptr, - CpuBackendContext* context) { + int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3, + int8_t* scratch4, int32_t* scratch5, CpuBackendContext* context) { ruy::profiler::ScopeLabel label("LstmStepInteger"); + // Make named scratch buffers for the different gates. + int16_t* input_gate_scratch = scratch0; + int16_t* forget_gate_scratch = scratch1; + int16_t* cell_gate_scratch = scratch2; + int16_t* output_gate_scratch = scratch3; + // Get hyper parameters. const bool use_cifg = (input_to_input_weight_ptr == nullptr); const bool use_peephole = (cell_to_output_weight_ptr != nullptr); @@ -1072,99 +1087,103 @@ inline void LstmStepInteger( // Set scratch to 0. if (!use_cifg) { - std::fill_n(scratch_0_ptr, n_batch * n_cell, 0); + std::fill_n(input_gate_scratch, n_batch * n_cell, 0); } - std::fill_n(scratch_1_ptr, n_batch * n_cell, 0); - std::fill_n(scratch_2_ptr, n_batch * n_cell, 0); - std::fill_n(scratch_3_ptr, n_batch * n_cell, 0); + std::fill_n(forget_gate_scratch, n_batch * n_cell, 0); + std::fill_n(cell_gate_scratch, n_batch * n_cell, 0); + std::fill_n(output_gate_scratch, n_batch * n_cell, 0); // Forget gate. tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_ptr, input_to_forget_effective_bias, input_to_forget_weight_ptr, effective_input_to_forget_scale_a, effective_input_to_forget_scale_b, - n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_1_ptr, context); + n_batch, n_input, n_cell, 0, scratch5, forget_gate_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( output_state_ptr, recurrent_to_forget_effective_bias, recurrent_to_forget_weight_ptr, effective_recurrent_to_forget_scale_a, effective_recurrent_to_forget_scale_b, n_batch, n_output, n_cell, 0, - scratch_5_ptr, scratch_1_ptr, context); + scratch5, forget_gate_scratch, context); if (use_peephole) { tensor_utils::VectorBatchVectorCwiseProductAccumulate( cell_to_forget_weight_ptr, n_output, cell_ptr, n_batch, effective_cell_to_forget_scale_a, effective_cell_to_forget_scale_b, - scratch_1_ptr); + forget_gate_scratch); } if (use_layer_norm) { tensor_utils::ApplyLayerNorm( - scratch_1_ptr, layer_norm_forget_weight_ptr, forget_gate_bias_ptr, + forget_gate_scratch, layer_norm_forget_weight_ptr, forget_gate_bias_ptr, layer_norm_forget_scale_a, layer_norm_forget_scale_b, - forget_variance_guard, n_batch, n_cell, scratch_1_ptr); + forget_variance_guard, n_batch, n_cell, forget_gate_scratch); } - tensor_utils::ApplySigmoid(scratch_1_ptr, n_batch, n_cell, scratch_1_ptr); + tensor_utils::ApplySigmoid(forget_gate_scratch, n_batch, n_cell, + forget_gate_scratch); - // Modulation gate. + // Cell gate. tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_ptr, input_to_cell_effective_bias, input_to_cell_weight_ptr, effective_input_to_cell_scale_a, effective_input_to_cell_scale_b, n_batch, - n_input, n_cell, 0, scratch_5_ptr, scratch_2_ptr, context); + n_input, n_cell, 0, scratch5, cell_gate_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( output_state_ptr, recurrent_to_cell_effective_bias, recurrent_to_cell_weight_ptr, effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b, n_batch, n_output, n_cell, 0, - scratch_5_ptr, scratch_2_ptr, context); + scratch5, cell_gate_scratch, context); if (use_layer_norm) { - tensor_utils::ApplyLayerNorm(scratch_2_ptr, layer_norm_cell_weight_ptr, + tensor_utils::ApplyLayerNorm(cell_gate_scratch, layer_norm_cell_weight_ptr, cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b, cell_variance_guard, - n_batch, n_cell, scratch_2_ptr); + n_batch, n_cell, cell_gate_scratch); } - tensor_utils::ApplyTanh(3, scratch_2_ptr, n_batch, n_cell, scratch_2_ptr); + tensor_utils::ApplyTanh(3, cell_gate_scratch, n_batch, n_cell, + cell_gate_scratch); // Input gate. if (use_cifg) { - tensor_utils::Sub1Vector(scratch_1_ptr, n_batch * n_cell, scratch_0_ptr); + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + input_gate_scratch); } else { tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_ptr, input_to_input_effective_bias, input_to_input_weight_ptr, effective_input_to_input_scale_a, effective_input_to_input_scale_b, - n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_0_ptr, context); + n_batch, n_input, n_cell, 0, scratch5, input_gate_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( output_state_ptr, recurrent_to_input_effective_bias, recurrent_to_input_weight_ptr, effective_recurrent_to_input_scale_a, effective_recurrent_to_input_scale_b, n_batch, n_output, n_cell, 0, - scratch_5_ptr, scratch_0_ptr, context); + scratch5, input_gate_scratch, context); if (use_peephole) { tensor_utils::VectorBatchVectorCwiseProductAccumulate( cell_to_input_weight_ptr, n_output, cell_ptr, n_batch, effective_cell_to_input_scale_a, effective_cell_to_input_scale_b, - scratch_0_ptr); + input_gate_scratch); } if (use_layer_norm) { tensor_utils::ApplyLayerNorm( - scratch_0_ptr, layer_norm_input_weight_ptr, input_gate_bias_ptr, + input_gate_scratch, layer_norm_input_weight_ptr, input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b, - input_variance_guard, n_batch, n_cell, scratch_0_ptr); + input_variance_guard, n_batch, n_cell, input_gate_scratch); } - tensor_utils::ApplySigmoid(scratch_0_ptr, n_batch, n_cell, scratch_0_ptr); + tensor_utils::ApplySigmoid(input_gate_scratch, n_batch, n_cell, + input_gate_scratch); } // New cell. - tensor_utils::CwiseMul(scratch_1_ptr, cell_ptr, n_batch, n_cell, 15, - scratch_1_ptr); + tensor_utils::CwiseMul(forget_gate_scratch, cell_ptr, n_batch, n_cell, 15, + forget_gate_scratch); - tensor_utils::CwiseMul(scratch_0_ptr, scratch_2_ptr, n_batch, n_cell, - 30 + cell_scale, scratch_2_ptr); + tensor_utils::CwiseMul(input_gate_scratch, cell_gate_scratch, n_batch, n_cell, + 30 + cell_scale, cell_gate_scratch); - tensor_utils::CwiseAdd(scratch_1_ptr, scratch_2_ptr, n_batch, n_cell, - cell_ptr); + tensor_utils::CwiseAdd(forget_gate_scratch, cell_gate_scratch, n_batch, + n_cell, cell_ptr); if (quantized_cell_clip > 0) { tensor_utils::CwiseClipping(cell_ptr, quantized_cell_clip, n_batch, n_cell); @@ -1174,49 +1193,50 @@ inline void LstmStepInteger( tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_ptr, input_to_output_effective_bias, input_to_output_weight_ptr, effective_input_to_output_scale_a, effective_input_to_output_scale_b, - n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_3_ptr, context); + n_batch, n_input, n_cell, 0, scratch5, output_gate_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( output_state_ptr, recurrent_to_output_effective_bias, recurrent_to_output_weight_ptr, effective_recurrent_to_output_scale_a, effective_recurrent_to_output_scale_b, n_batch, n_output, n_cell, 0, - scratch_5_ptr, scratch_3_ptr, context); + scratch5, output_gate_scratch, context); if (use_peephole) { tensor_utils::VectorBatchVectorCwiseProductAccumulate( cell_to_output_weight_ptr, n_output, cell_ptr, n_batch, effective_cell_to_output_scale_a, effective_cell_to_output_scale_b, - scratch_3_ptr); + output_gate_scratch); } if (use_layer_norm) { tensor_utils::ApplyLayerNorm( - scratch_3_ptr, layer_norm_output_weight_ptr, output_gate_bias_ptr, + output_gate_scratch, layer_norm_output_weight_ptr, output_gate_bias_ptr, layer_norm_output_scale_a, layer_norm_output_scale_b, - output_variance_guard, n_batch, n_cell, scratch_3_ptr); + output_variance_guard, n_batch, n_cell, output_gate_scratch); } - tensor_utils::ApplySigmoid(scratch_3_ptr, n_batch, n_cell, scratch_3_ptr); + tensor_utils::ApplySigmoid(output_gate_scratch, n_batch, n_cell, + output_gate_scratch); // Hidden. tensor_utils::ApplyTanh(15 + cell_scale, cell_ptr, n_batch, n_cell, - scratch_0_ptr); + input_gate_scratch); - tensor_utils::CwiseMul(scratch_3_ptr, scratch_0_ptr, effective_hidden_scale_a, - effective_hidden_scale_b, n_batch, n_cell, hidden_zp, - scratch_4_ptr); + tensor_utils::CwiseMul(output_gate_scratch, input_gate_scratch, + effective_hidden_scale_a, effective_hidden_scale_b, + n_batch, n_cell, hidden_zp, scratch4); // Projection. if (use_projection) { std::fill_n(output_ptr, n_batch * n_output, 0); tensor_utils::MatrixBatchVectorMultiplyAccumulate( - scratch_4_ptr, projection_effective_bias, projection_weight_ptr, + scratch4, projection_effective_bias, projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b, n_batch, n_cell, - n_output, output_state_zp, scratch_5_ptr, output_ptr, context); + n_output, output_state_zp, scratch5, output_ptr, context); if (quantized_proj_clip > 0) { tensor_utils::CwiseClipping(output_ptr, quantized_proj_clip, n_batch, n_output); } } else { - std::copy_n(scratch_4_ptr, n_batch * n_output, output_ptr); + std::copy_n(scratch4, n_batch * n_output, output_ptr); } std::copy_n(output_ptr, n_batch * n_output, output_state_ptr); } @@ -1300,14 +1320,14 @@ inline void LstmStepInteger( // // Temporary pre-allocated storage for the calculation. Each is of size n_cell * // n_batch. -// scratch_0 -// scratch_1 -// scratch_2 -// scratch_3 -// scratch_4 -// scratch_5 -// scratch_6 -// scratch_7 +// scratch0 +// scratch1 +// scratch2 +// scratch3 +// scratch4 +// scratch5 +// scratch6 +// scratch7 // // Outputs: // output_state_ptr - size 'n_batch * n_output' @@ -1369,6 +1389,12 @@ void LstmStepInteger( int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3, int16_t* scratch4, int16_t* scratch5, int16_t* scratch6, int16_t* scratch7) { + // Make named scratch buffers for the different gates. + int16_t* input_gate_scratch = scratch5; + int16_t* forget_gate_scratch = scratch2; + int16_t* cell_gate_scratch = scratch3; + int16_t* output_gate_scratch = scratch4; + // Forget gate. std::fill_n(scratch0, n_batch * n_cell, 0); std::fill_n(scratch1, n_batch * n_cell, 0); @@ -1386,16 +1412,17 @@ void LstmStepInteger( tensor_utils::TwoGateSaturationgAdd( scratch0, intermediate_zp[4], scratch1, intermediate_zp[5], intermediate_scale_a[2], intermediate_scale_b[2], intermediate_scale_a[3], - intermediate_scale_b[3], n_batch, n_cell, scratch2); + intermediate_scale_b[3], n_batch, n_cell, forget_gate_scratch); // Forget gate layer norm. tensor_utils::ApplyLayerNormFloat( - scratch2, layer_norm_forget_weight_ptr, layer_norm_forget_scale_a, - layer_norm_forget_scale_b, forget_gate_bias_ptr, n_batch, n_cell, - scratch2); + forget_gate_scratch, layer_norm_forget_weight_ptr, + layer_norm_forget_scale_a, layer_norm_forget_scale_b, + forget_gate_bias_ptr, n_batch, n_cell, forget_gate_scratch); // Forget gate sigmoid. - tensor_utils::ApplySigmoidFloat(scratch2, n_batch, n_cell, scratch2); + tensor_utils::ApplySigmoidFloat(forget_gate_scratch, n_batch, n_cell, + forget_gate_scratch); // Update gate. std::fill_n(scratch0, n_batch * n_cell, 0); @@ -1413,15 +1440,17 @@ void LstmStepInteger( tensor_utils::TwoGateSaturationgAdd( scratch0, intermediate_zp[7], scratch1, intermediate_zp[8], intermediate_scale_a[4], intermediate_scale_b[4], intermediate_scale_a[5], - intermediate_scale_b[5], n_batch, n_cell, scratch3); + intermediate_scale_b[5], n_batch, n_cell, cell_gate_scratch); - // Update gate with layer norm. + // Update gate layer norm. tensor_utils::ApplyLayerNormFloat( - scratch3, layer_norm_cell_weight_ptr, layer_norm_cell_scale_a, - layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_cell, scratch3); + cell_gate_scratch, layer_norm_cell_weight_ptr, layer_norm_cell_scale_a, + layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_cell, + cell_gate_scratch); // Update gate tanh. - tensor_utils::ApplyTanhFloat(scratch3, n_batch, n_cell, -12, scratch3); + tensor_utils::ApplyTanhFloat(cell_gate_scratch, n_batch, n_cell, -12, + cell_gate_scratch); // Output gate. std::fill_n(scratch0, n_batch * n_cell, 0); @@ -1440,26 +1469,28 @@ void LstmStepInteger( tensor_utils::TwoGateSaturationgAdd( scratch0, intermediate_zp[10], scratch1, intermediate_zp[11], intermediate_scale_a[6], intermediate_scale_b[6], intermediate_scale_a[7], - intermediate_scale_b[7], n_batch, n_cell, scratch4); + intermediate_scale_b[7], n_batch, n_cell, output_gate_scratch); // Output gate with layer norm. tensor_utils::ApplyLayerNormFloat( - scratch4, layer_norm_output_weight_ptr, layer_norm_output_scale_a, - layer_norm_output_scale_b, output_gate_bias_ptr, n_batch, n_cell, - scratch4); + output_gate_scratch, layer_norm_output_weight_ptr, + layer_norm_output_scale_a, layer_norm_output_scale_b, + output_gate_bias_ptr, n_batch, n_cell, output_gate_scratch); // Output gate sigmoid. - tensor_utils::ApplySigmoidFloat(scratch4, n_batch, n_cell, scratch4); + tensor_utils::ApplySigmoidFloat(output_gate_scratch, n_batch, n_cell, + output_gate_scratch); // Input gate with cifg - tensor_utils::Sub1Vector(scratch2, n_batch * n_cell, scratch5); + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + input_gate_scratch); // New cell. - tensor_utils::CwiseMul(scratch2, cell_ptr, n_batch, n_cell, 15 + 15 - 15, - scratch6); + tensor_utils::CwiseMul(forget_gate_scratch, cell_ptr, n_batch, n_cell, + 15 + 15 - 15, scratch6); - tensor_utils::CwiseMul(scratch5, scratch3, n_batch, n_cell, 15 + 15 - 15, - scratch7); + tensor_utils::CwiseMul(input_gate_scratch, cell_gate_scratch, n_batch, n_cell, + 15 + 15 - 15, scratch7); tensor_utils::CwiseAdd(scratch6, scratch7, n_batch, n_cell, cell_ptr); @@ -1468,15 +1499,16 @@ void LstmStepInteger( } // Cell to hidden. - tensor_utils::ApplyTanhFloat(cell_ptr, n_batch, n_cell, -15, scratch2); + tensor_utils::ApplyTanhFloat(cell_ptr, n_batch, n_cell, -15, + forget_gate_scratch); std::vector hidden(n_batch * n_cell); - tensor_utils::CwiseMul(scratch4, scratch2, n_batch, n_cell, 15 + 15 - 15, - scratch3); + tensor_utils::CwiseMul(output_gate_scratch, forget_gate_scratch, n_batch, + n_cell, 15 + 15 - 15, cell_gate_scratch); // Projection. tensor_utils::MatrixBatchVectorMultiply( - scratch3, projection_weight_ptr, effective_proj_scale_a, + cell_gate_scratch, projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b, projection_bias_ptr, n_batch, n_cell, n_output, output_state_zp, output_ptr); diff --git a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc index 50138442c25..09ce81c1d97 100644 --- a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc +++ b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc @@ -62,11 +62,16 @@ inline void LstmStepWithAuxInput( const float* projection_weights_ptr, const float* projection_bias_ptr, const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, int output_batch_leading_dim, - float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_gate_scratch, - float* output_gate_scratch, float* output_ptr, Logger* logger, - const std::vector& intermediate_tensor_indexes, + float* output_state_ptr, float* cell_state_ptr, float* scratch0, + float* scratch1, float* scratch2, float* scratch3, float* output_ptr, + Logger* logger, const std::vector& intermediate_tensor_indexes, ErrorReporter* error_reporter) { + // Make named scratch buffers for the different gates. + float* input_gate_scratch = scratch0; + float* forget_gate_scratch = scratch1; + float* cell_gate_scratch = scratch2; + float* output_gate_scratch = scratch3; + // Since we have already checked that weights are all there or none, we can // check the existence of only one to the get the condition. const bool use_cifg = (input_to_input_weights_ptr == nullptr);