Making GetStartIndicesOfEachDenseRow produce leading zeros in case if several beginning empty rows are empty.
PiperOrigin-RevId: 307769649 Change-Id: Ia2a23b19865c834ba33f8b1d550628dadfa5e82a
This commit is contained in:
parent
a51b060d16
commit
70a06f6402
@ -79,11 +79,14 @@ std::vector<Tindices> GetStartIndicesOfEachDenseRow(
|
||||
std::vector<Tindices> segment_indices;
|
||||
const Tindices num_entries_in_sparse_tensor = indices_mat.dimension(0);
|
||||
const Tindices num_dense_rows_in_sparse_tensor =
|
||||
1 + indices_mat(num_entries_in_sparse_tensor - 1, 0) - indices_mat(0, 0);
|
||||
1 + indices_mat(num_entries_in_sparse_tensor - 1, 0);
|
||||
// Reserve an extra slot for the 0 we store in the first entry by convention.
|
||||
segment_indices.reserve(1 + num_dense_rows_in_sparse_tensor);
|
||||
segment_indices.push_back(0);
|
||||
*contains_empty_rows = false;
|
||||
for (Tindices i = 0; i < indices_mat(0, 0); ++i) {
|
||||
segment_indices.push_back(0);
|
||||
}
|
||||
*contains_empty_rows = indices_mat(0, 0) > 0;
|
||||
while (true) {
|
||||
const Tindices start_sparse_index_of_next_dense_row =
|
||||
FindNextDenseRowStartIndex<Tindices>(
|
||||
@ -127,9 +130,9 @@ std::vector<Tindices> ParseRowStartIndices(
|
||||
|
||||
template <typename Tindices>
|
||||
bool ContainsEmptyRows(const std::vector<Tindices>& row_start_indices) {
|
||||
// Skip checking the lengths of the first and last dense rows since those are
|
||||
// Skip checking the length of the last dense row since it is
|
||||
// always non-empty.
|
||||
for (size_t i = 2; i < row_start_indices.size() - 1; ++i) {
|
||||
for (size_t i = 1; i < row_start_indices.size() - 1; ++i) {
|
||||
if (row_start_indices.at(i) - row_start_indices.at(i - 1) == 0) {
|
||||
return true;
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ Tindices FindNextDenseRowStartIndex(
|
||||
// v.front() = 0, v.back() = indices_mat.dimension(0), and for i > 0,
|
||||
// v[i] - v[i-1] is the length of the ith dense row in indices_mat.
|
||||
// *contains_empty_rows = true if and only if indices_mat contains empty rows
|
||||
// (rows without values) between its first and last row.
|
||||
// (rows without values) between row 0 and the last row.
|
||||
template <typename Tindices>
|
||||
std::vector<Tindices> GetStartIndicesOfEachDenseRow(
|
||||
const typename TTypes<Tindices>::ConstMatrix& indices_mat,
|
||||
|
@ -66,8 +66,8 @@ TEST(SparseUtilsTest, GetStartIndicesOfEachDenseRow) {
|
||||
bool contains_empty_rows;
|
||||
EXPECT_TRUE(GetStartIndicesOfEachDenseRow<int64>(indices_mat,
|
||||
&contains_empty_rows) ==
|
||||
std::vector<int64>({0, 1}));
|
||||
EXPECT_FALSE(contains_empty_rows);
|
||||
std::vector<int64>({0, 0, 0, 0, 1}));
|
||||
EXPECT_TRUE(contains_empty_rows);
|
||||
}
|
||||
{
|
||||
uint32 data[] = {3, 0, 3, 0};
|
||||
@ -75,8 +75,8 @@ TEST(SparseUtilsTest, GetStartIndicesOfEachDenseRow) {
|
||||
bool contains_empty_rows;
|
||||
EXPECT_TRUE(GetStartIndicesOfEachDenseRow<uint32>(indices_mat,
|
||||
&contains_empty_rows) ==
|
||||
std::vector<uint32>({0, 2}));
|
||||
EXPECT_FALSE(contains_empty_rows);
|
||||
std::vector<uint32>({0, 0, 0, 0, 2}));
|
||||
EXPECT_TRUE(contains_empty_rows);
|
||||
}
|
||||
{
|
||||
uint16 data[] = {0, 0, 0, 0, 0, 0, 1, 0};
|
||||
@ -165,7 +165,7 @@ TEST(SparseUtilsTest, ContainsEmptyRows) {
|
||||
const auto segment_indices =
|
||||
GetStartIndicesOfEachDenseRow<int32>(indices_mat, &contains_empty_rows);
|
||||
// indices_list = {1, 1, 2, 2, 2, 3};
|
||||
EXPECT_FALSE(ContainsEmptyRows(segment_indices));
|
||||
EXPECT_TRUE(ContainsEmptyRows(segment_indices));
|
||||
}
|
||||
{
|
||||
uint16 data[] = {1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 3, 4};
|
||||
@ -174,7 +174,7 @@ TEST(SparseUtilsTest, ContainsEmptyRows) {
|
||||
const auto segment_indices = GetStartIndicesOfEachDenseRow<uint16>(
|
||||
indices_mat, &contains_empty_rows);
|
||||
// indices_list = {1, 1, 2, 2, 2, 3};
|
||||
EXPECT_FALSE(ContainsEmptyRows(segment_indices));
|
||||
EXPECT_TRUE(ContainsEmptyRows(segment_indices));
|
||||
}
|
||||
{
|
||||
int32 data[] = {0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 3, 4};
|
||||
|
Loading…
Reference in New Issue
Block a user