Add check for the content of traversal order.

PiperOrigin-RevId: 288337584
Change-Id: I28c3cb3ffda59ecefe5bd92766e6b8d74c773d69
This commit is contained in:
Yunlu Li 2020-01-06 11:16:22 -08:00 committed by TensorFlower Gardener
parent a4b4a8d251
commit d6e63d2151
2 changed files with 51 additions and 3 deletions

View File

@ -178,18 +178,42 @@ absl::optional<uint64_t> VerifyAndCountSparseElements(const Tensor& tensor) {
}
const int total_dims = sparsity->traversal_order()->size();
const int original_rank = tensor.shape()->size();
if (total_dims < tensor.shape()->size() ||
if (total_dims < original_rank ||
sparsity->dim_metadata()->size() != total_dims) {
return absl::nullopt;
}
const int block_rank = total_dims - tensor.shape()->size();
const int block_rank = total_dims - original_rank;
if (block_rank > 0 && (sparsity->block_map() == nullptr ||
sparsity->block_map()->size() != block_rank)) {
return absl::nullopt;
}
// For a n-dimensional tensor (d0, ..., dn-1) with k-dimensional block (dn,
// ..., dn+k-1), the first n elements in the traversal order should be a
// permutation of (d0, ..., dn-1), and the last k elements should be a
// permutation of (dn, ..., dn+k-1).
std::vector<int> traversal_order(total_dims);
for (int i = 0; i < total_dims; i++) {
traversal_order[i] = sparsity->traversal_order()->Get(i);
}
std::sort(traversal_order.begin(), traversal_order.begin() + original_rank);
for (int i = 0; i < original_rank; i++) {
if (traversal_order[i] != i) {
return absl::nullopt;
}
}
std::sort(traversal_order.begin() + original_rank, traversal_order.end());
for (int i = original_rank; i < total_dims; i++) {
if (traversal_order[i] != i) {
return absl::nullopt;
}
}
// For a n-dimensional tensor (d0, ..., dn-1) with k-dimensional block (dn,
// ..., dn+k-1), the expanded_dim_sizes holds the size of each dimension in
// the order of (d0, ..., dn-1, dn, ..., dn+k-1), not the traversal order.
@ -197,7 +221,6 @@ absl::optional<uint64_t> VerifyAndCountSparseElements(const Tensor& tensor) {
// 2}.
std::vector<int> expanded_dim_sizes;
expanded_dim_sizes.resize(total_dims);
const int original_rank = tensor.shape()->size();
// First go through the original tensor dimensions, populate their sizes.
for (int i = 0; i < original_rank; i++) {
expanded_dim_sizes[i] = tensor.shape()->Get(i);

View File

@ -654,6 +654,31 @@ TEST(VerifyModel, InvalidSparseTensorInvalidBuffer) {
"requires 12 bytes, but is allocated with 8 bytes buffer"));
}
TEST(VerifyModel, InvalidSparseTensorInvalidTraversalOrder) {
const auto model = FlatBufferModel::BuildFromFile(kSparseTensorTestModel);
ASSERT_TRUE(model);
std::unique_ptr<ModelT> scoped_model;
scoped_model.reset(model->GetModel()->UnPack());
auto* tensor = scoped_model->subgraphs[0]->tensors[0].get();
// Valid dimensions are (0, 1, 2, 3) in this test model.
tensor->sparsity->traversal_order[0] = 10;
flatbuffers::FlatBufferBuilder builder;
auto model_ = Model::Pack(builder, scoped_model.get());
::tflite::FinishModelBuffer(builder, model_);
MockErrorReporter mock_reporter;
MutableOpResolver resolver;
TfLiteRegistration fake_op;
resolver.AddCustom("FakeOp", &fake_op);
ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), resolver,
&mock_reporter));
EXPECT_THAT(mock_reporter.GetAsString(),
::testing::ContainsRegex("invalid sparsity parameters"));
}
TEST(VerifyModel, ValidSparseTensorBCSC) {
const auto model = FlatBufferModel::BuildFromFile(kSparseTensorTestModel);
ASSERT_TRUE(model);