Support for post-training quantization for Gather
PiperOrigin-RevId: 239422886
This commit is contained in:
parent
b116bb78b2
commit
9f835183f0
@ -74,6 +74,7 @@ tf_cc_test(
|
||||
"--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin",
|
||||
],
|
||||
|
@ -60,7 +60,7 @@ std::vector<ConsumerOpInfo> GetTensorConsumers(const ModelT* model,
|
||||
for (size_t i = 0; i < op->inputs.size(); ++i) {
|
||||
if (op->inputs[i] == tensor_idx) {
|
||||
consumer_ops.push_back(
|
||||
{op, static_cast<int>(op_idx), static_cast<int>(i)});
|
||||
{op, static_cast<int32_t>(op_idx), static_cast<int32_t>(i)});
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -95,6 +95,9 @@ std::vector<int32_t> GetWeightInputIndices(const BuiltinOperator& op_code) {
|
||||
} else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) {
|
||||
// https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc
|
||||
return {1, 2, 4, 5};
|
||||
} else if (op_code == BuiltinOperator_GATHER) {
|
||||
// https://www.tensorflow.org/code/tensorflow/lite/kernels/gather.cc
|
||||
return {0};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
@ -194,7 +197,7 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator(
|
||||
// Returns the index of the Dequantize op_code.
|
||||
// If a Dequantize op_code doesn't exist, adds it and returns its index.
|
||||
int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) {
|
||||
for (size_t i = 0; i < model->operator_codes.size(); ++i) {
|
||||
for (int i = 0; i < model->operator_codes.size(); ++i) {
|
||||
if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) {
|
||||
return i;
|
||||
}
|
||||
@ -232,7 +235,7 @@ void MakeTensor(const string& name, const std::vector<int32_t>& shape,
|
||||
|
||||
// Updates operator code versions for the operators with INT8 inputs.
|
||||
void UpdateInt8OperatorVersions(ModelT* model) {
|
||||
for (size_t i = 0; i < model->operator_codes.size(); ++i) {
|
||||
for (int i = 0; i < model->operator_codes.size(); ++i) {
|
||||
const BuiltinOperator& op_code = model->operator_codes[i]->builtin_code;
|
||||
if (op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
|
||||
op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
|
||||
@ -250,6 +253,57 @@ void UpdateInt8OperatorVersions(ModelT* model) {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if the op in consumer_op_infos can pass through quantization.
|
||||
bool IsQuantizationPassThroughOps(
|
||||
const ModelT* model, const std::vector<ConsumerOpInfo>& consumer_op_infos) {
|
||||
if (consumer_op_infos.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
const OperatorT* consumer_op = consumer_op_infos.front().op;
|
||||
const BuiltinOperator op_code =
|
||||
model->operator_codes[consumer_op->opcode_index]->builtin_code;
|
||||
return op_code == BuiltinOperator_GATHER;
|
||||
}
|
||||
|
||||
// Copies quantization parameters from input to output and returns consumers of
|
||||
// the output tensor as a tuple with values:
|
||||
// - index of the output tensor
|
||||
// - pointer to the output tensor
|
||||
// - vector of consumers ops.
|
||||
std::tuple<int32_t, TensorT*, std::vector<ConsumerOpInfo>>
|
||||
PassQuantizationAndGetConsumers(
|
||||
const ModelT* model, const SubGraphT* subgraph,
|
||||
const std::vector<ConsumerOpInfo>& consumer_op_infos) {
|
||||
const OperatorT* op = consumer_op_infos.front().op;
|
||||
const BuiltinOperator op_code =
|
||||
model->operator_codes[op->opcode_index]->builtin_code;
|
||||
if (op->outputs.size() != 1) {
|
||||
LOG(ERROR)
|
||||
<< "An op that passes quantization has more than one quantized output";
|
||||
return std::make_tuple(-1, nullptr, std::vector<ConsumerOpInfo>());
|
||||
}
|
||||
const int32_t output_tensor_idx = op->outputs.front();
|
||||
const auto input_idx = GetWeightInputIndices(op_code);
|
||||
if (input_idx.size() != 1) {
|
||||
LOG(ERROR)
|
||||
<< "An op that passes quantization has more than one quantized input";
|
||||
return std::make_tuple(-1, nullptr, std::vector<ConsumerOpInfo>());
|
||||
}
|
||||
const int32_t input_tensor_idx = op->inputs[input_idx.front()];
|
||||
|
||||
// Propagate quantization params.
|
||||
const TensorT* input_tensor = subgraph->tensors[input_tensor_idx].get();
|
||||
TensorT* output_tensor = subgraph->tensors[output_tensor_idx].get();
|
||||
if (!output_tensor->quantization) {
|
||||
output_tensor->quantization = absl::make_unique<QuantizationParametersT>();
|
||||
}
|
||||
*output_tensor->quantization = *input_tensor->quantization;
|
||||
output_tensor->type = TensorType_INT8;
|
||||
return std::make_tuple(
|
||||
output_tensor_idx, output_tensor,
|
||||
GetTensorConsumers(model, subgraph, output_tensor_idx));
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
||||
const Model* input_model,
|
||||
bool use_hybrid_evaluation,
|
||||
@ -268,7 +322,7 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
||||
|
||||
std::vector<std::unique_ptr<OperatorT>> new_operators;
|
||||
std::unordered_map<int32_t, TensorT*> tensor_map;
|
||||
for (size_t i = 0; i < subgraph->operators.size(); ++i) {
|
||||
for (int i = 0; i < subgraph->operators.size(); ++i) {
|
||||
OperatorT* op = subgraph->operators[i].get();
|
||||
TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator(
|
||||
model.get(), op, weights_min_num_elements, &tensor_map));
|
||||
@ -285,10 +339,19 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
||||
|
||||
// Examine the tensor consumers to determine which require dequantize ops.
|
||||
for (const auto& tensor_pair : tensor_map) {
|
||||
const int32_t tensor_idx = tensor_pair.first;
|
||||
int32_t tensor_idx = tensor_pair.first;
|
||||
TensorT* tensor = tensor_pair.second;
|
||||
std::vector<ConsumerOpInfo> consumer_op_infos =
|
||||
GetTensorConsumers(model.get(), subgraph, tensor_idx);
|
||||
if (IsQuantizationPassThroughOps(model.get(), consumer_op_infos)) {
|
||||
std::tie(tensor_idx, tensor, consumer_op_infos) =
|
||||
PassQuantizationAndGetConsumers(model.get(), subgraph,
|
||||
consumer_op_infos);
|
||||
if (tensor_idx < 0) {
|
||||
// Error message is already logged by PassQuantizationAndGetConsumers.
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ConsumerOpInfo> dequant_op_infos; // Ops that need dequants.
|
||||
for (ConsumerOpInfo& consumer_op_info : consumer_op_infos) {
|
||||
@ -307,8 +370,18 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
||||
}
|
||||
}
|
||||
|
||||
// If no ops require dequant, we are done for this tensor.
|
||||
if (dequant_op_infos.empty()) {
|
||||
// Check that this tensor is an output tensor.
|
||||
int32_t output_index = -1;
|
||||
for (int32_t i = 0; i < subgraph->outputs.size(); ++i) {
|
||||
if (subgraph->outputs[i] == tensor_idx) {
|
||||
output_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If no ops require dequant and it is not output, we are done for this
|
||||
// tensor.
|
||||
if (dequant_op_infos.empty() && output_index < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -328,12 +401,16 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
||||
|
||||
// Update the op_input of all the ops that need the created dequantize
|
||||
// operation.
|
||||
int32_t min_op_idx = 0;
|
||||
int32_t min_op_idx = subgraph->operators.size();
|
||||
for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) {
|
||||
dequant_op_info.op->inputs[dequant_op_info.op_input_idx] =
|
||||
dequantize_output_idx;
|
||||
min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx);
|
||||
}
|
||||
// Update output name.
|
||||
if (output_index >= 0) {
|
||||
subgraph->outputs[output_index] = dequantize_output_idx;
|
||||
}
|
||||
|
||||
// Insert the newly created Dequantize operation before the earliest
|
||||
// consumer, since TFLite requires operators to be topo-sorted.
|
||||
|
@ -48,6 +48,12 @@ std::unique_ptr<FlatBufferModel> ReadSharedWeightsTestModel() {
|
||||
return FlatBufferModel::BuildFromFile(model_path.c_str());
|
||||
}
|
||||
|
||||
std::unique_ptr<FlatBufferModel> ReadGatherTestModel() {
|
||||
auto model_path = tensorflow::io::JoinPath(*g_test_model_dir,
|
||||
internal::kQuantizedWithGather);
|
||||
return FlatBufferModel::BuildFromFile(model_path.c_str());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetAsVector(const flatbuffers::Vector<T>* vec) {
|
||||
return std::vector<T>(vec->begin(), vec->end());
|
||||
@ -67,6 +73,11 @@ class QuantizeWeightsTest : public testing::Test {
|
||||
model_ = input_model_->GetModel();
|
||||
}
|
||||
|
||||
void LoadGatherTestModel() {
|
||||
input_model_ = ReadGatherTestModel();
|
||||
model_ = input_model_->GetModel();
|
||||
}
|
||||
|
||||
std::unique_ptr<FlatBufferModel> input_model_;
|
||||
const Model* model_;
|
||||
|
||||
@ -334,6 +345,34 @@ TEST_F(QuantizeWeightsTest, SharedWeights_Dequantize) {
|
||||
EXPECT_EQ(num_conv_ops, 2);
|
||||
}
|
||||
|
||||
TEST_F(QuantizeWeightsTest, VerifyGatherQuantization) {
|
||||
LoadGatherTestModel();
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto status = QuantizeWeights(&builder, model_, 0);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
const uint8_t* buffer = builder.GetBufferPointer();
|
||||
const Model* output_model = GetModel(buffer);
|
||||
ASSERT_TRUE(output_model);
|
||||
|
||||
ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
|
||||
++subgraph_idx) {
|
||||
const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
|
||||
for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
|
||||
const auto op = quantized_graph->operators()->Get(i);
|
||||
const uint32_t op_code_idx = op->opcode_index();
|
||||
const auto op_code =
|
||||
output_model->operator_codes()->Get(op_code_idx)->builtin_code();
|
||||
if (op_code == BuiltinOperator_GATHER) {
|
||||
uint32_t input_tensor_index = op->inputs()->Get(0);
|
||||
const auto weights_tensor =
|
||||
quantized_graph->tensors()->Get(input_tensor_index);
|
||||
EXPECT_EQ(weights_tensor->type(), TensorType_INT8);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace optimize
|
||||
} // namespace tflite
|
||||
|
@ -32,8 +32,8 @@ const char* kSingleAvgPoolModelMinMinus5MaxPlus5 =
|
||||
"single_avg_pool_min_minus_5_max_plus_5.bin";
|
||||
|
||||
const char* kModelWithSharedWeights = "weight_shared_between_convs.bin";
|
||||
|
||||
const char* kMultiInputAddWithReshape = "multi_input_add_reshape.bin";
|
||||
const char* kQuantizedWithGather = "quantized_with_gather.bin";
|
||||
|
||||
const char* kConstInputAddModel = "add_with_const_input.bin";
|
||||
|
||||
|
@ -49,6 +49,9 @@ extern const char* kModelWithSharedWeights;
|
||||
// Test model with Add followed by a reshape. Model has 2 inputs for add.
|
||||
extern const char* kMultiInputAddWithReshape;
|
||||
|
||||
// Test gather operation with quantized input.
|
||||
extern const char* kQuantizedWithGather;
|
||||
|
||||
// Test model with a tf.constant input to tf.add. Model has 2 inputs one
|
||||
// constant and other placeholder.
|
||||
extern const char* kConstInputAddModel;
|
||||
|
@ -25,3 +25,6 @@ This directory contains test models for testing quantization.
|
||||
A floating point model with two convs that have a use the same weight tensor.
|
||||
* `multi_input_add_reshape.bin` \
|
||||
A floating point model with two inputs with an add followed by a reshape.
|
||||
* `quantized_with_gather.bin` \
|
||||
A floating point model with an input with a gather, modeling a situation
|
||||
of mapping categorical input to embeddings.
|
||||
|
BIN
tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin
vendored
Normal file
BIN
tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user