Add support to only quantize specified operators in the quantization tool.
The operators are keys by their first output tensor name. PiperOrigin-RevId: 265096821
This commit is contained in:
parent
a148b74a28
commit
9e36fea5ff
@ -167,6 +167,7 @@ cc_library(
|
||||
":operator_property",
|
||||
":quantization_utils",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@flatbuffers",
|
||||
|
@ -20,6 +20,8 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "flatbuffers/flexbuffers.h"
|
||||
@ -36,6 +38,24 @@ namespace optimize {
|
||||
|
||||
namespace {
|
||||
|
||||
// Gets the operator property from the operator_property list and additionally
|
||||
// modifies the quantizable parameter based on the user's specified
|
||||
// operator_names.
|
||||
operator_property::OperatorProperty GetOperatorProperty(
|
||||
const std::unordered_set<string>& operator_names, const BuiltinOperator& op,
|
||||
const string& operator_name) {
|
||||
operator_property::OperatorProperty property =
|
||||
operator_property::GetOperatorProperty(op);
|
||||
// The algorithm adds Dequantize and Quantize, so we don't require them to be
|
||||
// in the operator_names.
|
||||
if (op != BuiltinOperator_DEQUANTIZE && op != BuiltinOperator_QUANTIZE) {
|
||||
property.quantizable =
|
||||
property.quantizable &&
|
||||
(operator_names.find(operator_name) != operator_names.end());
|
||||
}
|
||||
return property;
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
|
||||
const TensorT* weight_tensor, TensorT* bias_tensor,
|
||||
bool is_per_channel, int channel_dim_index,
|
||||
@ -239,8 +259,8 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
||||
// TODO(suharshs): Add support for this case if it ever comes up.
|
||||
if (tensor->type == TensorType_FLOAT32 && output_type != tensor->type) {
|
||||
error_reporter->Report(
|
||||
"Unsupported output type %s for output tensor %d of type %s.",
|
||||
EnumNameTensorType(output_type), subgraph->outputs[i],
|
||||
"Unsupported output type %s for output tensor '%s' of type %s.",
|
||||
EnumNameTensorType(output_type), tensor->name.c_str(),
|
||||
EnumNameTensorType(tensor->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
@ -260,7 +280,9 @@ TfLiteStatus SetInputAndOutputTypes(ModelT* model, const TensorType& input_type,
|
||||
// outpus must have the same scale and zero point. The other ones with
|
||||
// constraints(averagepool, maxpool, gather, softmax, tanh etc) are handled in
|
||||
// QuantizeWeightsAndInput.
|
||||
TfLiteStatus ApplyConstraints(ModelT* model, ErrorReporter* error_reporter) {
|
||||
TfLiteStatus ApplyConstraints(ModelT* model,
|
||||
const std::unordered_set<string>& operator_names,
|
||||
ErrorReporter* error_reporter) {
|
||||
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||
@ -269,8 +291,8 @@ TfLiteStatus ApplyConstraints(ModelT* model, ErrorReporter* error_reporter) {
|
||||
OperatorT* op = subgraph->operators[op_idx].get();
|
||||
const BuiltinOperator op_code =
|
||||
model->operator_codes[op->opcode_index]->builtin_code;
|
||||
operator_property::OperatorProperty property =
|
||||
operator_property::GetOperatorProperty(op_code);
|
||||
operator_property::OperatorProperty property = GetOperatorProperty(
|
||||
operator_names, op_code, subgraph->tensors[op->outputs[0]]->name);
|
||||
if (!property.quantizable) {
|
||||
continue;
|
||||
}
|
||||
@ -546,8 +568,10 @@ TfLiteStatus QuantizeOpOutput(
|
||||
|
||||
// Quantize inputs and weights.
|
||||
// Because of ops such as lstm, still need to do per op, instead of weights.
|
||||
TfLiteStatus QuantizeWeightsInputOutput(ModelT* model, bool allow_float,
|
||||
ErrorReporter* error_reporter) {
|
||||
TfLiteStatus QuantizeWeightsInputOutput(
|
||||
ModelT* model, bool allow_float,
|
||||
const std::unordered_set<string>& operator_names,
|
||||
ErrorReporter* error_reporter) {
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||
@ -555,8 +579,8 @@ TfLiteStatus QuantizeWeightsInputOutput(ModelT* model, bool allow_float,
|
||||
OperatorT* op = subgraph->operators[op_idx].get();
|
||||
const BuiltinOperator op_code =
|
||||
model->operator_codes[op->opcode_index]->builtin_code;
|
||||
operator_property::OperatorProperty property =
|
||||
operator_property::GetOperatorProperty(op_code);
|
||||
operator_property::OperatorProperty property = GetOperatorProperty(
|
||||
operator_names, op_code, subgraph->tensors[op->outputs[0]]->name);
|
||||
|
||||
if (!property.quantizable && !allow_float) {
|
||||
error_reporter->Report("Quantization not yet supported for op: %s",
|
||||
@ -583,7 +607,9 @@ TfLiteStatus QuantizeWeightsInputOutput(ModelT* model, bool allow_float,
|
||||
}
|
||||
|
||||
// Quantize bias.
|
||||
TfLiteStatus QuantizeBiases(ModelT* model, ErrorReporter* error_reporter) {
|
||||
TfLiteStatus QuantizeBiases(ModelT* model,
|
||||
const std::unordered_set<string>& operator_names,
|
||||
ErrorReporter* error_reporter) {
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||
@ -591,8 +617,8 @@ TfLiteStatus QuantizeBiases(ModelT* model, ErrorReporter* error_reporter) {
|
||||
OperatorT* op = subgraph->operators[op_idx].get();
|
||||
const BuiltinOperator op_code =
|
||||
model->operator_codes[op->opcode_index]->builtin_code;
|
||||
operator_property::OperatorProperty property =
|
||||
operator_property::GetOperatorProperty(op_code);
|
||||
operator_property::OperatorProperty property = GetOperatorProperty(
|
||||
operator_names, op_code, subgraph->tensors[op->outputs[0]]->name);
|
||||
if (!property.quantizable) {
|
||||
continue;
|
||||
}
|
||||
@ -639,17 +665,32 @@ TfLiteStatus QuantizeBiases(ModelT* model, ErrorReporter* error_reporter) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
std::unordered_set<string> GetAllOperatorOutputs(ModelT* model) {
|
||||
std::unordered_set<string> operator_names;
|
||||
for (int32_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||
for (int32_t tensor_idx = 0; tensor_idx < subgraph->tensors.size();
|
||||
tensor_idx++) {
|
||||
operator_names.insert(subgraph->tensors[tensor_idx]->name);
|
||||
}
|
||||
}
|
||||
return operator_names;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Assumes that the operators in the model have been topologically sorted.
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type, bool allow_float,
|
||||
const std::unordered_set<string>& operator_names,
|
||||
ErrorReporter* error_reporter) {
|
||||
TF_LITE_ENSURE_STATUS(QuantizeWeightsInputOutput(
|
||||
model, allow_float, operator_names, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
QuantizeWeightsInputOutput(model, allow_float, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(ApplyConstraints(model, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(QuantizeBiases(model, error_reporter));
|
||||
ApplyConstraints(model, operator_names, error_reporter));
|
||||
TF_LITE_ENSURE_STATUS(QuantizeBiases(model, operator_names, error_reporter));
|
||||
utils::SetOperatorCodeVersion(model);
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
SetInputAndOutputTypes(model, input_type, output_type, error_reporter));
|
||||
@ -661,6 +702,14 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type, bool allow_float,
|
||||
ErrorReporter* error_reporter) {
|
||||
return QuantizeModel(builder, model, input_type, output_type, allow_float,
|
||||
GetAllOperatorOutputs(model), error_reporter);
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type,
|
||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_QUANTIZE_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/lite/context.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace optimize {
|
||||
@ -53,6 +55,16 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
const TensorType& output_type, bool allow_float,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
// Same as above, but enables only quantizing a whitelist of operations,
|
||||
// specified by their operator output name.
|
||||
//
|
||||
// Note: This is a private API, subject to change.
|
||||
TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* input_model, const TensorType& input_type,
|
||||
const TensorType& output_type, bool allow_float,
|
||||
const std::unordered_set<string>& operator_names,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
} // namespace optimize
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -98,6 +98,31 @@ TEST_F(QuantizeConvModelTest, QuantizationSucceeds) {
|
||||
ASSERT_TRUE(output_model);
|
||||
}
|
||||
|
||||
TEST_F(QuantizeConvModelTest, SkipUnspecifiedLayer) {
|
||||
auto status =
|
||||
QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32,
|
||||
/*allow_float=*/true, {}, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
|
||||
// The resulting model should be the same.
|
||||
ASSERT_EQ(model_.subgraphs.size(), readonly_model_->subgraphs()->size());
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
|
||||
subgraph_idx++) {
|
||||
const auto quantized_graph = model_.subgraphs[subgraph_idx].get();
|
||||
const auto float_graph = readonly_model_->subgraphs()->Get(subgraph_idx);
|
||||
ASSERT_EQ(quantized_graph->tensors.size(), float_graph->tensors()->size());
|
||||
for (size_t i = 0; i < quantized_graph->tensors.size(); i++) {
|
||||
const auto quant_tensor = quantized_graph->tensors[i].get();
|
||||
const auto float_tensor = float_graph->tensors()->Get(i);
|
||||
EXPECT_EQ(quant_tensor->buffer, float_tensor->buffer());
|
||||
EXPECT_EQ(quant_tensor->is_variable, float_tensor->is_variable());
|
||||
EXPECT_EQ(quant_tensor->shape, GetAsVector(float_tensor->shape()));
|
||||
EXPECT_EQ(quant_tensor->name, float_tensor->name()->str());
|
||||
EXPECT_EQ(quant_tensor->type, float_tensor->type());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
|
Loading…
Reference in New Issue
Block a user