diff --git a/tensorflow/lite/experimental/quantization_debugger/debugger.py b/tensorflow/lite/experimental/quantization_debugger/debugger.py index bfccd35a898..5dedcbd9f81 100644 --- a/tensorflow/lite/experimental/quantization_debugger/debugger.py +++ b/tensorflow/lite/experimental/quantization_debugger/debugger.py @@ -32,6 +32,8 @@ _DEFAULT_LAYER_DEBUG_METRICS = { 'mean_square_error': lambda diffs: np.average(diffs**2), } +_NUMERIC_VERIFY_OP_NAME = 'NumericVerify' + def _get_quant_params( tensor_detail: Mapping[str, Any]) -> Optional[Tuple[float, int]]: @@ -225,8 +227,7 @@ class QuantizationDebugger: for metric_name, metric in model_statistics.items() } - def _set_input_tensors(self, - interpreter: tf.lite.Interpreter, + def _set_input_tensors(self, interpreter: tf.lite.Interpreter, tensor_data: Sequence[np.ndarray], initialize: bool) -> None: """Sets input tensors into TFLite model Interpreter. @@ -286,17 +287,30 @@ class QuantizationDebugger: def _get_numeric_verify_tensor_details(self) -> List[str]: """Returns all names of all tensors from NumericVerify op.""" + # pylint: disable=protected-access if not self._numeric_verify_tensor_details: - self._numeric_verify_tensor_details = [ - detail for detail in self._quant_interpreter.get_tensor_details() - if detail['name'].startswith('NumericVerify') - ] + self._numeric_verify_tensor_details = [] + for op_info in self._quant_interpreter._get_ops_details(): + if op_info['op_name'] == _NUMERIC_VERIFY_OP_NAME: + self._numeric_verify_tensor_details.append( + self._quant_interpreter._get_tensor_details( + op_info['outputs'][0])) + # pylint: enable=protected-access return self._numeric_verify_tensor_details - def _get_operand_index(self, numeric_verify_name: str) -> int: - """Gets the index of NumericVerify Op's quantized input tensor.""" - tensor_idx = numeric_verify_name.rsplit(':', 1)[-1] - return int(tensor_idx) + def _get_operand_name_and_index(self, + numeric_verify_name: str) -> Tuple[str, int]: + """Gets the index and name of NumericVerify Op's quantized input tensor. + + Args: + numeric_verify_name: name of the NumericVerify op's output tensor. It has + format of `NumericVerify/{quantized_tensor_name}:{quantized_tensor_idx}` + + Returns: + Tuple of (tensor_name, tensor_idx) for quantized op's output tensor. + """ + tensor_name, tensor_idx = numeric_verify_name.rsplit(':', 1) + return (tensor_name[len(_NUMERIC_VERIFY_OP_NAME) + 1:], int(tensor_idx)) def layer_statistics_dump(self, file: IO[str]) -> None: """Dumps layer statistics into file, in csv format. @@ -304,15 +318,17 @@ class QuantizationDebugger: Args: file: file, or file-like object to write. """ - fields = ['op_name', 'op_idx'] + list( - self._layer_debug_metrics.keys()) + ['scales', 'zero_points'] + # order of `fields` is the order of fields in csv. + fields = ['op_name', 'tensor_idx'] + list(self._layer_debug_metrics.keys( + )) + ['scales', 'zero_points', 'tensor_name'] writer = csv.DictWriter(file, fields) writer.writeheader() for name, metrics in self.layer_statistics.items(): data = metrics.copy() - data['op_idx'] = self._get_operand_index(name) - data['op_name'] = self._defining_op[data['op_idx']] - details = self._quant_interpreter._get_tensor_details(data['op_idx']) # pylint: disable=protected-access + (data['tensor_name'], + data['tensor_idx']) = self._get_operand_name_and_index(name) + data['op_name'] = self._defining_op[data['tensor_idx']] + details = self._quant_interpreter._get_tensor_details(data['tensor_idx']) # pylint: disable=protected-access data['scales'], data['zero_points'] = ( details['quantization_parameters']['scales'], details['quantization_parameters']['zero_points']) diff --git a/tensorflow/lite/experimental/quantization_debugger/debugger_test.py b/tensorflow/lite/experimental/quantization_debugger/debugger_test.py index 4339f4848eb..ceef9da6b6e 100644 --- a/tensorflow/lite/experimental/quantization_debugger/debugger_test.py +++ b/tensorflow/lite/experimental/quantization_debugger/debugger_test.py @@ -131,9 +131,10 @@ class QuantizationDebuggerTest(test_util.TensorFlowTestCase, expected_values = expected_metrics.copy() expected_values.update({ 'op_name': 'CONV_2D', - 'op_idx': 7 if quantized_io else 8, + 'tensor_idx': 7 if quantized_io else 8, 'scales': [0.15686275], 'zero_points': [-128], + 'tensor_name': 'Identity' if quantized_io else 'Identity4' }) for key, value in expected_values.items(): if isinstance(value, str):