diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index 772414161d8..0bbc797aebc 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -74,6 +74,7 @@ tf_cc_test( "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ + "//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin", "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin", ], diff --git a/tensorflow/lite/tools/optimize/quantize_weights.cc b/tensorflow/lite/tools/optimize/quantize_weights.cc index c8be07ec33c..20684babcfb 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights.cc @@ -60,7 +60,7 @@ std::vector GetTensorConsumers(const ModelT* model, for (size_t i = 0; i < op->inputs.size(); ++i) { if (op->inputs[i] == tensor_idx) { consumer_ops.push_back( - {op, static_cast(op_idx), static_cast(i)}); + {op, static_cast(op_idx), static_cast(i)}); } } } @@ -95,6 +95,9 @@ std::vector GetWeightInputIndices(const BuiltinOperator& op_code) { } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) { // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc return {1, 2, 4, 5}; + } else if (op_code == BuiltinOperator_GATHER) { + // https://www.tensorflow.org/code/tensorflow/lite/kernels/gather.cc + return {0}; } return {}; } @@ -194,7 +197,7 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator( // Returns the index of the Dequantize op_code. // If a Dequantize op_code doesn't exist, adds it and returns its index. int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) { - for (size_t i = 0; i < model->operator_codes.size(); ++i) { + for (int i = 0; i < model->operator_codes.size(); ++i) { if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) { return i; } @@ -232,7 +235,7 @@ void MakeTensor(const string& name, const std::vector& shape, // Updates operator code versions for the operators with INT8 inputs. void UpdateInt8OperatorVersions(ModelT* model) { - for (size_t i = 0; i < model->operator_codes.size(); ++i) { + for (int i = 0; i < model->operator_codes.size(); ++i) { const BuiltinOperator& op_code = model->operator_codes[i]->builtin_code; if (op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF || op_code == BuiltinOperator_EMBEDDING_LOOKUP || @@ -250,6 +253,57 @@ void UpdateInt8OperatorVersions(ModelT* model) { } } +// Returns true if the op in consumer_op_infos can pass through quantization. +bool IsQuantizationPassThroughOps( + const ModelT* model, const std::vector& consumer_op_infos) { + if (consumer_op_infos.size() != 1) { + return false; + } + const OperatorT* consumer_op = consumer_op_infos.front().op; + const BuiltinOperator op_code = + model->operator_codes[consumer_op->opcode_index]->builtin_code; + return op_code == BuiltinOperator_GATHER; +} + +// Copies quantization parameters from input to output and returns consumers of +// the output tensor as a tuple with values: +// - index of the output tensor +// - pointer to the output tensor +// - vector of consumers ops. +std::tuple> +PassQuantizationAndGetConsumers( + const ModelT* model, const SubGraphT* subgraph, + const std::vector& consumer_op_infos) { + const OperatorT* op = consumer_op_infos.front().op; + const BuiltinOperator op_code = + model->operator_codes[op->opcode_index]->builtin_code; + if (op->outputs.size() != 1) { + LOG(ERROR) + << "An op that passes quantization has more than one quantized output"; + return std::make_tuple(-1, nullptr, std::vector()); + } + const int32_t output_tensor_idx = op->outputs.front(); + const auto input_idx = GetWeightInputIndices(op_code); + if (input_idx.size() != 1) { + LOG(ERROR) + << "An op that passes quantization has more than one quantized input"; + return std::make_tuple(-1, nullptr, std::vector()); + } + const int32_t input_tensor_idx = op->inputs[input_idx.front()]; + + // Propagate quantization params. + const TensorT* input_tensor = subgraph->tensors[input_tensor_idx].get(); + TensorT* output_tensor = subgraph->tensors[output_tensor_idx].get(); + if (!output_tensor->quantization) { + output_tensor->quantization = absl::make_unique(); + } + *output_tensor->quantization = *input_tensor->quantization; + output_tensor->type = TensorType_INT8; + return std::make_tuple( + output_tensor_idx, output_tensor, + GetTensorConsumers(model, subgraph, output_tensor_idx)); +} + TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, bool use_hybrid_evaluation, @@ -268,7 +322,7 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder, std::vector> new_operators; std::unordered_map tensor_map; - for (size_t i = 0; i < subgraph->operators.size(); ++i) { + for (int i = 0; i < subgraph->operators.size(); ++i) { OperatorT* op = subgraph->operators[i].get(); TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator( model.get(), op, weights_min_num_elements, &tensor_map)); @@ -285,10 +339,19 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder, // Examine the tensor consumers to determine which require dequantize ops. for (const auto& tensor_pair : tensor_map) { - const int32_t tensor_idx = tensor_pair.first; + int32_t tensor_idx = tensor_pair.first; TensorT* tensor = tensor_pair.second; std::vector consumer_op_infos = GetTensorConsumers(model.get(), subgraph, tensor_idx); + if (IsQuantizationPassThroughOps(model.get(), consumer_op_infos)) { + std::tie(tensor_idx, tensor, consumer_op_infos) = + PassQuantizationAndGetConsumers(model.get(), subgraph, + consumer_op_infos); + if (tensor_idx < 0) { + // Error message is already logged by PassQuantizationAndGetConsumers. + return kTfLiteError; + } + } std::vector dequant_op_infos; // Ops that need dequants. for (ConsumerOpInfo& consumer_op_info : consumer_op_infos) { @@ -307,8 +370,18 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder, } } - // If no ops require dequant, we are done for this tensor. - if (dequant_op_infos.empty()) { + // Check that this tensor is an output tensor. + int32_t output_index = -1; + for (int32_t i = 0; i < subgraph->outputs.size(); ++i) { + if (subgraph->outputs[i] == tensor_idx) { + output_index = i; + break; + } + } + + // If no ops require dequant and it is not output, we are done for this + // tensor. + if (dequant_op_infos.empty() && output_index < 0) { continue; } @@ -328,12 +401,16 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder, // Update the op_input of all the ops that need the created dequantize // operation. - int32_t min_op_idx = 0; + int32_t min_op_idx = subgraph->operators.size(); for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) { dequant_op_info.op->inputs[dequant_op_info.op_input_idx] = dequantize_output_idx; min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx); } + // Update output name. + if (output_index >= 0) { + subgraph->outputs[output_index] = dequantize_output_idx; + } // Insert the newly created Dequantize operation before the earliest // consumer, since TFLite requires operators to be topo-sorted. diff --git a/tensorflow/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/lite/tools/optimize/quantize_weights_test.cc index a18b3bb7ffe..d93817ebc5b 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -48,6 +48,12 @@ std::unique_ptr ReadSharedWeightsTestModel() { return FlatBufferModel::BuildFromFile(model_path.c_str()); } +std::unique_ptr ReadGatherTestModel() { + auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, + internal::kQuantizedWithGather); + return FlatBufferModel::BuildFromFile(model_path.c_str()); +} + template std::vector GetAsVector(const flatbuffers::Vector* vec) { return std::vector(vec->begin(), vec->end()); @@ -67,6 +73,11 @@ class QuantizeWeightsTest : public testing::Test { model_ = input_model_->GetModel(); } + void LoadGatherTestModel() { + input_model_ = ReadGatherTestModel(); + model_ = input_model_->GetModel(); + } + std::unique_ptr input_model_; const Model* model_; @@ -334,6 +345,34 @@ TEST_F(QuantizeWeightsTest, SharedWeights_Dequantize) { EXPECT_EQ(num_conv_ops, 2); } +TEST_F(QuantizeWeightsTest, VerifyGatherQuantization) { + LoadGatherTestModel(); + flatbuffers::FlatBufferBuilder builder; + auto status = QuantizeWeights(&builder, model_, 0); + EXPECT_EQ(status, kTfLiteOk); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + const auto op_code = + output_model->operator_codes()->Get(op_code_idx)->builtin_code(); + if (op_code == BuiltinOperator_GATHER) { + uint32_t input_tensor_index = op->inputs()->Get(0); + const auto weights_tensor = + quantized_graph->tensors()->Get(input_tensor_index); + EXPECT_EQ(weights_tensor->type(), TensorType_INT8); + } + } + } +} } // namespace } // namespace optimize } // namespace tflite diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/lite/tools/optimize/test_util.cc index 255ba76a280..e520db2f6fc 100644 --- a/tensorflow/lite/tools/optimize/test_util.cc +++ b/tensorflow/lite/tools/optimize/test_util.cc @@ -32,8 +32,8 @@ const char* kSingleAvgPoolModelMinMinus5MaxPlus5 = "single_avg_pool_min_minus_5_max_plus_5.bin"; const char* kModelWithSharedWeights = "weight_shared_between_convs.bin"; - const char* kMultiInputAddWithReshape = "multi_input_add_reshape.bin"; +const char* kQuantizedWithGather = "quantized_with_gather.bin"; const char* kConstInputAddModel = "add_with_const_input.bin"; diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/lite/tools/optimize/test_util.h index 6d63d07090c..32dbe8aca96 100644 --- a/tensorflow/lite/tools/optimize/test_util.h +++ b/tensorflow/lite/tools/optimize/test_util.h @@ -49,6 +49,9 @@ extern const char* kModelWithSharedWeights; // Test model with Add followed by a reshape. Model has 2 inputs for add. extern const char* kMultiInputAddWithReshape; +// Test gather operation with quantized input. +extern const char* kQuantizedWithGather; + // Test model with a tf.constant input to tf.add. Model has 2 inputs one // constant and other placeholder. extern const char* kConstInputAddModel; diff --git a/tensorflow/lite/tools/optimize/testdata/README.md b/tensorflow/lite/tools/optimize/testdata/README.md index 0a924816f99..de178e050f7 100644 --- a/tensorflow/lite/tools/optimize/testdata/README.md +++ b/tensorflow/lite/tools/optimize/testdata/README.md @@ -25,3 +25,6 @@ This directory contains test models for testing quantization. A floating point model with two convs that have a use the same weight tensor. * `multi_input_add_reshape.bin` \ A floating point model with two inputs with an add followed by a reshape. +* `quantized_with_gather.bin` \ + A floating point model with an input with a gather, modeling a situation + of mapping categorical input to embeddings. diff --git a/tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin b/tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin new file mode 100644 index 00000000000..3ecbbd82828 Binary files /dev/null and b/tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin differ