[tflite] Test for kTfLiteOptionalTensor in GetInput.

`GetInput`, `GetVariableInput` and `GetOutput` all fail to check for the case where `node->inputs->data[index]` is the special `kTfLiteOptionalTensor` value (-1) which then causes `context->tensors[node->inputs->data[index]]` to read from invalid memory location.

This fix makes `GetInput` and related return `nullptr` in those cases, asking the caller to check for `nullptr`. This is better than having `GetOptionalInputTensor` and `GetOptionalOutputTensor` (does not exist but could be added) as using the patched `GetInput` in error would be caught by a sanitizer test in the default optimized build (due to the `-fsanitize=null` option).

PiperOrigin-RevId: 332512190
Change-Id: Iabca54da2f2de02b6ece3c38b54f76d4277d689e
This commit is contained in:
Mihai Maruseac 2020-09-18 13:10:41 -07:00 committed by TensorFlower Gardener
parent 204945b19e
commit 46d5b08525
2 changed files with 84 additions and 12 deletions
tensorflow/lite/kernels

View File

@ -32,11 +32,17 @@ namespace {
inline TfLiteTensor* GetMutableInput(const TfLiteContext* context,
const TfLiteNode* node, int index) {
if (context->tensors != nullptr) {
return &context->tensors[node->inputs->data[index]];
} else {
return context->GetTensor(context, node->inputs->data[index]);
if (index >= 0 && index < node->inputs->size) {
const int tensor_index = node->inputs->data[index];
if (tensor_index != kTfLiteOptionalTensor) {
if (context->tensors != nullptr) {
return &context->tensors[tensor_index];
} else {
return context->GetTensor(context, tensor_index);
}
}
}
return nullptr;
}
} // anonymous namespace.
@ -54,11 +60,17 @@ TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
int index) {
if (context->tensors != nullptr) {
return &context->tensors[node->outputs->data[index]];
} else {
return context->GetTensor(context, node->outputs->data[index]);
if (index >= 0 && index < node->outputs->size) {
const int tensor_index = node->outputs->data[index];
if (tensor_index != kTfLiteOptionalTensor) {
if (context->tensors != nullptr) {
return &context->tensors[tensor_index];
} else {
return context->GetTensor(context, tensor_index);
}
}
}
return nullptr;
}
const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,

View File

@ -29,18 +29,46 @@ namespace tflite {
// benchmark_model for MobileNet + MobileBERT is unaffected. If such a change is
// made, move the newly non-inlined function declarations to the top of this
// header file.
// Note: You must check if result is not null:
//
// TfLiteTensor* my_tensor = GetInput(context, node, kMyTensorIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
//
// This is because the index might point to the optional tensor constant
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
const TfLiteTensor* GetInput(const TfLiteContext* context,
const TfLiteNode* node, int index);
// Note: You must check if result is not null:
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
//
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
//
// This is because the index might point to the optional tensor constant
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
int index);
// Note: You must check if result is not null:
//
// TfLiteTensor* my_tensor = GetOutput(context, node, kMyTensorIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
//
// This is because the index might point to the optional tensor constant
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
int index);
// Note: You must check if result is not null:
//
// TfLiteTensor* my_tensor = GetOptionalInputTensor(context, node, kIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
//
// This is because the index might point to the optional tensor constant
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
//
// Deprecated. GetInput has the same functionality.
const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
const TfLiteNode* node, int index);
@ -50,14 +78,46 @@ inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
}
#ifndef TF_LITE_STATIC_MEMORY
// Note: You must check if result is not null:
//
// TfLiteTensor* my_tensor = GetTemporary(context, node, kMyTensorIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
//
// This is because the index might point to the optional tensor constant
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
inline TfLiteTensor* GetTemporary(TfLiteContext* context,
const TfLiteNode* node, int index) {
return &context->tensors[node->temporaries->data[index]];
if (index >= 0 && index < node->temporaries->size) {
const int tensor_index = node->temporaries->data[index];
if (tensor_index != kTfLiteOptionalTensor) {
if (context->tensors != nullptr) {
return &context->tensors[tensor_index];
}
}
}
return nullptr;
}
// Note: You must check if result is not null:
//
// TfLiteTensor* my_tensor = GetIntermediates(context, node, kMyTensorIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
//
// This is because the index might point to the optional tensor constant
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
inline const TfLiteTensor* GetIntermediates(TfLiteContext* context,
const TfLiteNode* node, int index) {
return &context->tensors[node->intermediates->data[index]];
if (index >= 0 && index < node->intermediates->size) {
const int tensor_index = node->intermediates->data[index];
if (tensor_index != kTfLiteOptionalTensor) {
if (context->tensors != nullptr) {
return &context->tensors[tensor_index];
}
}
}
return nullptr;
}
inline int NumIntermediates(const TfLiteNode* node) {
return node->intermediates->size;
}