From 731cd5e3ce64a5cbdde64d602274ff9f69d2ef93 Mon Sep 17 00:00:00 2001 From: Jian Li Date: Fri, 17 May 2019 15:19:40 -0700 Subject: [PATCH] Add layer norm to unidirectional lstm. With this change, unidirectional lstm takes 24 inputs with or without layer norm. The 20 input case is only kept for backward compatibility. PiperOrigin-RevId: 248797274 --- .../kernels/unidirectional_sequence_lstm.cc | 141 ++++++-- .../unidirectional_sequence_lstm_test.cc | 325 +++++++++++++++++- 2 files changed, 441 insertions(+), 25 deletions(-) diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc index f1793c13a72..d88f4265231 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc @@ -34,6 +34,13 @@ namespace ops { namespace builtin { namespace unidirectional_sequence_lstm { +struct OpData { + // If the lstm is layer norm. + bool is_layer_norm_lstm; + // The scratch tensor index. + int scratch_tensor_index; +}; + // Input Tensors of size {max_time, n_batch, n_input} constexpr int kInputTensor = 0; @@ -71,6 +78,13 @@ constexpr int kInputActivationStateTensor = 18; // Cell state tensor of size {n_batch, n_cell} constexpr int kInputCellStateTensor = 19; +// Layer norm coefficient tensors of size {n_cell}, representing a diagonal +// matrix. +constexpr int kInputLayerNormCoefficientsTensor = 20; // Optional +constexpr int kForgetLayerNormCoefficientsTensor = 21; // Optional +constexpr int kCellLayerNormCoefficientsTensor = 22; // Optional +constexpr int kOutputLayerNormCoefficientsTensor = 23; // Optional + // Output tensors. constexpr int kOutputTensor = 0; @@ -87,19 +101,21 @@ enum TemporaryTensor { }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int(); - context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); - return scratch_tensor_index; + auto* op_data = new OpData(); + context->AddTensors(context, kNumTemporaryTensors, + &op_data->scratch_tensor_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + delete reinterpret_cast(buffer); } // Check that input tensor dimensions matches with each other. TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, - int n_output, int n_cell) { + int n_output, int n_cell, + bool is_layer_norm_lstm) { const auto* params = reinterpret_cast(node->builtin_data); // Making sure clipping parameters have valid values. @@ -242,6 +258,48 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, ((projection_weights != nullptr) || (projection_bias == nullptr)); TF_LITE_ENSURE(context, projecton_tensors_consistent == true); + if (is_layer_norm_lstm) { + const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor( + context, node, kInputLayerNormCoefficientsTensor); + if (use_cifg) { + TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr); + } else { + TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr); + TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->type, + kTfLiteFloat32); + } + + const TfLiteTensor* forget_layer_norm_coefficients = + GetInput(context, node, kForgetLayerNormCoefficientsTensor); + TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr); + TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->type, + kTfLiteFloat32); + + const TfLiteTensor* cell_layer_norm_coefficients = + GetInput(context, node, kCellLayerNormCoefficientsTensor); + TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr); + TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->type, + kTfLiteFloat32); + + const TfLiteTensor* output_layer_norm_coefficients = + GetInput(context, node, kOutputLayerNormCoefficientsTensor); + TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr); + TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->type, + kTfLiteFloat32); + } + return kTfLiteOk; } @@ -249,11 +307,30 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, // Allocate a temporary scratch tensor. Also check that the sizes of the input // tensors match each other. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - int* scratch_tensor_index = reinterpret_cast(node->user_data); + OpData* op_data = reinterpret_cast(node->user_data); + const int scratch_tensor_index = op_data->scratch_tensor_index; // Check we have all the inputs and outputs we need. - TF_LITE_ENSURE_EQ(context, node->inputs->size, 20); + bool is_layer_norm_lstm = false; + if (node->inputs->size == 24) { + const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor( + context, node, kForgetLayerNormCoefficientsTensor); + if (forget_layer_norm_coefficients == nullptr) { + is_layer_norm_lstm = false; + } else { + is_layer_norm_lstm = true; + } + } else if (node->inputs->size == 20) { + // This is deprecated and is only kept here for backward compatibility. + is_layer_norm_lstm = false; + } else { + context->ReportError( + context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs", + node->inputs->size); + return kTfLiteError; + } TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + op_data->is_layer_norm_lstm = is_layer_norm_lstm; // Inferring batch size, number of outputs and sequence length and // number of cells from the input tensors. @@ -281,8 +358,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_output = recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input, - n_output, n_cell)); + TF_LITE_ENSURE_OK(context, + CheckInputTensorDimensions(context, node, n_input, n_output, + n_cell, is_layer_norm_lstm)); // Get the pointer to output, activation_state and cell_state buffer tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -310,7 +388,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } else { node->temporaries = TfLiteIntArrayCreate(1); } - node->temporaries->data[0] = *scratch_tensor_index; + node->temporaries->data[0] = scratch_tensor_index; // Create a scratch buffer tensor. TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer); @@ -336,7 +414,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Allocate temporary tensors to store quantized values of input, // activation_state and cell_state tensors. node->temporaries->data[kInputQuantized] = - *scratch_tensor_index + kInputQuantized; + scratch_tensor_index + kInputQuantized; TfLiteTensor* input_quantized = GetTemporary(context, node, kInputQuantized); input_quantized->type = input_to_output_weights->type; @@ -347,7 +425,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { input_quantized_size)); } node->temporaries->data[kOutputStateQuantized] = - *scratch_tensor_index + kOutputStateQuantized; + scratch_tensor_index + kOutputStateQuantized; TfLiteTensor* activation_state_quantized = GetTemporary(context, node, kOutputStateQuantized); activation_state_quantized->type = input_to_output_weights->type; @@ -361,7 +439,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { activation_state_quantized_size)); } node->temporaries->data[kCellStateQuantized] = - *scratch_tensor_index + kCellStateQuantized; + scratch_tensor_index + kCellStateQuantized; TfLiteTensor* cell_state_quantized = GetTemporary(context, node, kCellStateQuantized); cell_state_quantized->type = input_to_output_weights->type; @@ -380,7 +458,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // different matrices (which requires multiplying the scaling factors with // the scaling factor of the matrix). node->temporaries->data[kScalingFactors] = - *scratch_tensor_index + kScalingFactors; + scratch_tensor_index + kScalingFactors; TfLiteTensor* scaling_factors = GetTemporary(context, node, kScalingFactors); scaling_factors->type = kTfLiteFloat32; @@ -393,7 +471,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { scaling_factors_size)); } node->temporaries->data[kProductScalingFactors] = - *scratch_tensor_index + kProductScalingFactors; + scratch_tensor_index + kProductScalingFactors; TfLiteTensor* prod_scaling_factors = GetTemporary(context, node, kProductScalingFactors); prod_scaling_factors->type = kTfLiteFloat32; @@ -410,7 +488,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Allocate a temporary tensor to store the recovered cell weights. Since // this is used for diagonal matrices, only need to store n_cell values. node->temporaries->data[kRecoveredCellWeights] = - *scratch_tensor_index + kRecoveredCellWeights; + scratch_tensor_index + kRecoveredCellWeights; TfLiteTensor* recovered_cell_weights = GetTemporary(context, node, kRecoveredCellWeights); recovered_cell_weights->type = kTfLiteFloat32; @@ -432,6 +510,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const auto* params = reinterpret_cast( node->builtin_data); + const OpData* op_data = reinterpret_cast(node->user_data); + const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm; const bool time_major = params->time_major; const TfLiteTensor* input = GetInput(context, node, kInputTensor); @@ -481,6 +561,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* cell_state = GetVariableInput(context, node, kInputCellStateTensor); + const TfLiteTensor* input_layer_norm_coefficients = + is_layer_norm_lstm ? GetOptionalInputTensor( + context, node, kInputLayerNormCoefficientsTensor) + : nullptr; + const TfLiteTensor* forget_layer_norm_coefficients = + is_layer_norm_lstm + ? GetInput(context, node, kForgetLayerNormCoefficientsTensor) + : nullptr; + const TfLiteTensor* cell_layer_norm_coefficients = + is_layer_norm_lstm + ? GetInput(context, node, kCellLayerNormCoefficientsTensor) + : nullptr; + const TfLiteTensor* output_layer_norm_coefficients = + is_layer_norm_lstm + ? GetInput(context, node, kOutputLayerNormCoefficientsTensor) + : nullptr; + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Copy out the LSTM specific params so they can be passed in the function. @@ -497,10 +594,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, - /*input_layer_norm_coefficients=*/nullptr, - /*forget_layer_norm_coefficients=*/nullptr, - /*cell_layer_norm_coefficients=*/nullptr, - /*output_layer_norm_coefficients=*/nullptr, + input_layer_norm_coefficients, forget_layer_norm_coefficients, + cell_layer_norm_coefficients, output_layer_norm_coefficients, /*aux_input=*/nullptr, /*aux_input_to_input_weights=*/nullptr, /*aux_input_to_forget_weights=*/nullptr, @@ -529,10 +624,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, - /*input_layer_norm_coefficients=*/nullptr, - /*forget_layer_norm_coefficients=*/nullptr, - /*cell_layer_norm_coefficients=*/nullptr, - /*output_layer_norm_coefficients=*/nullptr, + input_layer_norm_coefficients, forget_layer_norm_coefficients, + cell_layer_norm_coefficients, output_layer_norm_coefficients, /*aux_input=*/nullptr, /*aux_input_to_input_weights=*/nullptr, /*aux_input_to_forget_weights=*/nullptr, diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc index c9f9f158fd2..ce2b070fb37 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc @@ -37,7 +37,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { bool use_projection_bias, float cell_clip, float proj_clip, const std::vector>& input_shapes, - const TensorType& weights_type = TensorType_FLOAT32) + const TensorType& weights_type = TensorType_FLOAT32, + bool is_layer_norm = false) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), @@ -108,6 +109,22 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, /*is_variable=*/true); + // Layer norm weights. + if (is_layer_norm) { + if (use_cifg) { + input_layer_norm_coefficients_ = AddNullInput(); + } else { + input_layer_norm_coefficients_ = + AddLayerNormCoeffsTensor(20, input_shapes); + } + forget_layer_norm_coefficients_ = + AddLayerNormCoeffsTensor(21, input_shapes); + cell_layer_norm_coefficients_ = + AddLayerNormCoeffsTensor(22, input_shapes); + output_layer_norm_coefficients_ = + AddLayerNormCoeffsTensor(23, input_shapes); + } + output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, @@ -187,6 +204,22 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { PopulateTensor(projection_bias_, f); } + void SetInputLayerNormCoefficients(std::vector f) { + PopulateTensor(input_layer_norm_coefficients_, f); + } + + void SetForgetLayerNormCoefficients(std::vector f) { + PopulateTensor(forget_layer_norm_coefficients_, f); + } + + void SetCellLayerNormCoefficients(std::vector f) { + PopulateTensor(cell_layer_norm_coefficients_, f); + } + + void SetOutputLayerNormCoefficients(std::vector f) { + PopulateTensor(output_layer_norm_coefficients_, f); + } + void SetInput(int offset, const float* begin, const float* end) { PopulateTensor(input_, offset, const_cast(begin), const_cast(end)); @@ -227,6 +260,11 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { int input_activation_state_; int input_cell_state_; + int input_layer_norm_coefficients_; + int forget_layer_norm_coefficients_; + int cell_layer_norm_coefficients_; + int output_layer_norm_coefficients_; + int output_; int n_batch_; @@ -234,6 +272,16 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { int n_cell_; int n_output_; int sequence_length_; + + private: + int AddLayerNormCoeffsTensor( + int tensor_index, const std::vector>& input_shapes) { + if (input_shapes[tensor_index][0] != 0) { + return AddInput(TensorType_FLOAT32); + } else { + return AddNullInput(); + } + } }; // The hybrid model has quantized weights. @@ -2403,6 +2451,281 @@ TEST_F(NoCifgPeepholeProjectionAndBiasClippingLstmTest, LstmBlackBoxTest) { VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } +class LayerNormUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel { + public: + LayerNormUnidirectionalLSTMOpModel( + int n_batch, int n_input, int n_cell, int n_output, int sequence_length, + bool time_major, bool use_cifg, bool use_peephole, + bool use_projection_weights, bool use_projection_bias, float cell_clip, + float proj_clip, const std::vector>& input_shapes, + const TensorType& weights_type = TensorType_FLOAT32) + : UnidirectionalLSTMOpModel( + n_batch, n_input, n_cell, n_output, sequence_length, time_major, + use_cifg, use_peephole, use_projection_weights, use_projection_bias, + cell_clip, proj_clip, input_shapes, TensorType_FLOAT32, true) {} +}; + +class BaseLayerNormLstmTest : public ::testing::Test { + protected: + // Weights of the LSTM model. Some are optional. + std::vector input_to_input_weights_; + std::vector input_to_cell_weights_; + std::vector input_to_forget_weights_; + std::vector input_to_output_weights_; + std::vector input_gate_bias_; + std::vector cell_gate_bias_; + std::vector forget_gate_bias_; + std::vector output_gate_bias_; + std::vector recurrent_to_input_weights_; + std::vector recurrent_to_cell_weights_; + std::vector recurrent_to_forget_weights_; + std::vector recurrent_to_output_weights_; + std::vector cell_to_input_weights_; + std::vector cell_to_forget_weights_; + std::vector cell_to_output_weights_; + std::vector projection_weights_; + std::vector projection_bias_; + std::vector input_layer_norm_coefficients_; + std::vector forget_layer_norm_coefficients_; + std::vector cell_layer_norm_coefficients_; + std::vector output_layer_norm_coefficients_; + + // LSTM input is stored as num_batch x num_inputs vector. + std::vector> lstm_input_; + // LSTM output is stored as num_batch x num_outputs vector. + std::vector> lstm_golden_output_; + + // Compares output up to tolerance to the result of the lstm given the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + // Feed the whole sequence as input. + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + lstm->SetInput(((i * num_batches) + b) * num_inputs, batch_start, + batch_end); + } + } + + lstm->Invoke(); + + const int num_outputs = lstm->num_outputs(); + EXPECT_GT(num_outputs, 0); + std::vector expected; + + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + } + EXPECT_THAT(lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } +}; + +class CifgPeepholeNoProjectionNoClippingLayerNormLstmTest + : public BaseLayerNormLstmTest { + void SetUp() override { + input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, + 0.05100781, 0.04717243, 0.48944736, + -0.38535351, -0.17212132}; + + input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, + 0.24407166, 0.33826375}; + + input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_cell_weights_ = { + 0.54066205, -0.32668582, -0.43562764, -0.56094903, + 0.42957711, 0.01841056, -0.32764608, -0.33027974, + -0.10826075, 0.20675004, 0.19069612, -0.03026325, + -0.54532051, 0.33003211, 0.44901288, 0.21193194}; + + recurrent_to_forget_weights_ = { + -0.13832897, -0.0515101, -0.2359007, -0.16661474, + -0.14340827, 0.36986142, 0.23414481, 0.55899, + 0.10798943, -0.41174671, 0.17751795, -0.34484994, + -0.35874045, -0.11352962, 0.27268326, 0.54058349}; + + recurrent_to_output_weights_ = { + 0.41613156, 0.42610586, -0.16495961, -0.5663873, + 0.30579174, -0.05115908, -0.33941799, 0.23364776, + 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}; + + cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, + 0.31544167}; + cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, + -0.77109635}; + + input_layer_norm_coefficients_ = {0.1, 0.2, 0.3, 0.5}; + forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3}; + cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8}; + output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.102089, 0.00653987, 0.0515139, -0.0630045, + -0.173317, 0.0109206, 0.0903292, -0.109497, + -0.23827, 0.0119514, 0.119525, -0.12748}}; + } +}; + +TEST_F(CifgPeepholeNoProjectionNoClippingLayerNormLstmTest, + LayerNormLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + + LayerNormUnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, + /*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {0}, // input_layer_norm_coefficient tensor + {n_cell}, // forget_layer_norm_coefficient tensor + {n_cell}, // cell_layer_norm_coefficient tensor + {n_cell}, // output_layer_norm_coefficient tensor + }); + + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients_); + lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_); + lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients_); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + +TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, + NonLayerNormLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + + LayerNormUnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, + /*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {0}, // input_layer_norm_coefficient tensor + {0}, // forget_layer_norm_coefficient tensor + {0}, // cell_layer_norm_coefficient tensor + {0}, // output_layer_norm_coefficient tensor + }); + + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + } // namespace } // namespace tflite