diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc index 1620374f466..b0be6d0dbd7 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc @@ -886,11 +886,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights, fw_cell_to_input_weights, fw_cell_to_forget_weights, - fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights, - fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights, - fw_aux_input_to_output_weights, fw_input_gate_bias, - fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias, - fw_projection_weights, fw_projection_bias, &lstm_params, + fw_cell_to_output_weights, + /*input_layer_norm_coefficients=*/nullptr, + /*forget_layer_norm_coefficients=*/nullptr, + /*cell_layer_norm_coefficients=*/nullptr, + /*output_layer_norm_coefficients=*/nullptr, aux_input, + fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights, + fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights, + fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias, + fw_output_gate_bias, fw_projection_weights, fw_projection_bias, + &lstm_params, /*forward_sequence=*/true, time_major, /*output_offset=*/0, fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output); TF_LITE_ENSURE_OK(context, fw_pass_status); @@ -901,11 +906,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, bw_cell_to_input_weights, bw_cell_to_forget_weights, - bw_cell_to_output_weights, aux_input, bw_aux_input_to_input_weights, - bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights, - bw_aux_input_to_output_weights, bw_input_gate_bias, - bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias, - bw_projection_weights, bw_projection_bias, &lstm_params, + bw_cell_to_output_weights, + /*input_layer_norm_coefficients=*/nullptr, + /*forget_layer_norm_coefficients=*/nullptr, + /*cell_layer_norm_coefficients=*/nullptr, + /*output_layer_norm_coefficients=*/nullptr, aux_input, + bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights, + bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights, + bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias, + bw_output_gate_bias, bw_projection_weights, bw_projection_bias, + &lstm_params, /*forward_sequence=*/false, time_major, bw_output_offset, bw_scratch_buffer, bw_activation_state, bw_cell_state, actual_bw_output); @@ -940,11 +950,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights, fw_cell_to_input_weights, fw_cell_to_forget_weights, - fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights, - fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights, - fw_aux_input_to_output_weights, fw_input_gate_bias, - fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias, - fw_projection_weights, fw_projection_bias, &lstm_params, + fw_cell_to_output_weights, + /*input_layer_norm_coefficients=*/nullptr, + /*forget_layer_norm_coefficients=*/nullptr, + /*cell_layer_norm_coefficients=*/nullptr, + /*output_layer_norm_coefficients=*/nullptr, aux_input, + fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights, + fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights, + fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias, + fw_output_gate_bias, fw_projection_weights, fw_projection_bias, + &lstm_params, /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, fw_scratch_buffer, scaling_factors, prod_scaling_factors, recovered_cell_weights, input_quantized, aux_input_quantized, @@ -958,11 +973,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, bw_cell_to_input_weights, bw_cell_to_forget_weights, - bw_cell_to_output_weights, aux_input, bw_aux_input_to_input_weights, - bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights, - bw_aux_input_to_output_weights, bw_input_gate_bias, - bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias, - bw_projection_weights, bw_projection_bias, &lstm_params, + bw_cell_to_output_weights, + /*input_layer_norm_coefficients=*/nullptr, + /*forget_layer_norm_coefficients=*/nullptr, + /*cell_layer_norm_coefficients=*/nullptr, + /*output_layer_norm_coefficients=*/nullptr, aux_input, + bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights, + bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights, + bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias, + bw_output_gate_bias, bw_projection_weights, bw_projection_bias, + &lstm_params, /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset, bw_scratch_buffer, scaling_factors, prod_scaling_factors, recovered_cell_weights, input_quantized, aux_input_quantized, diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index b57e2883b05..3689d77b012 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -38,17 +38,24 @@ namespace builtin { namespace lstm { struct OpData { - // Which kernel type to use. Full kernel (20 inputs) or basic kernel - // (5 inputs). + // Which kernel type to use. Full kernel (24 inputs) or basic kernel (5 + // inputs). + // Please note the 20-input full kernel is deprecated and only kept + // here for backward compatibility. TfLiteLSTMKernelType kernel_type; + // If the lstm is layer norm. + bool is_layer_norm_lstm; + // These fields are only used by full kernel. int activation_state_tensor_index; int cell_state_tensor_index; int scratch_tensor_index; }; -// For full inputs kernel (20-inputs). +// For full inputs kernel (24-inputs). +// Please note the 20-input full kernel is deprecated and only kept +// here for backward compatibility. namespace full { // Input Tensors of size {n_batch, n_input} @@ -87,6 +94,13 @@ constexpr int kProjectionBiasTensor = 17; // Optional constexpr int kInputActivationStateTensor = 18; constexpr int kInputCellStateTensor = 19; +// Layer norm coefficient tensors of size {n_cell}, representing a diagonal +// matrix. +constexpr int kInputLayerNormCoefficientsTensor = 20; // Optional +constexpr int kForgetLayerNormCoefficientsTensor = 21; // Optional +constexpr int kCellLayerNormCoefficientsTensor = 22; // Optional +constexpr int kOutputLayerNormCoefficientsTensor = 23; // Optional + // Output tensors. constexpr int kOutputTensor = 0; @@ -101,7 +115,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // Check that input tensor dimensions matches with each other. TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, - int n_output, int n_cell) { + int n_output, int n_cell, + bool is_layer_norm_lstm) { const auto* params = reinterpret_cast(node->builtin_data); // Making sure clipping parameters have valid values. @@ -112,7 +127,8 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - if (input_to_input_weights != nullptr) { + const bool use_cifg = (input_to_input_weights == nullptr); + if (!use_cifg) { TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); @@ -186,7 +202,6 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, } // Making sure the peephole weights are there all or none. - const bool use_cifg = (input_to_input_weights == nullptr); const bool peephole_weights_all_or_none = ((cell_to_input_weights != nullptr || use_cifg) && (cell_to_forget_weights != nullptr) && @@ -244,6 +259,40 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, ((projection_weights != nullptr) || (projection_bias == nullptr)); TF_LITE_ENSURE(context, projection_tensors_consistent == true); + if (is_layer_norm_lstm) { + const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor( + context, node, kInputLayerNormCoefficientsTensor); + if (use_cifg) { + TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr); + } else { + TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr); + TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0], + n_cell); + } + + const TfLiteTensor* forget_layer_norm_coefficients = + GetInput(context, node, kForgetLayerNormCoefficientsTensor); + TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr); + TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0], + n_cell); + + const TfLiteTensor* cell_layer_norm_coefficients = + GetInput(context, node, kCellLayerNormCoefficientsTensor); + TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr); + TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0], + n_cell); + + const TfLiteTensor* output_layer_norm_coefficients = + GetInput(context, node, kOutputLayerNormCoefficientsTensor); + TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr); + TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0], + n_cell); + } + return kTfLiteOk; } @@ -254,8 +303,32 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { OpData* op_data = reinterpret_cast(node->user_data); TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - TF_LITE_ENSURE_EQ(context, node->inputs->size, 20); + // Logic for determining regular lstm and layer norm lstm: + // input_size, forget_gate_layer_norm_tensor (20) null? is_layer_norm? + // 20, N/A, No. + // 24, null, No. + // 24, not null, Yes. + // 20-inputs lstm are deprecated and is only kept here for backward + // compatibility. + if (node->inputs->size == 24) { + const TfLiteTensor* forget_layer_norm_coefficients = + GetInput(context, node, kForgetLayerNormCoefficientsTensor); + if (forget_layer_norm_coefficients == nullptr) { + op_data->is_layer_norm_lstm = false; + } else { + op_data->is_layer_norm_lstm = true; + } + } else if (node->inputs->size == 20) { + // This is deprecated and is only kept here for backward compatibility. + op_data->is_layer_norm_lstm = false; + } else { + context->ReportError( + context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs", + node->inputs->size); + return kTfLiteError; + } + const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm; op_data->activation_state_tensor_index = node->inputs->data[kInputActivationStateTensor]; op_data->cell_state_tensor_index = node->inputs->data[kInputCellStateTensor]; @@ -282,8 +355,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_output = recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input, - n_output, n_cell)); + TF_LITE_ENSURE_OK(context, + CheckInputTensorDimensions(context, node, n_input, n_output, + n_cell, is_layer_norm_lstm)); // Get the pointer to output, activation_state and cell_state tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -430,6 +504,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const auto* params = reinterpret_cast(node->builtin_data); OpData* op_data = reinterpret_cast(node->user_data); + const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm; const TfLiteTensor* input = GetInput(context, node, kInputTensor); @@ -458,6 +533,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + const TfLiteTensor* input_layer_norm_coefficients = + is_layer_norm_lstm ? GetOptionalInputTensor( + context, node, kInputLayerNormCoefficientsTensor) + : nullptr; + const TfLiteTensor* forget_layer_norm_coefficients = + is_layer_norm_lstm + ? GetInput(context, node, kForgetLayerNormCoefficientsTensor) + : nullptr; + const TfLiteTensor* cell_layer_norm_coefficients = + is_layer_norm_lstm + ? GetInput(context, node, kCellLayerNormCoefficientsTensor) + : nullptr; + const TfLiteTensor* output_layer_norm_coefficients = + is_layer_norm_lstm + ? GetInput(context, node, kOutputLayerNormCoefficientsTensor) + : nullptr; + const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(context, node, kInputGateBiasTensor); const TfLiteTensor* forget_gate_bias = @@ -490,6 +582,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input_layer_norm_coefficients, forget_layer_norm_coefficients, + cell_layer_norm_coefficients, output_layer_norm_coefficients, /*aux_input=*/nullptr, /*aux_input_to_input_weights=*/nullptr, /*aux_input_to_forget_weights=*/nullptr, @@ -518,6 +612,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input_layer_norm_coefficients, forget_layer_norm_coefficients, + cell_layer_norm_coefficients, output_layer_norm_coefficients, /*aux_input=*/nullptr, /*aux_input_to_input_weights=*/nullptr, /*aux_input_to_forget_weights=*/nullptr, diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index 50b2bca7b54..6ba1e193437 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -27,6 +27,10 @@ namespace lstm_eval { namespace { +// Small float to avoid divergence during calculation of deviation for layer +// norm lstm. +const float kLayerNormEpsilon = 1e-8; + // Performs an LSTM batch inference step for input specified by input_ptr_batch. // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and // biases (*_bias_ptr), and buffers (*_scratch), along with additional @@ -65,30 +69,47 @@ inline void LstmStepWithAuxInput( const float* recurrent_to_output_weights_ptr, const float* cell_to_input_weights_ptr, const float* cell_to_forget_weights_ptr, - const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const float* projection_weights_ptr, - const float* projection_bias_ptr, const TfLiteLSTMParams* params, - int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, - int output_batch_leading_dim, float* output_state_ptr, - float* cell_state_ptr, float* input_gate_scratch, + const float* cell_to_output_weights_ptr, + const float* input_layer_norm_coefficients_ptr, + const float* forget_layer_norm_coefficients_ptr, + const float* cell_layer_norm_coefficients_ptr, + const float* output_layer_norm_coefficients_ptr, + const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr, + const float* cell_bias_ptr, const float* output_gate_bias_ptr, + const float* projection_weights_ptr, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_aux_input, int n_output, int output_batch_leading_dim, + float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* output_ptr_batch) { // Since we have already checked that weights are all there or none, we can // check the existense of only one to the get the condition. const bool use_cifg = (input_to_input_weights_ptr == nullptr); const bool use_peephole = (cell_to_output_weights_ptr != nullptr); - // Initialize scratch buffers with bias. - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, - input_gate_scratch); + const bool is_layer_norm_lstm = + (forget_layer_norm_coefficients_ptr != nullptr); + + // Initialize scratch buffers with bias for regular lstm or initialize with + // zero for layer norm lstm. + if (is_layer_norm_lstm) { + if (!use_cifg) { + tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch); + } + tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch); + tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch); + tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch); + } else { + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, + n_batch, input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); // For each batch and cell: compute input_weight * input. if (!use_cifg) { @@ -152,6 +173,16 @@ inline void LstmStepWithAuxInput( cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, input_gate_scratch); } + if (is_layer_norm_lstm) { + tensor_utils::MeanStddevNormalization(input_gate_scratch, + input_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct( + input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch, + n_batch, input_gate_scratch); + tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + } tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, input_gate_scratch); } @@ -162,12 +193,31 @@ inline void LstmStepWithAuxInput( cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, forget_gate_scratch); } + if (is_layer_norm_lstm) { + tensor_utils::MeanStddevNormalization(forget_gate_scratch, + forget_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct( + forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch, + n_batch, forget_gate_scratch); + tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + } tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, forget_gate_scratch); // For each batch and cell: update the cell. tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, n_batch * n_cell, cell_state_ptr); + if (is_layer_norm_lstm) { + tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, + n_batch, kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct( + cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + } tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, params->activation, cell_scratch); if (use_cifg) { @@ -190,6 +240,16 @@ inline void LstmStepWithAuxInput( cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, output_gate_scratch); } + if (is_layer_norm_lstm) { + tensor_utils::MeanStddevNormalization(output_gate_scratch, + output_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct( + output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch, + n_batch, output_gate_scratch); + tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + } tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, output_gate_scratch); tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, @@ -344,33 +404,50 @@ inline void LstmStepWithAuxInput( const int8_t* cell_to_forget_weights_ptr, float cell_to_forget_weights_scale, const int8_t* cell_to_output_weights_ptr, - float cell_to_output_weights_scale, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, - float projection_weights_scale, const float* projection_bias_ptr, - const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, - int n_aux_input, int n_output, int output_batch_leading_dim, - float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, - float* output_gate_scratch, float* scaling_factors, - float* product_scaling_factors, float* recovered_cell_weights, - int8_t* quantized_input_ptr_batch, int8_t* quantized_aux_input_ptr_batch, - int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr, - float* output_state_ptr, float* cell_state_ptr, float* output_ptr_batch) { + float cell_to_output_weights_scale, + const float* input_layer_norm_coefficients_ptr, + const float* forget_layer_norm_coefficients_ptr, + const float* cell_layer_norm_coefficients_ptr, + const float* output_layer_norm_coefficients_ptr, + const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr, + const float* cell_bias_ptr, const float* output_gate_bias_ptr, + const int8_t* projection_weights_ptr, float projection_weights_scale, + const float* projection_bias_ptr, const TfLiteLSTMParams* params, + int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, + int output_batch_leading_dim, float* input_gate_scratch, + float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, + float* scaling_factors, float* product_scaling_factors, + float* recovered_cell_weights, int8_t* quantized_input_ptr_batch, + int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch) { // Since we have already checked that weights are all there or none, we // can check the existense of only one to the get the condition. const bool use_cifg = (input_to_input_weights_ptr == nullptr); const bool use_peephole = (cell_to_output_weights_ptr != nullptr); + const bool is_layer_norm_lstm = + (forget_layer_norm_coefficients_ptr != nullptr); + // Initialize scratch buffers with bias. - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, - input_gate_scratch); + if (is_layer_norm_lstm) { + if (!use_cifg) { + tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch); + } + tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch); + tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch); + tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch); + } else { + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, + n_batch, input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { // Save quantization and matmul computation for all zero input. @@ -535,6 +612,16 @@ inline void LstmStepWithAuxInput( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, input_gate_scratch); } + if (is_layer_norm_lstm) { + tensor_utils::MeanStddevNormalization(input_gate_scratch, + input_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct( + input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch, + n_batch, input_gate_scratch); + tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + } tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, input_gate_scratch); } @@ -548,12 +635,31 @@ inline void LstmStepWithAuxInput( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, forget_gate_scratch); } + if (is_layer_norm_lstm) { + tensor_utils::MeanStddevNormalization(forget_gate_scratch, + forget_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct( + forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch, + n_batch, forget_gate_scratch); + tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + } tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, forget_gate_scratch); // For each batch and cell: update the cell. tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, n_batch * n_cell, cell_state_ptr); + if (is_layer_norm_lstm) { + tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, + n_batch, kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct( + cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + } tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, params->activation, cell_scratch); if (use_cifg) { @@ -581,6 +687,16 @@ inline void LstmStepWithAuxInput( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, output_gate_scratch); } + if (is_layer_norm_lstm) { + tensor_utils::MeanStddevNormalization(output_gate_scratch, + output_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct( + output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch, + n_batch, output_gate_scratch); + tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + } tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, output_gate_scratch); tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, @@ -702,7 +818,12 @@ TfLiteStatus EvalFloat( const TfLiteTensor* recurrent_to_output_weights, const TfLiteTensor* cell_to_input_weights, const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_layer_norm_coefficients, + const TfLiteTensor* forget_layer_norm_coefficients, + const TfLiteTensor* cell_layer_norm_coefficients, + const TfLiteTensor* output_layer_norm_coefficients, + const TfLiteTensor* aux_input, const TfLiteTensor* aux_input_to_input_weights, const TfLiteTensor* aux_input_to_forget_weights, const TfLiteTensor* aux_input_to_cell_weights, @@ -735,6 +856,7 @@ TfLiteStatus EvalFloat( // check the existense of only one to the get the condition. const bool use_cifg = (input_to_input_weights == nullptr); const bool use_peephole = (cell_to_output_weights != nullptr); + const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr); // Index the scratch buffers pointers to the global scratch buffer. float* input_gate_scratch = nullptr; @@ -765,6 +887,15 @@ TfLiteStatus EvalFloat( (use_peephole) ? cell_to_forget_weights->data.f : nullptr; const float* cell_to_output_weights_ptr = (use_peephole) ? cell_to_output_weights->data.f : nullptr; + const float* input_layer_norm_coefficients_ptr = + (is_layer_norm_lstm && !use_cifg) ? input_layer_norm_coefficients->data.f + : nullptr; + const float* forget_layer_norm_coefficients_ptr = + is_layer_norm_lstm ? forget_layer_norm_coefficients->data.f : nullptr; + const float* cell_layer_norm_coefficients_ptr = + is_layer_norm_lstm ? cell_layer_norm_coefficients->data.f : nullptr; + const float* output_layer_norm_coefficients_ptr = + is_layer_norm_lstm ? output_layer_norm_coefficients->data.f : nullptr; const float* projection_weights_ptr = (projection_weights == nullptr) ? nullptr : projection_weights->data.f; const float* projection_bias_ptr = @@ -811,6 +942,8 @@ TfLiteStatus EvalFloat( recurrent_to_cell_weights->data.f, recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, cell_to_forget_weights_ptr, cell_to_output_weights_ptr, + input_layer_norm_coefficients_ptr, forget_layer_norm_coefficients_ptr, + cell_layer_norm_coefficients_ptr, output_layer_norm_coefficients_ptr, input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, params, n_batch, n_cell, n_input, aux_input_size, n_output, @@ -855,7 +988,11 @@ TfLiteStatus EvalFloat( recurrent_to_cell_weights->data.f, recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, cell_to_forget_weights_ptr, cell_to_output_weights_ptr, - input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, + input_layer_norm_coefficients_ptr, + forget_layer_norm_coefficients_ptr, + cell_layer_norm_coefficients_ptr, + output_layer_norm_coefficients_ptr, input_gate_bias_ptr, + forget_gate_bias->data.f, cell_bias->data.f, output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, params, /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim, @@ -879,7 +1016,12 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* recurrent_to_output_weights, const TfLiteTensor* cell_to_input_weights, const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_layer_norm_coefficients, + const TfLiteTensor* forget_layer_norm_coefficients, + const TfLiteTensor* cell_layer_norm_coefficients, + const TfLiteTensor* output_layer_norm_coefficients, + const TfLiteTensor* aux_input, const TfLiteTensor* aux_input_to_input_weights, const TfLiteTensor* aux_input_to_forget_weights, const TfLiteTensor* aux_input_to_cell_weights, @@ -914,6 +1056,7 @@ TfLiteStatus EvalHybrid( // check the existence of only one to get the condition. const bool use_cifg = (input_to_input_weights == nullptr); const bool use_peephole = (cell_to_output_weights != nullptr); + const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr); float* input_gate_scratch = nullptr; float* cell_scratch = nullptr; @@ -966,6 +1109,16 @@ TfLiteStatus EvalHybrid( cell_to_output_weights_scale = cell_to_output_weights->params.scale; } + const float* input_layer_norm_coefficients_ptr = + (is_layer_norm_lstm && !use_cifg) ? input_layer_norm_coefficients->data.f + : nullptr; + const float* forget_layer_norm_coefficients_ptr = + is_layer_norm_lstm ? forget_layer_norm_coefficients->data.f : nullptr; + const float* cell_layer_norm_coefficients_ptr = + is_layer_norm_lstm ? cell_layer_norm_coefficients->data.f : nullptr; + const float* output_layer_norm_coefficients_ptr = + is_layer_norm_lstm ? output_layer_norm_coefficients->data.f : nullptr; + const int8_t* projection_weights_ptr = (projection_weights == nullptr) ? nullptr @@ -1084,6 +1237,8 @@ TfLiteStatus EvalHybrid( cell_to_input_weights_ptr, cell_to_input_weights_scale, cell_to_forget_weights_ptr, cell_to_forget_weights_scale, cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_layer_norm_coefficients_ptr, forget_layer_norm_coefficients_ptr, + cell_layer_norm_coefficients_ptr, output_layer_norm_coefficients_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, projection_bias_ptr, params, n_batch, @@ -1138,7 +1293,10 @@ TfLiteStatus EvalHybrid( recurrent_to_output_weights_scale, cell_to_input_weights_ptr, cell_to_input_weights_scale, cell_to_forget_weights_ptr, cell_to_forget_weights_scale, cell_to_output_weights_ptr, - cell_to_output_weights_scale, input_gate_bias_ptr, + cell_to_output_weights_scale, input_layer_norm_coefficients_ptr, + forget_layer_norm_coefficients_ptr, + cell_layer_norm_coefficients_ptr, + output_layer_norm_coefficients_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, projection_bias_ptr, params, diff --git a/tensorflow/lite/kernels/lstm_eval.h b/tensorflow/lite/kernels/lstm_eval.h index c8a4d284f3c..33e5bc07819 100644 --- a/tensorflow/lite/kernels/lstm_eval.h +++ b/tensorflow/lite/kernels/lstm_eval.h @@ -34,7 +34,12 @@ TfLiteStatus EvalFloat( const TfLiteTensor* recurrent_to_output_weights, const TfLiteTensor* cell_to_input_weights, const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_layer_norm_coefficients, + const TfLiteTensor* forget_layer_norm_coefficients, + const TfLiteTensor* cell_layer_norm_coefficients, + const TfLiteTensor* output_layer_norm_coefficients, + const TfLiteTensor* aux_input, const TfLiteTensor* aux_input_to_input_weights, const TfLiteTensor* aux_input_to_forget_weights, const TfLiteTensor* aux_input_to_cell_weights, @@ -58,7 +63,12 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* recurrent_to_output_weights, const TfLiteTensor* cell_to_input_weights, const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_layer_norm_coefficients, + const TfLiteTensor* forget_layer_norm_coefficients, + const TfLiteTensor* cell_layer_norm_coefficients, + const TfLiteTensor* output_layer_norm_coefficients, + const TfLiteTensor* aux_input, const TfLiteTensor* aux_input_to_input_weights, const TfLiteTensor* aux_input_to_forget_weights, const TfLiteTensor* aux_input_to_cell_weights, diff --git a/tensorflow/lite/kernels/lstm_test.cc b/tensorflow/lite/kernels/lstm_test.cc index 03ad2e899d2..fea95aacb1f 100644 --- a/tensorflow/lite/kernels/lstm_test.cc +++ b/tensorflow/lite/kernels/lstm_test.cc @@ -38,7 +38,8 @@ class LSTMOpModel : public SingleOpModel { bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, const std::vector>& input_shapes, - const TensorType& weight_type = TensorType_FLOAT32) + const TensorType& weight_type = TensorType_FLOAT32, + bool is_layer_norm = false) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), @@ -106,6 +107,18 @@ class LSTMOpModel : public SingleOpModel { input_cell_state_ = AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); + // Layer norm weights. + if (is_layer_norm) { + if (use_cifg) { + input_layer_norm_coefficients_ = AddNullInput(); + } else { + input_layer_norm_coefficients_ = AddInput(TensorType_FLOAT32); + } + forget_layer_norm_coefficients_ = AddInput(TensorType_FLOAT32); + cell_layer_norm_coefficients_ = AddInput(TensorType_FLOAT32); + output_layer_norm_coefficients_ = AddInput(TensorType_FLOAT32); + } + output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, @@ -160,6 +173,22 @@ class LSTMOpModel : public SingleOpModel { PopulateTensor(cell_to_output_weights_, f); } + void SetInputLayerNormCoefficients(std::vector f) { + PopulateTensor(input_layer_norm_coefficients_, f); + } + + void SetForgetLayerNormCoefficients(std::vector f) { + PopulateTensor(forget_layer_norm_coefficients_, f); + } + + void SetCellLayerNormCoefficients(std::vector f) { + PopulateTensor(cell_layer_norm_coefficients_, f); + } + + void SetOutputLayerNormCoefficients(std::vector f) { + PopulateTensor(output_layer_norm_coefficients_, f); + } + void SetInputGateBias(std::vector f) { PopulateTensor(input_gate_bias_, f); } @@ -210,6 +239,11 @@ class LSTMOpModel : public SingleOpModel { int cell_to_forget_weights_; int cell_to_output_weights_; + int input_layer_norm_coefficients_; + int forget_layer_norm_coefficients_; + int cell_layer_norm_coefficients_; + int output_layer_norm_coefficients_; + int input_gate_bias_; int forget_gate_bias_; int cell_bias_; @@ -1392,6 +1426,644 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } +class LayerNormLSTMOpModel : public LSTMOpModel { + public: + LayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, + bool use_cifg, bool use_peephole, + bool use_projection_weights, bool use_projection_bias, + float cell_clip, float proj_clip, + const std::vector>& input_shapes, + const TensorType& weight_type = TensorType_FLOAT32) + : LSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, use_peephole, + use_projection_weights, use_projection_bias, cell_clip, + proj_clip, input_shapes, weight_type, + /*is_layer_norm*/ true) {} +}; + +class HybridLayerNormLSTMOpModel : public LayerNormLSTMOpModel { + public: + HybridLayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, + bool use_cifg, bool use_peephole, + bool use_projection_weights, + bool use_projection_bias, float cell_clip, + float proj_clip, + const std::vector>& input_shapes) + : LayerNormLSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, + use_peephole, use_projection_weights, + use_projection_bias, cell_clip, proj_clip, + input_shapes, TensorType_UINT8) {} + + void SetInputToInputWeights(std::vector f) { + SymmetricQuantizeAndPopulate(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::vector f) { + SymmetricQuantizeAndPopulate(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::vector f) { + SymmetricQuantizeAndPopulate(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::vector f) { + SymmetricQuantizeAndPopulate(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::vector f) { + SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::vector f) { + SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::vector f) { + SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::vector f) { + SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::vector f) { + SymmetricQuantizeAndPopulate(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::vector f) { + SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::vector f) { + SymmetricQuantizeAndPopulate(cell_to_output_weights_, f); + } + + void SetInputLayerNormCoefficients(std::vector f) { + PopulateTensor(input_layer_norm_coefficients_, f); + } + + void SetForgetLayerNormCoefficients(std::vector f) { + PopulateTensor(forget_layer_norm_coefficients_, f); + } + + void SetCellLayerNormCoefficients(std::vector f) { + PopulateTensor(cell_layer_norm_coefficients_, f); + } + + void SetOutputLayerNormCoefficients(std::vector f) { + PopulateTensor(output_layer_norm_coefficients_, f); + } + + void SetProjectionWeights(std::vector f) { + SymmetricQuantizeAndPopulate(projection_weights_, f); + } +}; + +class BaseLayerNormLstmTest : public ::testing::Test { + protected: + // Weights of the Layer Norm LSTM model. Some are optional. + std::vector input_to_input_weights_; + std::vector input_to_cell_weights_; + std::vector input_to_forget_weights_; + std::vector input_to_output_weights_; + std::vector input_gate_bias_; + std::vector cell_gate_bias_; + std::vector forget_gate_bias_; + std::vector output_gate_bias_; + std::vector recurrent_to_input_weights_; + std::vector recurrent_to_cell_weights_; + std::vector recurrent_to_forget_weights_; + std::vector recurrent_to_output_weights_; + std::vector cell_to_input_weights_; + std::vector cell_to_forget_weights_; + std::vector cell_to_output_weights_; + std::vector projection_weights_; + std::vector input_layer_norm_coefficients_; + std::vector forget_layer_norm_coefficients_; + std::vector cell_layer_norm_coefficients_; + std::vector output_layer_norm_coefficients_; + + // Layer Norm LSTM input is stored as num_batch x num_inputs vector. + std::vector> layer_norm_lstm_input_; + + // Compares output up to tolerance to the result of the layer_norm_lstm given + // the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + LayerNormLSTMOpModel* layer_norm_lstm, + float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = layer_norm_lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(), + batch_start, batch_end); + } + + layer_norm_lstm->Invoke(); + + const int num_outputs = layer_norm_lstm->num_outputs(); + std::vector expected; + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + EXPECT_THAT(layer_norm_lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest + : public BaseLayerNormLstmTest { + void SetUp() override { + input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, + 0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5, + -0.4, -0.5, -0.4, -0.3, -0.2, -0.1}; + + input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, + -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4, + -0.6, 0.3, -0.4, -0.6, -0.5, -0.5}; + + input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, + -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3, + -0.7, 0.7, -0.9, -0.5, 0.8, 0.6}; + + input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, + -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7, + -0.3, -0.5, 0.1, 0.5, -0.6, -0.4}; + + input_gate_bias_ = {0.03, 0.15, 0.22, 0.38}; + + forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1}; + + cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08}; + + output_gate_bias_ = {0.05, -0.01, 0.2, 0.1}; + + recurrent_to_input_weights_ = {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, + -0.2, -0.3, -0.7, 0.05, -0.2, -0.6}; + + recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, + -0.2, 0.3, 0.8, -0.6, -0.1, 0.2}; + + recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, + 0.9, 0.3, -0.1, 0.2, 0.5, 0.2}; + + recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, + -0.2, -0.6, -0.1, -0.4, -0.7, -0.2}; + + cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15}; + + cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03}; + + cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05}; + + input_layer_norm_coefficients_ = {0.1, 0.2, 0.3, 0.5}; + forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3}; + cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8}; + output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5}; + + projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, + 0.3, 0.08, 0.07, 0.2, -0.4, 0.2}; + + layer_norm_lstm_input_ = { + {// Batch0: 3 (input_sequence_size) * 5 (n_input) + 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0 + 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1 + 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2 + + {// Batch1: 3 (input_sequence_size) * 5 (n_input) + 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0 + 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1 + 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2 + }; + } +}; + +TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, + LayerNormLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + const float ceil_clip = 0.0; + const float proj_clip = 0.0; + + LayerNormLSTMOpModel layer_norm_lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, ceil_clip, proj_clip, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {n_cell}, // input_layer_norm_coefficient tensor + {n_cell}, // forget_layer_norm_coefficient tensor + {n_cell}, // cell_layer_norm_coefficient tensor + {n_cell}, // output_layer_norm_coefficient tensor + }); + + layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_); + layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_); + layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_); + layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_); + + layer_norm_lstm.SetInputGateBias(input_gate_bias_); + layer_norm_lstm.SetCellBias(cell_gate_bias_); + layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); + layer_norm_lstm.SetOutputGateBias(output_gate_bias_); + + layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_); + layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); + layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); + + layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_); + layer_norm_lstm.SetForgetLayerNormCoefficients( + forget_layer_norm_coefficients_); + layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_); + layer_norm_lstm.SetOutputLayerNormCoefficients( + output_layer_norm_coefficients_); + + layer_norm_lstm.SetProjectionWeights(projection_weights_); + + // Verify the final output. + const std::vector> layer_norm_lstm_golden_output = { + { + // Batch0: 3 (input_sequence_size) * 3 (n_output) + 0.0244077, 0.128027, -0.00170918, // seq 0 + 0.0137642, 0.140751, 0.0395835, // seq 1 + -0.00459231, 0.155278, 0.0837377, // seq 2 + }, + { + // Batch1: 3 (input_sequence_size) * 3 (n_output) + -0.00692428, 0.0848741, 0.063445, // seq 0 + -0.00403912, 0.139963, 0.072681, // seq 1 + 0.00752706, 0.161903, 0.0561371, // seq 2 + }}; + + VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output, + &layer_norm_lstm); +} + +TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, + HybridLayerNormLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + const float ceil_clip = 0.0; + const float proj_clip = 0.0; + + HybridLayerNormLSTMOpModel layer_norm_lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, ceil_clip, proj_clip, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {n_cell}, // input_layer_norm_coefficient tensor + {n_cell}, // forget_layer_norm_coefficient tensor + {n_cell}, // cell_layer_norm_coefficient tensor + {n_cell}, // output_layer_norm_coefficient tensor + }); + + layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_); + layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_); + layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_); + layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_); + + layer_norm_lstm.SetInputGateBias(input_gate_bias_); + layer_norm_lstm.SetCellBias(cell_gate_bias_); + layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); + layer_norm_lstm.SetOutputGateBias(output_gate_bias_); + + layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_); + layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); + layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); + + layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_); + layer_norm_lstm.SetForgetLayerNormCoefficients( + forget_layer_norm_coefficients_); + layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_); + layer_norm_lstm.SetOutputLayerNormCoefficients( + output_layer_norm_coefficients_); + + layer_norm_lstm.SetProjectionWeights(projection_weights_); + + const std::vector> layer_norm_lstm_golden_output = { + { + // Batch0: 3 (input_sequence_size) * 3 (n_output) + 0.0244576, 0.127847, -0.00181765, // seq 0 + 0.0137518, 0.140892, 0.0402234, // seq 1 + -0.0048839, 0.155096, 0.0840309, // seq 2 + }, + { + // Batch1: 3 (input_sequence_size) * 3 (n_output) + -0.00728636, 0.0843957, 0.0634786, // seq 0 + -0.00448382, 0.139278, 0.0737372, // seq 1 + 0.00734616, 0.161793, 0.0560238, // seq 2 + }}; + + VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output, + &layer_norm_lstm); +} + +class CifgPeepholeProjectionNoClippingLayerNormLstmTest + : public BaseLayerNormLstmTest { + void SetUp() override { + input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, + -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4, + -0.6, 0.3, -0.4, -0.6, -0.5, -0.5}; + input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, + -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3, + -0.7, 0.7, -0.9, -0.5, 0.8, 0.6}; + input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, + -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7, + -0.3, -0.5, 0.1, 0.5, -0.6, -0.4}; + + forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1}; + cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08}; + output_gate_bias_ = {0.05, -0.01, 0.2, 0.1}; + + recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, + -0.2, 0.3, 0.8, -0.6, -0.1, 0.2}; + recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, + 0.9, 0.3, -0.1, 0.2, 0.5, 0.2}; + recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, + -0.2, -0.6, -0.1, -0.4, -0.7, -0.2}; + + cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03}; + cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05}; + + forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3}; + cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8}; + output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5}; + projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, + 0.3, 0.08, 0.07, 0.2, -0.4, 0.2}; + + layer_norm_lstm_input_ = { + {// Batch0: 3 (input_sequence_size) * 5 (n_input) + 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0 + 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1 + 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2 + + {// Batch1: 3 (input_sequence_size) * 5 (n_input) + 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0 + 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1 + 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2 + }; + } +}; + +TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, + LayerNormLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + const float ceil_clip = 0.0; + const float proj_clip = 0.0; + + LayerNormLSTMOpModel layer_norm_lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, ceil_clip, proj_clip, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {0}, // input_layer_norm_coefficient tensor + {n_cell}, // forget_layer_norm_coefficient tensor + {n_cell}, // cell_layer_norm_coefficient tensor + {n_cell}, // output_layer_norm_coefficient tensor + }); + + layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_); + layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_); + layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_); + + layer_norm_lstm.SetCellBias(cell_gate_bias_); + layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); + layer_norm_lstm.SetOutputGateBias(output_gate_bias_); + + layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); + layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); + + layer_norm_lstm.SetForgetLayerNormCoefficients( + forget_layer_norm_coefficients_); + layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_); + layer_norm_lstm.SetOutputLayerNormCoefficients( + output_layer_norm_coefficients_); + + layer_norm_lstm.SetProjectionWeights(projection_weights_); + + // Verify the final output. + const std::vector> layer_norm_lstm_golden_output = { + { + // Batch0: 3 (input_sequence_size) * 3 (n_output) + 0.02129706, 0.140816242, 0.0112733059, // seq 0 + 0.0132302344, 0.152308047, 0.0346313119, // seq 1 + -0.0123688057, 0.165790111, 0.0893077999, // seq 2 + }, + { + // Batch1: 3 (input_sequence_size) * 3 (n_output) + -0.0226350538, 0.0916948169, 0.0769175813, // seq 0 + -0.0269966982, 0.149707705, 0.094149217, // seq 1 + -0.0103429332, 0.173016444, 0.0720508844, // seq 2 + }}; + + VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output, + &layer_norm_lstm); +} + +TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, + HybridLayerNormLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + const float ceil_clip = 0.0; + const float proj_clip = 0.0; + + HybridLayerNormLSTMOpModel layer_norm_lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, ceil_clip, proj_clip, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {0}, // input_layer_norm_coefficient tensor + {n_cell}, // forget_layer_norm_coefficient tensor + {n_cell}, // cell_layer_norm_coefficient tensor + {n_cell}, // output_layer_norm_coefficient tensor + }); + + layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_); + layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_); + layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_); + + layer_norm_lstm.SetCellBias(cell_gate_bias_); + layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); + layer_norm_lstm.SetOutputGateBias(output_gate_bias_); + + layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); + layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); + + layer_norm_lstm.SetForgetLayerNormCoefficients( + forget_layer_norm_coefficients_); + layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_); + layer_norm_lstm.SetOutputLayerNormCoefficients( + output_layer_norm_coefficients_); + + layer_norm_lstm.SetProjectionWeights(projection_weights_); + + // Verify the final output. + const std::vector> layer_norm_lstm_golden_output = { + { + // Batch0: 3 (input_sequence_size) * 3 (n_output) + 0.0212250091, 0.140474007, 0.0115012666, // seq 0 + 0.0130806509, 0.152660668, 0.0347516984, // seq 1 + -0.0124010444, 0.166042402, 0.0898982584, // seq 2 + }, + { + // Batch1: 3 (input_sequence_size) * 3 (n_output) + -0.0228835996, 0.0917588323, 0.0778886303, // seq 0 + -0.0275101066, 0.148769245, 0.0938384682, // seq 1 + -0.0103605557, 0.172605693, 0.0728750974, // seq 2 + }}; + + VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output, + &layer_norm_lstm); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc index 08e56b0ebd3..7d41491ba33 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc @@ -502,6 +502,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + /*input_layer_norm_coefficients=*/nullptr, + /*forget_layer_norm_coefficients=*/nullptr, + /*cell_layer_norm_coefficients=*/nullptr, + /*output_layer_norm_coefficients=*/nullptr, /*aux_input=*/nullptr, /*aux_input_to_input_weights=*/nullptr, /*aux_input_to_forget_weights=*/nullptr, @@ -529,6 +533,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + /*input_layer_norm_coefficients=*/nullptr, + /*forget_layer_norm_coefficients=*/nullptr, + /*cell_layer_norm_coefficients=*/nullptr, + /*output_layer_norm_coefficients=*/nullptr, /*aux_input=*/nullptr, /*aux_input_to_input_weights=*/nullptr, /*aux_input_to_forget_weights=*/nullptr,