Add check for the content of traversal order.
PiperOrigin-RevId: 288337584 Change-Id: I28c3cb3ffda59ecefe5bd92766e6b8d74c773d69
This commit is contained in:
parent
a4b4a8d251
commit
d6e63d2151
|
@ -178,18 +178,42 @@ absl::optional<uint64_t> VerifyAndCountSparseElements(const Tensor& tensor) {
|
|||
}
|
||||
|
||||
const int total_dims = sparsity->traversal_order()->size();
|
||||
const int original_rank = tensor.shape()->size();
|
||||
|
||||
if (total_dims < tensor.shape()->size() ||
|
||||
if (total_dims < original_rank ||
|
||||
sparsity->dim_metadata()->size() != total_dims) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
const int block_rank = total_dims - tensor.shape()->size();
|
||||
const int block_rank = total_dims - original_rank;
|
||||
if (block_rank > 0 && (sparsity->block_map() == nullptr ||
|
||||
sparsity->block_map()->size() != block_rank)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// For a n-dimensional tensor (d0, ..., dn-1) with k-dimensional block (dn,
|
||||
// ..., dn+k-1), the first n elements in the traversal order should be a
|
||||
// permutation of (d0, ..., dn-1), and the last k elements should be a
|
||||
// permutation of (dn, ..., dn+k-1).
|
||||
std::vector<int> traversal_order(total_dims);
|
||||
for (int i = 0; i < total_dims; i++) {
|
||||
traversal_order[i] = sparsity->traversal_order()->Get(i);
|
||||
}
|
||||
|
||||
std::sort(traversal_order.begin(), traversal_order.begin() + original_rank);
|
||||
for (int i = 0; i < original_rank; i++) {
|
||||
if (traversal_order[i] != i) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(traversal_order.begin() + original_rank, traversal_order.end());
|
||||
for (int i = original_rank; i < total_dims; i++) {
|
||||
if (traversal_order[i] != i) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
// For a n-dimensional tensor (d0, ..., dn-1) with k-dimensional block (dn,
|
||||
// ..., dn+k-1), the expanded_dim_sizes holds the size of each dimension in
|
||||
// the order of (d0, ..., dn-1, dn, ..., dn+k-1), not the traversal order.
|
||||
|
@ -197,7 +221,6 @@ absl::optional<uint64_t> VerifyAndCountSparseElements(const Tensor& tensor) {
|
|||
// 2}.
|
||||
std::vector<int> expanded_dim_sizes;
|
||||
expanded_dim_sizes.resize(total_dims);
|
||||
const int original_rank = tensor.shape()->size();
|
||||
// First go through the original tensor dimensions, populate their sizes.
|
||||
for (int i = 0; i < original_rank; i++) {
|
||||
expanded_dim_sizes[i] = tensor.shape()->Get(i);
|
||||
|
|
|
@ -654,6 +654,31 @@ TEST(VerifyModel, InvalidSparseTensorInvalidBuffer) {
|
|||
"requires 12 bytes, but is allocated with 8 bytes buffer"));
|
||||
}
|
||||
|
||||
TEST(VerifyModel, InvalidSparseTensorInvalidTraversalOrder) {
|
||||
const auto model = FlatBufferModel::BuildFromFile(kSparseTensorTestModel);
|
||||
ASSERT_TRUE(model);
|
||||
|
||||
std::unique_ptr<ModelT> scoped_model;
|
||||
scoped_model.reset(model->GetModel()->UnPack());
|
||||
|
||||
auto* tensor = scoped_model->subgraphs[0]->tensors[0].get();
|
||||
// Valid dimensions are (0, 1, 2, 3) in this test model.
|
||||
tensor->sparsity->traversal_order[0] = 10;
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto model_ = Model::Pack(builder, scoped_model.get());
|
||||
|
||||
::tflite::FinishModelBuffer(builder, model_);
|
||||
MockErrorReporter mock_reporter;
|
||||
MutableOpResolver resolver;
|
||||
TfLiteRegistration fake_op;
|
||||
resolver.AddCustom("FakeOp", &fake_op);
|
||||
ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), resolver,
|
||||
&mock_reporter));
|
||||
EXPECT_THAT(mock_reporter.GetAsString(),
|
||||
::testing::ContainsRegex("invalid sparsity parameters"));
|
||||
}
|
||||
|
||||
TEST(VerifyModel, ValidSparseTensorBCSC) {
|
||||
const auto model = FlatBufferModel::BuildFromFile(kSparseTensorTestModel);
|
||||
ASSERT_TRUE(model);
|
||||
|
|
Loading…
Reference in New Issue