Bugfix: TFLite If with dynamically allocated tensors.
Fixes a bug when branch subgraphs have dynamically allocated inputs/outputs. The newly added tests are failing without the fix and passing with the fix. PiperOrigin-RevId: 351181346 Change-Id: I5bdec6179c03157df9fd012018983384a6a2cb81
This commit is contained in:
parent
062e504d66
commit
24efa2683f
@ -156,6 +156,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input));
|
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input));
|
||||||
TfLiteTensor* subgraph_input =
|
TfLiteTensor* subgraph_input =
|
||||||
active_branch_subgraph.tensor(active_branch_subgraph.inputs()[i]);
|
active_branch_subgraph.tensor(active_branch_subgraph.inputs()[i]);
|
||||||
|
|
||||||
|
if (IsDynamicTensor(subgraph_input)) {
|
||||||
|
TfLiteTensorRealloc(input->bytes, subgraph_input);
|
||||||
|
}
|
||||||
|
|
||||||
TF_LITE_ENSURE_EQ(context, input->bytes, subgraph_input->bytes);
|
TF_LITE_ENSURE_EQ(context, input->bytes, subgraph_input->bytes);
|
||||||
memcpy(subgraph_input->data.raw, input->data.raw, input->bytes);
|
memcpy(subgraph_input->data.raw, input->data.raw, input->bytes);
|
||||||
}
|
}
|
||||||
@ -195,6 +200,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
|
active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
|
||||||
TfLiteTensor* output;
|
TfLiteTensor* output;
|
||||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
|
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
|
||||||
|
|
||||||
|
if (IsDynamicTensor(output)) {
|
||||||
|
TfLiteTensorRealloc(subgraph_output->bytes, output);
|
||||||
|
}
|
||||||
|
|
||||||
TF_LITE_ENSURE_EQ(context, output->bytes, subgraph_output->bytes);
|
TF_LITE_ENSURE_EQ(context, output->bytes, subgraph_output->bytes);
|
||||||
memcpy(output->data.raw, subgraph_output->data.raw, output->bytes);
|
memcpy(output->data.raw, subgraph_output->data.raw, output->bytes);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user