Fix model_utils::GetOrInsertOpCodeIndex() method
This method is used for adding Dequantize and Quantize ops in the quantization process. This method should set up the deprecated_builtin_code field after https://github.com/tensorflow/community/pull/285. PiperOrigin-RevId: 336744237 Change-Id: I6a18fb751666163ded5c8d8a2275204fcd59303c
This commit is contained in:
parent
a92123447f
commit
3c61e13624
@ -763,6 +763,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
if quant_opcode_idx == -1:
|
||||
quant_op = schema_fb.OperatorCodeT()
|
||||
quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE
|
||||
quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE
|
||||
model.operatorCodes.append(quant_op)
|
||||
quant_opcode_idx = len(model.operatorCodes) - 1
|
||||
# Change dequant op (int8 to float) to quant op (int8 to uint8)
|
||||
|
@ -42,6 +42,8 @@ int32_t GetOrInsertOpCodeIndex(ModelT* model, const BuiltinOperator& op_code,
|
||||
model->operator_codes.push_back(absl::make_unique<OperatorCodeT>());
|
||||
int op_code_idx = model->operator_codes.size() - 1;
|
||||
model->operator_codes[op_code_idx]->builtin_code = op_code;
|
||||
model->operator_codes[op_code_idx]->deprecated_builtin_code =
|
||||
ConvertBuiltinCodeToDeprecatedBuiltinCode(op_code);
|
||||
// Version 2 and onwards supports INT8 inputs.
|
||||
model->operator_codes[op_code_idx]->version = version;
|
||||
|
||||
|
@ -46,12 +46,18 @@ std::unique_ptr<ModelT> CreateQuantizedModelSingleInputOutput(
|
||||
|
||||
// Op code
|
||||
quant_op_code->builtin_code = BuiltinOperator_QUANTIZE;
|
||||
quant_op_code->deprecated_builtin_code =
|
||||
static_cast<int8_t>(BuiltinOperator_QUANTIZE);
|
||||
quant_op_code->version = 2;
|
||||
|
||||
fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED;
|
||||
fc_op_code->deprecated_builtin_code =
|
||||
static_cast<int8_t>(BuiltinOperator_FULLY_CONNECTED);
|
||||
fc_op_code->version = 2;
|
||||
|
||||
dequant_op_code->builtin_code = BuiltinOperator_DEQUANTIZE;
|
||||
dequant_op_code->deprecated_builtin_code =
|
||||
static_cast<int8_t>(BuiltinOperator_DEQUANTIZE);
|
||||
dequant_op_code->version = 2;
|
||||
|
||||
// Op.
|
||||
@ -137,12 +143,18 @@ std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput(
|
||||
|
||||
// Op code
|
||||
quant_op_code->builtin_code = BuiltinOperator_QUANTIZE;
|
||||
quant_op_code->deprecated_builtin_code =
|
||||
static_cast<int8_t>(BuiltinOperator_QUANTIZE);
|
||||
quant_op_code->version = 2;
|
||||
|
||||
fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED;
|
||||
fc_op_code->deprecated_builtin_code =
|
||||
static_cast<int8_t>(BuiltinOperator_FULLY_CONNECTED);
|
||||
fc_op_code->version = 2;
|
||||
|
||||
dequant_op_code->builtin_code = BuiltinOperator_DEQUANTIZE;
|
||||
dequant_op_code->deprecated_builtin_code =
|
||||
static_cast<int8_t>(BuiltinOperator_DEQUANTIZE);
|
||||
dequant_op_code->version = 2;
|
||||
|
||||
// Op.
|
||||
@ -258,6 +270,8 @@ std::unique_ptr<ModelT> CreateFloatModel() {
|
||||
|
||||
// Op code
|
||||
fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED;
|
||||
fc_op_code->deprecated_builtin_code =
|
||||
static_cast<int8_t>(BuiltinOperator_FULLY_CONNECTED);
|
||||
fc_op_code->version = 2;
|
||||
|
||||
// Op.
|
||||
|
@ -36,6 +36,8 @@ TEST(LstmPreprocess, Add2Tensors) {
|
||||
auto lstm_op = absl::make_unique<OperatorT>();
|
||||
|
||||
lstm_op_code->builtin_code = BuiltinOperator_LSTM;
|
||||
lstm_op_code->deprecated_builtin_code =
|
||||
static_cast<int8_t>(BuiltinOperator_LSTM);
|
||||
lstm_op_code->version = 2;
|
||||
lstm_op->opcode_index = 0;
|
||||
lstm_op->inputs = {0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1,
|
||||
|
Loading…
Reference in New Issue
Block a user