Add tensor name to layer statistics, and order ops according to graph order
PiperOrigin-RevId: 360564509 Change-Id: I9bd317a1e9f7162f65211bb07ab71584a5ece3ca
This commit is contained in:
parent
a4fe022e3d
commit
011a0ca05a
@ -32,6 +32,8 @@ _DEFAULT_LAYER_DEBUG_METRICS = {
|
||||
'mean_square_error': lambda diffs: np.average(diffs**2),
|
||||
}
|
||||
|
||||
_NUMERIC_VERIFY_OP_NAME = 'NumericVerify'
|
||||
|
||||
|
||||
def _get_quant_params(
|
||||
tensor_detail: Mapping[str, Any]) -> Optional[Tuple[float, int]]:
|
||||
@ -225,8 +227,7 @@ class QuantizationDebugger:
|
||||
for metric_name, metric in model_statistics.items()
|
||||
}
|
||||
|
||||
def _set_input_tensors(self,
|
||||
interpreter: tf.lite.Interpreter,
|
||||
def _set_input_tensors(self, interpreter: tf.lite.Interpreter,
|
||||
tensor_data: Sequence[np.ndarray],
|
||||
initialize: bool) -> None:
|
||||
"""Sets input tensors into TFLite model Interpreter.
|
||||
@ -286,17 +287,30 @@ class QuantizationDebugger:
|
||||
|
||||
def _get_numeric_verify_tensor_details(self) -> List[str]:
|
||||
"""Returns all names of all tensors from NumericVerify op."""
|
||||
# pylint: disable=protected-access
|
||||
if not self._numeric_verify_tensor_details:
|
||||
self._numeric_verify_tensor_details = [
|
||||
detail for detail in self._quant_interpreter.get_tensor_details()
|
||||
if detail['name'].startswith('NumericVerify')
|
||||
]
|
||||
self._numeric_verify_tensor_details = []
|
||||
for op_info in self._quant_interpreter._get_ops_details():
|
||||
if op_info['op_name'] == _NUMERIC_VERIFY_OP_NAME:
|
||||
self._numeric_verify_tensor_details.append(
|
||||
self._quant_interpreter._get_tensor_details(
|
||||
op_info['outputs'][0]))
|
||||
# pylint: enable=protected-access
|
||||
return self._numeric_verify_tensor_details
|
||||
|
||||
def _get_operand_index(self, numeric_verify_name: str) -> int:
|
||||
"""Gets the index of NumericVerify Op's quantized input tensor."""
|
||||
tensor_idx = numeric_verify_name.rsplit(':', 1)[-1]
|
||||
return int(tensor_idx)
|
||||
def _get_operand_name_and_index(self,
|
||||
numeric_verify_name: str) -> Tuple[str, int]:
|
||||
"""Gets the index and name of NumericVerify Op's quantized input tensor.
|
||||
|
||||
Args:
|
||||
numeric_verify_name: name of the NumericVerify op's output tensor. It has
|
||||
format of `NumericVerify/{quantized_tensor_name}:{quantized_tensor_idx}`
|
||||
|
||||
Returns:
|
||||
Tuple of (tensor_name, tensor_idx) for quantized op's output tensor.
|
||||
"""
|
||||
tensor_name, tensor_idx = numeric_verify_name.rsplit(':', 1)
|
||||
return (tensor_name[len(_NUMERIC_VERIFY_OP_NAME) + 1:], int(tensor_idx))
|
||||
|
||||
def layer_statistics_dump(self, file: IO[str]) -> None:
|
||||
"""Dumps layer statistics into file, in csv format.
|
||||
@ -304,15 +318,17 @@ class QuantizationDebugger:
|
||||
Args:
|
||||
file: file, or file-like object to write.
|
||||
"""
|
||||
fields = ['op_name', 'op_idx'] + list(
|
||||
self._layer_debug_metrics.keys()) + ['scales', 'zero_points']
|
||||
# order of `fields` is the order of fields in csv.
|
||||
fields = ['op_name', 'tensor_idx'] + list(self._layer_debug_metrics.keys(
|
||||
)) + ['scales', 'zero_points', 'tensor_name']
|
||||
writer = csv.DictWriter(file, fields)
|
||||
writer.writeheader()
|
||||
for name, metrics in self.layer_statistics.items():
|
||||
data = metrics.copy()
|
||||
data['op_idx'] = self._get_operand_index(name)
|
||||
data['op_name'] = self._defining_op[data['op_idx']]
|
||||
details = self._quant_interpreter._get_tensor_details(data['op_idx']) # pylint: disable=protected-access
|
||||
(data['tensor_name'],
|
||||
data['tensor_idx']) = self._get_operand_name_and_index(name)
|
||||
data['op_name'] = self._defining_op[data['tensor_idx']]
|
||||
details = self._quant_interpreter._get_tensor_details(data['tensor_idx']) # pylint: disable=protected-access
|
||||
data['scales'], data['zero_points'] = (
|
||||
details['quantization_parameters']['scales'],
|
||||
details['quantization_parameters']['zero_points'])
|
||||
|
@ -131,9 +131,10 @@ class QuantizationDebuggerTest(test_util.TensorFlowTestCase,
|
||||
expected_values = expected_metrics.copy()
|
||||
expected_values.update({
|
||||
'op_name': 'CONV_2D',
|
||||
'op_idx': 7 if quantized_io else 8,
|
||||
'tensor_idx': 7 if quantized_io else 8,
|
||||
'scales': [0.15686275],
|
||||
'zero_points': [-128],
|
||||
'tensor_name': 'Identity' if quantized_io else 'Identity4'
|
||||
})
|
||||
for key, value in expected_values.items():
|
||||
if isinstance(value, str):
|
||||
|
Loading…
Reference in New Issue
Block a user