diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 091ddc35d46..15704a6da76 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -346,15 +346,9 @@ class TFLiteConverterBase(object): self.representative_dataset.input_gen, inference_input_type, inference_output_type, allow_float) - def _is_unknown_shapes_allowed(self, fp32_execution): - # TODO(b/128319310): Investigate which quantization methods work. - if not fp32_execution: - return False - + def _is_unknown_shapes_allowed(self): # Unknown dimensions are only allowed with the new converter. - if not self.experimental_new_converter: - return False - return True + return self.experimental_new_converter def _get_base_converter_args(self): """Returns the base converter args. @@ -657,7 +651,7 @@ class TFLiteConverterV2(TFLiteConverterBase): quant_mode = QuantizationMode(self.optimizations, self.target_spec, self.representative_dataset, graph_def) - if not self._is_unknown_shapes_allowed(quant_mode.fp32_execution()): + if not self._is_unknown_shapes_allowed(): # Checks dimensions in input tensor. for tensor in input_tensors: # Note that shape_list might be empty for scalar shapes. @@ -1197,8 +1191,7 @@ class TFLiteConverter(TFLiteConverterBase): self.representative_dataset, self._graph_def) # Checks dimensions in input tensor. - if (not self._is_unknown_shapes_allowed(quant_mode.fp32_execution()) and - self._has_valid_tensors()): + if (not self._is_unknown_shapes_allowed() and self._has_valid_tensors()): for tensor in self._input_tensors: shape = tensor.shape if not shape: @@ -1399,13 +1392,12 @@ class TFLiteConverter(TFLiteConverterBase): shape[0] = batch_size tensor.set_shape(shape) - def _is_unknown_shapes_allowed(self, fp32_execution): + def _is_unknown_shapes_allowed(self): # Ophint Converted nodes will need the shapes to be known. if _is_ophint_converted(self._graph_def): return False - if not super(TFLiteConverter, - self)._is_unknown_shapes_allowed(fp32_execution): + if not super(TFLiteConverter, self)._is_unknown_shapes_allowed(): return False # `conversion_summary_dir` calls TOCO. Unknown shapes are only supported by diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 70671f36265..d351fd492f6 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -876,6 +876,84 @@ class UnknownShapes(lite_v2_test_util.ModelTest): np.testing.assert_almost_equal( expected_value.numpy(), actual_value[0], decimal=6) + def _getQuantizedModel(self): + # Returns a model with tf.MatMul and unknown dimensions. + @tf.function( + input_signature=[tf.TensorSpec(shape=[None, 33], dtype=tf.float32)]) + def model(in_tensor): + # We need the tensor to have more than 1024 elements for quantize_weights + # to kick in. Thus, the [33, 33] shape. + const_tensor = tf.constant( + np.random.uniform(low=-10., high=10., size=[33, 33]), + shape=[33, 33], + dtype=tf.float32, + name='inputB') + + shape = tf.shape(in_tensor) + fill = tf.transpose(tf.fill(shape, 1.)) + mult = tf.matmul(fill, in_tensor) + return tf.matmul(mult, const_tensor) + + concrete_func = model.get_concrete_function() + + def calibration_gen(): + for batch in range(5, 20, 5): + for _ in range(5): + yield [np.random.uniform(-1, 1, size=(batch, 33)).astype(np.float32)] + + return concrete_func, calibration_gen + + @test_util.run_v2_only + def testMatMulQuantize(self): + concrete_func, _ = self._getQuantizedModel() + float_converter = lite.TFLiteConverterV2.from_concrete_functions( + [concrete_func]) + float_converter.experimental_new_converter = True + float_tflite_model = float_converter.convert() + + quantized_converter = lite.TFLiteConverterV2.from_concrete_functions( + [concrete_func]) + quantized_converter.experimental_new_converter = True + quantized_converter.optimizations = [lite.Optimize.DEFAULT] + quantized_tflite_model = quantized_converter.convert() + + # The default input and output types should be float. + quantized_interpreter = Interpreter(model_content=quantized_tflite_model) + quantized_interpreter.allocate_tensors() + input_details = quantized_interpreter.get_input_details() + self.assertLen(input_details, 1) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue((input_details[0]['shape_signature'] == [-1, 33]).all()) + + # Ensure that the quantized weights tflite model is smaller. + self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) + + @test_util.run_v2_only + def testMatMulCalibrateAndQuantize(self): + concrete_func, calibration_gen = self._getQuantizedModel() + float_converter = lite.TFLiteConverterV2.from_concrete_functions( + [concrete_func]) + float_converter.experimental_new_converter = True + float_tflite_model = float_converter.convert() + + quantized_converter = lite.TFLiteConverterV2.from_concrete_functions( + [concrete_func]) + quantized_converter.optimizations = [lite.Optimize.DEFAULT] + quantized_converter.representative_dataset = calibration_gen + quantized_converter.experimental_new_converter = True + quantized_tflite_model = quantized_converter.convert() + + # The default input and output types should be float. + quantized_interpreter = Interpreter(model_content=quantized_tflite_model) + quantized_interpreter.allocate_tensors() + input_details = quantized_interpreter.get_input_details() + self.assertLen(input_details, 1) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue((input_details[0]['shape_signature'] == [-1, 33]).all()) + + # Ensure that the quantized weights tflite model is smaller. + self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) + def testBatchMatMul(self): input_data_1 = tf.constant( np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32)) diff --git a/tensorflow/lite/python/lite_v2_test_util.py b/tensorflow/lite/python/lite_v2_test_util.py index b3da6a89b86..d8f764711cd 100644 --- a/tensorflow/lite/python/lite_v2_test_util.py +++ b/tensorflow/lite/python/lite_v2_test_util.py @@ -53,10 +53,12 @@ class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase): for idx, (shape_signature, final_shape) in enumerate(input_shapes): self.assertTrue( (input_details[idx]['shape_signature'] == shape_signature).all()) - interpreter.resize_tensor_input(idx, final_shape, strict=True) + index = input_details[idx]['index'] + interpreter.resize_tensor_input(index, final_shape, strict=True) interpreter.allocate_tensors() output_details = interpreter.get_output_details() + input_details = interpreter.get_input_details() for input_tensor, tensor_data in zip(input_details, input_data): interpreter.set_tensor(input_tensor['index'], tensor_data.numpy()) diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc index 5a7a3ae2aa5..a115e401cfa 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -218,16 +218,33 @@ PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) { return nullptr; } + std::vector dims(PyArray_NDIM(array)); + bool has_unknown_dims = false; for (int j = 0; j < PyArray_NDIM(array); j++) { - if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) { + // Ensure the calibration data input shape is the same as the model input + // shape unless the dimension is unknown. + if (tensor->dims_signature->size == tensor->dims->size && + tensor->dims_signature->data[j] == -1) { + has_unknown_dims = true; + } else if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) { PyErr_Format(PyExc_ValueError, "Cannot set tensor: Size mismatch, expected %d for dim " "%d but found %ld", tensor->dims->data[j], j, PyArray_SHAPE(array)[j]); return nullptr; } + dims[j] = PyArray_SHAPE(array)[j]; } + // Resize the input tensor if there are unknown dimensions. + if (has_unknown_dims) { + // Does strict checking on the `ResizeInputTensor` call. + TFLITE_PY_CHECK(interpreter_->ResizeInputTensorStrict(index, dims)); + TFLITE_PY_CHECK(interpreter_->AllocateTensors()); + } + + tensor = interpreter_->tensor(index); + size_t size = PyArray_NBYTES(array); if (size != tensor->bytes) { PyErr_Format(PyExc_ValueError, diff --git a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py index 26f518590df..71a1a31ac4c 100644 --- a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py +++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py @@ -424,7 +424,8 @@ def test_frozen_graph_quant(filename, # Convert and load the quantized model. converter = _lite.TFLiteConverter.from_frozen_graph(filename, input_arrays, - output_arrays) + output_arrays, + input_shapes) tflite_model_quant = _convert( converter, post_training_quantize=True, **kwargs) diff --git a/tensorflow/lite/tools/optimize/model_utils.cc b/tensorflow/lite/tools/optimize/model_utils.cc index 8c2e39e45fa..26dcff222bd 100644 --- a/tensorflow/lite/tools/optimize/model_utils.cc +++ b/tensorflow/lite/tools/optimize/model_utils.cc @@ -77,10 +77,14 @@ void MakeQuantizeOperator(ModelT* model, std::unique_ptr* op, // Create a new TensorT object without quantization parameters. void MakeTensor(const string& name, const std::vector& shape, + const std::vector& shape_signature, const TensorType& type, std::unique_ptr* tensor) { TensorT* tensor_raw = new TensorT; tensor_raw->name = name; tensor_raw->shape = shape; + if (!shape_signature.empty()) { + tensor_raw->shape_signature = shape_signature; + } tensor_raw->type = type; tensor->reset(tensor_raw); @@ -89,10 +93,11 @@ void MakeTensor(const string& name, const std::vector& shape, // Create a new TensorT object with quantization parameters. void MakeTensorWithQuantParam(const string& name, const std::vector& shape, + const std::vector& shape_signature, const TensorType& type, float scale, int64_t zero_point, std::unique_ptr* tensor) { - MakeTensor(name, shape, type, tensor); + MakeTensor(name, shape, shape_signature, type, tensor); (*tensor)->quantization = absl::make_unique(); (*tensor)->quantization->scale.push_back(scale); (*tensor)->quantization->zero_point.push_back(zero_point); diff --git a/tensorflow/lite/tools/optimize/model_utils.h b/tensorflow/lite/tools/optimize/model_utils.h index 6583d6a10db..f90e6b1a21d 100644 --- a/tensorflow/lite/tools/optimize/model_utils.h +++ b/tensorflow/lite/tools/optimize/model_utils.h @@ -34,11 +34,13 @@ void MakeQuantizeOperator(ModelT* model, std::unique_ptr* op, // Create a new TensorT object without quantization parameters. void MakeTensor(const string& name, const std::vector& shape, + const std::vector& shape_signature, const TensorType& type, std::unique_ptr* tensor); // Create a new TensorT object with quantization parameters. void MakeTensorWithQuantParam(const string& name, const std::vector& shape, + const std::vector& shape_signature, const TensorType& type, float scale, int64_t zero_point, std::unique_ptr* tensor); diff --git a/tensorflow/lite/tools/optimize/modify_model_interface.cc b/tensorflow/lite/tools/optimize/modify_model_interface.cc index d173bb608aa..0d2441a9c58 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface.cc +++ b/tensorflow/lite/tools/optimize/modify_model_interface.cc @@ -383,9 +383,9 @@ void AddUint8Dequant( const std::pair& provided_quant_params = quant_params.at(string(tensor->name)); utils::MakeTensorWithQuantParam( - added_tensor_name, tensor->shape, TensorType_UINT8, - provided_quant_params.first, provided_quant_params.second, - &leading_op_input); + added_tensor_name, tensor->shape, tensor->shape_signature, + TensorType_UINT8, provided_quant_params.first, + provided_quant_params.second, &leading_op_input); const int32_t leading_op_input_idx = subgraph->tensors.size(); subgraph->tensors.push_back(std::move(leading_op_input)); @@ -423,9 +423,9 @@ void AddUint8Quant( const std::pair& provided_quant_params = quant_params.at(string(tensor->name)); utils::MakeTensorWithQuantParam( - added_tensor_name, tensor->shape, TensorType_UINT8, - provided_quant_params.first, provided_quant_params.second, - &tailing_op_output); + added_tensor_name, tensor->shape, tensor->shape_signature, + TensorType_UINT8, provided_quant_params.first, + provided_quant_params.second, &tailing_op_output); const int32_t tailing_op_output_idx = subgraph->tensors.size(); subgraph->tensors.push_back(std::move(tailing_op_output)); diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc index 93372c6d460..b7b99d9c393 100644 --- a/tensorflow/lite/tools/optimize/quantize_model.cc +++ b/tensorflow/lite/tools/optimize/quantize_model.cc @@ -141,8 +141,8 @@ int32_t SetInputType(ModelT* model, SubGraphT* subgraph, const string leading_op_name = tensor->name; const string new_name_original_input = tensor->name + "_int8"; tensor->name = new_name_original_input; - utils::MakeTensor(leading_op_name, tensor->shape, input_type, - &leading_op_input); + utils::MakeTensor(leading_op_name, tensor->shape, tensor->shape_signature, + input_type, &leading_op_input); } else { // Get scale and zero point from the first tensor. const float scale = subgraph->tensors[tensor_idx]->quantization->scale[0]; @@ -156,9 +156,9 @@ int32_t SetInputType(ModelT* model, SubGraphT* subgraph, const string leading_op_name = tensor->name; const string new_name_original_input = tensor->name + "_int8"; tensor->name = new_name_original_input; - utils::MakeTensorWithQuantParam(leading_op_name, tensor->shape, - input_type, scale, zero_point + 128, - &leading_op_input); + utils::MakeTensorWithQuantParam( + leading_op_name, tensor->shape, tensor->shape_signature, input_type, + scale, zero_point + 128, &leading_op_input); } const int32_t leading_op_input_idx = subgraph->tensors.size(); subgraph->tensors.push_back(std::move(leading_op_input)); @@ -193,8 +193,8 @@ int32_t SetOutputType(ModelT* model, SubGraphT* subgraph, const string tailing_op_name = tensor->name; const string new_name_original_output = tensor->name + "_int8"; tensor->name = new_name_original_output; - utils::MakeTensor(tailing_op_name, tensor->shape, output_type, - &tailing_op_output); + utils::MakeTensor(tailing_op_name, tensor->shape, tensor->shape_signature, + output_type, &tailing_op_output); } else { // Get scale and zero point from the last tensor. const float scale = subgraph->tensors[tensor_idx]->quantization->scale[0]; @@ -208,9 +208,9 @@ int32_t SetOutputType(ModelT* model, SubGraphT* subgraph, const string tailing_op_name = tensor->name; const string new_name_original_output = tensor->name + "_int8"; tensor->name = new_name_original_output; - utils::MakeTensorWithQuantParam(tailing_op_name, tensor->shape, - output_type, scale, zero_point + 128, - &tailing_op_output); + utils::MakeTensorWithQuantParam( + tailing_op_name, tensor->shape, tensor->shape_signature, output_type, + scale, zero_point + 128, &tailing_op_output); } const int32_t tailing_op_output_idx = subgraph->tensors.size(); subgraph->tensors.push_back(std::move(tailing_op_output)); @@ -340,8 +340,9 @@ TfLiteStatus ApplyConstraints(ModelT* model, std::unique_ptr additional_tensor; const string requant_tensor_name = input_tensor->name + "_requantized"; utils::MakeTensorWithQuantParam( - requant_tensor_name, input_tensor->shape, TensorType_INT8, - output_scale, output_zp, &additional_tensor); + requant_tensor_name, input_tensor->shape, + input_tensor->shape_signature, TensorType_INT8, output_scale, + output_zp, &additional_tensor); const int32_t additional_tensor_idx = subgraph->tensors.size(); subgraph->tensors.push_back(std::move(additional_tensor)); @@ -545,7 +546,7 @@ TfLiteStatus QuantizeOpInput( // operation since the preceding op may require a float output. std::unique_ptr op_output; utils::MakeTensor(tensor->name + "_int8", tensor->shape, - TensorType_INT8, &op_output); + tensor->shape_signature, TensorType_INT8, &op_output); op_output->quantization = absl::make_unique(); op_output->quantization->min.push_back(tensor->quantization->min[0]); op_output->quantization->max.push_back(tensor->quantization->max[0]); @@ -573,7 +574,7 @@ TfLiteStatus QuantizeOpInput( // since this op is not quantizable. std::unique_ptr op_output; utils::MakeTensor(tensor->name + "_float", tensor->shape, - TensorType_FLOAT32, &op_output); + tensor->shape_signature, TensorType_FLOAT32, &op_output); const int32_t dequant_op_output_idx = subgraph->tensors.size(); subgraph->tensors.push_back(std::move(op_output)); std::unique_ptr dequant_op; diff --git a/tensorflow/lite/tools/optimize/quantize_weights.cc b/tensorflow/lite/tools/optimize/quantize_weights.cc index 581819495b1..7e3853c645c 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights.cc @@ -259,10 +259,14 @@ void MakeDequantizeOperator(ModelT* model, std::unique_ptr* op, // Create a new TensorT object. void MakeTensor(const string& name, const std::vector& shape, + const std::vector& shape_signature, std::unique_ptr* tensor) { TensorT* tensor_raw = new TensorT; tensor_raw->name = name; tensor_raw->shape = shape; + if (!shape_signature.empty()) { + tensor_raw->shape_signature = shape_signature; + } tensor->reset(tensor_raw); } @@ -419,8 +423,8 @@ TfLiteStatus QuantizeWeightsInt8(flatbuffers::FlatBufferBuilder* builder, // Create a new tensor to be the output of the dequantize op. std::unique_ptr dequantize_output; const string dequant_name = tensor->name + "_dequantize"; - utils::MakeTensor(dequant_name, tensor->shape, TensorType_FLOAT32, - &dequantize_output); + utils::MakeTensor(dequant_name, tensor->shape, tensor->shape_signature, + TensorType_FLOAT32, &dequantize_output); const int32_t dequantize_output_idx = subgraph->tensors.size(); subgraph->tensors.push_back(std::move(dequantize_output)); @@ -503,8 +507,8 @@ TfLiteStatus QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, // Create a new tensor to be the output of the dequantize op. std::unique_ptr dequantize_output; const string dequant_name = tensor->name + "_dequantize"; - utils::MakeTensor(dequant_name, tensor->shape, TensorType_FLOAT32, - &dequantize_output); + utils::MakeTensor(dequant_name, tensor->shape, tensor->shape_signature, + TensorType_FLOAT32, &dequantize_output); const int32_t dequantize_output_idx = subgraph->tensors.size(); subgraph->tensors.push_back(std::move(dequantize_output));