From ed557008d681b6bd612f0ea6b1aa056c4cfa744b Mon Sep 17 00:00:00 2001 From: Robert David Date: Tue, 16 Jun 2020 11:19:41 -0700 Subject: [PATCH] All LSTM implementations: Rename cell_scratch to cell_gate_scratch, and cell_bias_ptr to cell_gate_bias_ptr to better reflect what those arrays are. Do note this is not the same thing as the LSTM cell "state", but a layer/gate that calculates the update. The cell state depends on the input, forget, and cell gates; these arrays are the output and the bias for the last gate. PiperOrigin-RevId: 316720132 Change-Id: I71c370dabd27f776987e061b9393022c775589c9 --- tensorflow/lite/kernels/lstm_eval.cc | 156 +++++++++--------- .../calibration/builtin_logging_ops/lstm.cc | 61 +++---- 2 files changed, 114 insertions(+), 103 deletions(-) diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index b285ed1030f..b4d43414d89 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -212,13 +212,13 @@ inline void LstmStepFloat( const float* cell_layer_norm_coefficients_ptr, const float* output_layer_norm_coefficients_ptr, const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr, - const float* cell_bias_ptr, const float* output_gate_bias_ptr, + const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr, const float* projection_weights_ptr, const float* projection_bias_ptr, const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, int output_batch_leading_dim, float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, - float* output_ptr) { + float* forget_gate_scratch, float* cell_gate_scratch, + float* output_gate_scratch, float* output_ptr) { ruy::profiler::ScopeLabel label("LstmStepFloat"); // Since we have already checked that weights are all there or none, we can // check the existence of only one to the get the condition. @@ -233,7 +233,7 @@ inline void LstmStepFloat( std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f); } std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f); - std::fill_n(cell_scratch, n_cell * n_batch, 0.0f); + std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f); std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f); } else { if (!use_cifg) { @@ -242,8 +242,8 @@ inline void LstmStepFloat( } tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, - cell_scratch); + tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch, + cell_gate_scratch); tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, output_gate_scratch); } @@ -262,7 +262,7 @@ inline void LstmStepFloat( forget_gate_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_cell_weights_ptr, n_cell, n_input, input_ptr, n_batch, - cell_scratch); + cell_gate_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch, output_gate_scratch); @@ -283,7 +283,7 @@ inline void LstmStepFloat( n_batch, forget_gate_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr, - n_batch, cell_scratch); + n_batch, cell_gate_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr, n_batch, output_gate_scratch); @@ -300,7 +300,7 @@ inline void LstmStepFloat( n_batch, forget_gate_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, cell_scratch); + n_batch, cell_gate_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr, n_batch, output_gate_scratch); @@ -347,24 +347,26 @@ inline void LstmStepFloat( tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, n_batch * n_cell, cell_state_ptr); if (use_layer_norm) { - tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, - n_batch); + tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch, + n_cell, n_batch); tensor_utils::VectorBatchVectorCwiseProduct( - cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch, - cell_scratch); + cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch, + cell_gate_scratch); + tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch, + cell_gate_scratch); } - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - params->activation, cell_scratch); + tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell, + params->activation, cell_gate_scratch); if (use_cifg) { tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, forget_gate_scratch); tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + cell_gate_scratch, forget_gate_scratch, n_batch * n_cell, + cell_state_ptr); } else { tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + cell_gate_scratch, input_gate_scratch, n_batch * n_cell, + cell_state_ptr); } if (params->cell_clip > 0.0) { tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, @@ -389,8 +391,8 @@ inline void LstmStepFloat( tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, output_gate_scratch); tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, - params->activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + params->activation, cell_gate_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_gate_scratch, n_batch * n_cell, output_gate_scratch); const bool use_projection_weight = (projection_weights_ptr != nullptr); @@ -525,19 +527,19 @@ inline void LstmStepHybrid( const float* cell_layer_norm_coefficients_ptr, const float* output_layer_norm_coefficients_ptr, const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr, - const float* cell_bias_ptr, const float* output_gate_bias_ptr, + const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, float projection_weights_scale, const float* projection_bias_ptr, const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, int output_batch_leading_dim, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, - float* scaling_factors, float* scaling_factors_scratch, - float* recovered_cell_weights, int8_t* quantized_input_ptr, - int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr, - int8_t* quantized_cell_state_ptr, float* output_state_ptr, - float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr, - int32_t* zero_points, int32_t* row_sums, int row_sums_size, - bool* compute_row_sums, bool asymmetric_quantize_inputs, + float* forget_gate_scratch, float* cell_gate_scratch, + float* output_gate_scratch, float* scaling_factors, + float* scaling_factors_scratch, float* recovered_cell_weights, + int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr, + int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr, + float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr, + float* output_ptr, int32_t* zero_points, int32_t* row_sums, + int row_sums_size, bool* compute_row_sums, bool asymmetric_quantize_inputs, CpuBackendContext* context) { ruy::profiler::ScopeLabel label("LstmStepHybrid"); // Since we have already checked that weights are all there or none, we @@ -553,7 +555,7 @@ inline void LstmStepHybrid( std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f); } std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f); - std::fill_n(cell_scratch, n_cell * n_batch, 0.0f); + std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f); std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f); } else { if (!use_cifg) { @@ -562,8 +564,8 @@ inline void LstmStepHybrid( } tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, - cell_scratch); + tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch, + cell_gate_scratch); tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, output_gate_scratch); } @@ -657,7 +659,8 @@ inline void LstmStepHybrid( tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr, - input_to_cell_weights_scale, scaling_factors, n_batch, cell_scratch, + input_to_cell_weights_scale, scaling_factors, n_batch, + cell_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch, context); @@ -699,9 +702,10 @@ inline void LstmStepHybrid( tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_cell_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, aux_input_to_cell_weights_scale, - scaling_factors, n_batch, cell_scratch, /*per_channel_scale=*/nullptr, - zero_points, accum_scratch_ptr, aux_input_to_cell_row_sums, - compute_row_sums, scaling_factors_scratch, context); + scaling_factors, n_batch, cell_gate_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + aux_input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch, + context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_output_weights_ptr, n_cell, n_aux_input, @@ -739,9 +743,10 @@ inline void LstmStepHybrid( tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_cell_weights_ptr, n_cell, n_output, quantized_output_state_ptr, recurrent_to_cell_weights_scale, - scaling_factors, n_batch, cell_scratch, /*per_channel_scale=*/nullptr, - zero_points, accum_scratch_ptr, recurrent_to_cell_row_sums, - compute_row_sums, scaling_factors_scratch, context); + scaling_factors, n_batch, cell_gate_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + recurrent_to_cell_row_sums, compute_row_sums, scaling_factors_scratch, + context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_output_weights_ptr, n_cell, n_output, @@ -800,24 +805,26 @@ inline void LstmStepHybrid( tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, n_batch * n_cell, cell_state_ptr); if (use_layer_norm) { - tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, - n_batch); + tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch, + n_cell, n_batch); tensor_utils::VectorBatchVectorCwiseProduct( - cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch, - cell_scratch); + cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch, + cell_gate_scratch); + tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch, + cell_gate_scratch); } - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - params->activation, cell_scratch); + tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell, + params->activation, cell_gate_scratch); if (use_cifg) { tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, forget_gate_scratch); tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + cell_gate_scratch, forget_gate_scratch, n_batch * n_cell, + cell_state_ptr); } else { tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + cell_gate_scratch, input_gate_scratch, n_batch * n_cell, + cell_state_ptr); } if (params->cell_clip > 0.0) { tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, @@ -845,8 +852,8 @@ inline void LstmStepHybrid( tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, output_gate_scratch); tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, - params->activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + params->activation, cell_gate_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_gate_scratch, n_batch * n_cell, output_gate_scratch); const bool use_projection_weight = (projection_weights_ptr != nullptr); @@ -940,7 +947,7 @@ inline void LstmStepHybrid( // Gate biases of size 'n_cell': // input_bias_ptr - optional // forget_bias_ptr -// cell_bias_ptr +// cell_gate_bias_ptr // output_bias_ptr // // Layer norm coefficients of size 'n_cell', representing diagonal matrices. @@ -1028,7 +1035,7 @@ inline void LstmStepInteger( const int16_t* layer_norm_output_weight_ptr, int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b, const int32_t* input_bias_ptr, const int32_t* forget_bias_ptr, - const int32_t* cell_bias_ptr, const int32_t* output_bias_ptr, + const int32_t* cell_gate_bias_ptr, const int32_t* output_bias_ptr, int16_t quantized_cell_clip, int8_t quantized_proj_clip, int32_t cell_scale, int32_t input_variance_guard, int32_t forget_variance_guard, int32_t cell_variance_guard, int32_t output_variance_guard, @@ -1115,7 +1122,7 @@ inline void LstmStepInteger( if (use_layer_norm) { tensor_utils::ApplyLayerNorm(scratch_2_ptr, layer_norm_cell_weight_ptr, - cell_bias_ptr, layer_norm_cell_scale_a, + cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b, cell_variance_guard, n_batch, n_cell, scratch_2_ptr); } @@ -1266,7 +1273,7 @@ inline void LstmStepInteger( // Gate biases of size 'n_cell': // input_bias_ptr - optional // forget_bias_ptr -// cell_bias_ptr +// cell_gate_bias_ptr // output_bias_ptr // // Layer norm coefficients of size 'n_cell', representing diagonal matrices. @@ -1355,7 +1362,7 @@ void LstmStepInteger( const int16_t* layer_norm_output_weight_ptr, int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b, const int32_t* input_bias_ptr, const int32_t* forget_bias_ptr, - const int32_t* cell_bias_ptr, const int32_t* output_bias_ptr, + const int32_t* cell_gate_bias_ptr, const int32_t* output_bias_ptr, const int32_t* proj_bias_ptr, const TfLiteLSTMParams* params, const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b, const int32_t* intermediate_zp, int32 quantized_cell_clip, @@ -1413,7 +1420,7 @@ void LstmStepInteger( // Update gate with layer norm. tensor_utils::ApplyLayerNormFloat( scratch3, layer_norm_cell_weight_ptr, layer_norm_cell_scale_a, - layer_norm_cell_scale_b, cell_bias_ptr, n_batch, n_cell, scratch3); + layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_cell, scratch3); // Update gate tanh. tensor_utils::ApplyTanhFloat(scratch3, n_batch, n_cell, -12, scratch3); @@ -1538,16 +1545,16 @@ TfLiteStatus EvalFloat( // Index the scratch buffers pointers to the global scratch buffer. float* scratch_buffer_ptr = GetTensorData(scratch_buffer); float* input_gate_scratch = nullptr; - float* cell_scratch = nullptr; + float* cell_gate_scratch = nullptr; float* forget_gate_scratch = nullptr; float* output_gate_scratch = nullptr; if (use_cifg) { - cell_scratch = scratch_buffer_ptr; + cell_gate_scratch = scratch_buffer_ptr; forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch; output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch; } else { input_gate_scratch = scratch_buffer_ptr; - cell_scratch = scratch_buffer_ptr + n_cell * n_batch; + cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch; forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch; output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch; } @@ -1599,7 +1606,8 @@ TfLiteStatus EvalFloat( n_input, aux_input_size, n_output, output_batch_leading_dim, GetTensorData(activation_state), GetTensorData(cell_state), input_gate_scratch, - forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr); + forget_gate_scratch, cell_gate_scratch, output_gate_scratch, + output_ptr); } } else { for (int b = 0; b < n_batch; b++) { @@ -1628,7 +1636,7 @@ TfLiteStatus EvalFloat( float* input_gate_scratch_ptr = input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr; float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell; - float* cell_scratch_ptr = cell_scratch + b * n_cell; + float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell; float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell; LstmStepFloat( @@ -1659,8 +1667,8 @@ TfLiteStatus EvalFloat( GetTensorData(projection_bias), params, /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim, activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr, - forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr, - output_ptr); + forget_gate_scratch_ptr, cell_gate_scratch_ptr, + output_gate_scratch_ptr, output_ptr); } } } @@ -1723,16 +1731,16 @@ TfLiteStatus EvalHybrid( float* scratch_buffer_ptr = GetTensorData(scratch_buffer); float* input_gate_scratch = nullptr; - float* cell_scratch = nullptr; + float* cell_gate_scratch = nullptr; float* forget_gate_scratch = nullptr; float* output_gate_scratch = nullptr; if (use_cifg) { - cell_scratch = scratch_buffer_ptr; + cell_gate_scratch = scratch_buffer_ptr; forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch; output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch; } else { input_gate_scratch = scratch_buffer_ptr; - cell_scratch = scratch_buffer_ptr + n_cell * n_batch; + cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch; forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch; output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch; } @@ -1805,7 +1813,7 @@ TfLiteStatus EvalHybrid( GetTensorScale(projection_weights), GetTensorData(projection_bias), params, n_batch, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim, - input_gate_scratch, forget_gate_scratch, cell_scratch, + input_gate_scratch, forget_gate_scratch, cell_gate_scratch, output_gate_scratch, GetTensorData(scaling_factors), GetTensorData(prod_scaling_factors), GetTensorData(recovered_cell_weights), @@ -1845,7 +1853,7 @@ TfLiteStatus EvalHybrid( float* input_gate_scratch_ptr = input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr; float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell; - float* cell_scratch_ptr = cell_scratch + b * n_cell; + float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell; float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell; LstmStepHybrid( @@ -1892,8 +1900,8 @@ TfLiteStatus EvalHybrid( GetTensorData(projection_bias), params, /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim, input_gate_scratch_ptr, - forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr, - GetTensorData(scaling_factors), + forget_gate_scratch_ptr, cell_gate_scratch_ptr, + output_gate_scratch_ptr, GetTensorData(scaling_factors), GetTensorData(prod_scaling_factors), GetTensorData(recovered_cell_weights), GetTensorData(input_quantized), @@ -2119,7 +2127,7 @@ TfLiteStatus EvalInteger8x8_8( GetTensorData(output_layer_norm_coefficients); const int32_t* input_bias_ptr = GetTensorData(input_gate_bias); const int32_t* forget_bias_ptr = GetTensorData(forget_gate_bias); - const int32_t* cell_bias_ptr = GetTensorData(cell_bias); + const int32_t* cell_gate_bias_ptr = GetTensorData(cell_bias); const int32_t* output_bias_ptr = GetTensorData(output_gate_bias); const int32_t* proj_bias_ptr = GetTensorData(projection_bias); int16_t* cell_ptr = GetTensorData(cell_state); @@ -2206,7 +2214,7 @@ TfLiteStatus EvalInteger8x8_8( integer_lstm_param->layer_norm_output_scale_a, integer_lstm_param->layer_norm_output_scale_b, - input_bias_ptr, forget_bias_ptr, cell_bias_ptr, output_bias_ptr, + input_bias_ptr, forget_bias_ptr, cell_gate_bias_ptr, output_bias_ptr, proj_bias_ptr, params, integer_lstm_param->intermediate_scale_a, diff --git a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc index b58900c0bc6..0d4c614511d 100644 --- a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc +++ b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc @@ -58,13 +58,13 @@ inline void LstmStepWithAuxInput( const float* cell_layer_norm_coefficients_ptr, const float* output_layer_norm_coefficients_ptr, const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr, - const float* cell_bias_ptr, const float* output_gate_bias_ptr, + const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr, const float* projection_weights_ptr, const float* projection_bias_ptr, const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, int output_batch_leading_dim, float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, - float* output_ptr, Logger* logger, + float* forget_gate_scratch, float* cell_gate_scratch, + float* output_gate_scratch, float* output_ptr, Logger* logger, const std::vector& intermediate_tensor_indexes, ErrorReporter* error_reporter) { // Since we have already checked that weights are all there or none, we can @@ -80,7 +80,7 @@ inline void LstmStepWithAuxInput( std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f); } std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f); - std::fill_n(cell_scratch, n_cell * n_batch, 0.0f); + std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f); std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f); } else { if (!use_cifg) { @@ -89,8 +89,8 @@ inline void LstmStepWithAuxInput( } tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, - cell_scratch); + tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch, + cell_gate_scratch); tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, output_gate_scratch); } @@ -107,7 +107,7 @@ inline void LstmStepWithAuxInput( forget_gate_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_ptr, n_cell, n_input, input_ptr, - n_batch, cell_scratch); + n_batch, cell_gate_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch, output_gate_scratch); @@ -125,7 +125,7 @@ inline void LstmStepWithAuxInput( n_batch, forget_gate_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr, - n_batch, cell_scratch); + n_batch, cell_gate_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr, n_batch, output_gate_scratch); @@ -142,7 +142,7 @@ inline void LstmStepWithAuxInput( n_batch, forget_gate_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, cell_scratch); + n_batch, cell_gate_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr, n_batch, output_gate_scratch); @@ -193,26 +193,28 @@ inline void LstmStepWithAuxInput( tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, n_batch * n_cell, cell_state_ptr); if (use_layer_norm) { - logger->LogTensorValue(intermediate_tensor_indexes[2], cell_scratch, + logger->LogTensorValue(intermediate_tensor_indexes[2], cell_gate_scratch, n_cell * n_batch, error_reporter); - tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, - n_batch); + tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch, + n_cell, n_batch); tensor_utils::VectorBatchVectorCwiseProduct( - cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch, - cell_scratch); + cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch, + cell_gate_scratch); + tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch, + cell_gate_scratch); } - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - params->activation, cell_scratch); + tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell, + params->activation, cell_gate_scratch); if (use_cifg) { tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, forget_gate_scratch); tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + cell_gate_scratch, forget_gate_scratch, n_batch * n_cell, + cell_state_ptr); } else { tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + cell_gate_scratch, input_gate_scratch, n_batch * n_cell, + cell_state_ptr); } if (params->cell_clip > 0.0) { tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, @@ -239,8 +241,8 @@ inline void LstmStepWithAuxInput( tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, output_gate_scratch); tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, - params->activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + params->activation, cell_gate_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_gate_scratch, n_batch * n_cell, output_gate_scratch); logger->LogTensorValue(intermediate_tensor_indexes[4], output_gate_scratch, @@ -329,16 +331,16 @@ TfLiteStatus EvalFloat( // Index the scratch buffers pointers to the global scratch buffer. float* scratch_buffer_ptr = GetTensorData(scratch_buffer); float* input_gate_scratch = nullptr; - float* cell_scratch = nullptr; + float* cell_gate_scratch = nullptr; float* forget_gate_scratch = nullptr; float* output_gate_scratch = nullptr; if (use_cifg) { - cell_scratch = scratch_buffer_ptr; + cell_gate_scratch = scratch_buffer_ptr; forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch; output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch; } else { input_gate_scratch = scratch_buffer_ptr; - cell_scratch = scratch_buffer_ptr + n_cell * n_batch; + cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch; forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch; output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch; } @@ -390,7 +392,7 @@ TfLiteStatus EvalFloat( n_input, aux_input_size, n_output, output_batch_leading_dim, GetTensorData(activation_state), GetTensorData(cell_state), input_gate_scratch, - forget_gate_scratch, cell_scratch, output_gate_scratch, + forget_gate_scratch, cell_gate_scratch, output_gate_scratch, output_ptr_time, logger, intermediate_tensor_indexes, error_reporter); } } else { @@ -420,7 +422,7 @@ TfLiteStatus EvalFloat( float* input_gate_scratch_ptr = input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr; float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell; - float* cell_scratch_ptr = cell_scratch + b * n_cell; + float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell; float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell; LstmStepWithAuxInput( @@ -451,8 +453,9 @@ TfLiteStatus EvalFloat( GetTensorData(projection_bias), params, /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim, activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr, - forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr, - output_ptr, logger, intermediate_tensor_indexes, error_reporter); + forget_gate_scratch_ptr, cell_gate_scratch_ptr, + output_gate_scratch_ptr, output_ptr, logger, + intermediate_tensor_indexes, error_reporter); } } }