Add a safeguard for tensor removal. If a tensor to be removed is at the beginning of the tensor list, keep it.

PiperOrigin-RevId: 314421491
Change-Id: I0992fc26301ac6803e6a11b2cdd6dad3de1fc573
This commit is contained in:
Jian Li 2020-06-02 15:59:04 -07:00 committed by TensorFlower Gardener
parent eb9e0c1623
commit 408cb226f4

View File

@ -209,7 +209,8 @@ TfLiteStatus SetOutputTypeToUINT8(ModelT* model,
}
TfLiteStatus RemoveInputTensor(ModelT* model,
const std::vector<TensorOpTensor>& inputs) {
const std::vector<TensorOpTensor>& inputs,
int32 original_number_tensors) {
// Sanity check to make sure that erase start from the end.
int last_op_index = std::numeric_limits<int32_t>::max();
int last_tensor_index = std::numeric_limits<int32_t>::max();
@ -224,7 +225,9 @@ TfLiteStatus RemoveInputTensor(ModelT* model,
SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
TFLITE_DCHECK(tot.input_index < subgraph->tensors.size());
TFLITE_DCHECK(tot.op_index < subgraph->operators.size());
subgraph->tensors.erase(subgraph->tensors.begin() + tot.input_index);
if (tot.input_index >= original_number_tensors) {
subgraph->tensors.erase(subgraph->tensors.begin() + tot.input_index);
}
subgraph->operators.erase(subgraph->operators.begin() + tot.op_index);
subgraph->inputs[tot.model_index] = tot.output_index;
}
@ -232,7 +235,8 @@ TfLiteStatus RemoveInputTensor(ModelT* model,
}
TfLiteStatus RemoveOutputTensor(ModelT* model,
const std::vector<TensorOpTensor>& outputs) {
const std::vector<TensorOpTensor>& outputs,
int32 original_number_tensors) {
// Sanity check to make sure that erase start from the end.
int last_op_index = std::numeric_limits<int32_t>::max();
int last_tensor_index = std::numeric_limits<int32_t>::max();
@ -247,7 +251,9 @@ TfLiteStatus RemoveOutputTensor(ModelT* model,
SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
TFLITE_DCHECK(tot.output_index < subgraph->tensors.size());
TFLITE_DCHECK(tot.op_index < subgraph->operators.size());
subgraph->tensors.erase(subgraph->tensors.begin() + tot.output_index);
if (tot.output_index >= original_number_tensors) {
subgraph->tensors.erase(subgraph->tensors.begin() + tot.output_index);
}
subgraph->operators.erase(subgraph->operators.begin() + tot.op_index);
subgraph->outputs[tot.model_index] = tot.input_index;
}
@ -282,21 +288,29 @@ std::unique_ptr<tflite::ModelT> CreateMutableModelFromFile(
return copied_model;
}
int GetOriginalNumberOfTensors(ModelT* model, ErrorReporter* error_reporter) {
std::vector<TensorOpTensor> outputs = GetOutputTensors(model, error_reporter);
std::vector<TensorOpTensor> inputs = GetInputTensors(model, error_reporter);
return model->subgraphs[0]->tensors.size() - outputs.size() - inputs.size();
}
} // namespace
TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder,
ModelT* model, const TensorType& input_type,
const TensorType& output_type) {
tflite::StderrReporter error_reporter;
const int original_number_tensors =
GetOriginalNumberOfTensors(model, &error_reporter);
// Find float tensors that are model output and is consumed by a float to int8
// quantize Op.
// Do output first since the tensors are added into input first.,
tflite::StderrReporter error_reporter;
std::vector<TensorOpTensor> outputs =
GetOutputTensors(model, &error_reporter);
if (output_type == TensorType_UINT8) {
SetOutputTypeToUINT8(model, outputs);
} else if (output_type == TensorType_INT8) {
RemoveOutputTensor(model, outputs);
RemoveOutputTensor(model, outputs, original_number_tensors);
} else {
return kTfLiteError;
}
@ -307,7 +321,7 @@ TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder,
if (input_type == TensorType_UINT8) {
SetInputTypeToUINT8(model, inputs);
} else if (input_type == TensorType_INT8) {
RemoveInputTensor(model, inputs);
RemoveInputTensor(model, inputs, original_number_tensors);
} else {
return kTfLiteError;
}