diff --git a/tensorflow/lite/tools/verifier.cc b/tensorflow/lite/tools/verifier.cc index c16030be4e1..84275b34bb4 100644 --- a/tensorflow/lite/tools/verifier.cc +++ b/tensorflow/lite/tools/verifier.cc @@ -178,18 +178,42 @@ absl::optional VerifyAndCountSparseElements(const Tensor& tensor) { } const int total_dims = sparsity->traversal_order()->size(); + const int original_rank = tensor.shape()->size(); - if (total_dims < tensor.shape()->size() || + if (total_dims < original_rank || sparsity->dim_metadata()->size() != total_dims) { return absl::nullopt; } - const int block_rank = total_dims - tensor.shape()->size(); + const int block_rank = total_dims - original_rank; if (block_rank > 0 && (sparsity->block_map() == nullptr || sparsity->block_map()->size() != block_rank)) { return absl::nullopt; } + // For a n-dimensional tensor (d0, ..., dn-1) with k-dimensional block (dn, + // ..., dn+k-1), the first n elements in the traversal order should be a + // permutation of (d0, ..., dn-1), and the last k elements should be a + // permutation of (dn, ..., dn+k-1). + std::vector traversal_order(total_dims); + for (int i = 0; i < total_dims; i++) { + traversal_order[i] = sparsity->traversal_order()->Get(i); + } + + std::sort(traversal_order.begin(), traversal_order.begin() + original_rank); + for (int i = 0; i < original_rank; i++) { + if (traversal_order[i] != i) { + return absl::nullopt; + } + } + + std::sort(traversal_order.begin() + original_rank, traversal_order.end()); + for (int i = original_rank; i < total_dims; i++) { + if (traversal_order[i] != i) { + return absl::nullopt; + } + } + // For a n-dimensional tensor (d0, ..., dn-1) with k-dimensional block (dn, // ..., dn+k-1), the expanded_dim_sizes holds the size of each dimension in // the order of (d0, ..., dn-1, dn, ..., dn+k-1), not the traversal order. @@ -197,7 +221,6 @@ absl::optional VerifyAndCountSparseElements(const Tensor& tensor) { // 2}. std::vector expanded_dim_sizes; expanded_dim_sizes.resize(total_dims); - const int original_rank = tensor.shape()->size(); // First go through the original tensor dimensions, populate their sizes. for (int i = 0; i < original_rank; i++) { expanded_dim_sizes[i] = tensor.shape()->Get(i); diff --git a/tensorflow/lite/tools/verifier_test.cc b/tensorflow/lite/tools/verifier_test.cc index a945e980030..355ee6640c6 100644 --- a/tensorflow/lite/tools/verifier_test.cc +++ b/tensorflow/lite/tools/verifier_test.cc @@ -654,6 +654,31 @@ TEST(VerifyModel, InvalidSparseTensorInvalidBuffer) { "requires 12 bytes, but is allocated with 8 bytes buffer")); } +TEST(VerifyModel, InvalidSparseTensorInvalidTraversalOrder) { + const auto model = FlatBufferModel::BuildFromFile(kSparseTensorTestModel); + ASSERT_TRUE(model); + + std::unique_ptr scoped_model; + scoped_model.reset(model->GetModel()->UnPack()); + + auto* tensor = scoped_model->subgraphs[0]->tensors[0].get(); + // Valid dimensions are (0, 1, 2, 3) in this test model. + tensor->sparsity->traversal_order[0] = 10; + + flatbuffers::FlatBufferBuilder builder; + auto model_ = Model::Pack(builder, scoped_model.get()); + + ::tflite::FinishModelBuffer(builder, model_); + MockErrorReporter mock_reporter; + MutableOpResolver resolver; + TfLiteRegistration fake_op; + resolver.AddCustom("FakeOp", &fake_op); + ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), resolver, + &mock_reporter)); + EXPECT_THAT(mock_reporter.GetAsString(), + ::testing::ContainsRegex("invalid sparsity parameters")); +} + TEST(VerifyModel, ValidSparseTensorBCSC) { const auto model = FlatBufferModel::BuildFromFile(kSparseTensorTestModel); ASSERT_TRUE(model);