diff --git a/tensorflow/lite/tools/optimize/modify_model_interface.cc b/tensorflow/lite/tools/optimize/modify_model_interface.cc index 0d2441a9c58..9451483b79d 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface.cc +++ b/tensorflow/lite/tools/optimize/modify_model_interface.cc @@ -209,7 +209,8 @@ TfLiteStatus SetOutputTypeToUINT8(ModelT* model, } TfLiteStatus RemoveInputTensor(ModelT* model, - const std::vector& inputs) { + const std::vector& inputs, + int32 original_number_tensors) { // Sanity check to make sure that erase start from the end. int last_op_index = std::numeric_limits::max(); int last_tensor_index = std::numeric_limits::max(); @@ -224,7 +225,9 @@ TfLiteStatus RemoveInputTensor(ModelT* model, SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get(); TFLITE_DCHECK(tot.input_index < subgraph->tensors.size()); TFLITE_DCHECK(tot.op_index < subgraph->operators.size()); - subgraph->tensors.erase(subgraph->tensors.begin() + tot.input_index); + if (tot.input_index >= original_number_tensors) { + subgraph->tensors.erase(subgraph->tensors.begin() + tot.input_index); + } subgraph->operators.erase(subgraph->operators.begin() + tot.op_index); subgraph->inputs[tot.model_index] = tot.output_index; } @@ -232,7 +235,8 @@ TfLiteStatus RemoveInputTensor(ModelT* model, } TfLiteStatus RemoveOutputTensor(ModelT* model, - const std::vector& outputs) { + const std::vector& outputs, + int32 original_number_tensors) { // Sanity check to make sure that erase start from the end. int last_op_index = std::numeric_limits::max(); int last_tensor_index = std::numeric_limits::max(); @@ -247,7 +251,9 @@ TfLiteStatus RemoveOutputTensor(ModelT* model, SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get(); TFLITE_DCHECK(tot.output_index < subgraph->tensors.size()); TFLITE_DCHECK(tot.op_index < subgraph->operators.size()); - subgraph->tensors.erase(subgraph->tensors.begin() + tot.output_index); + if (tot.output_index >= original_number_tensors) { + subgraph->tensors.erase(subgraph->tensors.begin() + tot.output_index); + } subgraph->operators.erase(subgraph->operators.begin() + tot.op_index); subgraph->outputs[tot.model_index] = tot.input_index; } @@ -282,21 +288,29 @@ std::unique_ptr CreateMutableModelFromFile( return copied_model; } +int GetOriginalNumberOfTensors(ModelT* model, ErrorReporter* error_reporter) { + std::vector outputs = GetOutputTensors(model, error_reporter); + std::vector inputs = GetInputTensors(model, error_reporter); + return model->subgraphs[0]->tensors.size() - outputs.size() - inputs.size(); +} + } // namespace TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder, ModelT* model, const TensorType& input_type, const TensorType& output_type) { + tflite::StderrReporter error_reporter; + const int original_number_tensors = + GetOriginalNumberOfTensors(model, &error_reporter); // Find float tensors that are model output and is consumed by a float to int8 // quantize Op. // Do output first since the tensors are added into input first., - tflite::StderrReporter error_reporter; std::vector outputs = GetOutputTensors(model, &error_reporter); if (output_type == TensorType_UINT8) { SetOutputTypeToUINT8(model, outputs); } else if (output_type == TensorType_INT8) { - RemoveOutputTensor(model, outputs); + RemoveOutputTensor(model, outputs, original_number_tensors); } else { return kTfLiteError; } @@ -307,7 +321,7 @@ TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder, if (input_type == TensorType_UINT8) { SetInputTypeToUINT8(model, inputs); } else if (input_type == TensorType_INT8) { - RemoveInputTensor(model, inputs); + RemoveInputTensor(model, inputs, original_number_tensors); } else { return kTfLiteError; }