Add a safeguard for tensor removal. If a tensor to be removed is at the beginning of the tensor list, keep it.
PiperOrigin-RevId: 314421491 Change-Id: I0992fc26301ac6803e6a11b2cdd6dad3de1fc573
This commit is contained in:
parent
eb9e0c1623
commit
408cb226f4
@ -209,7 +209,8 @@ TfLiteStatus SetOutputTypeToUINT8(ModelT* model,
|
||||
}
|
||||
|
||||
TfLiteStatus RemoveInputTensor(ModelT* model,
|
||||
const std::vector<TensorOpTensor>& inputs) {
|
||||
const std::vector<TensorOpTensor>& inputs,
|
||||
int32 original_number_tensors) {
|
||||
// Sanity check to make sure that erase start from the end.
|
||||
int last_op_index = std::numeric_limits<int32_t>::max();
|
||||
int last_tensor_index = std::numeric_limits<int32_t>::max();
|
||||
@ -224,7 +225,9 @@ TfLiteStatus RemoveInputTensor(ModelT* model,
|
||||
SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
|
||||
TFLITE_DCHECK(tot.input_index < subgraph->tensors.size());
|
||||
TFLITE_DCHECK(tot.op_index < subgraph->operators.size());
|
||||
subgraph->tensors.erase(subgraph->tensors.begin() + tot.input_index);
|
||||
if (tot.input_index >= original_number_tensors) {
|
||||
subgraph->tensors.erase(subgraph->tensors.begin() + tot.input_index);
|
||||
}
|
||||
subgraph->operators.erase(subgraph->operators.begin() + tot.op_index);
|
||||
subgraph->inputs[tot.model_index] = tot.output_index;
|
||||
}
|
||||
@ -232,7 +235,8 @@ TfLiteStatus RemoveInputTensor(ModelT* model,
|
||||
}
|
||||
|
||||
TfLiteStatus RemoveOutputTensor(ModelT* model,
|
||||
const std::vector<TensorOpTensor>& outputs) {
|
||||
const std::vector<TensorOpTensor>& outputs,
|
||||
int32 original_number_tensors) {
|
||||
// Sanity check to make sure that erase start from the end.
|
||||
int last_op_index = std::numeric_limits<int32_t>::max();
|
||||
int last_tensor_index = std::numeric_limits<int32_t>::max();
|
||||
@ -247,7 +251,9 @@ TfLiteStatus RemoveOutputTensor(ModelT* model,
|
||||
SubGraphT* subgraph = model->subgraphs.at(tot.subgraph_index).get();
|
||||
TFLITE_DCHECK(tot.output_index < subgraph->tensors.size());
|
||||
TFLITE_DCHECK(tot.op_index < subgraph->operators.size());
|
||||
subgraph->tensors.erase(subgraph->tensors.begin() + tot.output_index);
|
||||
if (tot.output_index >= original_number_tensors) {
|
||||
subgraph->tensors.erase(subgraph->tensors.begin() + tot.output_index);
|
||||
}
|
||||
subgraph->operators.erase(subgraph->operators.begin() + tot.op_index);
|
||||
subgraph->outputs[tot.model_index] = tot.input_index;
|
||||
}
|
||||
@ -282,21 +288,29 @@ std::unique_ptr<tflite::ModelT> CreateMutableModelFromFile(
|
||||
return copied_model;
|
||||
}
|
||||
|
||||
int GetOriginalNumberOfTensors(ModelT* model, ErrorReporter* error_reporter) {
|
||||
std::vector<TensorOpTensor> outputs = GetOutputTensors(model, error_reporter);
|
||||
std::vector<TensorOpTensor> inputs = GetInputTensors(model, error_reporter);
|
||||
return model->subgraphs[0]->tensors.size() - outputs.size() - inputs.size();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder,
|
||||
ModelT* model, const TensorType& input_type,
|
||||
const TensorType& output_type) {
|
||||
tflite::StderrReporter error_reporter;
|
||||
const int original_number_tensors =
|
||||
GetOriginalNumberOfTensors(model, &error_reporter);
|
||||
// Find float tensors that are model output and is consumed by a float to int8
|
||||
// quantize Op.
|
||||
// Do output first since the tensors are added into input first.,
|
||||
tflite::StderrReporter error_reporter;
|
||||
std::vector<TensorOpTensor> outputs =
|
||||
GetOutputTensors(model, &error_reporter);
|
||||
if (output_type == TensorType_UINT8) {
|
||||
SetOutputTypeToUINT8(model, outputs);
|
||||
} else if (output_type == TensorType_INT8) {
|
||||
RemoveOutputTensor(model, outputs);
|
||||
RemoveOutputTensor(model, outputs, original_number_tensors);
|
||||
} else {
|
||||
return kTfLiteError;
|
||||
}
|
||||
@ -307,7 +321,7 @@ TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder,
|
||||
if (input_type == TensorType_UINT8) {
|
||||
SetInputTypeToUINT8(model, inputs);
|
||||
} else if (input_type == TensorType_INT8) {
|
||||
RemoveInputTensor(model, inputs);
|
||||
RemoveInputTensor(model, inputs, original_number_tensors);
|
||||
} else {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user