diff --git a/tensorflow/lite/tools/optimize/sparsity/format_converter.cc b/tensorflow/lite/tools/optimize/sparsity/format_converter.cc index 90d297ca42b..05cb8b32bf7 100644 --- a/tensorflow/lite/tools/optimize/sparsity/format_converter.cc +++ b/tensorflow/lite/tools/optimize/sparsity/format_converter.cc @@ -285,10 +285,12 @@ void FormatConverter::Populate(const T* src_data, std::vector indices, } const int metadata_idx = 2 * level; + const int shape_of_level = dim_metadata_[metadata_idx][0]; if (format_[level] == kTfLiteDimDense) { - for (int i = 0; i < dim_metadata_[metadata_idx][0]; i++) { + for (int i = 0; i < shape_of_level; i++) { indices[level] = i; - Populate(src_data, indices, level + 1, i, src_data_ptr); + Populate(src_data, indices, level + 1, prev_idx * shape_of_level + i, + src_data_ptr); } } else { const auto& array_segments = dim_metadata_[metadata_idx]; diff --git a/tensorflow/lite/tools/optimize/sparsity/format_converter_test.cc b/tensorflow/lite/tools/optimize/sparsity/format_converter_test.cc index 4531e7c3341..c3351810283 100644 --- a/tensorflow/lite/tools/optimize/sparsity/format_converter_test.cc +++ b/tensorflow/lite/tools/optimize/sparsity/format_converter_test.cc @@ -230,6 +230,66 @@ TEST(FormatConverterTest, SimpleTestS1S0) { EXPECT_EQ(data_back, dense_values); } +TEST(FormatConverterTest, 3DTestS0D1S2) { + const std::vector dense_values = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7}; + const std::vector dense_shape = {3, 2, 2}; + const std::vector traversal_order = {0, 1, 2}; + const std::vector format = { + kTfLiteDimSparseCSR, kTfLiteDimDense, kTfLiteDimSparseCSR}; + FormatConverter converter(dense_shape, traversal_order, format); + converter.DenseToSparse(dense_values.data()); + + const auto& dim_metadata = converter.GetDimMetadata(); + const std::vector dm0_0 = {0, 2}; + const std::vector dm0_1 = {0, 2}; + const std::vector dm1 = {2}; + const std::vector dm2_0 = {0, 1, 3, 4, 5}; + const std::vector dm2_1 = {0, 0, 1, 0, 1}; + + EXPECT_EQ(dm0_0, dim_metadata[0]); + EXPECT_EQ(dm0_1, dim_metadata[1]); + EXPECT_EQ(dm1, dim_metadata[2]); + EXPECT_EQ(dm2_0, dim_metadata[4]); + EXPECT_EQ(dm2_1, dim_metadata[5]); + + const auto& data = converter.GetData(); + const std::vector expected_data = {6, 9, 8, 5, 7}; + EXPECT_EQ(expected_data, data); + + converter.SparseToDense(expected_data.data()); + const auto& data_back = converter.GetData(); + EXPECT_EQ(data_back, dense_values); +} + +TEST(FormatConverterTest, 3DTestD0D1S2) { + const std::vector dense_values = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7}; + const std::vector dense_shape = {3, 2, 2}; + const std::vector traversal_order = {0, 1, 2}; + const std::vector format = { + kTfLiteDimDense, kTfLiteDimDense, kTfLiteDimSparseCSR}; + FormatConverter converter(dense_shape, traversal_order, format); + converter.DenseToSparse(dense_values.data()); + + const auto& dim_metadata = converter.GetDimMetadata(); + const std::vector dm0 = {3}; + const std::vector dm1 = {2}; + const std::vector dm2_0 = {0, 1, 3, 3, 3, 4, 5}; + const std::vector dm2_1 = {0, 0, 1, 0, 1}; + + EXPECT_EQ(dm0, dim_metadata[0]); + EXPECT_EQ(dm1, dim_metadata[2]); + EXPECT_EQ(dm2_0, dim_metadata[4]); + EXPECT_EQ(dm2_1, dim_metadata[5]); + + const auto& data = converter.GetData(); + const std::vector expected_data = {6, 9, 8, 5, 7}; + EXPECT_EQ(expected_data, data); + + converter.SparseToDense(expected_data.data()); + const auto& data_back = converter.GetData(); + EXPECT_EQ(data_back, dense_values); +} + TEST(FormatConverterTest, 3DTestS0S1S2) { const std::vector dense_values = {1, 7, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 4, 8, 3, 9};