[tflite] Ensure inputs and outputs don't overlap.

If a model uses the same tensor for both an input and an output then this can result in data loss and memory corruption. This should not happen.

PiperOrigin-RevId: 332522916
Change-Id: If0905b142415a9dfceaf2d181872f2a8fb88f48a
This commit is contained in:
Mihai Maruseac 2020-09-18 14:04:39 -07:00 committed by TensorFlower Gardener
parent 1970c2158b
commit d58c96946b
6 changed files with 66 additions and 0 deletions

View File

@ -466,6 +466,7 @@ cc_test(
data = [
"testdata/0_subgraphs.bin",
"testdata/2_subgraphs.bin",
"testdata/add_shared_tensors.bin",
"testdata/empty_model.bin",
"testdata/multi_add_flex.bin",
"testdata/sparse_tensor.bin",

View File

@ -581,6 +581,33 @@ TfLiteStatus Subgraph::CheckTensorIndices(const char* label, const int* indices,
return kTfLiteOk;
}
// We have two arrays and we need to check that elements from one array don't
// show up in the other. We could sort both arrays and then iterate with two
// pointers from start to finish always increasing the smaller one but since
// these arrays are usually short (<25 elements for inputs, usually <3 for
// outputs), this might be slower than the naive approach (if arrays have size n
// and m, with n >> m ~ O(1), first approach is O(nlogn) whereas the other is
// O(n)). Plus, sorting the input and output arrays might not be something we
// want as it destroys ordering of elements.
//
// If it turns out that this is an issue, we can switch to the other algorithm.
TfLiteStatus Subgraph::CheckInputAndOutputForOverlap(const int* input_indices,
int num_inputs,
const int* output_indices,
int num_outputs) {
for (int i = 0; i < num_inputs; i++) {
for (int j = 0; j < num_outputs; j++) {
if (input_indices[i] == output_indices[j]) {
ReportError("Tensor %d is both input %d and output %d\n",
input_indices[i], i, j);
consistent_ = false;
return kTfLiteError;
}
}
}
return kTfLiteOk;
}
namespace {
// Multiply two sizes and return true if overflow occurred;
// This is based off tensorflow/overflow.h but is simpler as we already
@ -707,6 +734,16 @@ TfLiteStatus Subgraph::AddNodeWithParameters(
&context_,
CheckTensorIndices("node outputs", outputs.data(), outputs.size()));
// For builtin ops, inputs and outputs must not overlap. Custom ops must do
// this check by themselves if they don't support overlapping tensors. This
// distinction is to allow custom ops to just forward a tensor, reusing it as
// both input and output.
if (builtin_data != nullptr) {
TF_LITE_ENSURE_OK(&context_, CheckInputAndOutputForOverlap(
inputs.data(), inputs.size(),
outputs.data(), outputs.size()));
}
int new_node_index = nodes_and_registration_.size();
if (node_index) *node_index = new_node_index;
nodes_and_registration_.resize(nodes_and_registration_.size() + 1);

View File

@ -451,6 +451,15 @@ class Subgraph {
TfLiteStatus CheckTensorIndices(const char* label, const int* indices,
int length);
// Check that the input indices and the output indices don't overlap.
// This is needed because same tensor must not be used both as input and
// output for an operator.
// NOTE: this changes consistent_ to be false if indices are out of bounds.
TfLiteStatus CheckInputAndOutputForOverlap(const int* input_indices,
int num_inputs,
const int* output_indices,
int num_outputs);
// Compute the number of bytes required to represent a tensor with dimensions
// specified by the array dims (of length dims_size). Returns the status code
// and bytes.

View File

@ -438,6 +438,25 @@ TEST(BasicFlatBufferModel, TestParseModelWithSparseTensor) {
}
// TODO(b/150072943): Add malformed model with sparse tensor tests.
TEST(BasicFlatBufferModel, TestHandleMalformedModel) {
const auto model_paths = {
// These models use the same tensor as both input and ouput of a node
"tensorflow/lite/testdata/add_shared_tensors.bin",
};
for (const auto& model_path : model_paths) {
std::unique_ptr<tflite::FlatBufferModel> model =
FlatBufferModel::BuildFromFile(model_path);
ASSERT_NE(model, nullptr);
tflite::ops::builtin::BuiltinOpResolver resolver;
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> interpreter;
ASSERT_EQ(builder(&interpreter), kTfLiteOk);
ASSERT_NE(interpreter, nullptr);
ASSERT_NE(interpreter->AllocateTensors(), kTfLiteOk);
}
}
// TODO(aselle): Add tests for serialization of builtin op data types.
// These tests will occur with the evaluation tests of individual operators,

Binary file not shown.

Binary file not shown.