[tflite] Test for kTfLiteOptionalTensor
in GetInput
.
`GetInput`, `GetVariableInput` and `GetOutput` all fail to check for the case where `node->inputs->data[index]` is the special `kTfLiteOptionalTensor` value (-1) which then causes `context->tensors[node->inputs->data[index]]` to read from invalid memory location. This fix makes `GetInput` and related return `nullptr` in those cases, asking the caller to check for `nullptr`. This is better than having `GetOptionalInputTensor` and `GetOptionalOutputTensor` (does not exist but could be added) as using the patched `GetInput` in error would be caught by a sanitizer test in the default optimized build (due to the `-fsanitize=null` option). PiperOrigin-RevId: 332512190 Change-Id: Iabca54da2f2de02b6ece3c38b54f76d4277d689e
This commit is contained in:
parent
204945b19e
commit
46d5b08525
tensorflow/lite/kernels
@ -32,11 +32,17 @@ namespace {
|
||||
|
||||
inline TfLiteTensor* GetMutableInput(const TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
if (context->tensors != nullptr) {
|
||||
return &context->tensors[node->inputs->data[index]];
|
||||
} else {
|
||||
return context->GetTensor(context, node->inputs->data[index]);
|
||||
if (index >= 0 && index < node->inputs->size) {
|
||||
const int tensor_index = node->inputs->data[index];
|
||||
if (tensor_index != kTfLiteOptionalTensor) {
|
||||
if (context->tensors != nullptr) {
|
||||
return &context->tensors[tensor_index];
|
||||
} else {
|
||||
return context->GetTensor(context, tensor_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // anonymous namespace.
|
||||
@ -54,11 +60,17 @@ TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
|
||||
|
||||
TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
|
||||
int index) {
|
||||
if (context->tensors != nullptr) {
|
||||
return &context->tensors[node->outputs->data[index]];
|
||||
} else {
|
||||
return context->GetTensor(context, node->outputs->data[index]);
|
||||
if (index >= 0 && index < node->outputs->size) {
|
||||
const int tensor_index = node->outputs->data[index];
|
||||
if (tensor_index != kTfLiteOptionalTensor) {
|
||||
if (context->tensors != nullptr) {
|
||||
return &context->tensors[tensor_index];
|
||||
} else {
|
||||
return context->GetTensor(context, tensor_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
|
||||
|
@ -29,18 +29,46 @@ namespace tflite {
|
||||
// benchmark_model for MobileNet + MobileBERT is unaffected. If such a change is
|
||||
// made, move the newly non-inlined function declarations to the top of this
|
||||
// header file.
|
||||
|
||||
// Note: You must check if result is not null:
|
||||
//
|
||||
// TfLiteTensor* my_tensor = GetInput(context, node, kMyTensorIdx);
|
||||
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
||||
//
|
||||
// This is because the index might point to the optional tensor constant
|
||||
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
|
||||
const TfLiteTensor* GetInput(const TfLiteContext* context,
|
||||
const TfLiteNode* node, int index);
|
||||
|
||||
// Note: You must check if result is not null:
|
||||
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
|
||||
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
||||
//
|
||||
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
|
||||
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
||||
//
|
||||
// This is because the index might point to the optional tensor constant
|
||||
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
|
||||
TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
|
||||
int index);
|
||||
|
||||
// Note: You must check if result is not null:
|
||||
//
|
||||
// TfLiteTensor* my_tensor = GetOutput(context, node, kMyTensorIdx);
|
||||
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
||||
//
|
||||
// This is because the index might point to the optional tensor constant
|
||||
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
|
||||
TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
|
||||
int index);
|
||||
|
||||
// Note: You must check if result is not null:
|
||||
//
|
||||
// TfLiteTensor* my_tensor = GetOptionalInputTensor(context, node, kIdx);
|
||||
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
||||
//
|
||||
// This is because the index might point to the optional tensor constant
|
||||
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
|
||||
//
|
||||
// Deprecated. GetInput has the same functionality.
|
||||
const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
|
||||
const TfLiteNode* node, int index);
|
||||
|
||||
@ -50,14 +78,46 @@ inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
|
||||
}
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
// Note: You must check if result is not null:
|
||||
//
|
||||
// TfLiteTensor* my_tensor = GetTemporary(context, node, kMyTensorIdx);
|
||||
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
||||
//
|
||||
// This is because the index might point to the optional tensor constant
|
||||
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
|
||||
inline TfLiteTensor* GetTemporary(TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
return &context->tensors[node->temporaries->data[index]];
|
||||
if (index >= 0 && index < node->temporaries->size) {
|
||||
const int tensor_index = node->temporaries->data[index];
|
||||
if (tensor_index != kTfLiteOptionalTensor) {
|
||||
if (context->tensors != nullptr) {
|
||||
return &context->tensors[tensor_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Note: You must check if result is not null:
|
||||
//
|
||||
// TfLiteTensor* my_tensor = GetIntermediates(context, node, kMyTensorIdx);
|
||||
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
||||
//
|
||||
// This is because the index might point to the optional tensor constant
|
||||
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
|
||||
inline const TfLiteTensor* GetIntermediates(TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
return &context->tensors[node->intermediates->data[index]];
|
||||
if (index >= 0 && index < node->intermediates->size) {
|
||||
const int tensor_index = node->intermediates->data[index];
|
||||
if (tensor_index != kTfLiteOptionalTensor) {
|
||||
if (context->tensors != nullptr) {
|
||||
return &context->tensors[tensor_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline int NumIntermediates(const TfLiteNode* node) {
|
||||
return node->intermediates->size;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user