Make quantized_input_stats optional inside the converter

This is to avoid using some default protobuf field values inside the converter
when the user doesn't specify the quantized_input_stats.

PiperOrigin-RevId: 313611828
Change-Id: I2da39069b67aac409fe8290709712572b17a1b6e
This commit is contained in:
Feng Liu 2020-05-28 10:20:11 -07:00 committed by TensorFlower Gardener
parent 5c3ac7400a
commit 61330290aa
7 changed files with 52 additions and 43 deletions

View File

@ -55,8 +55,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<double> node_mins;
std::vector<double> node_maxs;
std::vector<llvm::Optional<double>> node_mins;
std::vector<llvm::Optional<double>> node_maxs;
// Populate quantization specs.
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(

View File

@ -125,8 +125,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<double> node_mins;
std::vector<double> node_maxs;
std::vector<llvm::Optional<double>> node_mins;
std::vector<llvm::Optional<double>> node_maxs;
// Populate quantization specs.
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(

View File

@ -177,14 +177,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
return RegisterCustomBuiltinOps(extra_tf_opdefs);
}
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs,
std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<double>* node_mins,
std::vector<double>* node_maxs) {
Status PopulateQuantizationSpecs(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<llvm::Optional<double>>* node_mins,
std::vector<llvm::Optional<double>>* node_maxs) {
quant_specs->inference_input_type =
ConvertIODataTypeToDataType(toco_flags.inference_input_type());
tensorflow::DataType inference_type =
@ -211,11 +210,16 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
flag.shape().dims().end()));
// Currently, only UINT8 and INT8 require inputs stats
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
TF_ASSIGN_OR_RETURN(
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(),
inference_type));
node_mins->push_back(min_max.first);
node_maxs->push_back(min_max.second);
if (flag.has_mean_value() && flag.has_std_value()) {
TF_ASSIGN_OR_RETURN(
auto min_max, InputStatsToMinMax(flag.mean_value(),
flag.std_value(), inference_type));
node_mins->push_back(min_max.first);
node_maxs->push_back(min_max.second);
} else {
node_mins->push_back(llvm::None);
node_maxs->push_back(llvm::None);
}
}
}

View File

@ -34,14 +34,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags);
// Populate quantization specs (or not) given user specified ranges for each
// input arrays.
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs,
std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<double>* node_mins,
std::vector<double>* node_maxs);
Status PopulateQuantizationSpecs(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<llvm::Optional<double>>* node_mins,
std::vector<llvm::Optional<double>>* node_maxs);
// Convert imported MLIR file to TfLite flatbuffer.
// This will also run relevant passes as well.

View File

@ -45,7 +45,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
absl::string_view inference_type,
QuantizationSpecs* quant_specs) {
std::vector<std::string> input_nodes = absl::StrSplit(node_names, ',');
std::vector<double> node_mins;
std::vector<llvm::Optional<double>> node_mins;
if (!min_values.empty()) {
std::vector<std::string> node_mins_str = absl::StrSplit(min_values, ',');
for (int i = 0; i < node_mins_str.size(); i++) {
@ -57,7 +57,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
}
}
std::vector<double> node_maxs;
std::vector<llvm::Optional<double>> node_maxs;
if (!max_values.empty()) {
std::vector<std::string> node_maxs_str = absl::StrSplit(max_values, ',');
for (int i = 0; i < node_maxs_str.size(); i++) {
@ -79,11 +79,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
quant_specs);
}
bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names,
const std::vector<double>& node_mins,
const std::vector<double>& node_maxs,
tensorflow::DataType inference_type,
QuantizationSpecs* quant_specs) {
bool GetInputNodeQuantSpecs(
const std::vector<std::string>& node_names,
const std::vector<llvm::Optional<double>>& node_mins,
const std::vector<llvm::Optional<double>>& node_maxs,
tensorflow::DataType inference_type, QuantizationSpecs* quant_specs) {
quant_specs->inference_type = inference_type;
// If min/max are not specified, just return;

View File

@ -19,6 +19,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
#include <optional>
#include <string>
#include <vector>
@ -69,7 +70,8 @@ struct QuantizationSpecs {
// arguments. They are only used when `weight_quantization` is set to false,
// and the model is required to have quantization parameters, either from
// quantization aware training or calibration, for the remaining tensors.
std::vector<std::pair<double, double>> input_ranges;
std::vector<std::pair<llvm::Optional<double>, llvm::Optional<double>>>
input_ranges;
// The default ranges can be used when a tensor doesn't have quantization
// parameters and couldn't be quantized. Used only for latency tests.
@ -130,11 +132,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
// Gets the quantization specification for input arrays. The array names are not
// stored in the spec, and will be matched by position. The min/max will be
// ignored if the inference_type isn't a quantized type. Returns true if failed.
bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names,
const std::vector<double>& node_mins,
const std::vector<double>& node_maxs,
tensorflow::DataType inference_type,
QuantizationSpecs* quant_specs);
bool GetInputNodeQuantSpecs(
const std::vector<std::string>& node_names,
const std::vector<llvm::Optional<double>>& node_mins,
const std::vector<llvm::Optional<double>>& node_maxs,
tensorflow::DataType inference_type, QuantizationSpecs* quant_specs);
} // namespace TFL
} // namespace mlir

View File

@ -109,8 +109,8 @@ class PrepareQuantizePass
// Get the min and max values from the quantization specification for the
// current function function and argument index. Uses default values if
// the function is specified in the `quantize_whitelist`.
std::pair<double, double> GetMinMaxValuesForArgument(
llvm::StringRef func_name, int index) {
std::pair<llvm::Optional<double>, llvm::Optional<double>>
GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) {
if (func_name == quant_specs_.target_func) {
return quant_specs_.input_ranges[index];
} else {
@ -160,10 +160,14 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
}
auto min_max = GetMinMaxValuesForArgument(func_name, i);
// The input min/max or mean/std are not specified, then skip.
if (!min_max.first.hasValue() || !min_max.second.hasValue()) return;
TypeAttr params = quant::GetQuantizedTypeAttr(
builder, input_type, builder.getF64FloatAttr(min_max.first),
builder.getF64FloatAttr(min_max.second), /*quant_dim=*/-1, num_bits,
narrow_range, is_signed);
builder, input_type,
builder.getF64FloatAttr(min_max.first.getValue()),
builder.getF64FloatAttr(min_max.second.getValue()),
/*quant_dim=*/-1, num_bits, narrow_range, is_signed);
builder.setInsertionPoint(block, insertion_point);
auto q_op =
builder.create<quant::QuantizeCastOp>(loc, params.getValue(), arg);