Support for post-training quantization for Gather
PiperOrigin-RevId: 239422886
This commit is contained in:
parent
b116bb78b2
commit
9f835183f0
@ -74,6 +74,7 @@ tf_cc_test(
|
|||||||
"--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
|
"--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
|
||||||
],
|
],
|
||||||
data = [
|
data = [
|
||||||
|
"//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin",
|
||||||
"//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
|
"//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
|
||||||
"//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin",
|
"//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin",
|
||||||
],
|
],
|
||||||
|
@ -60,7 +60,7 @@ std::vector<ConsumerOpInfo> GetTensorConsumers(const ModelT* model,
|
|||||||
for (size_t i = 0; i < op->inputs.size(); ++i) {
|
for (size_t i = 0; i < op->inputs.size(); ++i) {
|
||||||
if (op->inputs[i] == tensor_idx) {
|
if (op->inputs[i] == tensor_idx) {
|
||||||
consumer_ops.push_back(
|
consumer_ops.push_back(
|
||||||
{op, static_cast<int>(op_idx), static_cast<int>(i)});
|
{op, static_cast<int32_t>(op_idx), static_cast<int32_t>(i)});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -95,6 +95,9 @@ std::vector<int32_t> GetWeightInputIndices(const BuiltinOperator& op_code) {
|
|||||||
} else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) {
|
} else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) {
|
||||||
// https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc
|
// https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc
|
||||||
return {1, 2, 4, 5};
|
return {1, 2, 4, 5};
|
||||||
|
} else if (op_code == BuiltinOperator_GATHER) {
|
||||||
|
// https://www.tensorflow.org/code/tensorflow/lite/kernels/gather.cc
|
||||||
|
return {0};
|
||||||
}
|
}
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
@ -194,7 +197,7 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator(
|
|||||||
// Returns the index of the Dequantize op_code.
|
// Returns the index of the Dequantize op_code.
|
||||||
// If a Dequantize op_code doesn't exist, adds it and returns its index.
|
// If a Dequantize op_code doesn't exist, adds it and returns its index.
|
||||||
int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) {
|
int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) {
|
||||||
for (size_t i = 0; i < model->operator_codes.size(); ++i) {
|
for (int i = 0; i < model->operator_codes.size(); ++i) {
|
||||||
if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) {
|
if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) {
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
@ -232,7 +235,7 @@ void MakeTensor(const string& name, const std::vector<int32_t>& shape,
|
|||||||
|
|
||||||
// Updates operator code versions for the operators with INT8 inputs.
|
// Updates operator code versions for the operators with INT8 inputs.
|
||||||
void UpdateInt8OperatorVersions(ModelT* model) {
|
void UpdateInt8OperatorVersions(ModelT* model) {
|
||||||
for (size_t i = 0; i < model->operator_codes.size(); ++i) {
|
for (int i = 0; i < model->operator_codes.size(); ++i) {
|
||||||
const BuiltinOperator& op_code = model->operator_codes[i]->builtin_code;
|
const BuiltinOperator& op_code = model->operator_codes[i]->builtin_code;
|
||||||
if (op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
|
if (op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
|
||||||
op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
|
op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
|
||||||
@ -250,6 +253,57 @@ void UpdateInt8OperatorVersions(ModelT* model) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns true if the op in consumer_op_infos can pass through quantization.
|
||||||
|
bool IsQuantizationPassThroughOps(
|
||||||
|
const ModelT* model, const std::vector<ConsumerOpInfo>& consumer_op_infos) {
|
||||||
|
if (consumer_op_infos.size() != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const OperatorT* consumer_op = consumer_op_infos.front().op;
|
||||||
|
const BuiltinOperator op_code =
|
||||||
|
model->operator_codes[consumer_op->opcode_index]->builtin_code;
|
||||||
|
return op_code == BuiltinOperator_GATHER;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copies quantization parameters from input to output and returns consumers of
|
||||||
|
// the output tensor as a tuple with values:
|
||||||
|
// - index of the output tensor
|
||||||
|
// - pointer to the output tensor
|
||||||
|
// - vector of consumers ops.
|
||||||
|
std::tuple<int32_t, TensorT*, std::vector<ConsumerOpInfo>>
|
||||||
|
PassQuantizationAndGetConsumers(
|
||||||
|
const ModelT* model, const SubGraphT* subgraph,
|
||||||
|
const std::vector<ConsumerOpInfo>& consumer_op_infos) {
|
||||||
|
const OperatorT* op = consumer_op_infos.front().op;
|
||||||
|
const BuiltinOperator op_code =
|
||||||
|
model->operator_codes[op->opcode_index]->builtin_code;
|
||||||
|
if (op->outputs.size() != 1) {
|
||||||
|
LOG(ERROR)
|
||||||
|
<< "An op that passes quantization has more than one quantized output";
|
||||||
|
return std::make_tuple(-1, nullptr, std::vector<ConsumerOpInfo>());
|
||||||
|
}
|
||||||
|
const int32_t output_tensor_idx = op->outputs.front();
|
||||||
|
const auto input_idx = GetWeightInputIndices(op_code);
|
||||||
|
if (input_idx.size() != 1) {
|
||||||
|
LOG(ERROR)
|
||||||
|
<< "An op that passes quantization has more than one quantized input";
|
||||||
|
return std::make_tuple(-1, nullptr, std::vector<ConsumerOpInfo>());
|
||||||
|
}
|
||||||
|
const int32_t input_tensor_idx = op->inputs[input_idx.front()];
|
||||||
|
|
||||||
|
// Propagate quantization params.
|
||||||
|
const TensorT* input_tensor = subgraph->tensors[input_tensor_idx].get();
|
||||||
|
TensorT* output_tensor = subgraph->tensors[output_tensor_idx].get();
|
||||||
|
if (!output_tensor->quantization) {
|
||||||
|
output_tensor->quantization = absl::make_unique<QuantizationParametersT>();
|
||||||
|
}
|
||||||
|
*output_tensor->quantization = *input_tensor->quantization;
|
||||||
|
output_tensor->type = TensorType_INT8;
|
||||||
|
return std::make_tuple(
|
||||||
|
output_tensor_idx, output_tensor,
|
||||||
|
GetTensorConsumers(model, subgraph, output_tensor_idx));
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
||||||
const Model* input_model,
|
const Model* input_model,
|
||||||
bool use_hybrid_evaluation,
|
bool use_hybrid_evaluation,
|
||||||
@ -268,7 +322,7 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
|||||||
|
|
||||||
std::vector<std::unique_ptr<OperatorT>> new_operators;
|
std::vector<std::unique_ptr<OperatorT>> new_operators;
|
||||||
std::unordered_map<int32_t, TensorT*> tensor_map;
|
std::unordered_map<int32_t, TensorT*> tensor_map;
|
||||||
for (size_t i = 0; i < subgraph->operators.size(); ++i) {
|
for (int i = 0; i < subgraph->operators.size(); ++i) {
|
||||||
OperatorT* op = subgraph->operators[i].get();
|
OperatorT* op = subgraph->operators[i].get();
|
||||||
TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator(
|
TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator(
|
||||||
model.get(), op, weights_min_num_elements, &tensor_map));
|
model.get(), op, weights_min_num_elements, &tensor_map));
|
||||||
@ -285,10 +339,19 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
|||||||
|
|
||||||
// Examine the tensor consumers to determine which require dequantize ops.
|
// Examine the tensor consumers to determine which require dequantize ops.
|
||||||
for (const auto& tensor_pair : tensor_map) {
|
for (const auto& tensor_pair : tensor_map) {
|
||||||
const int32_t tensor_idx = tensor_pair.first;
|
int32_t tensor_idx = tensor_pair.first;
|
||||||
TensorT* tensor = tensor_pair.second;
|
TensorT* tensor = tensor_pair.second;
|
||||||
std::vector<ConsumerOpInfo> consumer_op_infos =
|
std::vector<ConsumerOpInfo> consumer_op_infos =
|
||||||
GetTensorConsumers(model.get(), subgraph, tensor_idx);
|
GetTensorConsumers(model.get(), subgraph, tensor_idx);
|
||||||
|
if (IsQuantizationPassThroughOps(model.get(), consumer_op_infos)) {
|
||||||
|
std::tie(tensor_idx, tensor, consumer_op_infos) =
|
||||||
|
PassQuantizationAndGetConsumers(model.get(), subgraph,
|
||||||
|
consumer_op_infos);
|
||||||
|
if (tensor_idx < 0) {
|
||||||
|
// Error message is already logged by PassQuantizationAndGetConsumers.
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<ConsumerOpInfo> dequant_op_infos; // Ops that need dequants.
|
std::vector<ConsumerOpInfo> dequant_op_infos; // Ops that need dequants.
|
||||||
for (ConsumerOpInfo& consumer_op_info : consumer_op_infos) {
|
for (ConsumerOpInfo& consumer_op_info : consumer_op_infos) {
|
||||||
@ -307,8 +370,18 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no ops require dequant, we are done for this tensor.
|
// Check that this tensor is an output tensor.
|
||||||
if (dequant_op_infos.empty()) {
|
int32_t output_index = -1;
|
||||||
|
for (int32_t i = 0; i < subgraph->outputs.size(); ++i) {
|
||||||
|
if (subgraph->outputs[i] == tensor_idx) {
|
||||||
|
output_index = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no ops require dequant and it is not output, we are done for this
|
||||||
|
// tensor.
|
||||||
|
if (dequant_op_infos.empty() && output_index < 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -328,12 +401,16 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
|||||||
|
|
||||||
// Update the op_input of all the ops that need the created dequantize
|
// Update the op_input of all the ops that need the created dequantize
|
||||||
// operation.
|
// operation.
|
||||||
int32_t min_op_idx = 0;
|
int32_t min_op_idx = subgraph->operators.size();
|
||||||
for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) {
|
for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) {
|
||||||
dequant_op_info.op->inputs[dequant_op_info.op_input_idx] =
|
dequant_op_info.op->inputs[dequant_op_info.op_input_idx] =
|
||||||
dequantize_output_idx;
|
dequantize_output_idx;
|
||||||
min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx);
|
min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx);
|
||||||
}
|
}
|
||||||
|
// Update output name.
|
||||||
|
if (output_index >= 0) {
|
||||||
|
subgraph->outputs[output_index] = dequantize_output_idx;
|
||||||
|
}
|
||||||
|
|
||||||
// Insert the newly created Dequantize operation before the earliest
|
// Insert the newly created Dequantize operation before the earliest
|
||||||
// consumer, since TFLite requires operators to be topo-sorted.
|
// consumer, since TFLite requires operators to be topo-sorted.
|
||||||
|
@ -48,6 +48,12 @@ std::unique_ptr<FlatBufferModel> ReadSharedWeightsTestModel() {
|
|||||||
return FlatBufferModel::BuildFromFile(model_path.c_str());
|
return FlatBufferModel::BuildFromFile(model_path.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<FlatBufferModel> ReadGatherTestModel() {
|
||||||
|
auto model_path = tensorflow::io::JoinPath(*g_test_model_dir,
|
||||||
|
internal::kQuantizedWithGather);
|
||||||
|
return FlatBufferModel::BuildFromFile(model_path.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<T> GetAsVector(const flatbuffers::Vector<T>* vec) {
|
std::vector<T> GetAsVector(const flatbuffers::Vector<T>* vec) {
|
||||||
return std::vector<T>(vec->begin(), vec->end());
|
return std::vector<T>(vec->begin(), vec->end());
|
||||||
@ -67,6 +73,11 @@ class QuantizeWeightsTest : public testing::Test {
|
|||||||
model_ = input_model_->GetModel();
|
model_ = input_model_->GetModel();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LoadGatherTestModel() {
|
||||||
|
input_model_ = ReadGatherTestModel();
|
||||||
|
model_ = input_model_->GetModel();
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<FlatBufferModel> input_model_;
|
std::unique_ptr<FlatBufferModel> input_model_;
|
||||||
const Model* model_;
|
const Model* model_;
|
||||||
|
|
||||||
@ -334,6 +345,34 @@ TEST_F(QuantizeWeightsTest, SharedWeights_Dequantize) {
|
|||||||
EXPECT_EQ(num_conv_ops, 2);
|
EXPECT_EQ(num_conv_ops, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(QuantizeWeightsTest, VerifyGatherQuantization) {
|
||||||
|
LoadGatherTestModel();
|
||||||
|
flatbuffers::FlatBufferBuilder builder;
|
||||||
|
auto status = QuantizeWeights(&builder, model_, 0);
|
||||||
|
EXPECT_EQ(status, kTfLiteOk);
|
||||||
|
|
||||||
|
const uint8_t* buffer = builder.GetBufferPointer();
|
||||||
|
const Model* output_model = GetModel(buffer);
|
||||||
|
ASSERT_TRUE(output_model);
|
||||||
|
|
||||||
|
ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
|
||||||
|
for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
|
||||||
|
++subgraph_idx) {
|
||||||
|
const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
|
||||||
|
for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
|
||||||
|
const auto op = quantized_graph->operators()->Get(i);
|
||||||
|
const uint32_t op_code_idx = op->opcode_index();
|
||||||
|
const auto op_code =
|
||||||
|
output_model->operator_codes()->Get(op_code_idx)->builtin_code();
|
||||||
|
if (op_code == BuiltinOperator_GATHER) {
|
||||||
|
uint32_t input_tensor_index = op->inputs()->Get(0);
|
||||||
|
const auto weights_tensor =
|
||||||
|
quantized_graph->tensors()->Get(input_tensor_index);
|
||||||
|
EXPECT_EQ(weights_tensor->type(), TensorType_INT8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace optimize
|
} // namespace optimize
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -32,8 +32,8 @@ const char* kSingleAvgPoolModelMinMinus5MaxPlus5 =
|
|||||||
"single_avg_pool_min_minus_5_max_plus_5.bin";
|
"single_avg_pool_min_minus_5_max_plus_5.bin";
|
||||||
|
|
||||||
const char* kModelWithSharedWeights = "weight_shared_between_convs.bin";
|
const char* kModelWithSharedWeights = "weight_shared_between_convs.bin";
|
||||||
|
|
||||||
const char* kMultiInputAddWithReshape = "multi_input_add_reshape.bin";
|
const char* kMultiInputAddWithReshape = "multi_input_add_reshape.bin";
|
||||||
|
const char* kQuantizedWithGather = "quantized_with_gather.bin";
|
||||||
|
|
||||||
const char* kConstInputAddModel = "add_with_const_input.bin";
|
const char* kConstInputAddModel = "add_with_const_input.bin";
|
||||||
|
|
||||||
|
@ -49,6 +49,9 @@ extern const char* kModelWithSharedWeights;
|
|||||||
// Test model with Add followed by a reshape. Model has 2 inputs for add.
|
// Test model with Add followed by a reshape. Model has 2 inputs for add.
|
||||||
extern const char* kMultiInputAddWithReshape;
|
extern const char* kMultiInputAddWithReshape;
|
||||||
|
|
||||||
|
// Test gather operation with quantized input.
|
||||||
|
extern const char* kQuantizedWithGather;
|
||||||
|
|
||||||
// Test model with a tf.constant input to tf.add. Model has 2 inputs one
|
// Test model with a tf.constant input to tf.add. Model has 2 inputs one
|
||||||
// constant and other placeholder.
|
// constant and other placeholder.
|
||||||
extern const char* kConstInputAddModel;
|
extern const char* kConstInputAddModel;
|
||||||
|
@ -25,3 +25,6 @@ This directory contains test models for testing quantization.
|
|||||||
A floating point model with two convs that have a use the same weight tensor.
|
A floating point model with two convs that have a use the same weight tensor.
|
||||||
* `multi_input_add_reshape.bin` \
|
* `multi_input_add_reshape.bin` \
|
||||||
A floating point model with two inputs with an add followed by a reshape.
|
A floating point model with two inputs with an add followed by a reshape.
|
||||||
|
* `quantized_with_gather.bin` \
|
||||||
|
A floating point model with an input with a gather, modeling a situation
|
||||||
|
of mapping categorical input to embeddings.
|
||||||
|
BIN
tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin
vendored
Normal file
BIN
tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user