From 011a0ca05a1c7b75f6b3905d58c297beeba7a485 Mon Sep 17 00:00:00 2001
From: Taehee Jeong <taeheej@google.com>
Date: Tue, 2 Mar 2021 18:50:14 -0800
Subject: [PATCH] Add tensor name to layer statistics, and order ops according
 to graph order

PiperOrigin-RevId: 360564509
Change-Id: I9bd317a1e9f7162f65211bb07ab71584a5ece3ca
---
 .../quantization_debugger/debugger.py         | 46 +++++++++++++------
 .../quantization_debugger/debugger_test.py    |  3 +-
 2 files changed, 33 insertions(+), 16 deletions(-)

diff --git a/tensorflow/lite/experimental/quantization_debugger/debugger.py b/tensorflow/lite/experimental/quantization_debugger/debugger.py
index bfccd35a898..5dedcbd9f81 100644
--- a/tensorflow/lite/experimental/quantization_debugger/debugger.py
+++ b/tensorflow/lite/experimental/quantization_debugger/debugger.py
@@ -32,6 +32,8 @@ _DEFAULT_LAYER_DEBUG_METRICS = {
     'mean_square_error': lambda diffs: np.average(diffs**2),
 }
 
+_NUMERIC_VERIFY_OP_NAME = 'NumericVerify'
+
 
 def _get_quant_params(
     tensor_detail: Mapping[str, Any]) -> Optional[Tuple[float, int]]:
@@ -225,8 +227,7 @@ class QuantizationDebugger:
         for metric_name, metric in model_statistics.items()
     }
 
-  def _set_input_tensors(self,
-                         interpreter: tf.lite.Interpreter,
+  def _set_input_tensors(self, interpreter: tf.lite.Interpreter,
                          tensor_data: Sequence[np.ndarray],
                          initialize: bool) -> None:
     """Sets input tensors into TFLite model Interpreter.
@@ -286,17 +287,30 @@ class QuantizationDebugger:
 
   def _get_numeric_verify_tensor_details(self) -> List[str]:
     """Returns all names of all tensors from NumericVerify op."""
+    # pylint: disable=protected-access
     if not self._numeric_verify_tensor_details:
-      self._numeric_verify_tensor_details = [
-          detail for detail in self._quant_interpreter.get_tensor_details()
-          if detail['name'].startswith('NumericVerify')
-      ]
+      self._numeric_verify_tensor_details = []
+      for op_info in self._quant_interpreter._get_ops_details():
+        if op_info['op_name'] == _NUMERIC_VERIFY_OP_NAME:
+          self._numeric_verify_tensor_details.append(
+              self._quant_interpreter._get_tensor_details(
+                  op_info['outputs'][0]))
+    # pylint: enable=protected-access
     return self._numeric_verify_tensor_details
 
-  def _get_operand_index(self, numeric_verify_name: str) -> int:
-    """Gets the index of NumericVerify Op's quantized input tensor."""
-    tensor_idx = numeric_verify_name.rsplit(':', 1)[-1]
-    return int(tensor_idx)
+  def _get_operand_name_and_index(self,
+                                  numeric_verify_name: str) -> Tuple[str, int]:
+    """Gets the index and name of NumericVerify Op's quantized input tensor.
+
+    Args:
+      numeric_verify_name: name of the NumericVerify op's output tensor. It has
+        format of `NumericVerify/{quantized_tensor_name}:{quantized_tensor_idx}`
+
+    Returns:
+      Tuple of (tensor_name, tensor_idx) for quantized op's output tensor.
+    """
+    tensor_name, tensor_idx = numeric_verify_name.rsplit(':', 1)
+    return (tensor_name[len(_NUMERIC_VERIFY_OP_NAME) + 1:], int(tensor_idx))
 
   def layer_statistics_dump(self, file: IO[str]) -> None:
     """Dumps layer statistics into file, in csv format.
@@ -304,15 +318,17 @@ class QuantizationDebugger:
     Args:
       file: file, or file-like object to write.
     """
-    fields = ['op_name', 'op_idx'] + list(
-        self._layer_debug_metrics.keys()) + ['scales', 'zero_points']
+    # order of `fields` is the order of fields in csv.
+    fields = ['op_name', 'tensor_idx'] + list(self._layer_debug_metrics.keys(
+    )) + ['scales', 'zero_points', 'tensor_name']
     writer = csv.DictWriter(file, fields)
     writer.writeheader()
     for name, metrics in self.layer_statistics.items():
       data = metrics.copy()
-      data['op_idx'] = self._get_operand_index(name)
-      data['op_name'] = self._defining_op[data['op_idx']]
-      details = self._quant_interpreter._get_tensor_details(data['op_idx'])  # pylint: disable=protected-access
+      (data['tensor_name'],
+       data['tensor_idx']) = self._get_operand_name_and_index(name)
+      data['op_name'] = self._defining_op[data['tensor_idx']]
+      details = self._quant_interpreter._get_tensor_details(data['tensor_idx'])  # pylint: disable=protected-access
       data['scales'], data['zero_points'] = (
           details['quantization_parameters']['scales'],
           details['quantization_parameters']['zero_points'])
diff --git a/tensorflow/lite/experimental/quantization_debugger/debugger_test.py b/tensorflow/lite/experimental/quantization_debugger/debugger_test.py
index 4339f4848eb..ceef9da6b6e 100644
--- a/tensorflow/lite/experimental/quantization_debugger/debugger_test.py
+++ b/tensorflow/lite/experimental/quantization_debugger/debugger_test.py
@@ -131,9 +131,10 @@ class QuantizationDebuggerTest(test_util.TensorFlowTestCase,
     expected_values = expected_metrics.copy()
     expected_values.update({
         'op_name': 'CONV_2D',
-        'op_idx': 7 if quantized_io else 8,
+        'tensor_idx': 7 if quantized_io else 8,
         'scales': [0.15686275],
         'zero_points': [-128],
+        'tensor_name': 'Identity' if quantized_io else 'Identity4'
     })
     for key, value in expected_values.items():
       if isinstance(value, str):