[tflite] Test for kTfLiteOptionalTensor
in GetInput
.
`GetInput`, `GetVariableInput` and `GetOutput` all fail to check for the case where `node->inputs->data[index]` is the special `kTfLiteOptionalTensor` value (-1) which then causes `context->tensors[node->inputs->data[index]]` to read from invalid memory location. This fix makes `GetInput` and related return `nullptr` in those cases, asking the caller to check for `nullptr`. This is better than having `GetOptionalInputTensor` and `GetOptionalOutputTensor` (does not exist but could be added) as using the patched `GetInput` in error would be caught by a sanitizer test in the default optimized build (due to the `-fsanitize=null` option). PiperOrigin-RevId: 332512190 Change-Id: Iabca54da2f2de02b6ece3c38b54f76d4277d689e
This commit is contained in:
parent
00c7ed7ce8
commit
42ed6ac868
@ -30,27 +30,49 @@ inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
|
||||
}
|
||||
inline const TfLiteTensor* GetInput(const TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
return &context->tensors[node->inputs->data[index]];
|
||||
const int tensor_index = node->inputs->data[index];
|
||||
if (tensor_index < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return &context->tensors[tensor_index];
|
||||
}
|
||||
// Note: You must check if result is not null:
|
||||
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
|
||||
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
||||
inline TfLiteTensor* GetVariableInput(TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
|
||||
const int tensor_index = node->inputs->data[index];
|
||||
if (tensor_index < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
TfLiteTensor* tensor = &context->tensors[tensor_index];
|
||||
>>>>>>> d8f8236c29 ([tflite] Test for `kTfLiteOptionalTensor` in `GetInput`.)
|
||||
return (tensor->is_variable) ? tensor : nullptr;
|
||||
}
|
||||
inline TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
|
||||
int index) {
|
||||
return &context->tensors[node->outputs->data[index]];
|
||||
const int tensor_index = node->outputs->data[index];
|
||||
if (tensor_index < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return &context->tensors[tensor_index];
|
||||
}
|
||||
inline TfLiteTensor* GetTemporary(TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
return &context->tensors[node->temporaries->data[index]];
|
||||
const int tensor_index = node->temporaries->data[index];
|
||||
if (tensor_index < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return &context->tensors[tensor_index];
|
||||
}
|
||||
|
||||
inline const TfLiteTensor* GetIntermediates(TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
return &context->tensors[node->intermediates->data[index]];
|
||||
const int tensor_index = node->intermediates->data[index];
|
||||
if (tensor_index < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return &context->tensors[tensor_index];
|
||||
}
|
||||
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
|
||||
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
|
||||
|
Loading…
Reference in New Issue
Block a user