From 7e48fd7fcf0ef665ef8a4897bd45cb13fb24fc26 Mon Sep 17 00:00:00 2001
From: Robert David <lrdx@google.com>
Date: Fri, 8 Nov 2019 18:17:08 -0800
Subject: [PATCH] Remove normalization_epsilon parameter of
 MeanStddevNormalization; change it to a compile-time constant.

PiperOrigin-RevId: 279430358
Change-Id: I25aee6a065617e94b5b8d0a20f9cb2d3dce62314
---
 .../internal/optimized/neon_tensor_utils.h    |  6 ++--
 .../internal/optimized/sse_tensor_utils.h     |  6 ++--
 .../reference/portable_tensor_utils.cc        |  5 +--
 .../reference/portable_tensor_utils.h         |  6 ++--
 .../reference/portable_tensor_utils_impl.h    |  3 +-
 .../lite/kernels/internal/tensor_utils.h      |  4 +--
 .../kernels/internal/tensor_utils_test.cc     |  4 +--
 tensorflow/lite/kernels/lstm_eval.cc          | 31 ++++++-------------
 .../calibration/builtin_logging_ops/lstm.cc   | 15 +++------
 9 files changed, 27 insertions(+), 53 deletions(-)

diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
index 457a15c1b5a..b4d0581f85a 100644
--- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -254,10 +254,8 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
 }
 
 void MeanStddevNormalization(const float* input_vector, float* output_vector,
-                             int v_size, int n_batch,
-                             float normalization_epsilon) {
-  PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
-                                  normalization_epsilon);
+                             int v_size, int n_batch) {
+  PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
 }
 
 }  // namespace tensor_utils
diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
index f962d4a34ea..bc5858002c8 100644
--- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
@@ -264,10 +264,8 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
 }
 
 void MeanStddevNormalization(const float* input_vector, float* output_vector,
-                             int v_size, int n_batch,
-                             float normalization_epsilon) {
-  PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
-                                  normalization_epsilon);
+                             int v_size, int n_batch) {
+  PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
 }
 
 }  // namespace tensor_utils
diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
index 7570dcf78f0..6e287718d3e 100644
--- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -627,7 +627,7 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector,
 
 void PortableMeanStddevNormalization(const float* input_vector,
                                      float* output_vector, int v_size,
-                                     int n_batch, float normalization_epsilon) {
+                                     int n_batch) {
   for (int batch = 0; batch < n_batch; ++batch) {
     float sum = 0.0f;
     float sum_sq = 0.0f;
@@ -639,7 +639,8 @@ void PortableMeanStddevNormalization(const float* input_vector,
     float stddev_inv = 0.0f;
     const float variance = sum_sq / v_size - mean * mean;
     if (variance == 0) {
-      stddev_inv = 1.0f / std::sqrt(normalization_epsilon);
+      constexpr float kNormalizationConstant = 1e-8;
+      stddev_inv = 1.0f / std::sqrt(kNormalizationConstant);
     } else {
       stddev_inv = 1.0f / std::sqrt(variance);
     }
diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h
index 969be186d5f..51d64bfc328 100644
--- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -258,10 +258,8 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
 }
 
 void MeanStddevNormalization(const float* input_vector, float* output_vector,
-                             int v_size, int n_batch,
-                             float normalization_epsilon) {
-  PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
-                                  normalization_epsilon);
+                             int v_size, int n_batch) {
+  PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
 }
 
 }  // namespace tensor_utils
diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h
index d8dc86d59da..de878da5cb3 100644
--- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h
+++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h
@@ -200,10 +200,9 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector,
                                 int output_size, int reduction_size);
 
 // Layer norm for each batch.
-// normalization_epsilon is added to avoid divergence.
 void PortableMeanStddevNormalization(const float* input_vector,
                                      float* output_vector, int v_size,
-                                     int n_batch, float normalization_epsilon);
+                                     int n_batch);
 
 }  // namespace tensor_utils
 }  // namespace tflite
diff --git a/tensorflow/lite/kernels/internal/tensor_utils.h b/tensorflow/lite/kernels/internal/tensor_utils.h
index a1d55d11152..2c685518408 100644
--- a/tensorflow/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/tensor_utils.h
@@ -439,10 +439,8 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
                         int output_size, int reduction_size);
 
 // Layer norm for each batch.
-// normalization_epsilon is added to avoid divergence.
 void MeanStddevNormalization(const float* input_vector, float* output_vector,
-                             int v_size, int n_batch,
-                             float normalization_epsilon);
+                             int v_size, int n_batch);
 }  // namespace tensor_utils
 }  // namespace tflite
 
diff --git a/tensorflow/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/lite/kernels/internal/tensor_utils_test.cc
index f54004fc696..79a66c972d1 100644
--- a/tensorflow/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/lite/kernels/internal/tensor_utils_test.cc
@@ -1466,7 +1466,6 @@ TEST(uKernels, ReductionSumVectorTest) {
 TEST(uKernels, MeanStddevNormalization) {
   constexpr int kVectorSize = 4;
   constexpr int kBatchSize = 8;  // 9, but large mean, small variance fails
-  constexpr float kNormalizationEpsilon = 1e-8;
 
   // None-zero input.
   static float input[kVectorSize * kBatchSize] = {
@@ -1480,8 +1479,7 @@ TEST(uKernels, MeanStddevNormalization) {
       -100.0f,  0.0f,    200.0f,  300.0f,   // large mean, large variance
   };
   float output[kVectorSize * kBatchSize];
-  MeanStddevNormalization(input, output, kVectorSize, kBatchSize,
-                          kNormalizationEpsilon);
+  MeanStddevNormalization(input, output, kVectorSize, kBatchSize);
   const float ksqrt16 = std::sqrt(1.6f);
   const float ksqrt04 = std::sqrt(0.4f);
   const std::vector<float> expected_output = {
diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc
index 6e496801363..86cf2cd054c 100644
--- a/tensorflow/lite/kernels/lstm_eval.cc
+++ b/tensorflow/lite/kernels/lstm_eval.cc
@@ -38,11 +38,6 @@ namespace builtin {
 namespace lstm_eval {
 
 namespace {
-
-// Small float to avoid divergence during calculation of deviation for layer
-// norm lstm.
-const float kLayerNormEpsilon = 1e-8;
-
 // Performs an LSTM batch inference step for input specified by input_ptr_batch.
 // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
 // biases (*_bias_ptr), and buffers (*_scratch), along with additional
@@ -224,9 +219,8 @@ inline void LstmStepWithAuxInput(
           input_gate_scratch);
     }
     if (is_layer_norm_lstm) {
-      tensor_utils::MeanStddevNormalization(input_gate_scratch,
-                                            input_gate_scratch, n_cell, n_batch,
-                                            kLayerNormEpsilon);
+      tensor_utils::MeanStddevNormalization(
+          input_gate_scratch, input_gate_scratch, n_cell, n_batch);
       tensor_utils::VectorBatchVectorCwiseProduct(
           input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
           n_batch, input_gate_scratch);
@@ -245,8 +239,7 @@ inline void LstmStepWithAuxInput(
   }
   if (is_layer_norm_lstm) {
     tensor_utils::MeanStddevNormalization(forget_gate_scratch,
-                                          forget_gate_scratch, n_cell, n_batch,
-                                          kLayerNormEpsilon);
+                                          forget_gate_scratch, n_cell, n_batch);
     tensor_utils::VectorBatchVectorCwiseProduct(
         forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
         n_batch, forget_gate_scratch);
@@ -261,7 +254,7 @@ inline void LstmStepWithAuxInput(
                                          n_batch * n_cell, cell_state_ptr);
   if (is_layer_norm_lstm) {
     tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
-                                          n_batch, kLayerNormEpsilon);
+                                          n_batch);
     tensor_utils::VectorBatchVectorCwiseProduct(
         cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
         cell_scratch);
@@ -292,8 +285,7 @@ inline void LstmStepWithAuxInput(
   }
   if (is_layer_norm_lstm) {
     tensor_utils::MeanStddevNormalization(output_gate_scratch,
-                                          output_gate_scratch, n_cell, n_batch,
-                                          kLayerNormEpsilon);
+                                          output_gate_scratch, n_cell, n_batch);
     tensor_utils::VectorBatchVectorCwiseProduct(
         output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
         n_batch, output_gate_scratch);
@@ -699,9 +691,8 @@ inline void LstmStepWithAuxInput(
           input_gate_scratch);
     }
     if (is_layer_norm_lstm) {
-      tensor_utils::MeanStddevNormalization(input_gate_scratch,
-                                            input_gate_scratch, n_cell, n_batch,
-                                            kLayerNormEpsilon);
+      tensor_utils::MeanStddevNormalization(
+          input_gate_scratch, input_gate_scratch, n_cell, n_batch);
       tensor_utils::VectorBatchVectorCwiseProduct(
           input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
           n_batch, input_gate_scratch);
@@ -723,8 +714,7 @@ inline void LstmStepWithAuxInput(
   }
   if (is_layer_norm_lstm) {
     tensor_utils::MeanStddevNormalization(forget_gate_scratch,
-                                          forget_gate_scratch, n_cell, n_batch,
-                                          kLayerNormEpsilon);
+                                          forget_gate_scratch, n_cell, n_batch);
     tensor_utils::VectorBatchVectorCwiseProduct(
         forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
         n_batch, forget_gate_scratch);
@@ -739,7 +729,7 @@ inline void LstmStepWithAuxInput(
                                          n_batch * n_cell, cell_state_ptr);
   if (is_layer_norm_lstm) {
     tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
-                                          n_batch, kLayerNormEpsilon);
+                                          n_batch);
     tensor_utils::VectorBatchVectorCwiseProduct(
         cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
         cell_scratch);
@@ -775,8 +765,7 @@ inline void LstmStepWithAuxInput(
   }
   if (is_layer_norm_lstm) {
     tensor_utils::MeanStddevNormalization(output_gate_scratch,
-                                          output_gate_scratch, n_cell, n_batch,
-                                          kLayerNormEpsilon);
+                                          output_gate_scratch, n_cell, n_batch);
     tensor_utils::VectorBatchVectorCwiseProduct(
         output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
         n_batch, output_gate_scratch);
diff --git a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc
index 102967cc936..4e36b7d22c9 100644
--- a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc
+++ b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc
@@ -32,8 +32,6 @@ namespace builtin {
 
 namespace {
 
-const float kLayerNormEpsilon = 1e-8;
-
 inline void LstmStepWithAuxInput(
     const float* input_ptr_batch, const float* input_to_input_weights_ptr,
     const float* input_to_forget_weights_ptr,
@@ -157,9 +155,8 @@ inline void LstmStepWithAuxInput(
     if (is_layer_norm_lstm) {
       logger->LogTensorValue(intemediate_tensor_indexes[0], input_gate_scratch,
                              n_cell * n_batch);
-      tensor_utils::MeanStddevNormalization(input_gate_scratch,
-                                            input_gate_scratch, n_cell, n_batch,
-                                            kLayerNormEpsilon);
+      tensor_utils::MeanStddevNormalization(
+          input_gate_scratch, input_gate_scratch, n_cell, n_batch);
       tensor_utils::VectorBatchVectorCwiseProduct(
           input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
           n_batch, input_gate_scratch);
@@ -180,8 +177,7 @@ inline void LstmStepWithAuxInput(
     logger->LogTensorValue(intemediate_tensor_indexes[1], forget_gate_scratch,
                            n_cell * n_batch);
     tensor_utils::MeanStddevNormalization(forget_gate_scratch,
-                                          forget_gate_scratch, n_cell, n_batch,
-                                          kLayerNormEpsilon);
+                                          forget_gate_scratch, n_cell, n_batch);
     tensor_utils::VectorBatchVectorCwiseProduct(
         forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
         n_batch, forget_gate_scratch);
@@ -198,7 +194,7 @@ inline void LstmStepWithAuxInput(
     logger->LogTensorValue(intemediate_tensor_indexes[2], cell_scratch,
                            n_cell * n_batch);
     tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
-                                          n_batch, kLayerNormEpsilon);
+                                          n_batch);
     tensor_utils::VectorBatchVectorCwiseProduct(
         cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
         cell_scratch);
@@ -231,8 +227,7 @@ inline void LstmStepWithAuxInput(
     logger->LogTensorValue(intemediate_tensor_indexes[3], output_gate_scratch,
                            n_cell * n_batch);
     tensor_utils::MeanStddevNormalization(output_gate_scratch,
-                                          output_gate_scratch, n_cell, n_batch,
-                                          kLayerNormEpsilon);
+                                          output_gate_scratch, n_cell, n_batch);
     tensor_utils::VectorBatchVectorCwiseProduct(
         output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
         n_batch, output_gate_scratch);