Support for post-training quantization for Gather

PiperOrigin-RevId: 239422886
This commit is contained in:
A. Unique TensorFlower 2019-03-20 10:22:41 -07:00 committed by TensorFlower Gardener
parent b116bb78b2
commit 9f835183f0
7 changed files with 132 additions and 9 deletions

View File

@ -74,6 +74,7 @@ tf_cc_test(
"--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
], ],
data = [ data = [
"//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin",
"//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
"//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin", "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin",
], ],

View File

@ -60,7 +60,7 @@ std::vector<ConsumerOpInfo> GetTensorConsumers(const ModelT* model,
for (size_t i = 0; i < op->inputs.size(); ++i) { for (size_t i = 0; i < op->inputs.size(); ++i) {
if (op->inputs[i] == tensor_idx) { if (op->inputs[i] == tensor_idx) {
consumer_ops.push_back( consumer_ops.push_back(
{op, static_cast<int>(op_idx), static_cast<int>(i)}); {op, static_cast<int32_t>(op_idx), static_cast<int32_t>(i)});
} }
} }
} }
@ -95,6 +95,9 @@ std::vector<int32_t> GetWeightInputIndices(const BuiltinOperator& op_code) {
} else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) { } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) {
// https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc
return {1, 2, 4, 5}; return {1, 2, 4, 5};
} else if (op_code == BuiltinOperator_GATHER) {
// https://www.tensorflow.org/code/tensorflow/lite/kernels/gather.cc
return {0};
} }
return {}; return {};
} }
@ -194,7 +197,7 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator(
// Returns the index of the Dequantize op_code. // Returns the index of the Dequantize op_code.
// If a Dequantize op_code doesn't exist, adds it and returns its index. // If a Dequantize op_code doesn't exist, adds it and returns its index.
int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) { int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) {
for (size_t i = 0; i < model->operator_codes.size(); ++i) { for (int i = 0; i < model->operator_codes.size(); ++i) {
if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) { if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) {
return i; return i;
} }
@ -232,7 +235,7 @@ void MakeTensor(const string& name, const std::vector<int32_t>& shape,
// Updates operator code versions for the operators with INT8 inputs. // Updates operator code versions for the operators with INT8 inputs.
void UpdateInt8OperatorVersions(ModelT* model) { void UpdateInt8OperatorVersions(ModelT* model) {
for (size_t i = 0; i < model->operator_codes.size(); ++i) { for (int i = 0; i < model->operator_codes.size(); ++i) {
const BuiltinOperator& op_code = model->operator_codes[i]->builtin_code; const BuiltinOperator& op_code = model->operator_codes[i]->builtin_code;
if (op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF || if (op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
op_code == BuiltinOperator_EMBEDDING_LOOKUP || op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
@ -250,6 +253,57 @@ void UpdateInt8OperatorVersions(ModelT* model) {
} }
} }
// Returns true if the op in consumer_op_infos can pass through quantization.
bool IsQuantizationPassThroughOps(
const ModelT* model, const std::vector<ConsumerOpInfo>& consumer_op_infos) {
if (consumer_op_infos.size() != 1) {
return false;
}
const OperatorT* consumer_op = consumer_op_infos.front().op;
const BuiltinOperator op_code =
model->operator_codes[consumer_op->opcode_index]->builtin_code;
return op_code == BuiltinOperator_GATHER;
}
// Copies quantization parameters from input to output and returns consumers of
// the output tensor as a tuple with values:
// - index of the output tensor
// - pointer to the output tensor
// - vector of consumers ops.
std::tuple<int32_t, TensorT*, std::vector<ConsumerOpInfo>>
PassQuantizationAndGetConsumers(
const ModelT* model, const SubGraphT* subgraph,
const std::vector<ConsumerOpInfo>& consumer_op_infos) {
const OperatorT* op = consumer_op_infos.front().op;
const BuiltinOperator op_code =
model->operator_codes[op->opcode_index]->builtin_code;
if (op->outputs.size() != 1) {
LOG(ERROR)
<< "An op that passes quantization has more than one quantized output";
return std::make_tuple(-1, nullptr, std::vector<ConsumerOpInfo>());
}
const int32_t output_tensor_idx = op->outputs.front();
const auto input_idx = GetWeightInputIndices(op_code);
if (input_idx.size() != 1) {
LOG(ERROR)
<< "An op that passes quantization has more than one quantized input";
return std::make_tuple(-1, nullptr, std::vector<ConsumerOpInfo>());
}
const int32_t input_tensor_idx = op->inputs[input_idx.front()];
// Propagate quantization params.
const TensorT* input_tensor = subgraph->tensors[input_tensor_idx].get();
TensorT* output_tensor = subgraph->tensors[output_tensor_idx].get();
if (!output_tensor->quantization) {
output_tensor->quantization = absl::make_unique<QuantizationParametersT>();
}
*output_tensor->quantization = *input_tensor->quantization;
output_tensor->type = TensorType_INT8;
return std::make_tuple(
output_tensor_idx, output_tensor,
GetTensorConsumers(model, subgraph, output_tensor_idx));
}
TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder, TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
const Model* input_model, const Model* input_model,
bool use_hybrid_evaluation, bool use_hybrid_evaluation,
@ -268,7 +322,7 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
std::vector<std::unique_ptr<OperatorT>> new_operators; std::vector<std::unique_ptr<OperatorT>> new_operators;
std::unordered_map<int32_t, TensorT*> tensor_map; std::unordered_map<int32_t, TensorT*> tensor_map;
for (size_t i = 0; i < subgraph->operators.size(); ++i) { for (int i = 0; i < subgraph->operators.size(); ++i) {
OperatorT* op = subgraph->operators[i].get(); OperatorT* op = subgraph->operators[i].get();
TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator( TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator(
model.get(), op, weights_min_num_elements, &tensor_map)); model.get(), op, weights_min_num_elements, &tensor_map));
@ -285,10 +339,19 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
// Examine the tensor consumers to determine which require dequantize ops. // Examine the tensor consumers to determine which require dequantize ops.
for (const auto& tensor_pair : tensor_map) { for (const auto& tensor_pair : tensor_map) {
const int32_t tensor_idx = tensor_pair.first; int32_t tensor_idx = tensor_pair.first;
TensorT* tensor = tensor_pair.second; TensorT* tensor = tensor_pair.second;
std::vector<ConsumerOpInfo> consumer_op_infos = std::vector<ConsumerOpInfo> consumer_op_infos =
GetTensorConsumers(model.get(), subgraph, tensor_idx); GetTensorConsumers(model.get(), subgraph, tensor_idx);
if (IsQuantizationPassThroughOps(model.get(), consumer_op_infos)) {
std::tie(tensor_idx, tensor, consumer_op_infos) =
PassQuantizationAndGetConsumers(model.get(), subgraph,
consumer_op_infos);
if (tensor_idx < 0) {
// Error message is already logged by PassQuantizationAndGetConsumers.
return kTfLiteError;
}
}
std::vector<ConsumerOpInfo> dequant_op_infos; // Ops that need dequants. std::vector<ConsumerOpInfo> dequant_op_infos; // Ops that need dequants.
for (ConsumerOpInfo& consumer_op_info : consumer_op_infos) { for (ConsumerOpInfo& consumer_op_info : consumer_op_infos) {
@ -307,8 +370,18 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
} }
} }
// If no ops require dequant, we are done for this tensor. // Check that this tensor is an output tensor.
if (dequant_op_infos.empty()) { int32_t output_index = -1;
for (int32_t i = 0; i < subgraph->outputs.size(); ++i) {
if (subgraph->outputs[i] == tensor_idx) {
output_index = i;
break;
}
}
// If no ops require dequant and it is not output, we are done for this
// tensor.
if (dequant_op_infos.empty() && output_index < 0) {
continue; continue;
} }
@ -328,12 +401,16 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
// Update the op_input of all the ops that need the created dequantize // Update the op_input of all the ops that need the created dequantize
// operation. // operation.
int32_t min_op_idx = 0; int32_t min_op_idx = subgraph->operators.size();
for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) { for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) {
dequant_op_info.op->inputs[dequant_op_info.op_input_idx] = dequant_op_info.op->inputs[dequant_op_info.op_input_idx] =
dequantize_output_idx; dequantize_output_idx;
min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx); min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx);
} }
// Update output name.
if (output_index >= 0) {
subgraph->outputs[output_index] = dequantize_output_idx;
}
// Insert the newly created Dequantize operation before the earliest // Insert the newly created Dequantize operation before the earliest
// consumer, since TFLite requires operators to be topo-sorted. // consumer, since TFLite requires operators to be topo-sorted.

View File

@ -48,6 +48,12 @@ std::unique_ptr<FlatBufferModel> ReadSharedWeightsTestModel() {
return FlatBufferModel::BuildFromFile(model_path.c_str()); return FlatBufferModel::BuildFromFile(model_path.c_str());
} }
std::unique_ptr<FlatBufferModel> ReadGatherTestModel() {
auto model_path = tensorflow::io::JoinPath(*g_test_model_dir,
internal::kQuantizedWithGather);
return FlatBufferModel::BuildFromFile(model_path.c_str());
}
template <typename T> template <typename T>
std::vector<T> GetAsVector(const flatbuffers::Vector<T>* vec) { std::vector<T> GetAsVector(const flatbuffers::Vector<T>* vec) {
return std::vector<T>(vec->begin(), vec->end()); return std::vector<T>(vec->begin(), vec->end());
@ -67,6 +73,11 @@ class QuantizeWeightsTest : public testing::Test {
model_ = input_model_->GetModel(); model_ = input_model_->GetModel();
} }
void LoadGatherTestModel() {
input_model_ = ReadGatherTestModel();
model_ = input_model_->GetModel();
}
std::unique_ptr<FlatBufferModel> input_model_; std::unique_ptr<FlatBufferModel> input_model_;
const Model* model_; const Model* model_;
@ -334,6 +345,34 @@ TEST_F(QuantizeWeightsTest, SharedWeights_Dequantize) {
EXPECT_EQ(num_conv_ops, 2); EXPECT_EQ(num_conv_ops, 2);
} }
TEST_F(QuantizeWeightsTest, VerifyGatherQuantization) {
LoadGatherTestModel();
flatbuffers::FlatBufferBuilder builder;
auto status = QuantizeWeights(&builder, model_, 0);
EXPECT_EQ(status, kTfLiteOk);
const uint8_t* buffer = builder.GetBufferPointer();
const Model* output_model = GetModel(buffer);
ASSERT_TRUE(output_model);
ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
++subgraph_idx) {
const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
const auto op = quantized_graph->operators()->Get(i);
const uint32_t op_code_idx = op->opcode_index();
const auto op_code =
output_model->operator_codes()->Get(op_code_idx)->builtin_code();
if (op_code == BuiltinOperator_GATHER) {
uint32_t input_tensor_index = op->inputs()->Get(0);
const auto weights_tensor =
quantized_graph->tensors()->Get(input_tensor_index);
EXPECT_EQ(weights_tensor->type(), TensorType_INT8);
}
}
}
}
} // namespace } // namespace
} // namespace optimize } // namespace optimize
} // namespace tflite } // namespace tflite

View File

@ -32,8 +32,8 @@ const char* kSingleAvgPoolModelMinMinus5MaxPlus5 =
"single_avg_pool_min_minus_5_max_plus_5.bin"; "single_avg_pool_min_minus_5_max_plus_5.bin";
const char* kModelWithSharedWeights = "weight_shared_between_convs.bin"; const char* kModelWithSharedWeights = "weight_shared_between_convs.bin";
const char* kMultiInputAddWithReshape = "multi_input_add_reshape.bin"; const char* kMultiInputAddWithReshape = "multi_input_add_reshape.bin";
const char* kQuantizedWithGather = "quantized_with_gather.bin";
const char* kConstInputAddModel = "add_with_const_input.bin"; const char* kConstInputAddModel = "add_with_const_input.bin";

View File

@ -49,6 +49,9 @@ extern const char* kModelWithSharedWeights;
// Test model with Add followed by a reshape. Model has 2 inputs for add. // Test model with Add followed by a reshape. Model has 2 inputs for add.
extern const char* kMultiInputAddWithReshape; extern const char* kMultiInputAddWithReshape;
// Test gather operation with quantized input.
extern const char* kQuantizedWithGather;
// Test model with a tf.constant input to tf.add. Model has 2 inputs one // Test model with a tf.constant input to tf.add. Model has 2 inputs one
// constant and other placeholder. // constant and other placeholder.
extern const char* kConstInputAddModel; extern const char* kConstInputAddModel;

View File

@ -25,3 +25,6 @@ This directory contains test models for testing quantization.
A floating point model with two convs that have a use the same weight tensor. A floating point model with two convs that have a use the same weight tensor.
* `multi_input_add_reshape.bin` \ * `multi_input_add_reshape.bin` \
A floating point model with two inputs with an add followed by a reshape. A floating point model with two inputs with an add followed by a reshape.
* `quantized_with_gather.bin` \
A floating point model with an input with a gather, modeling a situation
of mapping categorical input to embeddings.

Binary file not shown.