diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc index 6ec910c8cee..453e33ec916 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc @@ -127,8 +127,10 @@ absl::Status PopulateQuantParams(const TfLiteTensor& tensor, int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context, const TfLiteNode* tflite_node) { int number_of_runtime_inputs = 0; - for (int i = 0; i < tflite_node->inputs->size; i++) { - if (!IsConstantTensor(&context->tensors[tflite_node->inputs->data[i]])) { + for (int i = 0; i < NumInputs(tflite_node); i++) { + const TfLiteTensor* tensor = + GetOptionalInputTensor(context, tflite_node, i); + if (tensor != nullptr && !IsConstantTensor(tensor)) { number_of_runtime_inputs++; } } @@ -137,7 +139,7 @@ int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context, int GetNumberOfConstInputsForNode(const TfLiteContext* context, const TfLiteNode* tflite_node) { - return tflite_node->inputs->size - + return NumInputs(tflite_node) - GetNumberOfRuntimeInputsForNode(context, tflite_node); } diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h index 98418399561..4660631dded 100644 --- a/tensorflow/lite/kernels/kernel_util.h +++ b/tensorflow/lite/kernels/kernel_util.h @@ -72,7 +72,7 @@ inline int64_t NumElements(const TfLiteTensor* t) { return NumElements(t->dims); } -inline const TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context, +inline const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context, const TfLiteNode* node, int index) { const bool use_tensor = index < node->inputs->size &&