Internal change
PiperOrigin-RevId: 289974538 Change-Id: Ie67bf5810f8c529916a302100cc94b4883252c1b
This commit is contained in:
parent
8ff1179b74
commit
fc7e43de1f
@ -108,83 +108,6 @@ SparseTensor::SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
|
|||||||
DCHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
|
DCHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optimized version of `IndicesValid()` with the following requirements:
|
|
||||||
// * The sparse tensor is two-dimensional.
|
|
||||||
// * The tensor's indices are in the "standard" (lexicographic) order.
|
|
||||||
// * All of the tensor's indices fit within the range of a signed int32.
|
|
||||||
//
|
|
||||||
// Returns true if the indices are valid, otherwise false.
|
|
||||||
// NOTE(mrry): If this method returns false, call IndicesValidHelper<true>()
|
|
||||||
// to obtain a meaningful error message.
|
|
||||||
bool SparseTensor::IndicesValid32BitFastPath() const {
|
|
||||||
const auto ix_t = ix_.matrix<int64>();
|
|
||||||
const int64* const shape_ptr = shape_.data();
|
|
||||||
|
|
||||||
DCHECK_EQ(shape_.size(), 2);
|
|
||||||
DCHECK_EQ(order_[0], 0);
|
|
||||||
DCHECK_EQ(order_[1], 1);
|
|
||||||
DCHECK_LE(shape_ptr[0], std::numeric_limits<int32>::max());
|
|
||||||
DCHECK_LE(shape_ptr[1], std::numeric_limits<int32>::max());
|
|
||||||
|
|
||||||
const int32 max_rows = static_cast<int32>(shape_ptr[0]);
|
|
||||||
const int32 max_cols = static_cast<int32>(shape_ptr[1]);
|
|
||||||
|
|
||||||
// We maintain separate bools for each validation predicate to enable
|
|
||||||
// vectorization across loop iterations.
|
|
||||||
bool row_zeros_valid = true;
|
|
||||||
bool row_in_range_valid = true;
|
|
||||||
bool col_zeros_valid = true;
|
|
||||||
bool col_in_range_valid = true;
|
|
||||||
bool order_valid = true;
|
|
||||||
|
|
||||||
int64 prev_index = -1;
|
|
||||||
|
|
||||||
// Points to the beginning of the current row of the indices matrix.
|
|
||||||
// Each row has two int64 elements, but we use an int32 pointer to access
|
|
||||||
// the low and high 32 bits of each element separately. This means that our
|
|
||||||
// stride per row is 4 elements.
|
|
||||||
const int32* index_ptr = reinterpret_cast<const int32*>(ix_t.data());
|
|
||||||
const size_t kInt32ElementsPerRow = 4;
|
|
||||||
|
|
||||||
for (std::size_t n = 0; n < ix_t.dimension(0); ++n) {
|
|
||||||
index_ptr += kInt32ElementsPerRow;
|
|
||||||
|
|
||||||
// Unpack the values on the current row of the indices matrix.
|
|
||||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
|
||||||
const int32 row_zeros = index_ptr[0];
|
|
||||||
const int32 row_32 = index_ptr[1];
|
|
||||||
const int32 col_zeros = index_ptr[2];
|
|
||||||
const int32 col_32 = index_ptr[3];
|
|
||||||
#else
|
|
||||||
const int32 row_32 = index_ptr[0];
|
|
||||||
const int32 row_zeros = index_ptr[1];
|
|
||||||
const int32 col_32 = index_ptr[2];
|
|
||||||
const int32 col_zeros = index_ptr[3];
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Validate that the high 32 bits of the row and column indices are zero.
|
|
||||||
row_zeros_valid = row_zeros_valid & (row_zeros == 0);
|
|
||||||
col_zeros_valid = col_zeros_valid & (col_zeros == 0);
|
|
||||||
|
|
||||||
// Validate that the low 32 bits of the row and column indices are within
|
|
||||||
// range of the shape.
|
|
||||||
row_in_range_valid =
|
|
||||||
row_in_range_valid & (row_32 >= 0) & (row_32 < max_rows);
|
|
||||||
col_in_range_valid =
|
|
||||||
col_in_range_valid & (col_32 >= 0) & (col_32 < max_cols);
|
|
||||||
|
|
||||||
// Interpret the row and column as a concatenated 64-bit integer, and
|
|
||||||
// validate that the concatenated indices are in strictly increasing order.
|
|
||||||
const int64 concatenated_index =
|
|
||||||
(static_cast<int64>(row_32) << 32) + col_32;
|
|
||||||
order_valid = order_valid & (concatenated_index > prev_index);
|
|
||||||
prev_index = concatenated_index;
|
|
||||||
}
|
|
||||||
|
|
||||||
return row_zeros_valid & row_in_range_valid & col_zeros_valid &
|
|
||||||
col_in_range_valid & order_valid;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool standard_order>
|
template <bool standard_order>
|
||||||
Status SparseTensor::IndicesValidHelper() const {
|
Status SparseTensor::IndicesValidHelper() const {
|
||||||
const auto ix_t = ix_.matrix<int64>();
|
const auto ix_t = ix_.matrix<int64>();
|
||||||
@ -251,12 +174,6 @@ Status SparseTensor::IndicesValid() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (standard_order) {
|
if (standard_order) {
|
||||||
if (shape_.size() == 2 && shape_[0] <= std::numeric_limits<int32>::max() &&
|
|
||||||
shape_[1] <= std::numeric_limits<int32>::max()) {
|
|
||||||
if (IndicesValid32BitFastPath()) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return IndicesValidHelper<true>();
|
return IndicesValidHelper<true>();
|
||||||
} else {
|
} else {
|
||||||
return IndicesValidHelper<false>();
|
return IndicesValidHelper<false>();
|
||||||
|
@ -201,8 +201,6 @@ class SparseTensor {
|
|||||||
return vec;
|
return vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IndicesValid32BitFastPath() const;
|
|
||||||
|
|
||||||
template <bool standard_order>
|
template <bool standard_order>
|
||||||
Status IndicesValidHelper() const;
|
Status IndicesValidHelper() const;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user