From 54081bae792cd50b911d2209996c9e7340a345fc Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 16 Oct 2020 10:19:35 -0700
Subject: [PATCH] Register Unidirectional_sequence_lstm logging op in
 calibrator.

PiperOrigin-RevId: 337529861
Change-Id: Ie4d13f8066cf6a75f2baab66e016336a95302c93
---
 .../testdata/unidirectional_sequence_lstm.bin | Bin 0 -> 2752 bytes
 .../lite/tools/optimize/calibration/BUILD     |   1 +
 .../calibration/builtin_logging_ops/lstm.cc   |  43 ++++++++--
 .../calibration/builtin_logging_ops/lstm.h    |   9 +++
 .../tools/optimize/calibration/calibrator.cc  |  20 +++--
 .../optimize/calibration/calibrator_test.cc   |  75 +++++++++++++++---
 .../lite/tools/optimize/operator_property.cc  |   6 +-
 7 files changed, 128 insertions(+), 26 deletions(-)
 create mode 100644 tensorflow/lite/testdata/unidirectional_sequence_lstm.bin

diff --git a/tensorflow/lite/testdata/unidirectional_sequence_lstm.bin b/tensorflow/lite/testdata/unidirectional_sequence_lstm.bin
new file mode 100644
index 0000000000000000000000000000000000000000..42c96d14faabbbaddbaede8aa80d3d200ea7b3af
GIT binary patch
literal 2752
zcmah~y-yQi9Db`+s#dLvwSJ~13=EjG1q&tyH6hBtWDIdIIVwHjOle6u5Tk>C00RRd
zny|Y#m`DhNF^&ujOz7a~;J|1|qv-iO@5gaRvG*oV-@E>P&+q$vFDrz2aCfL1n_eeG
zyNHW~Na8&uPT=<-@am`#?+yq^gg65%;dvGq0J6YNOo*?*GB6KRfI*-G5WqKVtH1&<
z2@C<90PSu5@jUJc;KhfJ{ljnH^yATA+Y$2g-`4zMgnT$cKBMI+L}cQ-ACG<dvYEcc
zAQvO}D-rU|2>Et|JngCc{}*2|LOvWJpNWt!M#xtp<eL$4e-OlB%t#!+YHs$+Hx4;x
ziDO53(x2v(nrH~(KttQ0-X%bYwA?xOZIUU#%-~<S?v=4U1O|aDKp%&I^S}k*B0xKr
zfeg^RVss_XRU3Xe>sT<8y+Z*xhR1gtDS7UUALiT!kcV>li>LHBz;k*MV9uw2(*VzD
zJ8%|Y9_Ik#O9q4YE5J(_?*X=u&@wO!+y)ZB1_F2ulz|+ugUetISO7+VE}(&6*MT{J
z_dycafP*Dq1Q-As@Usfc0(}5XiNr6@>(I}P&8fO}eZnp~CD*ocR*yCAl*@Ki#tgCO
zgMLoyUv1X!wkoCPcHOGh>Jyezot$=|1J`o7t+|C(-1n`y`BvPm*4$hxZqiH%s>MFb
znqjX^yGX9k!^Gx)dtMiI_M7qypybvV%UG#e&unLG+?^`OJD&AGd&C-etUYIiJno}*
zr4l3*@So9dXg{1U<L61OK4!a3e)3T9dAS|W>%==}lRVXDgx^HV`J8o3H%(;mLk@G3
z{j_9+Q!WzA-%xp&_)?gBb*5M;bAm;TzD>N!MrqV_W=gJ8t6Itu$5GbA*!qr`d6s#m
z$UO#3)@x6#$4+S~-|NSj#CItR8P6Z}s4w+&#q&jFm-QxYLcasM=_T_D_DY(+zFWJB
z)~A2kXYFMLWM`$eoB<Q}lGewkTI$n&ut$=CDKPr$TK|aF=M1PC?K%j~z^c|~EWA?~
zOF<`~Q@+OL(>~Xa{q%J}`{`%!onw7ypZAuuuL26H`}XIwKKo7%@4LD^x+kPQS$ori
z&zi%#lpN+vj;yQp(u@C$KKGxRW4=4aPdJrgh|`G~W-how{a(;dC>``V=;skQW?mT^
zYb^EmWW=Tic?;SXu42CqiWkhPJH%PjzGx@xE1VVkAM};P6jGnxNzO6jWxx3!d#3z`
mvXu8h*hW(7D=UgeKI<!)xa66OnNxE4On=l$;CxdTtLq=q0VEj!

literal 0
HcmV?d00001

diff --git a/tensorflow/lite/tools/optimize/calibration/BUILD b/tensorflow/lite/tools/optimize/calibration/BUILD
index 53bd1bb4faf..2f315ca509a 100644
--- a/tensorflow/lite/tools/optimize/calibration/BUILD
+++ b/tensorflow/lite/tools/optimize/calibration/BUILD
@@ -65,6 +65,7 @@ tf_cc_test(
     data = [
         "//tensorflow/lite:testdata/lstm.bin",
         "//tensorflow/lite:testdata/multi_add.bin",
+        "//tensorflow/lite:testdata/unidirectional_sequence_lstm.bin",
     ],
     tags = [
         "tflite_not_portable_android",
diff --git a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc
index 040ec86f6e2..bdf27d9a980 100644
--- a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc
+++ b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc
@@ -461,10 +461,9 @@ struct OpData {
 // Resize the output, state tensors based on the sizes of the input tensors.
 // Allocate a temporary scratch tensor. Also check that the sizes of the input
 // tensors match each other.
-TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
+TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node,
+                       LSTMType lstm_type, Logger* logger,
                        ErrorReporter* error_reporter) {
-  const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
-
   const TfLiteTensor* input;
   TF_LITE_ENSURE_OK(
       context, GetInputSafe(context, node,
@@ -578,6 +577,31 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
     intermediate_tensor_indexes[i] = node->intermediates->data[i];
   }
 
+  TfLiteLSTMParams lstm_params;
+  bool time_major = true;
+  switch (lstm_type) {
+    case LSTMType::kLSTM: {
+      lstm_params = *(static_cast<TfLiteLSTMParams*>(node->builtin_data));
+      time_major = true;
+      break;
+    }
+    case LSTMType::kUnidirectionalSequenceLSTM: {
+      const auto* params = static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
+          node->builtin_data);
+      // Copy out the LSTM specific params so they can be passed in the
+      // function.
+      lstm_params.activation = params->activation;
+      lstm_params.cell_clip = params->cell_clip;
+      lstm_params.proj_clip = params->proj_clip;
+      lstm_params.asymmetric_quantize_inputs =
+          params->asymmetric_quantize_inputs;
+      time_major = params->time_major;
+      break;
+    }
+    default:
+      return kTfLiteError;
+  }
+
   switch (input_to_output_weights->type) {
     case kTfLiteFloat32: {
       return EvalCalibration(
@@ -594,9 +618,9 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
           /*aux_input_to_cell_weights=*/nullptr,
           /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
           forget_gate_bias, cell_gate_bias, output_gate_bias,
-          projection_weights, projection_bias, params,
+          projection_weights, projection_bias, &lstm_params,
           /*forward_sequence=*/true,
-          /*time_major=*/true,
+          /*time_major=*/time_major,
           /*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
           logger, intermediate_tensor_indexes, error_reporter);
     }
@@ -613,7 +637,14 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
 TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
                                  Logger* logger,
                                  ErrorReporter* error_reporter) {
-  return lstm_eval(context, node, logger, error_reporter);
+  return lstm_eval(context, node, LSTMType::kLSTM, logger, error_reporter);
+}
+
+TfLiteStatus unidirectional_sequence_lstm_logging_kernel(
+    TfLiteContext* context, TfLiteNode* node, Logger* logger,
+    ErrorReporter* error_reporter) {
+  return lstm_eval(context, node, LSTMType::kUnidirectionalSequenceLSTM, logger,
+                   error_reporter);
 }
 
 }  // namespace builtin
diff --git a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h
index f3306bc0564..0a9e7095507 100644
--- a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h
+++ b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h
@@ -23,9 +23,18 @@ namespace optimize {
 namespace calibration {
 namespace builtin {
 
+enum class LSTMType {
+  kLSTM,
+  kUnidirectionalSequenceLSTM,
+};
+
 TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
                                  Logger* logger, ErrorReporter* error_reporter);
 
+TfLiteStatus unidirectional_sequence_lstm_logging_kernel(
+    TfLiteContext* context, TfLiteNode* node, Logger* logger,
+    ErrorReporter* error_reporter);
+
 }  // namespace builtin
 }  // namespace calibration
 }  // namespace optimize
diff --git a/tensorflow/lite/tools/optimize/calibration/calibrator.cc b/tensorflow/lite/tools/optimize/calibration/calibrator.cc
index be8fad8a221..6cddbc53009 100644
--- a/tensorflow/lite/tools/optimize/calibration/calibrator.cc
+++ b/tensorflow/lite/tools/optimize/calibration/calibrator.cc
@@ -174,13 +174,17 @@ GlobalCalibratorRegistry* GetCalibratorRegistry() {
 // TODO(jianlijianli): extend this to support multiple recipe for the same
 // model.
 logging_kernel_func_ptr GetLoggingEvalFunc(TfLiteContext* context,
-                                           TfLiteNode* node) {
-  const int lstm_number_input = 24;
-  if (node->inputs->size == lstm_number_input) {
-    // LSTM Op.
-    return tflite::optimize::calibration::builtin::lstm_logging_kernel;
+                                           TfLiteNode* node,
+                                           int builtin_op_code) {
+  switch (builtin_op_code) {
+    case BuiltinOperator_LSTM:
+      return tflite::optimize::calibration::builtin::lstm_logging_kernel;
+    case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
+      return tflite::optimize::calibration::builtin::
+          unidirectional_sequence_lstm_logging_kernel;
+    default:
+      return nullptr;
   }
-  return nullptr;
 }
 
 // A wrapper implementation for |TfLiteRegistration.invoke| that logs inputs,
@@ -203,7 +207,9 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
     TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
         i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
   }
-  auto kernel_invoke_intermediate = GetLoggingEvalFunc(context, node);
+  auto builtin_op_code = calibrator->GetOpInfo(node).builtin_op_code;
+  auto kernel_invoke_intermediate =
+      GetLoggingEvalFunc(context, node, builtin_op_code);
   TfLiteStatus status;
   if (kernel_invoke_intermediate == nullptr) {
     status = kernel_invoke(context, node);
diff --git a/tensorflow/lite/tools/optimize/calibration/calibrator_test.cc b/tensorflow/lite/tools/optimize/calibration/calibrator_test.cc
index f0cd27ef620..c2e205f2a6e 100644
--- a/tensorflow/lite/tools/optimize/calibration/calibrator_test.cc
+++ b/tensorflow/lite/tools/optimize/calibration/calibrator_test.cc
@@ -283,7 +283,7 @@ TEST(CalibratorTest, LSTM) {
   auto status = BuildLoggingInterpreter(*flatbuffer_model,
                                         ops::builtin::BuiltinOpResolver{},
                                         &interpreter, &reader);
-  EXPECT_EQ(kTfLiteOk, status);
+  EXPECT_EQ(status, kTfLiteOk);
 
   auto readonly_model = flatbuffer_model->GetModel();
   tflite::ModelT model;
@@ -294,24 +294,17 @@ TEST(CalibratorTest, LSTM) {
   status = interpreter->AllocateTensors();
 
   EXPECT_EQ(kTfLiteOk, status);
-  const std::vector<float> lstm_input = {
-      0.3, 0.2, 0.9, 0.8, 0.1,  //
-      0.1, 0.5, 0.2, 0.4, 0.2,  //
-      0.6, 0.9, 0.2, 0.5, 0.7,  //
-  };
+  const std::vector<float> lstm_input = {0.3, 0.2};
   int input_tensor_idx = interpreter->inputs()[0];
   TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx);
   for (size_t j = 0; j < lstm_input.size(); j++) {
     tensor->data.f[j] = lstm_input[j];
   }
 
-  // Invoke with update == true.
-  status = interpreter->Invoke();
-  ASSERT_EQ(kTfLiteOk, status);
+  ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
 
   absl::flat_hash_map<int, CalibrationReader::CalibrationStats> stats;
-  status = reader->GetTensorStatsAsMap(&stats);
-  EXPECT_EQ(kTfLiteOk, status);
+  EXPECT_EQ(reader->GetTensorStatsAsMap(&stats), kTfLiteOk);
 
   // Check the results.
   const float eps = 1e-6f;
@@ -344,6 +337,66 @@ TEST(CalibratorTest, LSTM) {
   }
 }
 
+TEST(CalibratorTest, UnidirectionalSequenceLSTM) {
+  auto flatbuffer_model = ReadModel("unidirectional_sequence_lstm.bin");
+  ASSERT_TRUE(flatbuffer_model);
+  std::unique_ptr<Interpreter> interpreter;
+  std::unique_ptr<CalibrationReader> reader;
+  auto status = BuildLoggingInterpreter(*flatbuffer_model,
+                                        ops::builtin::BuiltinOpResolver{},
+                                        &interpreter, &reader);
+  EXPECT_EQ(kTfLiteOk, status);
+
+  auto readonly_model = flatbuffer_model->GetModel();
+  tflite::ModelT model;
+  readonly_model->UnPackTo(&model);
+
+  ASSERT_TRUE(interpreter);
+  ASSERT_TRUE(reader);
+  EXPECT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
+  const std::vector<float> lstm_input = {0.3, 0.2, 0.9, 0.8};
+  int input_tensor_idx = interpreter->inputs()[0];
+  TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx);
+  for (size_t j = 0; j < lstm_input.size(); j++) {
+    tensor->data.f[j] = lstm_input[j];
+  }
+
+  ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
+
+  absl::flat_hash_map<int, CalibrationReader::CalibrationStats> stats;
+  EXPECT_EQ(reader->GetTensorStatsAsMap(&stats), kTfLiteOk);
+
+  // Check the results.
+  const float eps = 1e-6f;
+  const std::unordered_map<int, CalibrationReader::CalibrationStats>
+      expected_calibration_result = {
+          // Input.
+          {0, {0.200000, 0.900000}},
+          // State.
+          {18, {0.000000, 0.520999}},
+          // State.
+          {19, {0.000000, 0.711364}},
+          // Output.
+          {24, {0.247992, 0.520999}},
+          // Intemediate_0.
+          {25, {0.080045, 0.824241}},
+          // Intemediate_1.
+          {26, {0.080045, 0.824241}},
+          // Intemediate_2.
+          {27, {0.080045, 0.824241}},
+          // Intemediate_3.
+          {28, {0.080045, 0.824241}},
+          // Intemediate_4.
+          {29, {0.000000, 0.413618}},
+      };
+  EXPECT_EQ(expected_calibration_result.size(), stats.size());
+  for (const auto& e : stats) {
+    auto expected_result = expected_calibration_result.at(e.first);
+    EXPECT_NEAR(e.second.min, expected_result.min, eps);
+    EXPECT_NEAR(e.second.max, expected_result.max, eps);
+  }
+}
+
 }  // namespace
 }  // namespace calibration
 }  // namespace optimize
diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc
index 6375016d527..6ec320c4144 100644
--- a/tensorflow/lite/tools/optimize/operator_property.cc
+++ b/tensorflow/lite/tools/optimize/operator_property.cc
@@ -44,7 +44,8 @@ const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index,
       model->subgraphs.at(subgraph_index)->operators[op_index].get();
   op_variant.op_code =
       GetBuiltinCode(model->operator_codes[op->opcode_index].get());
-  if (op_variant.op_code == BuiltinOperator_LSTM) {
+  if (op_variant.op_code == BuiltinOperator_LSTM ||
+      op_variant.op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) {
     if (op->inputs.size() == 5) {
       // The 5 input ("basic") LSTM is not supported in this tooling (yet).
       op_variant.is_quantizable = false;
@@ -230,7 +231,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
       property.version = 2;
       break;
     }
-    case BuiltinOperator_LSTM: {
+    case BuiltinOperator_LSTM:
+    case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: {
       if (!op_variant.is_quantizable) {
         // Early exist for 5 input LSTM.
         // It is not supported in this tooling yet.