Fix GetNumberOfRuntimeInputsForNode crashing on optional input tensors.

Also use NumInputs/GetOptionalInputTensor from kernel_util.h instead of directly accessing TfLiteNode members.

PiperOrigin-RevId: 320971976
Change-Id: Ieb7073dbfe644ad1f87289738ae6ff0d24e5ffad
This commit is contained in:
Robert David 2020-07-13 09:36:06 -07:00 committed by TensorFlower Gardener
parent 8057a34b4b
commit f1ee6a406c
2 changed files with 6 additions and 4 deletions

View File

@ -127,8 +127,10 @@ absl::Status PopulateQuantParams(const TfLiteTensor& tensor,
int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context,
const TfLiteNode* tflite_node) {
int number_of_runtime_inputs = 0;
for (int i = 0; i < tflite_node->inputs->size; i++) {
if (!IsConstantTensor(&context->tensors[tflite_node->inputs->data[i]])) {
for (int i = 0; i < NumInputs(tflite_node); i++) {
const TfLiteTensor* tensor =
GetOptionalInputTensor(context, tflite_node, i);
if (tensor != nullptr && !IsConstantTensor(tensor)) {
number_of_runtime_inputs++;
}
}
@ -137,7 +139,7 @@ int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context,
int GetNumberOfConstInputsForNode(const TfLiteContext* context,
const TfLiteNode* tflite_node) {
return tflite_node->inputs->size -
return NumInputs(tflite_node) -
GetNumberOfRuntimeInputsForNode(context, tflite_node);
}

View File

@ -72,7 +72,7 @@ inline int64_t NumElements(const TfLiteTensor* t) {
return NumElements(t->dims);
}
inline const TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context,
inline const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
const TfLiteNode* node,
int index) {
const bool use_tensor = index < node->inputs->size &&