Use a more sophisticated technique to iterate over DeviceSet

Checking each bit individually is wasteful.

PiperOrigin-RevId: 247052784
This commit is contained in:
Sanjoy Das 2019-05-07 11:05:00 -07:00 committed by TensorFlower Gardener
parent 098ff16ec9
commit 6c6bfb97cc

View File

@ -71,17 +71,34 @@ class DeviceSet {
// iterator if this ends up being used widely.
for (int word_index = 0; word_index < storage_.size(); word_index++) {
uint64 word = storage_[word_index];
for (int bit_index = 0; bit_index < kWordSize; bit_index++) {
if (word & (1ull << bit_index)) {
while (word != 0) {
uint64 only_lowest_bit_set = word & -word;
// The number of trailing zeros in a non-zero word is the index of the
// least significant 1.
int bit_index = ctz_uint64(word);
if (!func(DeviceId(word_index * kWordSize + bit_index))) {
return;
}
}
word ^= only_lowest_bit_set;
}
}
}
private:
static int ctz_uint64(uint64 x) {
DCHECK_NE(x, 0);
#ifdef __GNUC__
return __builtin_ctzl(x);
#else
int result = 0u;
while ((x & 1u) == 0u) {
x >>= 1;
++result;
}
return result;
#endif
}
absl::InlinedVector<uint64, 1> storage_;
const int kWordSize = 64;