Fix the output type in modify_model_inference.
PiperOrigin-RevId: 307155237 Change-Id: I5020db2ee599aea9afcae35114e2a9a6fafffe75
This commit is contained in:
parent
5fcc70306d
commit
b8651b5dd2
@ -295,7 +295,7 @@ TfLiteStatus ModifyModelInterface(flatbuffers::FlatBufferBuilder* builder,
|
|||||||
GetOutputTensors(model, &error_reporter);
|
GetOutputTensors(model, &error_reporter);
|
||||||
if (output_type == TensorType_UINT8) {
|
if (output_type == TensorType_UINT8) {
|
||||||
SetOutputTypeToUINT8(model, outputs);
|
SetOutputTypeToUINT8(model, outputs);
|
||||||
} else if (input_type == TensorType_INT8) {
|
} else if (output_type == TensorType_INT8) {
|
||||||
RemoveOutputTensor(model, outputs);
|
RemoveOutputTensor(model, outputs);
|
||||||
} else {
|
} else {
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
|
@ -351,6 +351,28 @@ TEST(ModelInterface, Int8SingleInputOutput) {
|
|||||||
EXPECT_EQ(model->subgraphs[0]->outputs[0], 1);
|
EXPECT_EQ(model->subgraphs[0]->outputs[0], 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ModelInterface, MixedTypeSingleInputOutput) {
|
||||||
|
auto model = CreateModelSingleInputOutput();
|
||||||
|
|
||||||
|
// Change model type.
|
||||||
|
flatbuffers::FlatBufferBuilder builder;
|
||||||
|
EXPECT_EQ(ModifyModelInterface(&builder, model.get(), TensorType_UINT8,
|
||||||
|
TensorType_INT8),
|
||||||
|
kTfLiteOk);
|
||||||
|
|
||||||
|
// Verify results.
|
||||||
|
EXPECT_EQ(model->operator_codes.size(), 3);
|
||||||
|
EXPECT_EQ(model->subgraphs.size(), 1);
|
||||||
|
EXPECT_EQ(model->subgraphs[0]->operators.size(), 2);
|
||||||
|
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 3);
|
||||||
|
EXPECT_EQ(model->buffers.size(), 1);
|
||||||
|
|
||||||
|
EXPECT_EQ(model->subgraphs[0]->inputs.size(), 1);
|
||||||
|
EXPECT_EQ(model->subgraphs[0]->inputs[0], 2);
|
||||||
|
EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1);
|
||||||
|
EXPECT_EQ(model->subgraphs[0]->outputs[0], 1);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ModelInterface, Uint8MutipleInputOutput) {
|
TEST(ModelInterface, Uint8MutipleInputOutput) {
|
||||||
auto model = CreateModelMultipleInputOutput();
|
auto model = CreateModelMultipleInputOutput();
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user