From b13a153d077046fc0c823400c176ce39f41ce94d Mon Sep 17 00:00:00 2001
From: Robert David <lrdx@google.com>
Date: Fri, 17 Jul 2020 18:06:19 -0700
Subject: [PATCH] Separate "is layer norm" and "has layer norm tensors"
 parameters of LSTMOpModel.

PiperOrigin-RevId: 321890450
Change-Id: Ie5a07786688bd1e3e2914362e78735fee29df093
---
 tensorflow/lite/kernels/lstm_test.cc | 161 +++++++++++++--------------
 1 file changed, 80 insertions(+), 81 deletions(-)

diff --git a/tensorflow/lite/kernels/lstm_test.cc b/tensorflow/lite/kernels/lstm_test.cc
index 1a42d637c08..754aaba9319 100644
--- a/tensorflow/lite/kernels/lstm_test.cc
+++ b/tensorflow/lite/kernels/lstm_test.cc
@@ -40,8 +40,8 @@ class LSTMOpModel : public SingleOpModel {
               bool use_peephole, bool use_projection_weights,
               bool use_projection_bias, float cell_clip, float proj_clip,
               const std::vector<std::vector<int>>& input_shapes,
-              const TensorType weight_type, bool is_layer_norm,
-              bool asymmetric_quantize_inputs)
+              const TensorType weight_type, bool model_has_legacy_20_inputs,
+              bool is_layer_norm, bool asymmetric_quantize_inputs)
       : n_batch_(n_batch),
         n_input_(n_input),
         n_cell_(n_cell),
@@ -111,23 +111,19 @@ class LSTMOpModel : public SingleOpModel {
         AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true);
 
     // Layer norm weights.
-    if (is_layer_norm) {
-      const int kInputLayerNormCoeffsIndex = 20;
-      const int kForgetLayerNormCoeffsIndex = 21;
-      const int kCellLayerNormCoeffsIndex = 22;
-      const int kOutputLayerNormCoeffsIndex = 23;
+    if (!model_has_legacy_20_inputs) {
       if (use_cifg) {
         input_layer_norm_coefficients_ = AddNullInput();
       } else {
         input_layer_norm_coefficients_ =
-            AddLayerNormCoeffsTensor(kInputLayerNormCoeffsIndex, input_shapes);
+            is_layer_norm ? AddInput(TensorType_FLOAT32) : AddNullInput();
       }
       forget_layer_norm_coefficients_ =
-          AddLayerNormCoeffsTensor(kForgetLayerNormCoeffsIndex, input_shapes);
+          is_layer_norm ? AddInput(TensorType_FLOAT32) : AddNullInput();
       cell_layer_norm_coefficients_ =
-          AddLayerNormCoeffsTensor(kCellLayerNormCoeffsIndex, input_shapes);
+          is_layer_norm ? AddInput(TensorType_FLOAT32) : AddNullInput();
       output_layer_norm_coefficients_ =
-          AddLayerNormCoeffsTensor(kOutputLayerNormCoeffsIndex, input_shapes);
+          is_layer_norm ? AddInput(TensorType_FLOAT32) : AddNullInput();
     }
 
     output_ = AddOutput(TensorType_FLOAT32);
@@ -277,15 +273,6 @@ class LSTMOpModel : public SingleOpModel {
   int n_output_;
 
  private:
-  int AddLayerNormCoeffsTensor(
-      int tensor_index, const std::vector<std::vector<int>>& input_shapes) {
-    if (input_shapes[tensor_index][0] != 0) {
-      return AddInput(TensorType_FLOAT32);
-    } else {
-      return AddNullInput();
-    }
-  }
-
   template <typename T>
   void PopulateTensor(int index, const std::vector<T>& data) {
     // Nothing to do if tensor is an optional input or if data vector is empty.
@@ -504,16 +491,17 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
                        {0, 0},  // projection_weight tensor
                        {0},     // projection_bias tensor
                    },
-                   /*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/false,
+                   /*weight_type=*/TensorType_FLOAT32,
+                   /*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
                    /*asymmetric_quantize_inputs=*/false);
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
 }
 
-class NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest
+class NoCifgNoPeepholeNoProjectionNoClippingNoLayerNormLstmTest
     : public NoCifgNoPeepholeNoProjectionNoClippingLstmTest {};
 
-TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
+TEST_F(NoCifgNoPeepholeNoProjectionNoClippingNoLayerNormLstmTest,
        LstmBlackBoxTest) {
   const int n_batch = 1;
   const int n_input = 2;
@@ -559,7 +547,9 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
                        {0},  // cell_layer_norm_coefficient tensor
                        {0},  // output_layer_norm_coefficient tensor
                    },
-                   /*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/true,
+                   /*weight_type=*/TensorType_FLOAT32,
+                   /*model_has_legacy_20_inputs=*/false,
+                   /*is_layer_norm=*/false,
                    /*asymmetric_quantize_inputs=*/false);
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
@@ -607,7 +597,8 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
                        {0, 0},  // projection_weight tensor
                        {0},     // projection_bias tensor
                    },
-                   /*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/false,
+                   /*weight_type=*/TensorType_UINT8,
+                   /*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
                    /*asymmetric_quantize_inputs=*/GetParam());
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
@@ -658,7 +649,8 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
                        {0, 0},  // projection_weight tensor
                        {0},     // projection_bias tensor
                    },
-                   /*weight_type=*/TensorType_INT8, /*is_layer_norm=*/false,
+                   /*weight_type=*/TensorType_INT8,
+                   /*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
                    /*asymmetric_quantize_inputs=*/GetParam());
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
@@ -749,7 +741,8 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
                        {0, 0},  // projection_weight tensor
                        {0},     // projection_bias tensor
                    },
-                   /*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/false,
+                   /*weight_type=*/TensorType_FLOAT32,
+                   /*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
                    /*asymmetric_quantize_inputs=*/false);
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
@@ -797,7 +790,8 @@ TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmTest,
                        {0, 0},  // projection_weight tensor
                        {0},     // projection_bias tensor
                    },
-                   /*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/false,
+                   /*weight_type=*/TensorType_UINT8,
+                   /*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
                    /*asymmetric_quantize_inputs=*/GetParam());
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
@@ -846,7 +840,8 @@ TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
                        {0, 0},  // projection_weight tensor
                        {0},     // projection_bias tensor
                    },
-                   /*weight_type=*/TensorType_INT8, /*is_layer_norm=*/false,
+                   /*weight_type=*/TensorType_INT8,
+                   /*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
                    /*asymmetric_quantize_inputs=*/GetParam());
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
@@ -1487,7 +1482,8 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
                        {n_output, n_cell},  // projection_weight tensor
                        {0},                 // projection_bias tensor
                    },
-                   /*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/false,
+                   /*weight_type=*/TensorType_FLOAT32,
+                   /*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
                    /*asymmetric_quantize_inputs=*/false);
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
@@ -1534,7 +1530,8 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLstmTest,
                        {n_output, n_cell},  // projection_weight tensor
                        {0},                 // projection_bias tensor
                    },
-                   /*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/false,
+                   /*weight_type=*/TensorType_UINT8,
+                   /*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
                    /*asymmetric_quantize_inputs=*/GetParam());
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
@@ -1583,7 +1580,8 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLstmInt8Test,
                        {n_output, n_cell},  // projection_weight tensor
                        {0},                 // projection_bias tensor
                    },
-                   /*weight_type=*/TensorType_INT8, /*is_layer_norm=*/false,
+                   /*weight_type=*/TensorType_INT8,
+                   /*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
                    /*asymmetric_quantize_inputs=*/GetParam());
 
   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0015);
@@ -1703,8 +1701,8 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
           {n_cell},  // cell_layer_norm_coefficient tensor
           {n_cell},  // output_layer_norm_coefficient tensor
       },
-      /*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/true,
-      /*asymmetric_quantize_inputs=*/false);
+      /*weight_type=*/TensorType_FLOAT32, /*model_has_legacy_20_inputs=*/false,
+      /*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/false);
 
   // Verify the final output.
   lstm_golden_output_ = {{
@@ -1774,8 +1772,8 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
           {n_cell},  // cell_layer_norm_coefficient tensor
           {n_cell},  // output_layer_norm_coefficient tensor
       },
-      /*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/true,
-      /*asymmetric_quantize_inputs=*/GetParam());
+      /*weight_type=*/TensorType_UINT8, /*model_has_legacy_20_inputs=*/false,
+      /*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/GetParam());
 
   lstm_golden_output_ = {{
                              // Batch0: 3 (input_sequence_size) * 3 (n_output)
@@ -1847,8 +1845,8 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
           {n_cell},  // cell_layer_norm_coefficient tensor
           {n_cell},  // output_layer_norm_coefficient tensor
       },
-      /*weight_type=*/TensorType_INT8, /*is_layer_norm=*/true,
-      /*asymmetric_quantize_inputs=*/GetParam());
+      /*weight_type=*/TensorType_INT8, /*model_has_legacy_20_inputs=*/false,
+      /*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/GetParam());
 
   // Goldens are calculated from weight_type=TensorType_FLOAT32.
   lstm_golden_output_ = {{
@@ -1961,8 +1959,8 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
           {n_cell},  // cell_layer_norm_coefficient tensor
           {n_cell},  // output_layer_norm_coefficient tensor
       },
-      /*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/true,
-      /*asymmetric_quantize_inputs=*/false);
+      /*weight_type=*/TensorType_FLOAT32, /*model_has_legacy_20_inputs=*/false,
+      /*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/false);
 
   // Verify the final output.
   lstm_golden_output_ = {
@@ -2032,8 +2030,8 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
           {n_cell},  // cell_layer_norm_coefficient tensor
           {n_cell},  // output_layer_norm_coefficient tensor
       },
-      /*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/true,
-      /*asymmetric_quantize_inputs=*/GetParam());
+      /*weight_type=*/TensorType_UINT8, /*model_has_legacy_20_inputs=*/false,
+      /*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/GetParam());
 
   // Verify the final output.
   lstm_golden_output_ = {
@@ -2104,8 +2102,8 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
           {n_cell},  // cell_layer_norm_coefficient tensor
           {n_cell},  // output_layer_norm_coefficient tensor
       },
-      /*weight_type=*/TensorType_INT8, /*is_layer_norm=*/true,
-      /*asymmetric_quantize_inputs=*/GetParam());
+      /*weight_type=*/TensorType_INT8, /*model_has_legacy_20_inputs=*/false,
+      /*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/GetParam());
 
   // Goldens are results using FLOAT32 inference.
   lstm_golden_output_ = {{
@@ -3278,41 +3276,6 @@ TEST(LSTMOpModel, InvalidTypeTest) {
   const int n_cell = 4;
   const int n_output = 4;
 
-  EXPECT_DEATH(LSTMOpModel lstm(
-                   n_batch, n_input, n_cell, n_output,
-                   /*use_cifg=*/false, /*use_peephole=*/false,
-                   /*use_projection_weights=*/false,
-                   /*use_projection_bias=*/false,
-                   /*cell_clip=*/0.0, /*proj_clip=*/0.0,
-                   {
-                       {n_batch, n_input},  // input tensor
-
-                       {n_cell, n_input},  // input_to_input_weight tensor
-                       {n_cell, n_input},  // input_to_forget_weight tensor
-                       {n_cell, n_input},  // input_to_cell_weight tensor
-                       {n_cell, n_input},  // input_to_output_weight tensor
-
-                       {n_cell, n_output},  // recurrent_to_input_weight_tensor
-                       {n_cell, n_output},  // recurrent_to_forget_weight_tensor
-                       {n_cell, n_output},  // recurrent_to_cell_weight_tensor
-                       {n_cell, n_output},  // recurrent_to_output_weight_tensor
-
-                       {0},  // cell_to_input_weight tensor
-                       {0},  // cell_to_forget_weight tensor
-                       {0},  // cell_to_output_weight tensor
-
-                       {n_cell},  // input_gate_bias tensor
-                       {n_cell},  // forget_gate_bias tensor
-                       {n_cell},  // cell_gate_bias tensor
-                       {n_cell},  // output_gate_bias tensor
-
-                       {0, 0},  // projection_weight tensor
-                       {0},     // projection_bias tensor
-                   },
-                   /*weight_type=*/TensorType_INT32, /*is_layer_norm=*/false,
-                   /*asymmetric_quantize_inputs=*/false),
-               "");
-
   EXPECT_DEATH(
       LSTMOpModel lstm(
           n_batch, n_input, n_cell, n_output,
@@ -3345,9 +3308,45 @@ TEST(LSTMOpModel, InvalidTypeTest) {
               {0, 0},  // projection_weight tensor
               {0},     // projection_bias tensor
           },
-          /*weight_type=*/TensorType_COMPLEX64, /*is_layer_norm=*/false,
-          /*asymmetric_quantize_inputs=*/false),
+          /*weight_type=*/TensorType_INT32, /*model_has_legacy_20_inputs=*/true,
+          /*is_layer_norm=*/false, /*asymmetric_quantize_inputs=*/false),
       "");
+
+  EXPECT_DEATH(LSTMOpModel lstm(
+                   n_batch, n_input, n_cell, n_output,
+                   /*use_cifg=*/false, /*use_peephole=*/false,
+                   /*use_projection_weights=*/false,
+                   /*use_projection_bias=*/false,
+                   /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+                   {
+                       {n_batch, n_input},  // input tensor
+
+                       {n_cell, n_input},  // input_to_input_weight tensor
+                       {n_cell, n_input},  // input_to_forget_weight tensor
+                       {n_cell, n_input},  // input_to_cell_weight tensor
+                       {n_cell, n_input},  // input_to_output_weight tensor
+
+                       {n_cell, n_output},  // recurrent_to_input_weight_tensor
+                       {n_cell, n_output},  // recurrent_to_forget_weight_tensor
+                       {n_cell, n_output},  // recurrent_to_cell_weight_tensor
+                       {n_cell, n_output},  // recurrent_to_output_weight_tensor
+
+                       {0},  // cell_to_input_weight tensor
+                       {0},  // cell_to_forget_weight tensor
+                       {0},  // cell_to_output_weight tensor
+
+                       {n_cell},  // input_gate_bias tensor
+                       {n_cell},  // forget_gate_bias tensor
+                       {n_cell},  // cell_gate_bias tensor
+                       {n_cell},  // output_gate_bias tensor
+
+                       {0, 0},  // projection_weight tensor
+                       {0},     // projection_bias tensor
+                   },
+                   /*weight_type=*/TensorType_COMPLEX64,
+                   /*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
+                   /*asymmetric_quantize_inputs=*/false),
+               "");
 }
 #endif