From ec07b637cecf04897f34c745bd37404ad795e645 Mon Sep 17 00:00:00 2001 From: Robert David Date: Tue, 30 Jun 2020 15:39:21 -0700 Subject: [PATCH] LSTM: Split gate calculations to separate functions. PiperOrigin-RevId: 319113154 Change-Id: I3f6d5358ac2bc619e3705d6f8aaf66c38c7c1b66 --- tensorflow/lite/kernels/lstm_eval.cc | 1106 ++++++++--------- .../calibration/builtin_logging_ops/lstm.cc | 264 ++-- 2 files changed, 633 insertions(+), 737 deletions(-) diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index 42d6f89c0e4..9087bbeada9 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -128,6 +128,91 @@ inline float GetTensorScale(const TfLiteTensor* tensor) { } // LINT.IfChange +// Calculates a single LSTM gate. +// +// Implements the following formula: (* is matrix multiply) +// gate = activate(W_input * input + W_aux * aux_input + +// W_peephole * cell + W_recurrent * prev_output + bias) +// with layer norm: +// gate = activate(W_norm * normalize(...) + bias) // not adding bias inside +// +// Activation is sigmoid except for the "cell" gate (configurable, usually tanh) +// +// Parameters: +// Input vectors (to LSTM): | Size: | Optional? +// input | n_input | +// aux_input | n_aux_input | y (bidir LSTM) +// Input vectors (persistent states): +// output_state | n_output | +// cell_state | n_cell | +// 'Constant' inputs: +// input_to_gate_weights | n_cell * n_input | +// aux_input_to_gate_weights | n_cell * n_aux_input | y (bidir LSTM) +// recurrent_to_gate_weights | n_cell * n_output | +// cell_to_gate_weights | n_cell | y (peephole) +// gate_bias | n_cell | +// layer_norm_coefficients | n_cell | y (layer norm) +// Output vector: +// gate | n_cell | +// Scalar parameters: +// n_batch - batch size / number of vectors +// n_input, n_aux_input, n_output, n_cell - size of vectors. +// activation - activation to use. +// is_input_all_zeros, is_aux_input_all_zeros - if input vectors are all zero. +// use_layer_norm - if doing layer norm LSTM. +inline void CalculateLstmGateFloat( + const float* input, const float* input_to_gate_weights, + const float* aux_input, const float* aux_input_to_gate_weights, + const float* output_state, const float* recurrent_to_gate_weights, + const float* cell_state, const float* cell_to_gate_weights, + const float* layer_norm_coefficients, const float* gate_bias, + const int n_batch, const int n_input, const int n_aux_input, + const int n_output, const int n_cell, + const TfLiteFusedActivation activation, float* gate, + const bool is_input_all_zeros, const bool is_aux_input_all_zeros) { + const bool use_peephole = (cell_to_gate_weights != nullptr); + const bool use_layer_norm = (layer_norm_coefficients != nullptr); + + // Initialize scratch buffers with bias for regular lstm or initialize with + // zero for layer norm lstm. + if (use_layer_norm) { + std::fill_n(gate, n_cell * n_batch, 0.0f); + } else { + tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate); + } + // For each batch and cell: compute input_weight * input. + // Skip if input is all zeros. + if (!is_input_all_zeros) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_gate_weights, n_cell, n_input, input, n_batch, gate); + } + // For each batch and cell: compute aux_input_weight * aux_input. + // Skip if auxiliary input is not available or all zeros. + if (!is_aux_input_all_zeros) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights, + n_cell, n_aux_input, + aux_input, n_batch, gate); + } + // For each batch and cell: compute recurrent_weight * output_state. + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate); + // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM) + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_gate_weights, n_cell, cell_state, n_batch, gate); + } + // Do layer normalization (if layer norm LSTM) + if (use_layer_norm) { + tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch); + tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell, + gate, n_batch, gate); + tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate); + } + // Apply activation + tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation, + gate); +} + // Updates the LSTM cell state, used by both float and hybrid LSTM versions. // // Implements the following formula: @@ -221,6 +306,101 @@ void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output, // LINT.ThenChange(../tools/optimize/calibration/builtin_logging_ops/lstm.cc,\ // ../experimental/kernels/fp16/lstm_eval.cc) +// Calculates a single LSTM gate, hybrid version. +// Implements the same functionality as CalculateLstmGateFloat. +void CalculateLstmGateHybrid( + // Input and weights + const int8_t* input, const float* input_sf, const int32_t* input_zp, + const int8_t* input_to_gate_weights, + const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums, + // Aux input and weights + const int8_t* aux_input, const float* aux_input_sf, + const int32_t* aux_input_zp, const int8_t* aux_input_to_gate_weights, + const float aux_input_to_gate_weights_scale, + int32_t* aux_input_to_gate_row_sums, + // Output state and weights + const int8_t* output_state, const float* output_state_sf, + const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights, + const float recurrent_to_gate_weights_scale, + int32_t* recurrent_to_gate_row_sums, + // Cell state and weights (peephole LSTM) + const float* cell_state, const int8_t* cell_to_gate_weights, + const float cell_to_gate_weights_scale, + // Layer normalization coefficients (layer norm LSTM) + gate bias + const float* layer_norm_coefficients, const float* gate_bias, + // Array sizes + const int n_batch, const int n_input, const int n_aux_input, + const int n_output, const int n_cell, + const TfLiteFusedActivation activation, + // Output + float* gate, + // Parameters for performance optimizations + const bool is_input_all_zeros, const bool is_aux_input_all_zeros, + const bool is_output_state_all_zeros, bool* compute_row_sums, + CpuBackendContext* context, + // Scratch arrays + float* scratch0, // size: n_batch + float* scratch1, // size: n_cell, only used if peephole LSTM + int32_t* accum_scratch // For MatrixBatchVectorMultiplyAccumulate +) { + const bool use_peephole = (cell_to_gate_weights != nullptr); + const bool use_layer_norm = (layer_norm_coefficients != nullptr); + + // Initialize scratch buffers with bias for regular lstm or initialize with + // zero for layer norm lstm. + if (use_layer_norm) { + std::fill_n(gate, n_cell * n_batch, 0.0f); + } else { + tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate); + } + // For each batch and cell: compute input_weight * input. + // Skip if input is all zeros. + if (!is_input_all_zeros) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_gate_weights, n_cell, n_input, input, + input_to_gate_weights_scale, input_sf, n_batch, gate, + /*per_channel_scale=*/nullptr, input_zp, accum_scratch, + input_to_gate_row_sums, compute_row_sums, scratch0, context); + } + // For each batch and cell: compute aux_input_weight * aux_input. + // Skip if auxiliary input is not available or all zeros. + if (!is_aux_input_all_zeros) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_gate_weights, n_cell, n_aux_input, aux_input, + aux_input_to_gate_weights_scale, aux_input_sf, n_batch, gate, + /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch, + aux_input_to_gate_row_sums, compute_row_sums, scratch0, context); + } + // For each batch and cell: compute recurrent_weight * output_state. + // Skip if output state is all zeros. + if (!is_output_state_all_zeros) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_gate_weights, n_cell, n_output, output_state, + recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate, + /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch, + recurrent_to_gate_row_sums, compute_row_sums, scratch0, context); + } + // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM) + if (use_peephole) { + float* recovered_cell_weights = scratch1; + tensor_utils::VectorScalarMultiply(cell_to_gate_weights, n_cell, + cell_to_gate_weights_scale, + recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state, n_batch, gate); + } + // Do layer normalization (if layer norm LSTM) + if (use_layer_norm) { + tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch); + tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell, + gate, n_batch, gate); + tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate); + } + // Apply activation + tensor_utils::ApplyActivationToVector(gate, n_cell * n_batch, activation, + gate); +} + // Calculates the output state tensor of an LSTM step. See Float version too. // // Parameters: @@ -281,6 +461,80 @@ void CalculateLstmOutputHybrid( } } +// Calculates a single LSTM gate, int8x8_16 version. +// Implements the same functionality as CalculateLstmGateFloat. +void CalculateLstmGateInteger8x8_16( + // Input and weights + const int8_t* input, const int8_t* input_to_gate_weights, + const int32_t* input_to_gate_bias, const int32_t input_to_gate_scale_a, + const int32_t input_to_gate_scale_b, + // Output state and weights + const int8_t* output_state, const int8_t* recurrent_to_gate_weights, + const int32_t* recurrent_to_gate_bias, + const int32_t recurrent_to_gate_scale_a, + const int32_t recurrent_to_gate_scale_b, + // Cell state and weights + const int16_t* cell_state, const int16_t* cell_to_gate_weights, + const int32_t cell_to_gate_scale_a, const int32_t cell_to_gate_scale_b, + // Layer normalization parameters (layer norm LSTM) + const int16_t* layer_norm_coefficients, const int32_t* layer_norm_bias, + const int32_t layer_norm_input_scale_a, + const int32_t layer_norm_input_scale_b, + const int32_t layer_norm_variance_guard, + // Array sizes + const int n_batch, const int n_input, const int n_output, const int n_cell, + const TfLiteFusedActivation activation, + // Output + int16_t* gate, + // Parameters for performance optimizations + CpuBackendContext* context, + // Scratch arrays + int32_t* scratch5) { + const bool use_peephole = (cell_to_gate_weights != nullptr); + const bool use_layer_norm = (layer_norm_coefficients != nullptr); + + // Initialize scratch buffers with zeros. Note that unlike float and hybrid + // versions, bias is only used in layer normalization. + std::fill_n(gate, n_batch * n_cell, 0); + // For each batch and cell: compute input_weight * input. + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input, input_to_gate_bias, input_to_gate_weights, input_to_gate_scale_a, + input_to_gate_scale_b, n_batch, n_input, n_cell, 0, scratch5, gate, + context); + // Note: no aux_input. + + // For each batch and cell: compute recurrent_weight * output_state. + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + output_state, recurrent_to_gate_bias, recurrent_to_gate_weights, + recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output, + n_cell, 0, scratch5, gate, context); + // For each batch and cell: compute cell_weight * cell_state (peephole LSTM) + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_gate_weights, n_output, cell_state, n_batch, + cell_to_gate_scale_a, cell_to_gate_scale_b, gate); + } + // Do layer normalization (if layer norm LSTM) + if (use_layer_norm) { + tensor_utils::ApplyLayerNorm( + gate, layer_norm_coefficients, layer_norm_bias, + layer_norm_input_scale_a, layer_norm_input_scale_b, + layer_norm_variance_guard, n_batch, n_cell, gate); + } + // Apply activation + switch (activation) { + case kTfLiteActSigmoid: + tensor_utils::ApplySigmoid(gate, n_batch, n_cell, gate); + break; + case kTfLiteActTanh: + tensor_utils::ApplyTanh(3, gate, n_batch, n_cell, gate); + break; + default: + // Only Sigmoid or Tanh is used. + TFLITE_ASSERT_FALSE; + } +} + // Updates the LSTM cell state, used by both integer LSTM versions. // Also see UpdateLstmCellFloat. // @@ -327,9 +581,9 @@ void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state, // - n_cell, n_output: sizes of vectors. // - cell_state, output_gate: input vectors, size n_batch*n_cell. // - cell_state_scale: scaling of cell_state. -// - effective_hidden_scale_[a|b]: effective scale of cell_state.*output_gate +// - hidden_scale_[a|b]: effective scale of cell_state.*output_gate // - hidden_zp: zero_point for cell_state.*output_gate -// - projection_weights, effective_proj_scale_[a|b], projection_effective_bias: +// - projection_weights, proj_scale_[a|b], projection_bias: // constant inputs, describing projection matrix and bias. // - output_state_zp: zero point of output_state. (Input, calibrated value.) // - quantized_proj_clip: if > 0, clip the output of the projection. @@ -341,19 +595,17 @@ void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state, void CalculateLstmOutputInteger8x8_16( int n_batch, int n_cell, int n_output, const int16_t* cell_state, int32_t cell_state_scale, const int16_t* output_gate, - int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b, - int32_t hidden_zp, const int8_t* projection_weights, - int32_t effective_proj_scale_a, int32_t effective_proj_scale_b, - const int32_t* projection_effective_bias, int32_t output_state_zp, - int8_t quantized_proj_clip, int8_t* output_state, + int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp, + const int8_t* projection_weights, int32_t proj_scale_a, + int32_t proj_scale_b, const int32_t* projection_bias, + int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state, CpuBackendContext* context, int16_t* scratch0, int8_t* scratch1, int32_t* scratch2) { // Note: unlike float/hybrid, the activation is always Tanh. tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch, n_cell, scratch0); - tensor_utils::CwiseMul(output_gate, scratch0, effective_hidden_scale_a, - effective_hidden_scale_b, n_batch, n_cell, hidden_zp, - scratch1); + tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a, hidden_scale_b, + n_batch, n_cell, hidden_zp, scratch1); const bool use_projection = (projection_weights != nullptr); @@ -361,9 +613,9 @@ void CalculateLstmOutputInteger8x8_16( // Note: no bias like in float/hybrid std::fill_n(output_state, n_batch * n_output, 0); tensor_utils::MatrixBatchVectorMultiplyAccumulate( - scratch1, projection_effective_bias, projection_weights, - effective_proj_scale_a, effective_proj_scale_b, n_batch, n_cell, - n_output, output_state_zp, scratch2, output_state, context); + scratch1, projection_bias, projection_weights, proj_scale_a, + proj_scale_b, n_batch, n_cell, n_output, output_state_zp, scratch2, + output_state, context); if (quantized_proj_clip > 0) { tensor_utils::CwiseClipping(output_state, n_batch * n_output, quantized_proj_clip); @@ -373,6 +625,68 @@ void CalculateLstmOutputInteger8x8_16( } } +// Calculates a single LSTM gate, int8x8_8 version. +// Implements the same functionality as CalculateLstmGateFloat. +void CalculateLstmGateInteger8x8_8( + // Inputs and weights + const int8_t* input, int32_t input_zp, const int8_t* input_to_gate_weight, + const int32_t input_to_gate_scale_a, const int32_t input_to_gate_scale_b, + const int32_t input_times_weights_scale_a, + const int32_t input_times_weights_scale_b, + const int32_t input_times_weights_zp, + // Output state and weights + const int8_t* output_state, const int32_t output_state_zp, + const int8_t* recurrent_to_gate_weight, + const int32_t recurrent_to_gate_scale_a, + const int32_t recurrent_to_gate_scale_b, + const int32_t output_state_times_weights_scale_a, + const int32_t output_state_times_weights_scale_b, + const int32_t output_state_times_weights_zp, + // Layer normalization parameters (layer norm LSTM) + const int16_t* layer_norm_gate_weight, + const int32_t layer_norm_gate_scale_a, + const int32_t layer_norm_gate_scale_b, const int32_t* gate_bias, + // Array sizes + const int n_batch, const int n_input, const int n_output, const int n_cell, + const TfLiteFusedActivation activation, + // Output + int16_t* gate, + // Scratch arrays, both sized n_batch*n_cell + int8_t* scratch0, int8_t* scratch1) { + // Multiply input * input_weights => scratch0 + tensor_utils::MatrixBatchVectorMultiply( + input, input_zp, input_to_gate_weight, input_to_gate_scale_a, + input_to_gate_scale_b, n_batch, n_input, n_cell, scratch0, + input_times_weights_zp); + // Multiply output_state * recurrent_weights => scratch1 + tensor_utils::MatrixBatchVectorMultiply( + output_state, output_state_zp, recurrent_to_gate_weight, + recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output, + n_cell, scratch1, output_state_times_weights_zp); + // Add scratch0 + scratch1 => gate + tensor_utils::TwoGateSaturatingAdd( + scratch0, input_times_weights_zp, scratch1, output_state_times_weights_zp, + input_times_weights_scale_a, input_times_weights_scale_b, + output_state_times_weights_scale_a, output_state_times_weights_scale_b, + n_batch, n_cell, gate); + // Apply layer normalization. + tensor_utils::ApplyLayerNormFloat( + gate, layer_norm_gate_weight, layer_norm_gate_scale_a, + layer_norm_gate_scale_b, gate_bias, n_batch, n_cell, gate); + // Apply activation. // Apply activation + switch (activation) { + case kTfLiteActSigmoid: + tensor_utils::ApplySigmoidFloat(gate, n_batch, n_cell, gate); + break; + case kTfLiteActTanh: + tensor_utils::ApplyTanhFloat(gate, n_batch, n_cell, -12, gate); + break; + default: + // Only Sigmoid or Tanh is used. + TFLITE_ASSERT_FALSE; + } +} + // Calculates the output state tensor of an LSTM step. See Float and hybrid // versions as well. // @@ -380,7 +694,7 @@ void CalculateLstmOutputInteger8x8_16( // - n_batch: batches: the number of distinct vectors in each array. // - n_cell, n_output: sizes of vectors. // - cell_state, output_gate: input vectors, size n_batch*n_cell. -// - projection_weights, effective_proj_scale_[a|b], projection_bias: +// - projection_weights, proj_scale_[a|b], projection_bias: // constant inputs, describing projection matrix and bias. // - output_state_zp: zero point of the output state. // - quantized_proj_clip: if > 0, clip the output of the projection. @@ -389,18 +703,17 @@ void CalculateLstmOutputInteger8x8_16( void CalculateLstmOutputInteger8x8_8( int n_batch, int n_cell, int n_output, const int16_t* cell_state, const int16_t* output_gate, const int8_t* projection_weights, - int32_t effective_proj_scale_a, int32_t effective_proj_scale_b, - const int32_t* projection_bias, int32_t output_state_zp, - int32_t quantized_proj_clip, int8_t* output_state, int16_t* scratch) { + int32_t proj_scale_a, int32_t proj_scale_b, const int32_t* projection_bias, + int32_t output_state_zp, int32_t quantized_proj_clip, int8_t* output_state, + int16_t* scratch) { // Note: unlike float/hybrid, the activation is always Tanh. tensor_utils::ApplyTanhFloat(cell_state, n_batch, n_cell, -15, scratch); tensor_utils::CwiseMul(output_gate, scratch, n_batch, n_cell, 15 + 15 - 15, scratch); // Note: no bias like in float/hybrid tensor_utils::MatrixBatchVectorMultiply( - scratch, projection_weights, effective_proj_scale_a, - effective_proj_scale_b, projection_bias, n_batch, n_cell, n_output, - output_state_zp, output_state); + scratch, projection_weights, proj_scale_a, proj_scale_b, projection_bias, + n_batch, n_cell, n_output, output_state_zp, output_state); if (quantized_proj_clip > 0) { tensor_utils::CwiseClipping(output_state, n_batch * n_output, quantized_proj_clip); @@ -502,178 +815,67 @@ inline void LstmStepFloat( // Since we have already checked that weights are all there or none, we can // check the existence of only one to the get the condition. const bool use_cifg = (input_to_input_weights_ptr == nullptr); - const bool use_peephole = (cell_to_output_weights_ptr != nullptr); - const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr); - // Make named scratch buffers for the different gates. + // Make named scratch buffers. float* input_gate_scratch = scratch0; float* forget_gate_scratch = scratch1; float* cell_gate_scratch = scratch2; float* output_gate_scratch = scratch3; + // Check if inputs are all zeros so we can skip some computations. const bool is_input_all_zeros = tensor_utils::IsZeroVector(input_ptr, n_batch * n_input); const bool is_aux_input_all_zeros = (aux_input_ptr == nullptr || tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)); - - // Initialize scratch buffers with bias for regular lstm or initialize with - // zero for layer norm lstm. - if (use_layer_norm) { - if (!use_cifg) { - std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f); - } - std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f); - std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f); - std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f); - } else { - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, - n_batch, input_gate_scratch); - } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch, - cell_gate_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); - } - - // For each batch and cell: compute input_weight * input. - // Skip if input is all zeros. - if (!is_input_all_zeros) { - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights_ptr, n_cell, n_input, input_ptr, n_batch, - input_gate_scratch); - } - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights_ptr, n_cell, n_input, input_ptr, n_batch, - forget_gate_scratch); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_cell_weights_ptr, n_cell, n_input, input_ptr, n_batch, - cell_gate_scratch); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch, - output_gate_scratch); - } - - // For each batch and cell: compute aux_input_weight * aux_input. - // Skip if auxiliary input is not available or all zeros. - if (!is_aux_input_all_zeros) { - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_input_weights_ptr, n_cell, n_aux_input, aux_input_ptr, - n_batch, input_gate_scratch); - } - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_forget_weights_ptr, n_cell, n_aux_input, aux_input_ptr, - n_batch, forget_gate_scratch); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr, - n_batch, cell_gate_scratch); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr, - n_batch, output_gate_scratch); - } - - // For each batch and cell: compute recurrent_weight * output_state. if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, input_gate_scratch); + // Calculate the input gate. (If not CIFG.) + CalculateLstmGateFloat( + input_ptr, input_to_input_weights_ptr, aux_input_ptr, + aux_input_to_input_weights_ptr, output_state_ptr, + recurrent_to_input_weights_ptr, cell_state_ptr, + cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr, + input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell, + /*activation=*/kTfLiteActSigmoid, input_gate_scratch, + is_input_all_zeros, is_aux_input_all_zeros); } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, forget_gate_scratch); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, cell_gate_scratch); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, output_gate_scratch); - - // For each batch and cell: update input gate. - if (!use_cifg) { - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, - input_gate_scratch); - } - if (use_layer_norm) { - tensor_utils::MeanStddevNormalization( - input_gate_scratch, input_gate_scratch, n_cell, n_batch); - tensor_utils::VectorBatchVectorCwiseProduct( - input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch, - n_batch, input_gate_scratch); - tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch, - input_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); - } - - // For each batch and cell: update forget gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, - forget_gate_scratch); - } - if (use_layer_norm) { - tensor_utils::MeanStddevNormalization(forget_gate_scratch, - forget_gate_scratch, n_cell, n_batch); - tensor_utils::VectorBatchVectorCwiseProduct( - forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch, - n_batch, forget_gate_scratch); - tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); - - // For each batch and cell: update the cell. - if (use_layer_norm) { - tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch, - n_cell, n_batch); - tensor_utils::VectorBatchVectorCwiseProduct( - cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch, - cell_gate_scratch); - tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch, - cell_gate_scratch); - } - tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell, - params->activation, cell_gate_scratch); - + // Calculate the forget gate. + CalculateLstmGateFloat( + input_ptr, input_to_forget_weights_ptr, aux_input_ptr, + aux_input_to_forget_weights_ptr, output_state_ptr, + recurrent_to_forget_weights_ptr, cell_state_ptr, + cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr, + forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell, + /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros, + is_aux_input_all_zeros); + // Calculate the cell update gate. + CalculateLstmGateFloat(input_ptr, input_to_cell_weights_ptr, aux_input_ptr, + aux_input_to_cell_weights_ptr, output_state_ptr, + recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr, + /*cell_to_gate_weights=*/nullptr, + cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr, + n_batch, n_input, n_aux_input, n_output, n_cell, + params->activation, cell_gate_scratch, + is_input_all_zeros, is_aux_input_all_zeros); + // Update the cell state. UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, forget_gate_scratch, cell_gate_scratch, use_cifg, params->cell_clip); - - // For each batch and cell: update the output gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, - output_gate_scratch); - } - if (use_layer_norm) { - tensor_utils::MeanStddevNormalization(output_gate_scratch, - output_gate_scratch, n_cell, n_batch); - tensor_utils::VectorBatchVectorCwiseProduct( - output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch, - n_batch, output_gate_scratch); - tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - + // Calculate output gate. + CalculateLstmGateFloat( + input_ptr, input_to_output_weights_ptr, aux_input_ptr, + aux_input_to_output_weights_ptr, output_state_ptr, + recurrent_to_output_weights_ptr, cell_state_ptr, + cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr, + output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell, + /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros, + is_aux_input_all_zeros); + // Update the output state. CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch, params->activation, projection_weights_ptr, projection_bias_ptr, params->proj_clip, output_state_ptr, scratch2); - - // Copy output_state to the output. Note that the output batch rows may not be + // Copy output state to the output. Note that the output's rows may not be // contiguous (output_batch_leading_dim != n_output). for (int b = 0; b < n_batch; b++) { std::copy_n(output_state_ptr + b * n_output, n_output, @@ -803,9 +1005,6 @@ inline void LstmStepHybrid( // Since we have already checked that weights are all there or none, we // can check the existence of only one to the get the condition. const bool use_cifg = (input_to_input_weights_ptr == nullptr); - const bool use_peephole = (cell_to_output_weights_ptr != nullptr); - const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr); - // Make named scratch buffers for the different gates. float* input_gate_scratch = scratch0; float* forget_gate_scratch = scratch1; @@ -879,6 +1078,7 @@ inline void LstmStepHybrid( } } + // Check if inputs are all zeros so we can skip some computations. const bool is_input_all_zeros = tensor_utils::IsZeroVector(input_ptr, n_batch * n_input); const bool is_aux_input_all_zeros = @@ -886,7 +1086,7 @@ inline void LstmStepHybrid( tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)); const bool is_output_state_all_zeros = tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output); - + // Quantize inputs. if (!is_input_all_zeros) { tensor_utils::BatchQuantizeFloats(input_ptr, n_batch, n_input, quantized_input_ptr, input_sf, input_zp, @@ -902,217 +1102,74 @@ inline void LstmStepHybrid( output_state_ptr, n_batch, n_output, quantized_output_state_ptr, output_state_sf, output_state_zp, asymmetric_quantize_inputs); } - - // Initialize scratch buffers with bias for regular lstm or initialize with - // zero for layer norm lstm. - if (use_layer_norm) { - if (!use_cifg) { - std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f); - } - std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f); - std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f); - std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f); - } else { - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, - n_batch, input_gate_scratch); - } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch, - cell_gate_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); - } - - if (!is_input_all_zeros) { - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr, - input_to_input_weights_scale, input_sf, n_batch, input_gate_scratch, - /*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr, - input_to_input_row_sums, compute_row_sums, scaling_factors_scratch, - context); - } - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr, - input_to_forget_weights_scale, input_sf, n_batch, forget_gate_scratch, - /*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr, - input_to_forget_row_sums, compute_row_sums, scaling_factors_scratch, - context); - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr, - input_to_cell_weights_scale, input_sf, n_batch, cell_gate_scratch, - /*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr, - input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch, - context); - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr, - input_to_output_weights_scale, input_sf, n_batch, output_gate_scratch, - /*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr, - input_to_output_row_sums, compute_row_sums, scaling_factors_scratch, - context); - } - - // For each batch and cell: compute aux_input_weight * aux_input. - // Skip if auxiliary input is not available or all zeros. - if (!is_aux_input_all_zeros) { - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_input_weights_ptr, n_cell, n_aux_input, - quantized_aux_input_ptr, aux_input_to_input_weights_scale, - aux_input_sf, n_batch, input_gate_scratch, - /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr, - aux_input_to_input_row_sums, compute_row_sums, - scaling_factors_scratch, context); - } - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_forget_weights_ptr, n_cell, n_aux_input, - quantized_aux_input_ptr, aux_input_to_forget_weights_scale, - aux_input_sf, n_batch, forget_gate_scratch, - /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr, - aux_input_to_forget_row_sums, compute_row_sums, scaling_factors_scratch, - context); - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_cell_weights_ptr, n_cell, n_aux_input, - quantized_aux_input_ptr, aux_input_to_cell_weights_scale, aux_input_sf, - n_batch, cell_gate_scratch, /*per_channel_scale=*/nullptr, aux_input_zp, - accum_scratch_ptr, aux_input_to_cell_row_sums, compute_row_sums, - scaling_factors_scratch, context); - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_output_weights_ptr, n_cell, n_aux_input, - quantized_aux_input_ptr, aux_input_to_output_weights_scale, - aux_input_sf, n_batch, output_gate_scratch, - /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr, - aux_input_to_output_row_sums, compute_row_sums, scaling_factors_scratch, - context); - } - - if (!is_output_state_all_zeros) { - // For each batch and cell: compute recurrent_weight * output_state. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, recurrent_to_input_weights_scale, - output_state_sf, n_batch, input_gate_scratch, - /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr, - recurrent_to_input_row_sums, compute_row_sums, - scaling_factors_scratch, context); - } - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, recurrent_to_forget_weights_scale, - output_state_sf, n_batch, forget_gate_scratch, - /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr, - recurrent_to_forget_row_sums, compute_row_sums, scaling_factors_scratch, - context); - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, recurrent_to_cell_weights_scale, - output_state_sf, n_batch, cell_gate_scratch, - /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr, - recurrent_to_cell_row_sums, compute_row_sums, scaling_factors_scratch, - context); - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, recurrent_to_output_weights_scale, - output_state_sf, n_batch, output_gate_scratch, - /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr, - recurrent_to_output_row_sums, compute_row_sums, scaling_factors_scratch, - context); - } - - // For each batch and cell: update input gate. if (!use_cifg) { - if (use_peephole) { - tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell, - cell_to_input_weights_scale, - recovered_cell_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_cell_weights, n_cell, cell_state_ptr, n_batch, - input_gate_scratch); - } - if (use_layer_norm) { - tensor_utils::MeanStddevNormalization( - input_gate_scratch, input_gate_scratch, n_cell, n_batch); - tensor_utils::VectorBatchVectorCwiseProduct( - input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch, - n_batch, input_gate_scratch); - tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch, - input_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); + // Calculate the input gate. (If not CIFG.) + CalculateLstmGateHybrid( + quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr, + input_to_input_weights_scale, input_to_input_row_sums, + quantized_aux_input_ptr, aux_input_sf, aux_input_zp, + aux_input_to_input_weights_ptr, aux_input_to_input_weights_scale, + aux_input_to_input_row_sums, quantized_output_state_ptr, + output_state_sf, output_state_zp, recurrent_to_input_weights_ptr, + recurrent_to_input_weights_scale, recurrent_to_input_row_sums, + cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale, + input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch, + n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid, + input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros, + is_output_state_all_zeros, compute_row_sums, context, + scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr); } - - // For each batch and cell: update forget gate. - if (use_peephole) { - tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell, - cell_to_forget_weights_scale, - recovered_cell_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_cell_weights, n_cell, cell_state_ptr, n_batch, - forget_gate_scratch); - } - if (use_layer_norm) { - tensor_utils::MeanStddevNormalization(forget_gate_scratch, - forget_gate_scratch, n_cell, n_batch); - tensor_utils::VectorBatchVectorCwiseProduct( - forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch, - n_batch, forget_gate_scratch); - tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); - - // For each batch and cell: update the cell. - if (use_layer_norm) { - tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch, - n_cell, n_batch); - tensor_utils::VectorBatchVectorCwiseProduct( - cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch, - cell_gate_scratch); - tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch, - cell_gate_scratch); - } - tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell, - params->activation, cell_gate_scratch); - + // Calculate the forget gate. + CalculateLstmGateHybrid( + quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr, + input_to_forget_weights_scale, input_to_forget_row_sums, + quantized_aux_input_ptr, aux_input_sf, aux_input_zp, + aux_input_to_forget_weights_ptr, aux_input_to_forget_weights_scale, + aux_input_to_forget_row_sums, quantized_output_state_ptr, output_state_sf, + output_state_zp, recurrent_to_forget_weights_ptr, + recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums, + cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch, + n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid, + forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros, + is_output_state_all_zeros, compute_row_sums, context, + scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr); + // Calculate the cell update gate. + CalculateLstmGateHybrid( + quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr, + input_to_cell_weights_scale, input_to_cell_row_sums, + quantized_aux_input_ptr, aux_input_sf, aux_input_zp, + aux_input_to_cell_weights_ptr, aux_input_to_cell_weights_scale, + aux_input_to_cell_row_sums, quantized_output_state_ptr, output_state_sf, + output_state_zp, recurrent_to_cell_weights_ptr, + recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums, + /*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr, + /*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr, + cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell, + params->activation, cell_gate_scratch, is_input_all_zeros, + is_aux_input_all_zeros, is_output_state_all_zeros, compute_row_sums, + context, scaling_factors_scratch, recovered_cell_weights, + accum_scratch_ptr); + // Update the cell state. UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, forget_gate_scratch, cell_gate_scratch, use_cifg, params->cell_clip); - - // For each batch and cell: update the output gate. - if (use_peephole) { - tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell, - cell_to_output_weights_scale, - recovered_cell_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_cell_weights, n_cell, cell_state_ptr, n_batch, - output_gate_scratch); - } - if (use_layer_norm) { - tensor_utils::MeanStddevNormalization(output_gate_scratch, - output_gate_scratch, n_cell, n_batch); - tensor_utils::VectorBatchVectorCwiseProduct( - output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch, - n_batch, output_gate_scratch); - tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - + // Calculate the output gate. + CalculateLstmGateHybrid( + quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr, + input_to_output_weights_scale, input_to_output_row_sums, + quantized_aux_input_ptr, aux_input_sf, aux_input_zp, + aux_input_to_output_weights_ptr, aux_input_to_output_weights_scale, + aux_input_to_output_row_sums, quantized_output_state_ptr, output_state_sf, + output_state_zp, recurrent_to_output_weights_ptr, + recurrent_to_output_weights_scale, recurrent_to_output_row_sums, + cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale, + output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch, + n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid, + output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros, + is_output_state_all_zeros, compute_row_sums, context, + scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr); + // Update the output state. CalculateLstmOutputHybrid( n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch, params->activation, projection_weights_ptr, projection_weights_scale, @@ -1120,9 +1177,8 @@ inline void LstmStepHybrid( asymmetric_quantize_inputs, projection_weights_row_sums, compute_row_sums, context, scratch2, quantized_output_scratch, input_sf, input_zp, accum_scratch_ptr); - - // Copy output_state_ptr to the output. Note that the output batch rows may - // not be contiguous (output_batch_leading_dim != n_output). + // Copy output state to the output. Note that the output's rows may not be + // contiguous (output_batch_leading_dim != n_output). for (int b = 0; b < n_batch; b++) { std::copy_n(output_state_ptr + b * n_output, n_output, output_ptr + b * output_batch_leading_dim); @@ -1292,10 +1348,9 @@ inline void LstmStepInteger8x8_16( int16_t* cell_gate_scratch = scratch2; int16_t* output_gate_scratch = scratch3; - // Get hyper parameters. + // Since we have already checked that weights are all there or none, we + // can check the existence of only one to the get the condition. const bool use_cifg = (input_to_input_weight_ptr == nullptr); - const bool use_peephole = (cell_to_output_weight_ptr != nullptr); - const bool use_layer_norm = (layer_norm_forget_weight_ptr != nullptr); // Check for nullptrs. TFLITE_DCHECK(input_to_forget_effective_bias); @@ -1310,125 +1365,63 @@ inline void LstmStepInteger8x8_16( } TFLITE_DCHECK(projection_effective_bias); - // Set scratch to 0. if (!use_cifg) { - std::fill_n(input_gate_scratch, n_batch * n_cell, 0); - } - std::fill_n(forget_gate_scratch, n_batch * n_cell, 0); - std::fill_n(cell_gate_scratch, n_batch * n_cell, 0); - std::fill_n(output_gate_scratch, n_batch * n_cell, 0); - - // Forget gate. - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_ptr, input_to_forget_effective_bias, input_to_forget_weight_ptr, - effective_input_to_forget_scale_a, effective_input_to_forget_scale_b, - n_batch, n_input, n_cell, 0, scratch5, forget_gate_scratch, context); - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - output_state_ptr, recurrent_to_forget_effective_bias, - recurrent_to_forget_weight_ptr, effective_recurrent_to_forget_scale_a, - effective_recurrent_to_forget_scale_b, n_batch, n_output, n_cell, 0, - scratch5, forget_gate_scratch, context); - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_forget_weight_ptr, n_output, cell_state_ptr, n_batch, - effective_cell_to_forget_scale_a, effective_cell_to_forget_scale_b, - forget_gate_scratch); - } - - if (use_layer_norm) { - tensor_utils::ApplyLayerNorm( - forget_gate_scratch, layer_norm_forget_weight_ptr, forget_gate_bias_ptr, - layer_norm_forget_scale_a, layer_norm_forget_scale_b, - forget_variance_guard, n_batch, n_cell, forget_gate_scratch); - } - - tensor_utils::ApplySigmoid(forget_gate_scratch, n_batch, n_cell, - forget_gate_scratch); - - // Cell gate. - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_ptr, input_to_cell_effective_bias, input_to_cell_weight_ptr, - effective_input_to_cell_scale_a, effective_input_to_cell_scale_b, n_batch, - n_input, n_cell, 0, scratch5, cell_gate_scratch, context); - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - output_state_ptr, recurrent_to_cell_effective_bias, - recurrent_to_cell_weight_ptr, effective_recurrent_to_cell_scale_a, - effective_recurrent_to_cell_scale_b, n_batch, n_output, n_cell, 0, - scratch5, cell_gate_scratch, context); - - if (use_layer_norm) { - tensor_utils::ApplyLayerNorm(cell_gate_scratch, layer_norm_cell_weight_ptr, - cell_gate_bias_ptr, layer_norm_cell_scale_a, - layer_norm_cell_scale_b, cell_variance_guard, - n_batch, n_cell, cell_gate_scratch); - } - - tensor_utils::ApplyTanh(3, cell_gate_scratch, n_batch, n_cell, - cell_gate_scratch); - - // Input gate. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_ptr, input_to_input_effective_bias, input_to_input_weight_ptr, + // Calculate the input gate. (If not CIFG.) + CalculateLstmGateInteger8x8_16( + input_ptr, input_to_input_weight_ptr, input_to_input_effective_bias, effective_input_to_input_scale_a, effective_input_to_input_scale_b, - n_batch, n_input, n_cell, 0, scratch5, input_gate_scratch, context); - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - output_state_ptr, recurrent_to_input_effective_bias, - recurrent_to_input_weight_ptr, effective_recurrent_to_input_scale_a, - effective_recurrent_to_input_scale_b, n_batch, n_output, n_cell, 0, - scratch5, input_gate_scratch, context); - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_input_weight_ptr, n_output, cell_state_ptr, n_batch, - effective_cell_to_input_scale_a, effective_cell_to_input_scale_b, - input_gate_scratch); - } - - if (use_layer_norm) { - tensor_utils::ApplyLayerNorm( - input_gate_scratch, layer_norm_input_weight_ptr, input_gate_bias_ptr, - layer_norm_input_scale_a, layer_norm_input_scale_b, - input_variance_guard, n_batch, n_cell, input_gate_scratch); - } - tensor_utils::ApplySigmoid(input_gate_scratch, n_batch, n_cell, - input_gate_scratch); + output_state_ptr, recurrent_to_input_weight_ptr, + recurrent_to_input_effective_bias, effective_recurrent_to_input_scale_a, + effective_recurrent_to_input_scale_b, cell_state_ptr, + cell_to_input_weight_ptr, effective_cell_to_input_scale_a, + effective_cell_to_input_scale_b, layer_norm_input_weight_ptr, + input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b, + input_variance_guard, n_batch, n_input, n_output, n_cell, + kTfLiteActSigmoid, input_gate_scratch, context, scratch5); } - + // Calculate the forget gate. + CalculateLstmGateInteger8x8_16( + input_ptr, input_to_forget_weight_ptr, input_to_forget_effective_bias, + effective_input_to_forget_scale_a, effective_input_to_forget_scale_b, + output_state_ptr, recurrent_to_forget_weight_ptr, + recurrent_to_forget_effective_bias, effective_recurrent_to_forget_scale_a, + effective_recurrent_to_forget_scale_b, cell_state_ptr, + cell_to_forget_weight_ptr, effective_cell_to_forget_scale_a, + effective_cell_to_forget_scale_b, layer_norm_forget_weight_ptr, + forget_gate_bias_ptr, layer_norm_forget_scale_a, + layer_norm_forget_scale_b, forget_variance_guard, n_batch, n_input, + n_output, n_cell, kTfLiteActSigmoid, forget_gate_scratch, context, + scratch5); + // Calculate the cell update gate. + CalculateLstmGateInteger8x8_16( + input_ptr, input_to_cell_weight_ptr, input_to_cell_effective_bias, + effective_input_to_cell_scale_a, effective_input_to_cell_scale_b, + output_state_ptr, recurrent_to_cell_weight_ptr, + recurrent_to_cell_effective_bias, effective_recurrent_to_cell_scale_a, + effective_recurrent_to_cell_scale_b, cell_state_ptr, + /*cell_to_gate_weights=*/nullptr, /*cell_to_gate_scale_a=*/0, + /*cell_to_gate_scale_b=*/0, layer_norm_cell_weight_ptr, + cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b, + cell_variance_guard, n_batch, n_input, n_output, n_cell, kTfLiteActTanh, + cell_gate_scratch, context, scratch5); + // Update the cell state. UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, cell_state_scale, input_gate_scratch, forget_gate_scratch, cell_gate_scratch, use_cifg, quantized_cell_clip); - - // Ouptut gate. - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_ptr, input_to_output_effective_bias, input_to_output_weight_ptr, + // Calculate the output gate. + CalculateLstmGateInteger8x8_16( + input_ptr, input_to_output_weight_ptr, input_to_output_effective_bias, effective_input_to_output_scale_a, effective_input_to_output_scale_b, - n_batch, n_input, n_cell, 0, scratch5, output_gate_scratch, context); - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - output_state_ptr, recurrent_to_output_effective_bias, - recurrent_to_output_weight_ptr, effective_recurrent_to_output_scale_a, - effective_recurrent_to_output_scale_b, n_batch, n_output, n_cell, 0, - scratch5, output_gate_scratch, context); - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_output_weight_ptr, n_output, cell_state_ptr, n_batch, - effective_cell_to_output_scale_a, effective_cell_to_output_scale_b, - output_gate_scratch); - } - - if (use_layer_norm) { - tensor_utils::ApplyLayerNorm( - output_gate_scratch, layer_norm_output_weight_ptr, output_gate_bias_ptr, - layer_norm_output_scale_a, layer_norm_output_scale_b, - output_variance_guard, n_batch, n_cell, output_gate_scratch); - } - - tensor_utils::ApplySigmoid(output_gate_scratch, n_batch, n_cell, - output_gate_scratch); - + output_state_ptr, recurrent_to_output_weight_ptr, + recurrent_to_output_effective_bias, effective_recurrent_to_output_scale_a, + effective_recurrent_to_output_scale_b, cell_state_ptr, + cell_to_output_weight_ptr, effective_cell_to_output_scale_a, + effective_cell_to_output_scale_b, layer_norm_output_weight_ptr, + output_gate_bias_ptr, layer_norm_output_scale_a, + layer_norm_output_scale_b, output_variance_guard, n_batch, n_input, + n_output, n_cell, kTfLiteActSigmoid, output_gate_scratch, context, + scratch5); + // Update the output state. CalculateLstmOutputInteger8x8_16( n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale, output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b, @@ -1436,7 +1429,6 @@ inline void LstmStepInteger8x8_16( effective_proj_scale_b, projection_effective_bias, output_state_zp, quantized_proj_clip, output_state_ptr, context, scratch0, scratch4, scratch5); - // Copy output state to the output. Note that unlike float or hybrid, output // is always contigous. std::copy_n(output_state_ptr, n_batch * n_output, output_ptr); @@ -1591,109 +1583,61 @@ inline void LstmStepInteger8x8_8( int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3, int16_t* scratch4, int16_t* scratch5, int16_t* scratch6, int16_t* scratch7) { + // TODO(b/159066113): scratch5 is unused, remove. + ruy::profiler::ScopeLabel label("LstmStepInteger8x8_8"); // Make named scratch buffers for the different gates. int16_t* forget_gate_scratch = scratch2; int16_t* cell_gate_scratch = scratch3; int16_t* output_gate_scratch = scratch4; + // no-CIFG is not supported here - // Forget gate. - std::fill_n(scratch0, n_batch * n_cell, 0); - std::fill_n(scratch1, n_batch * n_cell, 0); - tensor_utils::MatrixBatchVectorMultiply( + // Calculate the forget gate. + CalculateLstmGateInteger8x8_8( input_ptr, input_zp, input_to_forget_weight_ptr, effective_input_to_forget_scale_a, effective_input_to_forget_scale_b, - n_batch, n_input, n_cell, scratch0, intermediate_zp[4]); - - tensor_utils::MatrixBatchVectorMultiply( + intermediate_scale_a[2], intermediate_scale_b[2], intermediate_zp[4], output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr, effective_recurrent_to_forget_scale_a, - effective_recurrent_to_forget_scale_b, n_batch, n_output, n_cell, - scratch1, intermediate_zp[5]); - - tensor_utils::TwoGateSaturatingAdd( - scratch0, intermediate_zp[4], scratch1, intermediate_zp[5], - intermediate_scale_a[2], intermediate_scale_b[2], intermediate_scale_a[3], - intermediate_scale_b[3], n_batch, n_cell, forget_gate_scratch); - - // Forget gate layer norm. - tensor_utils::ApplyLayerNormFloat( - forget_gate_scratch, layer_norm_forget_weight_ptr, + effective_recurrent_to_forget_scale_b, intermediate_scale_a[3], + intermediate_scale_b[3], intermediate_zp[5], layer_norm_forget_weight_ptr, layer_norm_forget_scale_a, layer_norm_forget_scale_b, - forget_gate_bias_ptr, n_batch, n_cell, forget_gate_scratch); - - // Forget gate sigmoid. - tensor_utils::ApplySigmoidFloat(forget_gate_scratch, n_batch, n_cell, - forget_gate_scratch); - - // Cell gate. - std::fill_n(scratch0, n_batch * n_cell, 0); - std::fill_n(scratch1, n_batch * n_cell, 0); - tensor_utils::MatrixBatchVectorMultiply( + forget_gate_bias_ptr, n_batch, n_input, n_output, n_cell, + kTfLiteActSigmoid, forget_gate_scratch, scratch0, scratch1); + // Calculate the cell update gate. + CalculateLstmGateInteger8x8_8( input_ptr, input_zp, input_to_cell_weight_ptr, - effective_input_to_cell_scale_a, effective_input_to_cell_scale_b, n_batch, - n_input, n_cell, scratch0, intermediate_zp[7]); - - tensor_utils::MatrixBatchVectorMultiply( + effective_input_to_cell_scale_a, effective_input_to_cell_scale_b, + intermediate_scale_a[4], intermediate_scale_b[4], intermediate_zp[7], output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr, effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b, - n_batch, n_output, n_cell, scratch1, intermediate_zp[8]); - - tensor_utils::TwoGateSaturatingAdd( - scratch0, intermediate_zp[7], scratch1, intermediate_zp[8], - intermediate_scale_a[4], intermediate_scale_b[4], intermediate_scale_a[5], - intermediate_scale_b[5], n_batch, n_cell, cell_gate_scratch); - - // Cell gate layer norm. - tensor_utils::ApplyLayerNormFloat( - cell_gate_scratch, layer_norm_cell_weight_ptr, layer_norm_cell_scale_a, - layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_cell, - cell_gate_scratch); - - // Cell gate tanh. - tensor_utils::ApplyTanhFloat(cell_gate_scratch, n_batch, n_cell, -12, - cell_gate_scratch); - - // Output gate. - std::fill_n(scratch0, n_batch * n_cell, 0); - std::fill_n(scratch1, n_batch * n_cell, 0); - tensor_utils::MatrixBatchVectorMultiply( - input_ptr, input_zp, input_to_output_weight_ptr, - effective_input_to_output_scale_a, effective_input_to_output_scale_b, - n_batch, n_input, n_cell, scratch0, intermediate_zp[10]); - - tensor_utils::MatrixBatchVectorMultiply( - output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr, - effective_recurrent_to_output_scale_a, - effective_recurrent_to_output_scale_b, n_batch, n_output, n_cell, - scratch1, intermediate_zp[11]); - - tensor_utils::TwoGateSaturatingAdd( - scratch0, intermediate_zp[10], scratch1, intermediate_zp[11], - intermediate_scale_a[6], intermediate_scale_b[6], intermediate_scale_a[7], - intermediate_scale_b[7], n_batch, n_cell, output_gate_scratch); - - // Output gate with layer norm. - tensor_utils::ApplyLayerNormFloat( - output_gate_scratch, layer_norm_output_weight_ptr, - layer_norm_output_scale_a, layer_norm_output_scale_b, - output_gate_bias_ptr, n_batch, n_cell, output_gate_scratch); - - // Output gate sigmoid. - tensor_utils::ApplySigmoidFloat(output_gate_scratch, n_batch, n_cell, - output_gate_scratch); - + intermediate_scale_a[5], intermediate_scale_b[5], intermediate_zp[8], + layer_norm_cell_weight_ptr, layer_norm_cell_scale_a, + layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_input, n_output, + n_cell, kTfLiteActTanh, cell_gate_scratch, scratch0, scratch1); + // Update the cell state. UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, /*cell_state_scale=*/-15, /*input_gate=*/nullptr, forget_gate_scratch, cell_gate_scratch, /*use_cifg=*/true, quantized_cell_clip); - + // Calculate the output gate. + CalculateLstmGateInteger8x8_8( + input_ptr, input_zp, input_to_output_weight_ptr, + effective_input_to_output_scale_a, effective_input_to_output_scale_b, + intermediate_scale_a[6], intermediate_scale_b[6], intermediate_zp[10], + output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr, + effective_recurrent_to_output_scale_a, + effective_recurrent_to_output_scale_b, intermediate_scale_a[11], + intermediate_scale_b[7], intermediate_zp[7], layer_norm_output_weight_ptr, + layer_norm_output_scale_a, layer_norm_output_scale_b, + output_gate_bias_ptr, n_batch, n_input, n_output, n_cell, + kTfLiteActSigmoid, output_gate_scratch, scratch0, scratch1); + // Update the output state. CalculateLstmOutputInteger8x8_8( n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch, projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b, projection_bias_ptr, output_state_zp, quantized_proj_clip, output_state_ptr, scratch2); - // Copy output state to the output. Note that unlike float or hybrid, output // is always contigous. std::copy_n(output_state_ptr, n_batch * n_output, output_ptr); diff --git a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc index 19be9c59e70..1ac996abe87 100644 --- a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc +++ b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc @@ -37,6 +37,64 @@ namespace builtin { namespace { +inline void CalculateLstmGateFloat( + const float* input, const float* input_to_gate_weights, + const float* aux_input, const float* aux_input_to_gate_weights, + const float* output_state, const float* recurrent_to_gate_weights, + const float* cell_state, const float* cell_to_gate_weights, + const float* layer_norm_coefficients, const float* gate_bias, + const int n_batch, const int n_input, const int n_aux_input, + const int n_output, const int n_cell, + const TfLiteFusedActivation activation, float* gate, + const bool is_input_all_zeros, const bool is_aux_input_all_zeros, + Logger* logger, int intermediate_tensor_index, + ErrorReporter* error_reporter) { + const bool use_peephole = (cell_to_gate_weights != nullptr); + const bool use_layer_norm = (layer_norm_coefficients != nullptr); + + // Initialize scratch buffers with bias for regular lstm or initialize with + // zero for layer norm lstm. + if (use_layer_norm) { + std::fill_n(gate, n_cell * n_batch, 0.0f); + } else { + tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate); + } + // For each batch and cell: compute input_weight * input. + // Skip if input is all zeros. + if (!is_input_all_zeros) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_gate_weights, n_cell, n_input, input, n_batch, gate); + } + // For each batch and cell: compute aux_input_weight * aux_input. + // Skip if auxiliary input is not available or all zeros. + if (!is_aux_input_all_zeros) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights, + n_cell, n_aux_input, + aux_input, n_batch, gate); + } + // For each batch and cell: compute recurrent_weight * output_state. + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate); + // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM) + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_gate_weights, n_cell, cell_state, n_batch, gate); + } + // Do layer normalization (if layer norm LSTM) + if (use_layer_norm) { + logger->LogTensorValue(intermediate_tensor_index, gate, n_cell * n_batch, + error_reporter); + + tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch); + tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell, + gate, n_batch, gate); + tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate); + } + // Apply activation + tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation, + gate); +} + // TODO(b/159066113): This is the exact same function as UpdateLstmCellFloat in // kernels/lstm_eval.cc, make that public and remove this. void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state, @@ -130,178 +188,72 @@ inline void LstmStepCalibration( // Since we have already checked that weights are all there or none, we can // check the existence of only one to the get the condition. const bool use_cifg = (input_to_input_weights_ptr == nullptr); - const bool use_peephole = (cell_to_output_weights_ptr != nullptr); - const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr); - // Make named scratch buffers for the different gates. + // Make named scratch buffers. float* input_gate_scratch = scratch0; float* forget_gate_scratch = scratch1; float* cell_gate_scratch = scratch2; float* output_gate_scratch = scratch3; - // Initialize scratch buffers with bias for regular lstm or initialize with - // zero for layer norm lstm. - if (use_layer_norm) { - if (!use_cifg) { - std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f); - } - std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f); - std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f); - std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f); - } else { - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, - n_batch, input_gate_scratch); - } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch, - cell_gate_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); - } - - // For each batch and cell: compute input_weight * input. + // Check if inputs are all zeros so we can skip some computations. + const bool is_input_all_zeros = + tensor_utils::IsZeroVector(input_ptr, n_batch * n_input); + const bool is_aux_input_all_zeros = + (aux_input_ptr == nullptr || + tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)); if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights_ptr, n_cell, n_input, input_ptr, n_batch, - input_gate_scratch); + // Calculate the input gate. (If not CIFG.) + CalculateLstmGateFloat( + input_ptr, input_to_input_weights_ptr, aux_input_ptr, + aux_input_to_input_weights_ptr, output_state_ptr, + recurrent_to_input_weights_ptr, cell_state_ptr, + cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr, + input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell, + /*activation=*/kTfLiteActSigmoid, input_gate_scratch, + is_input_all_zeros, is_aux_input_all_zeros, logger, + intermediate_tensor_indexes[0], error_reporter); } - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights_ptr, n_cell, n_input, input_ptr, n_batch, - forget_gate_scratch); - tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_ptr, - n_cell, n_input, input_ptr, - n_batch, cell_gate_scratch); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch, - output_gate_scratch); - - // For each batch and cell: compute aux_input_weight * aux_input. - // Skip if auxiliary input is not available. - if (aux_input_ptr != nullptr) { - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_input_weights_ptr, n_cell, n_aux_input, aux_input_ptr, - n_batch, input_gate_scratch); - } - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_forget_weights_ptr, n_cell, n_aux_input, aux_input_ptr, - n_batch, forget_gate_scratch); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr, - n_batch, cell_gate_scratch); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr, - n_batch, output_gate_scratch); - } - - // For each batch and cell: compute recurrent_weight * output_state. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, input_gate_scratch); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, forget_gate_scratch); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, cell_gate_scratch); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, output_gate_scratch); - - // For each batch and cell: update input gate. - if (!use_cifg) { - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, - input_gate_scratch); - } - if (use_layer_norm) { - logger->LogTensorValue(intermediate_tensor_indexes[0], input_gate_scratch, - n_cell * n_batch, error_reporter); - tensor_utils::MeanStddevNormalization( - input_gate_scratch, input_gate_scratch, n_cell, n_batch); - tensor_utils::VectorBatchVectorCwiseProduct( - input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch, - n_batch, input_gate_scratch); - tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch, - input_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); - } - - // For each batch and cell: update forget gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, - forget_gate_scratch); - } - if (use_layer_norm) { - logger->LogTensorValue(intermediate_tensor_indexes[1], forget_gate_scratch, - n_cell * n_batch, error_reporter); - tensor_utils::MeanStddevNormalization(forget_gate_scratch, - forget_gate_scratch, n_cell, n_batch); - tensor_utils::VectorBatchVectorCwiseProduct( - forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch, - n_batch, forget_gate_scratch); - tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); - - // For each batch and cell: update the cell. - if (use_layer_norm) { - logger->LogTensorValue(intermediate_tensor_indexes[2], cell_gate_scratch, - n_cell * n_batch, error_reporter); - tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch, - n_cell, n_batch); - tensor_utils::VectorBatchVectorCwiseProduct( - cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch, - cell_gate_scratch); - tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch, - cell_gate_scratch); - } - tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell, - params->activation, cell_gate_scratch); - + // Calculate the forget gate. + CalculateLstmGateFloat( + input_ptr, input_to_forget_weights_ptr, aux_input_ptr, + aux_input_to_forget_weights_ptr, output_state_ptr, + recurrent_to_forget_weights_ptr, cell_state_ptr, + cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr, + forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell, + /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros, + is_aux_input_all_zeros, logger, intermediate_tensor_indexes[1], + error_reporter); + // Calculate the cell update gate. + CalculateLstmGateFloat(input_ptr, input_to_cell_weights_ptr, aux_input_ptr, + aux_input_to_cell_weights_ptr, output_state_ptr, + recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr, + /*cell_to_gate_weights=*/nullptr, + cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr, + n_batch, n_input, n_aux_input, n_output, n_cell, + params->activation, cell_gate_scratch, + is_input_all_zeros, is_aux_input_all_zeros, logger, + intermediate_tensor_indexes[2], error_reporter); + // Update the cell state. UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, forget_gate_scratch, cell_gate_scratch, use_cifg, params->cell_clip); - - // For each batch and cell: update the output gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, - output_gate_scratch); - } - if (use_layer_norm) { - logger->LogTensorValue(intermediate_tensor_indexes[3], output_gate_scratch, - n_cell * n_batch, error_reporter); - tensor_utils::MeanStddevNormalization(output_gate_scratch, - output_gate_scratch, n_cell, n_batch); - tensor_utils::VectorBatchVectorCwiseProduct( - output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch, - n_batch, output_gate_scratch); - tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - + // Calculate output gate. + CalculateLstmGateFloat( + input_ptr, input_to_output_weights_ptr, aux_input_ptr, + aux_input_to_output_weights_ptr, output_state_ptr, + recurrent_to_output_weights_ptr, cell_state_ptr, + cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr, + output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell, + /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros, + is_aux_input_all_zeros, logger, intermediate_tensor_indexes[3], + error_reporter); + // Update the output state. CalculateLstmOutputCalibration( n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch, params->activation, projection_weights_ptr, projection_bias_ptr, params->proj_clip, output_state_ptr, scratch2, logger, intermediate_tensor_indexes, error_reporter); - - // Copy output_state to the output. Note that the output batch rows may not be + // Copy output state to the output. Note that the output's rows may not be // contiguous (output_batch_leading_dim != n_output). for (int b = 0; b < n_batch; b++) { std::copy_n(output_state_ptr + b * n_output, n_output,