Add 'quantization_parameters' field to python interpreter to allow getting the full per-axis quantization parameters via python.
PiperOrigin-RevId: 280220657 Change-Id: I3fe27b9c88a717f9c85e59164cd12449679006e3
This commit is contained in:
parent
53fc0686ec
commit
a697e2cb1f
@ -306,7 +306,18 @@ class Interpreter(object):
|
||||
tensor_index: Tensor index of tensor to query.
|
||||
|
||||
Returns:
|
||||
a dictionary containing the name, index, shape and type of the tensor.
|
||||
A dictionary containing the following fields of the tensor:
|
||||
'name': The tensor name.
|
||||
'index': The tensor index in the interpreter.
|
||||
'shape': The shape of the tensor.
|
||||
'quantization': Deprecated, use 'quantization_parameters'. This field
|
||||
only works for per-tensor quantization, whereas
|
||||
'quantization_parameters' works in all cases.
|
||||
'quantization_parameters': The parameters used to quantize the tensor:
|
||||
'scales': List of scales (one if per-tensor quantization)
|
||||
'zero_points': List of zero_points (one if per-tensor quantization)
|
||||
'quantized_dimension': Specifies the dimension of per-axis
|
||||
quantization, in the case of multiple scales/zero_points.
|
||||
|
||||
Raises:
|
||||
ValueError: If tensor_index is invalid.
|
||||
@ -316,6 +327,8 @@ class Interpreter(object):
|
||||
tensor_size = self._interpreter.TensorSize(tensor_index)
|
||||
tensor_type = self._interpreter.TensorType(tensor_index)
|
||||
tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
|
||||
tensor_quantization_params = self._interpreter.TensorQuantizationParameters(
|
||||
tensor_index)
|
||||
|
||||
if not tensor_name or not tensor_type:
|
||||
raise ValueError('Could not get tensor details')
|
||||
@ -326,6 +339,11 @@ class Interpreter(object):
|
||||
'shape': tensor_size,
|
||||
'dtype': tensor_type,
|
||||
'quantization': tensor_quantization,
|
||||
'quantization_parameters': {
|
||||
'scales': tensor_quantization_params[0],
|
||||
'zero_points': tensor_quantization_params[1],
|
||||
'quantized_dimension': tensor_quantization_params[2],
|
||||
}
|
||||
}
|
||||
|
||||
return details
|
||||
|
@ -61,6 +61,12 @@ class InterpreterCustomOpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def assertQuantizationParamsEqual(self, scales, zero_points,
|
||||
quantized_dimension, params):
|
||||
self.assertAllEqual(scales, params['scales'])
|
||||
self.assertAllEqual(zero_points, params['zero_points'])
|
||||
self.assertEqual(quantized_dimension, params['quantized_dimension'])
|
||||
|
||||
def testFloat(self):
|
||||
interpreter = interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
@ -73,6 +79,8 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue(([1, 4] == input_details[0]['shape']).all())
|
||||
self.assertEqual((0.0, 0), input_details[0]['quantization'])
|
||||
self.assertQuantizationParamsEqual(
|
||||
[], [], 0, input_details[0]['quantization_parameters'])
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(1, len(output_details))
|
||||
@ -80,6 +88,8 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||
self.assertTrue(([1, 4] == output_details[0]['shape']).all())
|
||||
self.assertEqual((0.0, 0), output_details[0]['quantization'])
|
||||
self.assertQuantizationParamsEqual(
|
||||
[], [], 0, output_details[0]['quantization_parameters'])
|
||||
|
||||
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
|
||||
expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
|
||||
@ -104,6 +114,8 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(np.uint8, input_details[0]['dtype'])
|
||||
self.assertTrue(([1, 4] == input_details[0]['shape']).all())
|
||||
self.assertEqual((1.0, 0), input_details[0]['quantization'])
|
||||
self.assertQuantizationParamsEqual(
|
||||
[1.0], [0], 0, input_details[0]['quantization_parameters'])
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(1, len(output_details))
|
||||
@ -111,6 +123,8 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(np.uint8, output_details[0]['dtype'])
|
||||
self.assertTrue(([1, 4] == output_details[0]['shape']).all())
|
||||
self.assertEqual((1.0, 0), output_details[0]['quantization'])
|
||||
self.assertQuantizationParamsEqual(
|
||||
[1.0], [0], 0, output_details[0]['quantization_parameters'])
|
||||
|
||||
test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
|
||||
expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
|
||||
@ -135,10 +149,14 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(np.string_, input_details[0]['dtype'])
|
||||
self.assertTrue(([10] == input_details[0]['shape']).all())
|
||||
self.assertEqual((0.0, 0), input_details[0]['quantization'])
|
||||
self.assertQuantizationParamsEqual(
|
||||
[], [], 0, input_details[0]['quantization_parameters'])
|
||||
self.assertEqual('indices', input_details[1]['name'])
|
||||
self.assertEqual(np.int64, input_details[1]['dtype'])
|
||||
self.assertTrue(([3] == input_details[1]['shape']).all())
|
||||
self.assertEqual((0.0, 0), input_details[1]['quantization'])
|
||||
self.assertQuantizationParamsEqual(
|
||||
[], [], 0, input_details[1]['quantization_parameters'])
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(1, len(output_details))
|
||||
@ -146,6 +164,8 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(np.string_, output_details[0]['dtype'])
|
||||
self.assertTrue(([3] == output_details[0]['shape']).all())
|
||||
self.assertEqual((0.0, 0), output_details[0]['quantization'])
|
||||
self.assertQuantizationParamsEqual(
|
||||
[], [], 0, output_details[0]['quantization_parameters'])
|
||||
|
||||
test_input = np.array([1, 2, 3], dtype=np.int64)
|
||||
interpreter.set_tensor(input_details[1]['index'], test_input)
|
||||
@ -158,6 +178,17 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||
self.assertTrue((expected_output == output_data).all())
|
||||
|
||||
def testPerChannelParams(self):
|
||||
interpreter = interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin'))
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
# Tensor index 1 is the weight.
|
||||
weight_details = interpreter.get_tensor_details()[1]
|
||||
qparams = weight_details['quantization_parameters']
|
||||
# Ensure that we retrieve per channel quantization params correctly.
|
||||
self.assertEqual(len(qparams['scales']), 128)
|
||||
|
||||
|
||||
class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -86,6 +86,12 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
|
||||
return interpreter;
|
||||
}
|
||||
|
||||
PyObject* PyArrayFromFloatVector(const float* data, npy_intp size) {
|
||||
void* pydata = malloc(size * sizeof(float));
|
||||
memcpy(pydata, data, size * sizeof(float));
|
||||
return PyArray_SimpleNewFromData(1, &size, NPY_FLOAT32, pydata);
|
||||
}
|
||||
|
||||
PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
|
||||
void* pydata = malloc(size * sizeof(int));
|
||||
memcpy(pydata, data, size * sizeof(int));
|
||||
@ -301,6 +307,40 @@ PyObject* InterpreterWrapper::TensorQuantization(int i) const {
|
||||
return PyTupleFromQuantizationParam(tensor->params);
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::TensorQuantizationParameters(int i) const {
|
||||
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
|
||||
const TfLiteTensor* tensor = interpreter_->tensor(i);
|
||||
const TfLiteQuantization quantization = tensor->quantization;
|
||||
float* scales_data = nullptr;
|
||||
int32_t* zero_points_data = nullptr;
|
||||
int32_t scales_size = 0;
|
||||
int32_t zero_points_size = 0;
|
||||
int32_t quantized_dimension = 0;
|
||||
if (quantization.type == kTfLiteAffineQuantization) {
|
||||
const TfLiteAffineQuantization* q_params =
|
||||
reinterpret_cast<const TfLiteAffineQuantization*>(quantization.params);
|
||||
if (q_params->scale) {
|
||||
scales_data = q_params->scale->data;
|
||||
scales_size = q_params->scale->size;
|
||||
}
|
||||
if (q_params->zero_point) {
|
||||
zero_points_data = q_params->zero_point->data;
|
||||
zero_points_size = q_params->zero_point->size;
|
||||
}
|
||||
quantized_dimension = q_params->quantized_dimension;
|
||||
}
|
||||
PyObject* scales_array = PyArrayFromFloatVector(scales_data, scales_size);
|
||||
PyObject* zero_points_array =
|
||||
PyArrayFromIntVector(zero_points_data, zero_points_size);
|
||||
|
||||
PyObject* result = PyTuple_New(3);
|
||||
PyTuple_SET_ITEM(result, 0, scales_array);
|
||||
PyTuple_SET_ITEM(result, 1, zero_points_array);
|
||||
PyTuple_SET_ITEM(result, 2, PyLong_FromLong(quantized_dimension));
|
||||
return result;
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
|
||||
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
|
||||
|
@ -67,7 +67,9 @@ class InterpreterWrapper {
|
||||
std::string TensorName(int i) const;
|
||||
PyObject* TensorType(int i) const;
|
||||
PyObject* TensorSize(int i) const;
|
||||
// Deprecated in favor of TensorQuantizationScales, below.
|
||||
PyObject* TensorQuantization(int i) const;
|
||||
PyObject* TensorQuantizationParameters(int i) const;
|
||||
PyObject* SetTensor(int i, PyObject* value);
|
||||
PyObject* GetTensor(int i) const;
|
||||
PyObject* ResetVariableTensors();
|
||||
|
1
tensorflow/lite/python/testdata/BUILD
vendored
1
tensorflow/lite/python/testdata/BUILD
vendored
@ -46,6 +46,7 @@ tf_to_tflite(
|
||||
filegroup(
|
||||
name = "interpreter_test_data",
|
||||
srcs = [
|
||||
"pc_conv.bin",
|
||||
":gather_string",
|
||||
":permute_float",
|
||||
":permute_uint8",
|
||||
|
BIN
tensorflow/lite/python/testdata/pc_conv.bin
vendored
Normal file
BIN
tensorflow/lite/python/testdata/pc_conv.bin
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user