Add 'quantization_parameters' field to python interpreter to allow getting the full per-axis quantization parameters via python.
PiperOrigin-RevId: 280220657 Change-Id: I3fe27b9c88a717f9c85e59164cd12449679006e3
This commit is contained in:
parent
53fc0686ec
commit
a697e2cb1f
tensorflow/lite/python
@ -306,7 +306,18 @@ class Interpreter(object):
|
|||||||
tensor_index: Tensor index of tensor to query.
|
tensor_index: Tensor index of tensor to query.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a dictionary containing the name, index, shape and type of the tensor.
|
A dictionary containing the following fields of the tensor:
|
||||||
|
'name': The tensor name.
|
||||||
|
'index': The tensor index in the interpreter.
|
||||||
|
'shape': The shape of the tensor.
|
||||||
|
'quantization': Deprecated, use 'quantization_parameters'. This field
|
||||||
|
only works for per-tensor quantization, whereas
|
||||||
|
'quantization_parameters' works in all cases.
|
||||||
|
'quantization_parameters': The parameters used to quantize the tensor:
|
||||||
|
'scales': List of scales (one if per-tensor quantization)
|
||||||
|
'zero_points': List of zero_points (one if per-tensor quantization)
|
||||||
|
'quantized_dimension': Specifies the dimension of per-axis
|
||||||
|
quantization, in the case of multiple scales/zero_points.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If tensor_index is invalid.
|
ValueError: If tensor_index is invalid.
|
||||||
@ -316,6 +327,8 @@ class Interpreter(object):
|
|||||||
tensor_size = self._interpreter.TensorSize(tensor_index)
|
tensor_size = self._interpreter.TensorSize(tensor_index)
|
||||||
tensor_type = self._interpreter.TensorType(tensor_index)
|
tensor_type = self._interpreter.TensorType(tensor_index)
|
||||||
tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
|
tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
|
||||||
|
tensor_quantization_params = self._interpreter.TensorQuantizationParameters(
|
||||||
|
tensor_index)
|
||||||
|
|
||||||
if not tensor_name or not tensor_type:
|
if not tensor_name or not tensor_type:
|
||||||
raise ValueError('Could not get tensor details')
|
raise ValueError('Could not get tensor details')
|
||||||
@ -326,6 +339,11 @@ class Interpreter(object):
|
|||||||
'shape': tensor_size,
|
'shape': tensor_size,
|
||||||
'dtype': tensor_type,
|
'dtype': tensor_type,
|
||||||
'quantization': tensor_quantization,
|
'quantization': tensor_quantization,
|
||||||
|
'quantization_parameters': {
|
||||||
|
'scales': tensor_quantization_params[0],
|
||||||
|
'zero_points': tensor_quantization_params[1],
|
||||||
|
'quantized_dimension': tensor_quantization_params[2],
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return details
|
return details
|
||||||
|
@ -61,6 +61,12 @@ class InterpreterCustomOpsTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
class InterpreterTest(test_util.TensorFlowTestCase):
|
class InterpreterTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def assertQuantizationParamsEqual(self, scales, zero_points,
|
||||||
|
quantized_dimension, params):
|
||||||
|
self.assertAllEqual(scales, params['scales'])
|
||||||
|
self.assertAllEqual(zero_points, params['zero_points'])
|
||||||
|
self.assertEqual(quantized_dimension, params['quantized_dimension'])
|
||||||
|
|
||||||
def testFloat(self):
|
def testFloat(self):
|
||||||
interpreter = interpreter_wrapper.Interpreter(
|
interpreter = interpreter_wrapper.Interpreter(
|
||||||
model_path=resource_loader.get_path_to_datafile(
|
model_path=resource_loader.get_path_to_datafile(
|
||||||
@ -73,6 +79,8 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||||
self.assertTrue(([1, 4] == input_details[0]['shape']).all())
|
self.assertTrue(([1, 4] == input_details[0]['shape']).all())
|
||||||
self.assertEqual((0.0, 0), input_details[0]['quantization'])
|
self.assertEqual((0.0, 0), input_details[0]['quantization'])
|
||||||
|
self.assertQuantizationParamsEqual(
|
||||||
|
[], [], 0, input_details[0]['quantization_parameters'])
|
||||||
|
|
||||||
output_details = interpreter.get_output_details()
|
output_details = interpreter.get_output_details()
|
||||||
self.assertEqual(1, len(output_details))
|
self.assertEqual(1, len(output_details))
|
||||||
@ -80,6 +88,8 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||||
self.assertTrue(([1, 4] == output_details[0]['shape']).all())
|
self.assertTrue(([1, 4] == output_details[0]['shape']).all())
|
||||||
self.assertEqual((0.0, 0), output_details[0]['quantization'])
|
self.assertEqual((0.0, 0), output_details[0]['quantization'])
|
||||||
|
self.assertQuantizationParamsEqual(
|
||||||
|
[], [], 0, output_details[0]['quantization_parameters'])
|
||||||
|
|
||||||
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
|
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
|
||||||
expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
|
expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
|
||||||
@ -104,6 +114,8 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(np.uint8, input_details[0]['dtype'])
|
self.assertEqual(np.uint8, input_details[0]['dtype'])
|
||||||
self.assertTrue(([1, 4] == input_details[0]['shape']).all())
|
self.assertTrue(([1, 4] == input_details[0]['shape']).all())
|
||||||
self.assertEqual((1.0, 0), input_details[0]['quantization'])
|
self.assertEqual((1.0, 0), input_details[0]['quantization'])
|
||||||
|
self.assertQuantizationParamsEqual(
|
||||||
|
[1.0], [0], 0, input_details[0]['quantization_parameters'])
|
||||||
|
|
||||||
output_details = interpreter.get_output_details()
|
output_details = interpreter.get_output_details()
|
||||||
self.assertEqual(1, len(output_details))
|
self.assertEqual(1, len(output_details))
|
||||||
@ -111,6 +123,8 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(np.uint8, output_details[0]['dtype'])
|
self.assertEqual(np.uint8, output_details[0]['dtype'])
|
||||||
self.assertTrue(([1, 4] == output_details[0]['shape']).all())
|
self.assertTrue(([1, 4] == output_details[0]['shape']).all())
|
||||||
self.assertEqual((1.0, 0), output_details[0]['quantization'])
|
self.assertEqual((1.0, 0), output_details[0]['quantization'])
|
||||||
|
self.assertQuantizationParamsEqual(
|
||||||
|
[1.0], [0], 0, output_details[0]['quantization_parameters'])
|
||||||
|
|
||||||
test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
|
test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
|
||||||
expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
|
expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
|
||||||
@ -135,10 +149,14 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(np.string_, input_details[0]['dtype'])
|
self.assertEqual(np.string_, input_details[0]['dtype'])
|
||||||
self.assertTrue(([10] == input_details[0]['shape']).all())
|
self.assertTrue(([10] == input_details[0]['shape']).all())
|
||||||
self.assertEqual((0.0, 0), input_details[0]['quantization'])
|
self.assertEqual((0.0, 0), input_details[0]['quantization'])
|
||||||
|
self.assertQuantizationParamsEqual(
|
||||||
|
[], [], 0, input_details[0]['quantization_parameters'])
|
||||||
self.assertEqual('indices', input_details[1]['name'])
|
self.assertEqual('indices', input_details[1]['name'])
|
||||||
self.assertEqual(np.int64, input_details[1]['dtype'])
|
self.assertEqual(np.int64, input_details[1]['dtype'])
|
||||||
self.assertTrue(([3] == input_details[1]['shape']).all())
|
self.assertTrue(([3] == input_details[1]['shape']).all())
|
||||||
self.assertEqual((0.0, 0), input_details[1]['quantization'])
|
self.assertEqual((0.0, 0), input_details[1]['quantization'])
|
||||||
|
self.assertQuantizationParamsEqual(
|
||||||
|
[], [], 0, input_details[1]['quantization_parameters'])
|
||||||
|
|
||||||
output_details = interpreter.get_output_details()
|
output_details = interpreter.get_output_details()
|
||||||
self.assertEqual(1, len(output_details))
|
self.assertEqual(1, len(output_details))
|
||||||
@ -146,6 +164,8 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(np.string_, output_details[0]['dtype'])
|
self.assertEqual(np.string_, output_details[0]['dtype'])
|
||||||
self.assertTrue(([3] == output_details[0]['shape']).all())
|
self.assertTrue(([3] == output_details[0]['shape']).all())
|
||||||
self.assertEqual((0.0, 0), output_details[0]['quantization'])
|
self.assertEqual((0.0, 0), output_details[0]['quantization'])
|
||||||
|
self.assertQuantizationParamsEqual(
|
||||||
|
[], [], 0, output_details[0]['quantization_parameters'])
|
||||||
|
|
||||||
test_input = np.array([1, 2, 3], dtype=np.int64)
|
test_input = np.array([1, 2, 3], dtype=np.int64)
|
||||||
interpreter.set_tensor(input_details[1]['index'], test_input)
|
interpreter.set_tensor(input_details[1]['index'], test_input)
|
||||||
@ -158,6 +178,17 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
|||||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||||
self.assertTrue((expected_output == output_data).all())
|
self.assertTrue((expected_output == output_data).all())
|
||||||
|
|
||||||
|
def testPerChannelParams(self):
|
||||||
|
interpreter = interpreter_wrapper.Interpreter(
|
||||||
|
model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin'))
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
# Tensor index 1 is the weight.
|
||||||
|
weight_details = interpreter.get_tensor_details()[1]
|
||||||
|
qparams = weight_details['quantization_parameters']
|
||||||
|
# Ensure that we retrieve per channel quantization params correctly.
|
||||||
|
self.assertEqual(len(qparams['scales']), 128)
|
||||||
|
|
||||||
|
|
||||||
class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
|
class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@ -86,6 +86,12 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
|
|||||||
return interpreter;
|
return interpreter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* PyArrayFromFloatVector(const float* data, npy_intp size) {
|
||||||
|
void* pydata = malloc(size * sizeof(float));
|
||||||
|
memcpy(pydata, data, size * sizeof(float));
|
||||||
|
return PyArray_SimpleNewFromData(1, &size, NPY_FLOAT32, pydata);
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
|
PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
|
||||||
void* pydata = malloc(size * sizeof(int));
|
void* pydata = malloc(size * sizeof(int));
|
||||||
memcpy(pydata, data, size * sizeof(int));
|
memcpy(pydata, data, size * sizeof(int));
|
||||||
@ -301,6 +307,40 @@ PyObject* InterpreterWrapper::TensorQuantization(int i) const {
|
|||||||
return PyTupleFromQuantizationParam(tensor->params);
|
return PyTupleFromQuantizationParam(tensor->params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* InterpreterWrapper::TensorQuantizationParameters(int i) const {
|
||||||
|
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||||
|
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
|
||||||
|
const TfLiteTensor* tensor = interpreter_->tensor(i);
|
||||||
|
const TfLiteQuantization quantization = tensor->quantization;
|
||||||
|
float* scales_data = nullptr;
|
||||||
|
int32_t* zero_points_data = nullptr;
|
||||||
|
int32_t scales_size = 0;
|
||||||
|
int32_t zero_points_size = 0;
|
||||||
|
int32_t quantized_dimension = 0;
|
||||||
|
if (quantization.type == kTfLiteAffineQuantization) {
|
||||||
|
const TfLiteAffineQuantization* q_params =
|
||||||
|
reinterpret_cast<const TfLiteAffineQuantization*>(quantization.params);
|
||||||
|
if (q_params->scale) {
|
||||||
|
scales_data = q_params->scale->data;
|
||||||
|
scales_size = q_params->scale->size;
|
||||||
|
}
|
||||||
|
if (q_params->zero_point) {
|
||||||
|
zero_points_data = q_params->zero_point->data;
|
||||||
|
zero_points_size = q_params->zero_point->size;
|
||||||
|
}
|
||||||
|
quantized_dimension = q_params->quantized_dimension;
|
||||||
|
}
|
||||||
|
PyObject* scales_array = PyArrayFromFloatVector(scales_data, scales_size);
|
||||||
|
PyObject* zero_points_array =
|
||||||
|
PyArrayFromIntVector(zero_points_data, zero_points_size);
|
||||||
|
|
||||||
|
PyObject* result = PyTuple_New(3);
|
||||||
|
PyTuple_SET_ITEM(result, 0, scales_array);
|
||||||
|
PyTuple_SET_ITEM(result, 1, zero_points_array);
|
||||||
|
PyTuple_SET_ITEM(result, 2, PyLong_FromLong(quantized_dimension));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
|
PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
|
||||||
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||||
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
|
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
|
||||||
|
@ -67,7 +67,9 @@ class InterpreterWrapper {
|
|||||||
std::string TensorName(int i) const;
|
std::string TensorName(int i) const;
|
||||||
PyObject* TensorType(int i) const;
|
PyObject* TensorType(int i) const;
|
||||||
PyObject* TensorSize(int i) const;
|
PyObject* TensorSize(int i) const;
|
||||||
|
// Deprecated in favor of TensorQuantizationScales, below.
|
||||||
PyObject* TensorQuantization(int i) const;
|
PyObject* TensorQuantization(int i) const;
|
||||||
|
PyObject* TensorQuantizationParameters(int i) const;
|
||||||
PyObject* SetTensor(int i, PyObject* value);
|
PyObject* SetTensor(int i, PyObject* value);
|
||||||
PyObject* GetTensor(int i) const;
|
PyObject* GetTensor(int i) const;
|
||||||
PyObject* ResetVariableTensors();
|
PyObject* ResetVariableTensors();
|
||||||
|
1
tensorflow/lite/python/testdata/BUILD
vendored
1
tensorflow/lite/python/testdata/BUILD
vendored
@ -46,6 +46,7 @@ tf_to_tflite(
|
|||||||
filegroup(
|
filegroup(
|
||||||
name = "interpreter_test_data",
|
name = "interpreter_test_data",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"pc_conv.bin",
|
||||||
":gather_string",
|
":gather_string",
|
||||||
":permute_float",
|
":permute_float",
|
||||||
":permute_uint8",
|
":permute_uint8",
|
||||||
|
BIN
tensorflow/lite/python/testdata/pc_conv.bin
vendored
Normal file
BIN
tensorflow/lite/python/testdata/pc_conv.bin
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user