Fix model_utils::GetOrInsertOpCodeIndex() method

This method is used for adding Dequantize and Quantize ops in the quantization
process. This method should set up the deprecated_builtin_code field after
https://github.com/tensorflow/community/pull/285.

PiperOrigin-RevId: 336744237
Change-Id: I6a18fb751666163ded5c8d8a2275204fcd59303c
This commit is contained in:
Jaesung Chung 2020-10-12 14:30:16 -07:00 committed by TensorFlower Gardener
parent a92123447f
commit 3c61e13624
4 changed files with 19 additions and 0 deletions

View File

@ -763,6 +763,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
if quant_opcode_idx == -1: if quant_opcode_idx == -1:
quant_op = schema_fb.OperatorCodeT() quant_op = schema_fb.OperatorCodeT()
quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE
quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE
model.operatorCodes.append(quant_op) model.operatorCodes.append(quant_op)
quant_opcode_idx = len(model.operatorCodes) - 1 quant_opcode_idx = len(model.operatorCodes) - 1
# Change dequant op (int8 to float) to quant op (int8 to uint8) # Change dequant op (int8 to float) to quant op (int8 to uint8)

View File

@ -42,6 +42,8 @@ int32_t GetOrInsertOpCodeIndex(ModelT* model, const BuiltinOperator& op_code,
model->operator_codes.push_back(absl::make_unique<OperatorCodeT>()); model->operator_codes.push_back(absl::make_unique<OperatorCodeT>());
int op_code_idx = model->operator_codes.size() - 1; int op_code_idx = model->operator_codes.size() - 1;
model->operator_codes[op_code_idx]->builtin_code = op_code; model->operator_codes[op_code_idx]->builtin_code = op_code;
model->operator_codes[op_code_idx]->deprecated_builtin_code =
ConvertBuiltinCodeToDeprecatedBuiltinCode(op_code);
// Version 2 and onwards supports INT8 inputs. // Version 2 and onwards supports INT8 inputs.
model->operator_codes[op_code_idx]->version = version; model->operator_codes[op_code_idx]->version = version;

View File

@ -46,12 +46,18 @@ std::unique_ptr<ModelT> CreateQuantizedModelSingleInputOutput(
// Op code // Op code
quant_op_code->builtin_code = BuiltinOperator_QUANTIZE; quant_op_code->builtin_code = BuiltinOperator_QUANTIZE;
quant_op_code->deprecated_builtin_code =
static_cast<int8_t>(BuiltinOperator_QUANTIZE);
quant_op_code->version = 2; quant_op_code->version = 2;
fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED; fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED;
fc_op_code->deprecated_builtin_code =
static_cast<int8_t>(BuiltinOperator_FULLY_CONNECTED);
fc_op_code->version = 2; fc_op_code->version = 2;
dequant_op_code->builtin_code = BuiltinOperator_DEQUANTIZE; dequant_op_code->builtin_code = BuiltinOperator_DEQUANTIZE;
dequant_op_code->deprecated_builtin_code =
static_cast<int8_t>(BuiltinOperator_DEQUANTIZE);
dequant_op_code->version = 2; dequant_op_code->version = 2;
// Op. // Op.
@ -137,12 +143,18 @@ std::unique_ptr<ModelT> CreateQuantizedModelMultipleInputOutput(
// Op code // Op code
quant_op_code->builtin_code = BuiltinOperator_QUANTIZE; quant_op_code->builtin_code = BuiltinOperator_QUANTIZE;
quant_op_code->deprecated_builtin_code =
static_cast<int8_t>(BuiltinOperator_QUANTIZE);
quant_op_code->version = 2; quant_op_code->version = 2;
fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED; fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED;
fc_op_code->deprecated_builtin_code =
static_cast<int8_t>(BuiltinOperator_FULLY_CONNECTED);
fc_op_code->version = 2; fc_op_code->version = 2;
dequant_op_code->builtin_code = BuiltinOperator_DEQUANTIZE; dequant_op_code->builtin_code = BuiltinOperator_DEQUANTIZE;
dequant_op_code->deprecated_builtin_code =
static_cast<int8_t>(BuiltinOperator_DEQUANTIZE);
dequant_op_code->version = 2; dequant_op_code->version = 2;
// Op. // Op.
@ -258,6 +270,8 @@ std::unique_ptr<ModelT> CreateFloatModel() {
// Op code // Op code
fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED; fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED;
fc_op_code->deprecated_builtin_code =
static_cast<int8_t>(BuiltinOperator_FULLY_CONNECTED);
fc_op_code->version = 2; fc_op_code->version = 2;
// Op. // Op.

View File

@ -36,6 +36,8 @@ TEST(LstmPreprocess, Add2Tensors) {
auto lstm_op = absl::make_unique<OperatorT>(); auto lstm_op = absl::make_unique<OperatorT>();
lstm_op_code->builtin_code = BuiltinOperator_LSTM; lstm_op_code->builtin_code = BuiltinOperator_LSTM;
lstm_op_code->deprecated_builtin_code =
static_cast<int8_t>(BuiltinOperator_LSTM);
lstm_op_code->version = 2; lstm_op_code->version = 2;
lstm_op->opcode_index = 0; lstm_op->opcode_index = 0;
lstm_op->inputs = {0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, lstm_op->inputs = {0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1,