diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index a07b7b8dd1d..8a2faebcbe6 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -55,8 +55,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, std::vector node_names; std::vector node_dtypes; std::vector> node_shapes; - std::vector node_mins; - std::vector node_maxs; + std::vector> node_mins; + std::vector> node_maxs; // Populate quantization specs. TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs( diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 51fcbb97360..ab80746f8b7 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -125,8 +125,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer( std::vector node_names; std::vector node_dtypes; std::vector> node_shapes; - std::vector node_mins; - std::vector node_maxs; + std::vector> node_mins; + std::vector> node_maxs; // Populate quantization specs. TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs( diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index a1401323e89..8f2c8bc362c 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -177,14 +177,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) { return RegisterCustomBuiltinOps(extra_tf_opdefs); } -Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - mlir::TFL::QuantizationSpecs* quant_specs, - std::vector* node_names, - std::vector* node_dtypes, - std::vector>* node_shapes, - std::vector* node_mins, - std::vector* node_maxs) { +Status PopulateQuantizationSpecs( + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + mlir::TFL::QuantizationSpecs* quant_specs, std::vector* node_names, + std::vector* node_dtypes, + std::vector>* node_shapes, + std::vector>* node_mins, + std::vector>* node_maxs) { quant_specs->inference_input_type = ConvertIODataTypeToDataType(toco_flags.inference_input_type()); tensorflow::DataType inference_type = @@ -211,11 +210,16 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, flag.shape().dims().end())); // Currently, only UINT8 and INT8 require inputs stats if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) { - TF_ASSIGN_OR_RETURN( - auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(), - inference_type)); - node_mins->push_back(min_max.first); - node_maxs->push_back(min_max.second); + if (flag.has_mean_value() && flag.has_std_value()) { + TF_ASSIGN_OR_RETURN( + auto min_max, InputStatsToMinMax(flag.mean_value(), + flag.std_value(), inference_type)); + node_mins->push_back(min_max.first); + node_maxs->push_back(min_max.second); + } else { + node_mins->push_back(llvm::None); + node_maxs->push_back(llvm::None); + } } } diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index 3ea36e5eb1d..87e73912a46 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -34,14 +34,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags); // Populate quantization specs (or not) given user specified ranges for each // input arrays. -Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - mlir::TFL::QuantizationSpecs* quant_specs, - std::vector* node_names, - std::vector* node_dtypes, - std::vector>* node_shapes, - std::vector* node_mins, - std::vector* node_maxs); +Status PopulateQuantizationSpecs( + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + mlir::TFL::QuantizationSpecs* quant_specs, std::vector* node_names, + std::vector* node_dtypes, + std::vector>* node_shapes, + std::vector>* node_mins, + std::vector>* node_maxs); // Convert imported MLIR file to TfLite flatbuffer. // This will also run relevant passes as well. diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc index 6b897bd5608..3edd9c36760 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc @@ -45,7 +45,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names, absl::string_view inference_type, QuantizationSpecs* quant_specs) { std::vector input_nodes = absl::StrSplit(node_names, ','); - std::vector node_mins; + std::vector> node_mins; if (!min_values.empty()) { std::vector node_mins_str = absl::StrSplit(min_values, ','); for (int i = 0; i < node_mins_str.size(); i++) { @@ -57,7 +57,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names, } } - std::vector node_maxs; + std::vector> node_maxs; if (!max_values.empty()) { std::vector node_maxs_str = absl::StrSplit(max_values, ','); for (int i = 0; i < node_maxs_str.size(); i++) { @@ -79,11 +79,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names, quant_specs); } -bool GetInputNodeQuantSpecs(const std::vector& node_names, - const std::vector& node_mins, - const std::vector& node_maxs, - tensorflow::DataType inference_type, - QuantizationSpecs* quant_specs) { +bool GetInputNodeQuantSpecs( + const std::vector& node_names, + const std::vector>& node_mins, + const std::vector>& node_maxs, + tensorflow::DataType inference_type, QuantizationSpecs* quant_specs) { quant_specs->inference_type = inference_type; // If min/max are not specified, just return; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h index 2ffba579548..a4046553d17 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h @@ -19,6 +19,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_ +#include #include #include @@ -69,7 +70,8 @@ struct QuantizationSpecs { // arguments. They are only used when `weight_quantization` is set to false, // and the model is required to have quantization parameters, either from // quantization aware training or calibration, for the remaining tensors. - std::vector> input_ranges; + std::vector, llvm::Optional>> + input_ranges; // The default ranges can be used when a tensor doesn't have quantization // parameters and couldn't be quantized. Used only for latency tests. @@ -130,11 +132,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names, // Gets the quantization specification for input arrays. The array names are not // stored in the spec, and will be matched by position. The min/max will be // ignored if the inference_type isn't a quantized type. Returns true if failed. -bool GetInputNodeQuantSpecs(const std::vector& node_names, - const std::vector& node_mins, - const std::vector& node_maxs, - tensorflow::DataType inference_type, - QuantizationSpecs* quant_specs); +bool GetInputNodeQuantSpecs( + const std::vector& node_names, + const std::vector>& node_mins, + const std::vector>& node_maxs, + tensorflow::DataType inference_type, QuantizationSpecs* quant_specs); } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 87cae3dd957..702808ac892 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -109,8 +109,8 @@ class PrepareQuantizePass // Get the min and max values from the quantization specification for the // current function function and argument index. Uses default values if // the function is specified in the `quantize_whitelist`. - std::pair GetMinMaxValuesForArgument( - llvm::StringRef func_name, int index) { + std::pair, llvm::Optional> + GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) { if (func_name == quant_specs_.target_func) { return quant_specs_.input_ranges[index]; } else { @@ -160,10 +160,14 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) { } auto min_max = GetMinMaxValuesForArgument(func_name, i); + // The input min/max or mean/std are not specified, then skip. + if (!min_max.first.hasValue() || !min_max.second.hasValue()) return; + TypeAttr params = quant::GetQuantizedTypeAttr( - builder, input_type, builder.getF64FloatAttr(min_max.first), - builder.getF64FloatAttr(min_max.second), /*quant_dim=*/-1, num_bits, - narrow_range, is_signed); + builder, input_type, + builder.getF64FloatAttr(min_max.first.getValue()), + builder.getF64FloatAttr(min_max.second.getValue()), + /*quant_dim=*/-1, num_bits, narrow_range, is_signed); builder.setInsertionPoint(block, insertion_point); auto q_op = builder.create(loc, params.getValue(), arg);