Add tensor name to layer statistics, and order ops according to graph order

PiperOrigin-RevId: 360564509
Change-Id: I9bd317a1e9f7162f65211bb07ab71584a5ece3ca
This commit is contained in:
Taehee Jeong 2021-03-02 18:50:14 -08:00 committed by TensorFlower Gardener
parent a4fe022e3d
commit 011a0ca05a
2 changed files with 33 additions and 16 deletions

View File

@ -32,6 +32,8 @@ _DEFAULT_LAYER_DEBUG_METRICS = {
'mean_square_error': lambda diffs: np.average(diffs**2),
}
_NUMERIC_VERIFY_OP_NAME = 'NumericVerify'
def _get_quant_params(
tensor_detail: Mapping[str, Any]) -> Optional[Tuple[float, int]]:
@ -225,8 +227,7 @@ class QuantizationDebugger:
for metric_name, metric in model_statistics.items()
}
def _set_input_tensors(self,
interpreter: tf.lite.Interpreter,
def _set_input_tensors(self, interpreter: tf.lite.Interpreter,
tensor_data: Sequence[np.ndarray],
initialize: bool) -> None:
"""Sets input tensors into TFLite model Interpreter.
@ -286,17 +287,30 @@ class QuantizationDebugger:
def _get_numeric_verify_tensor_details(self) -> List[str]:
"""Returns all names of all tensors from NumericVerify op."""
# pylint: disable=protected-access
if not self._numeric_verify_tensor_details:
self._numeric_verify_tensor_details = [
detail for detail in self._quant_interpreter.get_tensor_details()
if detail['name'].startswith('NumericVerify')
]
self._numeric_verify_tensor_details = []
for op_info in self._quant_interpreter._get_ops_details():
if op_info['op_name'] == _NUMERIC_VERIFY_OP_NAME:
self._numeric_verify_tensor_details.append(
self._quant_interpreter._get_tensor_details(
op_info['outputs'][0]))
# pylint: enable=protected-access
return self._numeric_verify_tensor_details
def _get_operand_index(self, numeric_verify_name: str) -> int:
"""Gets the index of NumericVerify Op's quantized input tensor."""
tensor_idx = numeric_verify_name.rsplit(':', 1)[-1]
return int(tensor_idx)
def _get_operand_name_and_index(self,
numeric_verify_name: str) -> Tuple[str, int]:
"""Gets the index and name of NumericVerify Op's quantized input tensor.
Args:
numeric_verify_name: name of the NumericVerify op's output tensor. It has
format of `NumericVerify/{quantized_tensor_name}:{quantized_tensor_idx}`
Returns:
Tuple of (tensor_name, tensor_idx) for quantized op's output tensor.
"""
tensor_name, tensor_idx = numeric_verify_name.rsplit(':', 1)
return (tensor_name[len(_NUMERIC_VERIFY_OP_NAME) + 1:], int(tensor_idx))
def layer_statistics_dump(self, file: IO[str]) -> None:
"""Dumps layer statistics into file, in csv format.
@ -304,15 +318,17 @@ class QuantizationDebugger:
Args:
file: file, or file-like object to write.
"""
fields = ['op_name', 'op_idx'] + list(
self._layer_debug_metrics.keys()) + ['scales', 'zero_points']
# order of `fields` is the order of fields in csv.
fields = ['op_name', 'tensor_idx'] + list(self._layer_debug_metrics.keys(
)) + ['scales', 'zero_points', 'tensor_name']
writer = csv.DictWriter(file, fields)
writer.writeheader()
for name, metrics in self.layer_statistics.items():
data = metrics.copy()
data['op_idx'] = self._get_operand_index(name)
data['op_name'] = self._defining_op[data['op_idx']]
details = self._quant_interpreter._get_tensor_details(data['op_idx']) # pylint: disable=protected-access
(data['tensor_name'],
data['tensor_idx']) = self._get_operand_name_and_index(name)
data['op_name'] = self._defining_op[data['tensor_idx']]
details = self._quant_interpreter._get_tensor_details(data['tensor_idx']) # pylint: disable=protected-access
data['scales'], data['zero_points'] = (
details['quantization_parameters']['scales'],
details['quantization_parameters']['zero_points'])

View File

@ -131,9 +131,10 @@ class QuantizationDebuggerTest(test_util.TensorFlowTestCase,
expected_values = expected_metrics.copy()
expected_values.update({
'op_name': 'CONV_2D',
'op_idx': 7 if quantized_io else 8,
'tensor_idx': 7 if quantized_io else 8,
'scales': [0.15686275],
'zero_points': [-128],
'tensor_name': 'Identity' if quantized_io else 'Identity4'
})
for key, value in expected_values.items():
if isinstance(value, str):