Added new tensorflow::gtl::FlatMap and tensorflow::gtl::FlatSet classes.

Mostly drop-in replacements for std::unordered_map and std::unordered_set,
but much faster (does not do an allocation per entry, and represents
entries in groups of 8 in a flat array, which is much more cache efficient).

Benchmarks not included in this cl show about 3X to 5X performance
improvements over the std::unordered_{set,map} for many kinds of
common maps e.g. std::unordered_mapmap<int64, int64> or
std::unordered_map<string, int64>.
Change: 137401863
This commit is contained in:
A. Unique TensorFlower 2016-10-27 08:10:54 -08:00 committed by TensorFlower Gardener
parent e43eaf662d
commit 80aec93166
7 changed files with 2057 additions and 0 deletions

View File

@ -164,6 +164,8 @@ cc_library(
"lib/core/threadpool.h",
"lib/gtl/array_slice.h",
"lib/gtl/cleanup.h",
"lib/gtl/flatmap.h",
"lib/gtl/flatset.h",
"lib/gtl/inlined_vector.h",
"lib/gtl/priority_queue_util.h",
"lib/hash/crc32c.h",
@ -1447,6 +1449,8 @@ tf_cc_tests(
"lib/gtl/array_slice_test.cc",
"lib/gtl/cleanup_test.cc",
"lib/gtl/edit_distance_test.cc",
"lib/gtl/flatmap_test.cc",
"lib/gtl/flatset_test.cc",
"lib/gtl/inlined_vector_test.cc",
"lib/gtl/int_type_test.cc",
"lib/gtl/iterator_range_test.cc",

View File

@ -0,0 +1,349 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
#include <stddef.h>
#include <utility>
#include "tensorflow/core/lib/gtl/flatrep.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace gtl {
// FlatMap<K,V,...> provides a map from K to V.
//
// The map is implemented using an open-addressed hash table. A
// single array holds entire map contents and collisions are resolved
// by probing at a sequence of locations in the array.
template <typename Key, typename Val, class Hash, class Eq = std::equal_to<Key>>
class FlatMap {
private:
// Forward declare some internal types needed in public section.
struct Bucket;
public:
typedef Key key_type;
typedef Val mapped_type;
typedef Hash hasher;
typedef Eq key_equal;
typedef size_t size_type;
typedef ptrdiff_t difference_type;
// We cannot use std::pair<> since internal representation stores
// keys and values in separate arrays, so we make a custom struct
// that holds references to the internal key, value elements.
struct value_type {
typedef Key first_type;
typedef Val second_type;
const Key& first;
Val& second;
value_type(const Key& k, Val& v) : first(k), second(v) {}
};
typedef value_type* pointer;
typedef const value_type* const_pointer;
typedef value_type& reference;
typedef const value_type& const_reference;
FlatMap() : FlatMap(1) {}
explicit FlatMap(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq())
: rep_(N, hf, eq) {}
FlatMap(const FlatMap& src) : rep_(src.rep_) {}
template <typename InputIter>
FlatMap(InputIter first, InputIter last, size_t N = 1,
const Hash& hf = Hash(), const Eq& eq = Eq())
: FlatMap(N, hf, eq) {
insert(first, last);
}
FlatMap& operator=(const FlatMap& src) {
rep_.CopyFrom(src.rep_);
return *this;
}
~FlatMap() {}
void swap(FlatMap& x) { rep_.swap(x.rep_); }
void clear_no_resize() { rep_.clear_no_resize(); }
void clear() { rep_.clear(); }
void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
void rehash(size_t N) { rep_.Resize(std::max(N, size())); }
void resize(size_t N) { rep_.Resize(std::max(N, size())); }
size_t size() const { return rep_.size(); }
bool empty() const { return size() == 0; }
size_t bucket_count() const { return rep_.bucket_count(); }
hasher hash_function() const { return rep_.hash_function(); }
key_equal key_eq() const { return rep_.key_eq(); }
class iterator {
public:
iterator() : b_(nullptr), end_(nullptr), i_(0) {}
// Make iterator pointing at first element at or after b.
explicit iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) {
SkipUnused();
}
// Make iterator pointing exactly at ith element in b, which must exist.
iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) {
FillValue();
}
value_type& operator*() { return *val(); }
value_type* operator->() { return val(); }
bool operator==(const iterator& x) const {
return b_ == x.b_ && i_ == x.i_;
}
bool operator!=(const iterator& x) const { return !(*this == x); }
iterator& operator++() {
DCHECK(b_ != end_);
i_++;
SkipUnused();
return *this;
}
private:
friend class FlatMap;
Bucket* b_;
Bucket* end_;
uint32 i_;
char space_[sizeof(value_type)];
value_type* val() { return reinterpret_cast<value_type*>(space_); }
void FillValue() { new (space_) value_type(b_->key(i_), b_->val(i_)); }
void SkipUnused() {
while (b_ < end_) {
if (i_ >= Rep::kWidth) {
i_ = 0;
b_++;
} else if (b_->marker[i_] < 2) {
i_++;
} else {
FillValue();
break;
}
}
}
};
class const_iterator {
private:
mutable iterator rep_; // Share state and logic with non-const iterator.
public:
const_iterator() : rep_() {}
explicit const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {}
const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {}
const value_type& operator*() const { return *rep_.val(); }
const value_type* operator->() const { return rep_.val(); }
bool operator==(const const_iterator& x) const { return rep_ == x.rep_; }
bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; }
const_iterator& operator++() {
++rep_;
return *this;
}
};
iterator begin() { return iterator(rep_.start(), rep_.limit()); }
iterator end() { return iterator(rep_.limit(), rep_.limit()); }
const_iterator begin() const {
return const_iterator(rep_.start(), rep_.limit());
}
const_iterator end() const {
return const_iterator(rep_.limit(), rep_.limit());
}
size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; }
iterator find(const Key& k) {
auto r = rep_.Find(k);
return r.found ? iterator(r.b, rep_.limit(), r.index) : end();
}
const_iterator find(const Key& k) const {
auto r = rep_.Find(k);
return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end();
}
Val& at(const Key& k) {
auto r = rep_.Find(k);
DCHECK(r.found);
return r.b->val(r.index);
}
const Val& at(const Key& k) const {
auto r = rep_.Find(k);
DCHECK(r.found);
return r.b->val(r.index);
}
template <typename P>
std::pair<iterator, bool> insert(const P& p) {
return Insert(p.first, p.second);
}
std::pair<iterator, bool> insert(const std::pair<const Key, Val>& p) {
return Insert(p.first, p.second);
}
template <typename InputIter>
void insert(InputIter first, InputIter last) {
for (; first != last; ++first) {
insert(*first);
}
}
Val& operator[](const Key& k) { return IndexOp(k); }
Val& operator[](Key&& k) { return IndexOp(std::forward<Key>(k)); }
template <typename... Args>
std::pair<iterator, bool> emplace(Args&&... args) {
return InsertPair(std::make_pair(std::forward<Args>(args)...));
}
size_t erase(const Key& k) {
auto r = rep_.Find(k);
if (!r.found) return 0;
rep_.Erase(r.b, r.index);
return 1;
}
iterator erase(iterator pos) {
rep_.Erase(pos.b_, pos.i_);
++pos;
return pos;
}
iterator erase(iterator pos, iterator last) {
for (; pos != last; ++pos) {
rep_.Erase(pos.b_, pos.i_);
}
return pos;
}
std::pair<iterator, iterator> equal_range(const Key& k) {
auto pos = find(k);
if (pos == end()) {
return std::make_pair(pos, pos);
} else {
auto next = pos;
++next;
return std::make_pair(pos, next);
}
}
std::pair<const_iterator, const_iterator> equal_range(const Key& k) const {
auto pos = find(k);
if (pos == end()) {
return std::make_pair(pos, pos);
} else {
auto next = pos;
++next;
return std::make_pair(pos, next);
}
}
bool operator==(const FlatMap& x) const {
if (size() != x.size()) return false;
for (auto& p : x) {
auto i = find(p.first);
if (i == end()) return false;
if (i->second != p.second) return false;
}
return true;
}
bool operator!=(const FlatMap& x) const { return !(*this == x); }
// If key exists in the table, prefetch the associated value. This
// is a hint, and may have no effect.
void prefetch_value(const Key& key) const { rep_.Prefetch(key); }
private:
using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>;
// Bucket stores kWidth <marker, key, value> triples.
// The data is organized as three parallel arrays to reduce padding.
struct Bucket {
uint8 marker[Rep::kWidth];
// Wrap keys and values in union to control construction and destruction.
union Storage {
struct {
Key key[Rep::kWidth];
Val val[Rep::kWidth];
};
Storage() {}
~Storage() {}
} storage;
Key& key(uint32 i) {
DCHECK_GE(marker[i], 2);
return storage.key[i];
}
Val& val(uint32 i) {
DCHECK_GE(marker[i], 2);
return storage.val[i];
}
template <typename V>
void InitVal(uint32 i, V&& v) {
new (&storage.val[i]) Val(std::forward<V>(v));
}
void Destroy(uint32 i) {
storage.key[i].Key::~Key();
storage.val[i].Val::~Val();
}
void MoveFrom(uint32 i, Bucket* src, uint32 src_index) {
new (&storage.key[i]) Key(std::move(src->storage.key[src_index]));
new (&storage.val[i]) Val(std::move(src->storage.val[src_index]));
}
void CopyFrom(uint32 i, Bucket* src, uint32 src_index) {
new (&storage.key[i]) Key(src->storage.key[src_index]);
new (&storage.val[i]) Val(src->storage.val[src_index]);
}
};
template <typename Pair>
std::pair<iterator, bool> InsertPair(Pair&& p) {
return Insert(std::forward<decltype(p.first)>(p.first),
std::forward<decltype(p.second)>(p.second));
}
template <typename K, typename V>
std::pair<iterator, bool> Insert(K&& k, V&& v) {
rep_.MaybeResize();
auto r = rep_.FindOrInsert(std::forward<K>(k));
const bool inserted = !r.found;
if (inserted) {
r.b->InitVal(r.index, std::forward<V>(v));
}
return {iterator(r.b, rep_.limit(), r.index), inserted};
}
template <typename K>
Val& IndexOp(K&& k) {
rep_.MaybeResize();
auto r = rep_.FindOrInsert(std::forward<K>(k));
Val* vptr = &r.b->val(r.index);
if (!r.found) {
new (vptr) Val(); // Initialize value in new slot.
}
return *vptr;
}
Rep rep_;
};
} // namespace gtl
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_

View File

@ -0,0 +1,576 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/lib/gtl/flatmap.h"
#include <algorithm>
#include <string>
#include <vector>
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace gtl {
namespace {
typedef FlatMap<int64, int32, HashInt64> NumMap;
// If map has an entry for k, return the corresponding value, else return def.
int32 Get(const NumMap& map, int64 k, int32 def = -1) {
auto iter = map.find(k);
if (iter == map.end()) {
EXPECT_EQ(map.count(k), 0);
return def;
} else {
EXPECT_EQ(map.count(k), 1);
EXPECT_EQ(&map.at(k), &iter->second);
EXPECT_EQ(iter->first, k);
return iter->second;
}
}
// Return contents of map as a sorted list of pairs.
typedef std::vector<std::pair<int64, int32>> NumMapContents;
NumMapContents Contents(const NumMap& map) {
NumMapContents result;
for (const auto& p : map) {
result.push_back({p.first, p.second});
}
std::sort(result.begin(), result.end());
return result;
}
// Fill entries with keys [start,limit).
void Fill(NumMap* map, int64 start, int64 limit) {
for (int64 i = start; i < limit; i++) {
map->insert({i, i * 100});
}
}
TEST(FlatMapTest, Find) {
NumMap map;
EXPECT_EQ(Get(map, 1), -1);
map.insert({1, 100});
map.insert({2, 200});
EXPECT_EQ(Get(map, 1), 100);
EXPECT_EQ(Get(map, 2), 200);
EXPECT_EQ(Get(map, 3), -1);
}
TEST(FlatMapTest, Insert) {
NumMap map;
EXPECT_EQ(Get(map, 1), -1);
// New entry.
auto result = map.insert({1, 100});
EXPECT_TRUE(result.second);
EXPECT_EQ(result.first->first, 1);
EXPECT_EQ(result.first->second, 100);
EXPECT_EQ(Get(map, 1), 100);
// Attempt to insert over existing entry.
result = map.insert({1, 200});
EXPECT_FALSE(result.second);
EXPECT_EQ(result.first->first, 1);
EXPECT_EQ(result.first->second, 100);
EXPECT_EQ(Get(map, 1), 100);
// Overwrite through iterator.
result.first->second = 300;
EXPECT_EQ(result.first->second, 300);
EXPECT_EQ(Get(map, 1), 300);
// Should get updated value.
result = map.insert({1, 400});
EXPECT_FALSE(result.second);
EXPECT_EQ(result.first->first, 1);
EXPECT_EQ(result.first->second, 300);
EXPECT_EQ(Get(map, 1), 300);
}
TEST(FlatMapTest, InsertGrowth) {
NumMap map;
const int n = 100;
Fill(&map, 0, 100);
EXPECT_EQ(map.size(), n);
for (int i = 0; i < n; i++) {
EXPECT_EQ(Get(map, i), i * 100) << i;
}
}
TEST(FlatMapTest, Emplace) {
NumMap map;
// New entry.
auto result = map.emplace(1, 100);
EXPECT_TRUE(result.second);
EXPECT_EQ(result.first->first, 1);
EXPECT_EQ(result.first->second, 100);
EXPECT_EQ(Get(map, 1), 100);
// Attempt to insert over existing entry.
result = map.emplace(1, 200);
EXPECT_FALSE(result.second);
EXPECT_EQ(result.first->first, 1);
EXPECT_EQ(result.first->second, 100);
EXPECT_EQ(Get(map, 1), 100);
// Overwrite through iterator.
result.first->second = 300;
EXPECT_EQ(result.first->second, 300);
EXPECT_EQ(Get(map, 1), 300);
// Update a second value
result = map.emplace(2, 400);
EXPECT_TRUE(result.second);
EXPECT_EQ(result.first->first, 2);
EXPECT_EQ(result.first->second, 400);
EXPECT_EQ(Get(map, 2), 400);
}
TEST(FlatMapTest, EmplaceUniquePtr) {
FlatMap<int64, std::unique_ptr<string>, HashInt64> smap;
smap.emplace(1, std::unique_ptr<string>(new string("hello")));
}
TEST(FlatMapTest, Size) {
NumMap map;
EXPECT_EQ(map.size(), 0);
map.insert({1, 100});
map.insert({2, 200});
EXPECT_EQ(map.size(), 2);
}
TEST(FlatMapTest, Empty) {
NumMap map;
EXPECT_TRUE(map.empty());
map.insert({1, 100});
map.insert({2, 200});
EXPECT_FALSE(map.empty());
}
TEST(FlatMapTest, ArrayOperator) {
NumMap map;
// Create new element if not found.
auto v1 = &map[1];
EXPECT_EQ(*v1, 0);
EXPECT_EQ(Get(map, 1), 0);
// Write through returned reference.
*v1 = 100;
EXPECT_EQ(map[1], 100);
EXPECT_EQ(Get(map, 1), 100);
// Reuse existing element if found.
auto v1a = &map[1];
EXPECT_EQ(v1, v1a);
EXPECT_EQ(*v1, 100);
// Create another element.
map[2] = 200;
EXPECT_EQ(Get(map, 1), 100);
EXPECT_EQ(Get(map, 2), 200);
}
TEST(FlatMapTest, Count) {
NumMap map;
EXPECT_EQ(map.count(1), 0);
EXPECT_EQ(map.count(2), 0);
map.insert({1, 100});
EXPECT_EQ(map.count(1), 1);
EXPECT_EQ(map.count(2), 0);
map.insert({2, 200});
EXPECT_EQ(map.count(1), 1);
EXPECT_EQ(map.count(2), 1);
}
TEST(FlatMapTest, Iter) {
NumMap map;
EXPECT_EQ(Contents(map), NumMapContents());
map.insert({1, 100});
map.insert({2, 200});
EXPECT_EQ(Contents(map), NumMapContents({{1, 100}, {2, 200}}));
}
TEST(FlatMapTest, Erase) {
NumMap map;
EXPECT_EQ(map.erase(1), 0);
map[1] = 100;
map[2] = 200;
EXPECT_EQ(map.erase(3), 0);
EXPECT_EQ(map.erase(1), 1);
EXPECT_EQ(map.size(), 1);
EXPECT_EQ(Get(map, 2), 200);
EXPECT_EQ(Contents(map), NumMapContents({{2, 200}}));
EXPECT_EQ(map.erase(2), 1);
EXPECT_EQ(Contents(map), NumMapContents());
}
TEST(FlatMapTest, EraseIter) {
NumMap map;
Fill(&map, 1, 11);
size_t size = 10;
for (auto iter = map.begin(); iter != map.end();) {
iter = map.erase(iter);
size--;
EXPECT_EQ(map.size(), size);
}
EXPECT_EQ(Contents(map), NumMapContents());
}
TEST(FlatMapTest, EraseIterPair) {
NumMap map;
Fill(&map, 1, 11);
NumMap expected;
auto p1 = map.begin();
expected.insert(*p1);
++p1;
expected.insert(*p1);
++p1;
auto p2 = map.end();
EXPECT_EQ(map.erase(p1, p2), map.end());
EXPECT_EQ(map.size(), 2);
EXPECT_EQ(Contents(map), Contents(expected));
}
TEST(FlatMapTest, EraseLongChains) {
// Make a map with lots of elements and erase a bunch of them to ensure
// that we are likely to hit them on future lookups.
NumMap map;
const int num = 128;
Fill(&map, 0, num);
for (int i = 0; i < num; i += 3) {
EXPECT_EQ(map.erase(i), 1);
}
for (int i = 0; i < num; i++) {
if ((i % 3) != 0) {
EXPECT_EQ(Get(map, i), i * 100);
} else {
EXPECT_EQ(map.count(i), 0);
}
}
// Erase remainder to trigger table shrinking.
const size_t orig_buckets = map.bucket_count();
for (int i = 0; i < num; i++) {
map.erase(i);
}
EXPECT_TRUE(map.empty());
EXPECT_EQ(map.bucket_count(), orig_buckets);
map[1] = 100; // Actual shrinking is triggered by an insert.
EXPECT_LT(map.bucket_count(), orig_buckets);
}
TEST(FlatMap, AlternatingInsertRemove) {
NumMap map;
map.insert({1000, 1000});
map.insert({2000, 1000});
map.insert({3000, 1000});
for (int i = 0; i < 10000; i++) {
map.insert({i, i});
map.erase(i);
}
}
TEST(FlatMap, ClearNoResize) {
NumMap map;
Fill(&map, 0, 100);
const size_t orig = map.bucket_count();
map.clear_no_resize();
EXPECT_EQ(map.size(), 0);
EXPECT_EQ(Contents(map), NumMapContents());
EXPECT_EQ(map.bucket_count(), orig);
}
TEST(FlatMap, Clear) {
NumMap map;
Fill(&map, 0, 100);
const size_t orig = map.bucket_count();
map.clear();
EXPECT_EQ(map.size(), 0);
EXPECT_EQ(Contents(map), NumMapContents());
EXPECT_LT(map.bucket_count(), orig);
}
TEST(FlatMap, Copy) {
for (int n = 0; n < 10; n++) {
NumMap src;
Fill(&src, 0, n);
NumMap copy = src;
EXPECT_EQ(Contents(src), Contents(copy));
NumMap copy2;
copy2 = src;
EXPECT_EQ(Contents(src), Contents(copy2));
copy2 = copy2; // Self-assignment
EXPECT_EQ(Contents(src), Contents(copy2));
}
}
TEST(FlatMap, InitFromIter) {
for (int n = 0; n < 10; n++) {
NumMap src;
Fill(&src, 0, n);
auto vec = Contents(src);
NumMap dst(vec.begin(), vec.end());
EXPECT_EQ(Contents(dst), vec);
}
}
TEST(FlatMap, InsertIter) {
NumMap a, b;
Fill(&a, 1, 10);
Fill(&b, 8, 20);
b[9] = 10000; // Should not get inserted into a since a already has 9
a.insert(b.begin(), b.end());
NumMap expected;
Fill(&expected, 1, 20);
EXPECT_EQ(Contents(a), Contents(expected));
}
TEST(FlatMap, Eq) {
NumMap empty;
NumMap elems;
Fill(&elems, 0, 5);
EXPECT_FALSE(empty == elems);
EXPECT_TRUE(empty != elems);
NumMap copy = elems;
EXPECT_TRUE(copy == elems);
EXPECT_FALSE(copy != elems);
NumMap changed = elems;
changed[3] = 1;
EXPECT_FALSE(changed == elems);
EXPECT_TRUE(changed != elems);
NumMap changed2 = elems;
changed2.erase(3);
EXPECT_FALSE(changed2 == elems);
EXPECT_TRUE(changed2 != elems);
}
TEST(FlatMap, Swap) {
NumMap a, b;
Fill(&a, 1, 5);
Fill(&b, 100, 200);
NumMap c = a;
NumMap d = b;
EXPECT_EQ(c, a);
EXPECT_EQ(d, b);
c.swap(d);
EXPECT_EQ(c, b);
EXPECT_EQ(d, a);
}
TEST(FlatMap, Reserve) {
NumMap src;
Fill(&src, 1, 100);
NumMap a = src;
a.reserve(10);
EXPECT_EQ(a, src);
NumMap b = src;
b.rehash(1000);
EXPECT_EQ(b, src);
}
TEST(FlatMap, EqualRangeMutable) {
NumMap map;
Fill(&map, 1, 10);
// Existing element
auto p1 = map.equal_range(3);
EXPECT_TRUE(p1.first != p1.second);
EXPECT_EQ(p1.first->first, 3);
EXPECT_EQ(p1.first->second, 300);
++p1.first;
EXPECT_TRUE(p1.first == p1.second);
// Missing element
auto p2 = map.equal_range(100);
EXPECT_TRUE(p2.first == p2.second);
}
TEST(FlatMap, EqualRangeConst) {
NumMap tmp;
Fill(&tmp, 1, 10);
const NumMap map = tmp;
// Existing element
auto p1 = map.equal_range(3);
EXPECT_TRUE(p1.first != p1.second);
EXPECT_EQ(p1.first->first, 3);
EXPECT_EQ(p1.first->second, 300);
++p1.first;
EXPECT_TRUE(p1.first == p1.second);
// Missing element
auto p2 = map.equal_range(100);
EXPECT_TRUE(p2.first == p2.second);
}
TEST(FlatMap, Prefetch) {
NumMap map;
Fill(&map, 0, 1000);
// Prefetch present and missing keys.
for (int i = 0; i < 2000; i++) {
map.prefetch_value(i);
}
}
// Non-copyable values should work.
struct NC {
int64 value;
NC() : value(-1) {}
NC(int64 v) : value(v) {}
NC(const NC& x) : value(x.value) {}
bool operator==(const NC& x) const { return value == x.value; }
};
struct HashNC {
size_t operator()(NC x) const { return x.value; }
};
TEST(FlatMap, NonCopyable) {
FlatMap<NC, NC, HashNC> map;
for (int i = 0; i < 100; i++) {
map[NC(i)] = NC(i * 100);
}
for (int i = 0; i < 100; i++) {
EXPECT_EQ(map.count(NC(i)), 1);
auto iter = map.find(NC(i));
EXPECT_NE(iter, map.end());
EXPECT_EQ(iter->first, NC(i));
EXPECT_EQ(iter->second, NC(i * 100));
EXPECT_EQ(map[NC(i)], NC(i * 100));
}
map.erase(NC(10));
EXPECT_EQ(map.count(NC(10)), 0);
}
// Test with heap-allocated objects so that mismanaged constructions
// or destructions will show up as errors under a sanitizer or
// heap checker.
TEST(FlatMap, ConstructDestruct) {
FlatMap<string, string, HashStr> map;
string k1 = "the quick brown fox jumped over the lazy dog";
string k2 = k1 + k1;
string k3 = k1 + k2;
map[k1] = k2;
map[k3] = k1;
EXPECT_EQ(k1, map.find(k1)->first);
EXPECT_EQ(k2, map.find(k1)->second);
EXPECT_EQ(k1, map[k3]);
map.erase(k3);
EXPECT_EQ(string(), map[k3]);
map.clear();
map[k1] = k2;
EXPECT_EQ(k2, map[k1]);
map.reserve(100);
EXPECT_EQ(k2, map[k1]);
}
// Type to use to ensure that custom equality operator is used
// that ignores extra value.
struct CustomCmpKey {
int64 a;
int64 b;
CustomCmpKey(int64 v1, int64 v2) : a(v1), b(v2) {}
bool operator==(const CustomCmpKey& x) const { return a == x.a && b == x.b; }
};
struct HashA {
size_t operator()(CustomCmpKey x) const { return x.a; }
};
struct EqA {
// Ignore b fields.
bool operator()(CustomCmpKey x, CustomCmpKey y) const { return x.a == y.a; }
};
TEST(FlatMap, CustomCmp) {
FlatMap<CustomCmpKey, int, HashA, EqA> map;
map[CustomCmpKey(100, 200)] = 300;
EXPECT_EQ(300, map[CustomCmpKey(100, 200)]);
EXPECT_EQ(300, map[CustomCmpKey(100, 500)]); // Differences in key.b ignored
}
// Test unique_ptr handling.
typedef std::unique_ptr<int> UniqInt;
static UniqInt MakeUniq(int i) { return UniqInt(new int(i)); }
struct HashUniq {
size_t operator()(const UniqInt& p) const { return *p; }
};
struct EqUniq {
bool operator()(const UniqInt& a, const UniqInt& b) const { return *a == *b; }
};
typedef FlatMap<UniqInt, UniqInt, HashUniq, EqUniq> UniqMap;
TEST(FlatMap, UniqueMap) {
UniqMap map;
// Fill map
const int N = 10;
for (int i = 0; i < N; i++) {
if ((i % 2) == 0) {
map[MakeUniq(i)] = MakeUniq(i + 100);
} else {
map.emplace(MakeUniq(i), MakeUniq(i + 100));
}
}
EXPECT_EQ(map.size(), N);
// Lookups
for (int i = 0; i < N; i++) {
EXPECT_EQ(*map.at(MakeUniq(i)), i + 100);
}
// find+erase
EXPECT_EQ(map.count(MakeUniq(2)), 1);
map.erase(MakeUniq(2));
EXPECT_EQ(map.count(MakeUniq(2)), 0);
// clear
map.clear();
EXPECT_EQ(map.size(), 0);
}
TEST(FlatMap, UniqueMapIter) {
UniqMap map;
const int kCount = 10;
const int kValueDelta = 100;
for (int i = 1; i <= kCount; i++) {
map[MakeUniq(i)] = MakeUniq(i + kValueDelta);
}
int key_sum = 0;
int val_sum = 0;
for (const auto& p : map) {
key_sum += *p.first;
val_sum += *p.second;
}
EXPECT_EQ(key_sum, (kCount * (kCount + 1)) / 2);
EXPECT_EQ(val_sum, key_sum + (kCount * kValueDelta));
}
} // namespace
} // namespace gtl
} // namespace tensorflow

View File

@ -0,0 +1,332 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
#include <string.h>
#include <utility>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace gtl {
namespace internal {
// Internal representation for FlatMap and FlatSet.
//
// The representation is an open-addressed hash table. Conceptually,
// the representation is a flat array of entries. However we
// structure it as an array of of buckets where each bucket holds
// kWidth entries along with metadata for the kWidth entries. The
// metadata marker is
//
// (a) kEmpty: the entry is empty
// (b) kDeleted: the entry has been deleted
// (c) other: the entry is occupied and has low-8 bits of its hash.
// These hash bits can be used to avoid potentially expensive
// key comparisons.
//
// FlatMap passes in a bucket that contains keys and values, FlatSet
// passes in a bucket that does not contain values.
template <typename Key, typename Bucket, class Hash, class Eq>
class FlatRep {
public:
// kWidth is the number of entries stored in a bucket.
static const uint32 kBase = 3;
static const uint32 kWidth = (1 << kBase);
FlatRep(size_t N, const Hash& hf, const Eq& eq) : hash_(hf), equal_(eq) {
Init(N);
}
explicit FlatRep(const FlatRep& src) : hash_(src.hash_), equal_(src.equal_) {
Init(src.size());
CopyEntries(src.array_, src.end_, CopyEntry());
}
~FlatRep() {
clear_no_resize();
delete[] array_;
}
// Simple accessors.
size_t size() const { return not_empty_ - deleted_; }
size_t bucket_count() const { return mask_ + 1; }
Bucket* start() const { return array_; }
Bucket* limit() const { return end_; }
const Hash& hash_function() const { return hash_; }
const Eq& key_eq() const { return equal_; }
// Overwrite contents of *this with contents of src.
void CopyFrom(const FlatRep& src) {
if (this != &src) {
clear_no_resize();
delete[] array_;
Init(src.size());
CopyEntries(src.array_, src.end_, CopyEntry());
}
}
void clear_no_resize() {
for (Bucket* b = array_; b != end_; b++) {
for (uint32 i = 0; i < kWidth; i++) {
if (b->marker[i] >= 2) {
b->Destroy(i);
b->marker[i] = kEmpty;
}
}
}
not_empty_ = 0;
deleted_ = 0;
}
void clear() {
clear_no_resize();
grow_ = 0; // Consider shrinking in MaybeResize()
MaybeResize();
}
void swap(FlatRep& x) {
using std::swap;
swap(array_, x.array_);
swap(end_, x.end_);
swap(lglen_, x.lglen_);
swap(mask_, x.mask_);
swap(not_empty_, x.not_empty_);
swap(deleted_, x.deleted_);
swap(grow_, x.grow_);
swap(shrink_, x.shrink_);
}
struct SearchResult {
bool found;
Bucket* b;
uint32 index;
};
// Hash value is partitioned as follows:
// 1. Bottom 8 bits are stored in bucket to help speed up comparisons.
// 2. Next 3 bits give index inside bucket.
// 3. Remaining bits give bucket number.
// Find bucket/index for key k.
SearchResult Find(const Key& k) const {
size_t h = hash_(k);
const uint32 marker = Marker(h & 0xff);
size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket
uint32 num_probes = 1; // Needed for quadratic probing
while (true) {
uint32 bi = index & (kWidth - 1);
Bucket* b = &array_[index >> kBase];
const uint32 x = b->marker[bi];
if (x == marker && equal_(b->key(bi), k)) {
return {true, b, bi};
} else if (x == kEmpty) {
return {false, nullptr, 0};
}
// Quadratic probing.
index = (index + num_probes) & mask_;
num_probes++;
}
}
// Find bucket/index for key k, creating a new one if necessary.
//
// KeyType is a template parameter so that k's type is deduced and it
// becomes a universal reference which allows the key initialization
// below to use an rvalue constructor if available.
template <typename KeyType>
SearchResult FindOrInsert(KeyType&& k) {
size_t h = hash_(k);
const uint32 marker = Marker(h & 0xff);
size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket
uint32 num_probes = 1; // Needed for quadratic probing
Bucket* del = nullptr; // First encountered deletion for kInsert
uint32 di = 0;
while (true) {
uint32 bi = index & (kWidth - 1);
Bucket* b = &array_[index >> kBase];
const uint32 x = b->marker[bi];
if (x == marker && equal_(b->key(bi), k)) {
return {true, b, bi};
} else if (!del && x == kDeleted) {
// Remember deleted index to use for insertion.
del = b;
di = bi;
} else if (x == kEmpty) {
if (del) {
// Store in the first deleted slot we encountered
b = del;
bi = di;
deleted_--; // not_empty_ does not change
} else {
not_empty_++;
}
b->marker[bi] = marker;
new (&b->key(bi)) Key(std::forward<KeyType>(k));
return {false, b, bi};
}
// Quadratic probing.
index = (index + num_probes) & mask_;
num_probes++;
}
}
void Erase(Bucket* b, uint32 i) {
b->Destroy(i);
b->marker[i] = kDeleted;
deleted_++;
grow_ = 0; // Consider shrinking on next insert
}
void Prefetch(const Key& k) const {
size_t h = hash_(k);
size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket
uint32 bi = index & (kWidth - 1);
Bucket* b = &array_[index >> kBase];
prefetch(&b->storage.key[bi]);
}
void prefetch(const void* ptr) const {
// TODO(jeff,sanjay): Remove this routine when we add a
// prefetch(...) call to platform so that the Prefetch routine
// actually does something
}
inline void MaybeResize() {
if (not_empty_ < grow_) {
return; // Nothing to do
}
if (grow_ == 0) {
// Special value set by erase to cause shrink on next insert.
if (size() >= shrink_) {
// Not small enough to shrink.
grow_ = static_cast<size_t>(bucket_count() * 0.8);
if (not_empty_ < grow_) return;
}
}
Resize(size() + 1);
}
void Resize(size_t N) {
Bucket* old = array_;
Bucket* old_end = end_;
Init(N);
CopyEntries(old, old_end, MoveEntry());
delete[] old;
}
private:
enum { kEmpty = 0, kDeleted = 1 }; // Special markers for an entry.
Hash hash_; // User-supplied hasher
Eq equal_; // User-supplied comparator
uint8 lglen_; // lg(#buckets)
Bucket* array_; // array of length (1 << lglen_)
Bucket* end_; // Points just past last bucket in array_
size_t mask_; // (# of entries in table) - 1
size_t not_empty_; // Count of entries with marker != kEmpty
size_t deleted_; // Count of entries with marker == kDeleted
size_t grow_; // Grow array when not_empty_ >= grow_
size_t shrink_; // Shrink array when size() < shrink_
// Avoid kEmpty and kDeleted markers when computing hash values to
// store in Bucket::marker[].
static uint32 Marker(uint32 hb) { return hb + (hb < 2 ? 2 : 0); }
void Init(size_t N) {
// Make enough room for N elements.
size_t lg = 0; // Smallest table is just one bucket.
while (N >= 0.8 * ((1 << lg) * kWidth)) {
lg++;
}
const size_t n = (1 << lg);
Bucket* array = new Bucket[n];
for (size_t i = 0; i < n; i++) {
Bucket* b = &array[i];
memset(b->marker, kEmpty, kWidth);
}
const size_t capacity = (1 << lg) * kWidth;
lglen_ = lg;
mask_ = capacity - 1;
array_ = array;
end_ = array + n;
not_empty_ = 0;
deleted_ = 0;
grow_ = static_cast<size_t>(capacity * 0.8);
if (lg == 0) {
// Already down to one bucket; no more shrinking.
shrink_ = 0;
} else {
shrink_ = static_cast<size_t>(grow_ * 0.4); // Must be less than 0.5
}
}
// Used by FreshInsert when we should copy from source.
struct CopyEntry {
inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) {
dst->CopyFrom(dsti, src, srci);
}
};
// Used by FreshInsert when we should move from source.
struct MoveEntry {
inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) {
dst->MoveFrom(dsti, src, srci);
src->Destroy(srci);
src->marker[srci] = kDeleted;
}
};
template <typename Copier>
void CopyEntries(Bucket* start, Bucket* end, Copier copier) {
for (Bucket* b = start; b != end; b++) {
for (uint32 i = 0; i < kWidth; i++) {
if (b->marker[i] >= 2) {
FreshInsert(b, i, copier);
}
}
}
}
// Create an entry for the key numbered src_index in *src and return
// its bucket/index. Used for insertion into a fresh table. We
// assume that there are no deletions, and k does not already exist
// in the table.
template <typename Copier>
void FreshInsert(Bucket* src, uint32 src_index, Copier copier) {
size_t h = hash_(src->key(src_index));
const uint32 marker = Marker(h & 0xff);
size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket
uint32 num_probes = 1; // Needed for quadratic probing
while (true) {
uint32 bi = index & (kWidth - 1);
Bucket* b = &array_[index >> kBase];
const uint32 x = b->marker[bi];
if (x == 0) {
b->marker[bi] = marker;
not_empty_++;
copier(b, bi, src, src_index);
return;
}
// Quadratic probing.
index = (index + num_probes) & mask_;
num_probes++;
}
}
};
} // namespace internal
} // namespace gtl
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_

View File

@ -0,0 +1,277 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
#include <stddef.h>
#include <utility>
#include "tensorflow/core/lib/gtl/flatrep.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace gtl {
// FlatSet<K,...> provides a set of K.
//
// The map is implemented using an open-addressed hash table. A
// single array holds entire map contents and collisions are resolved
// by probing at a sequence of locations in the array.
template <typename Key, class Hash, class Eq = std::equal_to<Key>>
class FlatSet {
private:
// Forward declare some internal types needed in public section.
struct Bucket;
public:
typedef Key key_type;
typedef Key value_type;
typedef Hash hasher;
typedef Eq key_equal;
typedef size_t size_type;
typedef ptrdiff_t difference_type;
typedef value_type* pointer;
typedef const value_type* const_pointer;
typedef value_type& reference;
typedef const value_type& const_reference;
FlatSet() : FlatSet(1) {}
explicit FlatSet(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq())
: rep_(N, hf, eq) {}
FlatSet(const FlatSet& src) : rep_(src.rep_) {}
template <typename InputIter>
FlatSet(InputIter first, InputIter last, size_t N = 1,
const Hash& hf = Hash(), const Eq& eq = Eq())
: FlatSet(N, hf, eq) {
insert(first, last);
}
FlatSet& operator=(const FlatSet& src) {
rep_.CopyFrom(src.rep_);
return *this;
}
~FlatSet() {}
void swap(FlatSet& x) { rep_.swap(x.rep_); }
void clear_no_resize() { rep_.clear_no_resize(); }
void clear() { rep_.clear(); }
void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
void rehash(size_t N) { rep_.Resize(std::max(N, size())); }
void resize(size_t N) { rep_.Resize(std::max(N, size())); }
size_t size() const { return rep_.size(); }
bool empty() const { return size() == 0; }
size_t bucket_count() const { return rep_.bucket_count(); }
hasher hash_function() const { return rep_.hash_function(); }
key_equal key_eq() const { return rep_.key_eq(); }
class iterator {
public:
iterator() : b_(nullptr), end_(nullptr), i_(0) {}
// Make iterator pointing at first element at or after b.
explicit iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) {
SkipUnused();
}
// Make iterator pointing exactly at ith element in b, which must exist.
iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) {}
Key& operator*() { return key(); }
Key* operator->() { return &key(); }
bool operator==(const iterator& x) const {
return b_ == x.b_ && i_ == x.i_;
}
bool operator!=(const iterator& x) const { return !(*this == x); }
iterator& operator++() {
DCHECK(b_ != end_);
i_++;
SkipUnused();
return *this;
}
private:
friend class FlatSet;
Bucket* b_;
Bucket* end_;
uint32 i_;
Key& key() const { return b_->key(i_); }
void SkipUnused() {
while (b_ < end_) {
if (i_ >= Rep::kWidth) {
i_ = 0;
b_++;
} else if (b_->marker[i_] < 2) {
i_++;
} else {
break;
}
}
}
};
class const_iterator {
private:
mutable iterator rep_; // Share state and logic with non-const iterator.
public:
const_iterator() : rep_() {}
explicit const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {}
const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {}
const Key& operator*() const { return rep_.key(); }
const Key* operator->() const { return &rep_.key(); }
bool operator==(const const_iterator& x) const { return rep_ == x.rep_; }
bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; }
const_iterator& operator++() {
++rep_;
return *this;
}
};
iterator begin() { return iterator(rep_.start(), rep_.limit()); }
iterator end() { return iterator(rep_.limit(), rep_.limit()); }
const_iterator begin() const {
return const_iterator(rep_.start(), rep_.limit());
}
const_iterator end() const {
return const_iterator(rep_.limit(), rep_.limit());
}
size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; }
iterator find(const Key& k) {
auto r = rep_.Find(k);
return r.found ? iterator(r.b, rep_.limit(), r.index) : end();
}
const_iterator find(const Key& k) const {
auto r = rep_.Find(k);
return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end();
}
std::pair<iterator, bool> insert(const Key& k) { return Insert(k); }
template <typename InputIter>
void insert(InputIter first, InputIter last) {
for (; first != last; ++first) {
insert(*first);
}
}
template <typename... Args>
std::pair<iterator, bool> emplace(Args&&... args) {
rep_.MaybeResize();
auto r = rep_.FindOrInsert(std::forward<Args>(args)...);
const bool inserted = !r.found;
return {iterator(r.b, rep_.limit(), r.index), inserted};
}
size_t erase(const Key& k) {
auto r = rep_.Find(k);
if (!r.found) return 0;
rep_.Erase(r.b, r.index);
return 1;
}
iterator erase(iterator pos) {
rep_.Erase(pos.b_, pos.i_);
++pos;
return pos;
}
iterator erase(iterator pos, iterator last) {
for (; pos != last; ++pos) {
rep_.Erase(pos.b_, pos.i_);
}
return pos;
}
std::pair<iterator, iterator> equal_range(const Key& k) {
auto pos = find(k);
if (pos == end()) {
return std::make_pair(pos, pos);
} else {
auto next = pos;
++next;
return std::make_pair(pos, next);
}
}
std::pair<const_iterator, const_iterator> equal_range(const Key& k) const {
auto pos = find(k);
if (pos == end()) {
return std::make_pair(pos, pos);
} else {
auto next = pos;
++next;
return std::make_pair(pos, next);
}
}
bool operator==(const FlatSet& x) const {
if (size() != x.size()) return false;
for (const auto& elem : x) {
auto i = find(elem);
if (i == end()) return false;
}
return true;
}
bool operator!=(const FlatSet& x) const { return !(*this == x); }
// If key exists in the table, prefetch it. This is a hint, and may
// have no effect.
void prefetch_value(const Key& key) const { rep_.Prefetch(key); }
private:
using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>;
// Bucket stores kWidth <marker, key, value> triples.
// The data is organized as three parallel arrays to reduce padding.
struct Bucket {
uint8 marker[Rep::kWidth];
// Wrap keys in union to control construction and destruction.
union Storage {
Key key[Rep::kWidth];
Storage() {}
~Storage() {}
} storage;
Key& key(uint32 i) {
DCHECK_GE(marker[i], 2);
return storage.key[i];
}
void Destroy(uint32 i) { storage.key[i].Key::~Key(); }
void MoveFrom(uint32 i, Bucket* src, uint32 src_index) {
new (&storage.key[i]) Key(std::move(src->storage.key[src_index]));
}
void CopyFrom(uint32 i, Bucket* src, uint32 src_index) {
new (&storage.key[i]) Key(src->storage.key[src_index]);
}
};
std::pair<iterator, bool> Insert(const Key& k) {
rep_.MaybeResize();
auto r = rep_.FindOrInsert(k);
const bool inserted = !r.found;
return {iterator(r.b, rep_.limit(), r.index), inserted};
}
Rep rep_;
};
} // namespace gtl
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_

View File

@ -0,0 +1,501 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/lib/gtl/flatset.h"
#include <algorithm>
#include <string>
#include <vector>
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace gtl {
namespace {
typedef FlatSet<int64, HashInt64> NumSet;
// Returns true iff set has an entry for k.
// Also verifies that find and count give consistent results.
bool Has(const NumSet& set, int64 k) {
auto iter = set.find(k);
if (iter == set.end()) {
EXPECT_EQ(set.count(k), 0);
return false;
} else {
EXPECT_EQ(set.count(k), 1);
EXPECT_EQ(*iter, k);
return true;
}
}
// Return contents of set as a sorted list of numbers.
typedef std::vector<int64> NumSetContents;
NumSetContents Contents(const NumSet& set) {
NumSetContents result;
for (int64 n : set) {
result.push_back(n);
}
std::sort(result.begin(), result.end());
return result;
}
// Fill entries with keys [start,limit).
void Fill(NumSet* set, int64 start, int64 limit) {
for (int64 i = start; i < limit; i++) {
set->insert(i);
}
}
TEST(FlatSetTest, Find) {
NumSet set;
EXPECT_FALSE(Has(set, 1));
set.insert(1);
set.insert(2);
EXPECT_TRUE(Has(set, 1));
EXPECT_TRUE(Has(set, 2));
EXPECT_FALSE(Has(set, 3));
}
TEST(FlatSetTest, Insert) {
NumSet set;
EXPECT_FALSE(Has(set, 1));
// New entry.
auto result = set.insert(1);
EXPECT_TRUE(result.second);
EXPECT_EQ(*result.first, 1);
EXPECT_TRUE(Has(set, 1));
// Attempt to insert over existing entry.
result = set.insert(1);
EXPECT_FALSE(result.second);
EXPECT_EQ(*result.first, 1);
EXPECT_TRUE(Has(set, 1));
}
TEST(FlatSetTest, InsertGrowth) {
NumSet set;
const int n = 100;
Fill(&set, 0, 100);
EXPECT_EQ(set.size(), n);
for (int i = 0; i < n; i++) {
EXPECT_TRUE(Has(set, i)) << i;
}
}
TEST(FlatSetTest, Emplace) {
NumSet set;
// New entry.
auto result = set.emplace(73);
EXPECT_TRUE(result.second);
EXPECT_EQ(*result.first, 73);
EXPECT_TRUE(Has(set, 73));
// Attempt to insert an existing entry.
result = set.emplace(73);
EXPECT_FALSE(result.second);
EXPECT_EQ(*result.first, 73);
EXPECT_TRUE(Has(set, 73));
// Add a second value
result = set.emplace(103);
EXPECT_TRUE(result.second);
EXPECT_EQ(*result.first, 103);
EXPECT_TRUE(Has(set, 103));
}
TEST(FlatSetTest, Size) {
NumSet set;
EXPECT_EQ(set.size(), 0);
set.insert(1);
set.insert(2);
EXPECT_EQ(set.size(), 2);
}
TEST(FlatSetTest, Empty) {
NumSet set;
EXPECT_TRUE(set.empty());
set.insert(1);
set.insert(2);
EXPECT_FALSE(set.empty());
}
TEST(FlatSetTest, Count) {
NumSet set;
EXPECT_EQ(set.count(1), 0);
EXPECT_EQ(set.count(2), 0);
set.insert(1);
EXPECT_EQ(set.count(1), 1);
EXPECT_EQ(set.count(2), 0);
set.insert(2);
EXPECT_EQ(set.count(1), 1);
EXPECT_EQ(set.count(2), 1);
}
TEST(FlatSetTest, Iter) {
NumSet set;
EXPECT_EQ(Contents(set), NumSetContents());
set.insert(1);
set.insert(2);
EXPECT_EQ(Contents(set), NumSetContents({1, 2}));
}
TEST(FlatSetTest, Erase) {
NumSet set;
EXPECT_EQ(set.erase(1), 0);
set.insert(1);
set.insert(2);
EXPECT_EQ(set.erase(3), 0);
EXPECT_EQ(set.erase(1), 1);
EXPECT_EQ(set.size(), 1);
EXPECT_TRUE(Has(set, 2));
EXPECT_EQ(Contents(set), NumSetContents({2}));
EXPECT_EQ(set.erase(2), 1);
EXPECT_EQ(Contents(set), NumSetContents());
}
TEST(FlatSetTest, EraseIter) {
NumSet set;
Fill(&set, 1, 11);
size_t size = 10;
for (auto iter = set.begin(); iter != set.end();) {
iter = set.erase(iter);
size--;
EXPECT_EQ(set.size(), size);
}
EXPECT_EQ(Contents(set), NumSetContents());
}
TEST(FlatSetTest, EraseIterPair) {
NumSet set;
Fill(&set, 1, 11);
NumSet expected;
auto p1 = set.begin();
expected.insert(*p1);
++p1;
expected.insert(*p1);
++p1;
auto p2 = set.end();
EXPECT_EQ(set.erase(p1, p2), set.end());
EXPECT_EQ(set.size(), 2);
EXPECT_EQ(Contents(set), Contents(expected));
}
TEST(FlatSetTest, EraseLongChains) {
// Make a set with lots of elements and erase a bunch of them to ensure
// that we are likely to hit them on future lookups.
NumSet set;
const int num = 128;
Fill(&set, 0, num);
for (int i = 0; i < num; i += 3) {
EXPECT_EQ(set.erase(i), 1);
}
for (int i = 0; i < num; i++) {
// Multiples of 3 should be not present.
EXPECT_EQ(Has(set, i), ((i % 3) != 0)) << i;
}
// Erase remainder to trigger table shrinking.
const size_t orig_buckets = set.bucket_count();
for (int i = 0; i < num; i++) {
set.erase(i);
}
EXPECT_TRUE(set.empty());
EXPECT_EQ(set.bucket_count(), orig_buckets);
set.insert(1); // Actual shrinking is triggered by an insert.
EXPECT_LT(set.bucket_count(), orig_buckets);
}
TEST(FlatSet, ClearNoResize) {
NumSet set;
Fill(&set, 0, 100);
const size_t orig = set.bucket_count();
set.clear_no_resize();
EXPECT_EQ(set.size(), 0);
EXPECT_EQ(Contents(set), NumSetContents());
EXPECT_EQ(set.bucket_count(), orig);
}
TEST(FlatSet, Clear) {
NumSet set;
Fill(&set, 0, 100);
const size_t orig = set.bucket_count();
set.clear();
EXPECT_EQ(set.size(), 0);
EXPECT_EQ(Contents(set), NumSetContents());
EXPECT_LT(set.bucket_count(), orig);
}
TEST(FlatSet, Copy) {
for (int n = 0; n < 10; n++) {
NumSet src;
Fill(&src, 0, n);
NumSet copy = src;
EXPECT_EQ(Contents(src), Contents(copy));
NumSet copy2;
copy2 = src;
EXPECT_EQ(Contents(src), Contents(copy2));
copy2 = copy2; // Self-assignment
EXPECT_EQ(Contents(src), Contents(copy2));
}
}
TEST(FlatSet, InitFromIter) {
for (int n = 0; n < 10; n++) {
NumSet src;
Fill(&src, 0, n);
auto vec = Contents(src);
NumSet dst(vec.begin(), vec.end());
EXPECT_EQ(Contents(dst), vec);
}
}
TEST(FlatSet, InsertIter) {
NumSet a, b;
Fill(&a, 1, 10);
Fill(&b, 8, 20);
b.insert(9); // Should not get inserted into a since a already has 9
a.insert(b.begin(), b.end());
NumSet expected;
Fill(&expected, 1, 20);
EXPECT_EQ(Contents(a), Contents(expected));
}
TEST(FlatSet, Eq) {
NumSet empty;
NumSet elems;
Fill(&elems, 0, 5);
EXPECT_FALSE(empty == elems);
EXPECT_TRUE(empty != elems);
NumSet copy = elems;
EXPECT_TRUE(copy == elems);
EXPECT_FALSE(copy != elems);
NumSet changed = elems;
changed.insert(7);
EXPECT_FALSE(changed == elems);
EXPECT_TRUE(changed != elems);
NumSet changed2 = elems;
changed2.erase(3);
EXPECT_FALSE(changed2 == elems);
EXPECT_TRUE(changed2 != elems);
}
TEST(FlatSet, Swap) {
NumSet a, b;
Fill(&a, 1, 5);
Fill(&b, 100, 200);
NumSet c = a;
NumSet d = b;
EXPECT_EQ(c, a);
EXPECT_EQ(d, b);
c.swap(d);
EXPECT_EQ(c, b);
EXPECT_EQ(d, a);
}
TEST(FlatSet, Reserve) {
NumSet src;
Fill(&src, 1, 100);
NumSet a = src;
a.reserve(10);
EXPECT_EQ(a, src);
NumSet b = src;
b.rehash(1000);
EXPECT_EQ(b, src);
}
TEST(FlatSet, EqualRangeMutable) {
NumSet set;
Fill(&set, 1, 10);
// Existing element
auto p1 = set.equal_range(3);
EXPECT_TRUE(p1.first != p1.second);
EXPECT_EQ(*p1.first, 3);
++p1.first;
EXPECT_TRUE(p1.first == p1.second);
// Missing element
auto p2 = set.equal_range(100);
EXPECT_TRUE(p2.first == p2.second);
}
TEST(FlatSet, EqualRangeConst) {
NumSet tmp;
Fill(&tmp, 1, 10);
const NumSet set = tmp;
// Existing element
auto p1 = set.equal_range(3);
EXPECT_TRUE(p1.first != p1.second);
EXPECT_EQ(*p1.first, 3);
++p1.first;
EXPECT_TRUE(p1.first == p1.second);
// Missing element
auto p2 = set.equal_range(100);
EXPECT_TRUE(p2.first == p2.second);
}
TEST(FlatSet, Prefetch) {
NumSet set;
Fill(&set, 0, 1000);
// Prefetch present and missing keys.
for (int i = 0; i < 2000; i++) {
set.prefetch_value(i);
}
}
// Non-copyable values should work.
struct NC {
int64 value;
NC() : value(-1) {}
NC(int64 v) : value(v) {}
NC(const NC& x) : value(x.value) {}
bool operator==(const NC& x) const { return value == x.value; }
};
struct HashNC {
size_t operator()(NC x) const { return x.value; }
};
TEST(FlatSet, NonCopyable) {
FlatSet<NC, HashNC> set;
for (int i = 0; i < 100; i++) {
set.insert(NC(i));
}
for (int i = 0; i < 100; i++) {
EXPECT_EQ(set.count(NC(i)), 1);
auto iter = set.find(NC(i));
EXPECT_NE(iter, set.end());
EXPECT_EQ(*iter, NC(i));
}
set.erase(NC(10));
EXPECT_EQ(set.count(NC(10)), 0);
}
// Test with heap-allocated objects so that mismanaged constructions
// or destructions will show up as errors under a sanitizer or
// heap checker.
TEST(FlatSet, ConstructDestruct) {
FlatSet<string, HashStr> set;
string k1 = "the quick brown fox jumped over the lazy dog";
string k2 = k1 + k1;
string k3 = k1 + k2;
set.insert(k1);
set.insert(k3);
EXPECT_EQ(set.count(k1), 1);
EXPECT_EQ(set.count(k2), 0);
EXPECT_EQ(set.count(k3), 1);
set.erase(k3);
EXPECT_EQ(set.count(k3), 0);
set.clear();
set.insert(k1);
EXPECT_EQ(set.count(k1), 1);
EXPECT_EQ(set.count(k3), 0);
set.reserve(100);
EXPECT_EQ(set.count(k1), 1);
EXPECT_EQ(set.count(k3), 0);
}
// Type to use to ensure that custom equality operator is used
// that ignores extra value.
struct CustomCmpKey {
int64 a;
int64 b;
CustomCmpKey(int64 v1, int64 v2) : a(v1), b(v2) {}
bool operator==(const CustomCmpKey& x) const { return a == x.a && b == x.b; }
};
struct HashA {
size_t operator()(CustomCmpKey x) const { return x.a; }
};
struct EqA {
// Ignore b fields.
bool operator()(CustomCmpKey x, CustomCmpKey y) const { return x.a == y.a; }
};
TEST(FlatSet, CustomCmp) {
FlatSet<CustomCmpKey, HashA, EqA> set;
set.insert(CustomCmpKey(100, 200));
EXPECT_EQ(set.count(CustomCmpKey(100, 200)), 1);
EXPECT_EQ(set.count(CustomCmpKey(100, 500)), 1); // key.b ignored
}
// Test unique_ptr handling.
typedef std::unique_ptr<int> UniqInt;
static UniqInt MakeUniq(int i) { return UniqInt(new int(i)); }
struct HashUniq {
size_t operator()(const UniqInt& p) const { return *p; }
};
struct EqUniq {
bool operator()(const UniqInt& a, const UniqInt& b) const { return *a == *b; }
};
typedef FlatSet<UniqInt, HashUniq, EqUniq> UniqSet;
TEST(FlatSet, UniqueSet) {
UniqSet set;
// Fill set
const int N = 10;
for (int i = 0; i < N; i++) {
set.emplace(MakeUniq(i));
}
EXPECT_EQ(set.size(), N);
// Lookups
for (int i = 0; i < N; i++) {
EXPECT_EQ(set.count(MakeUniq(i)), 1);
}
// erase
set.erase(MakeUniq(2));
EXPECT_EQ(set.count(MakeUniq(2)), 0);
// clear
set.clear();
EXPECT_EQ(set.size(), 0);
}
TEST(FlatSet, UniqueSetIter) {
UniqSet set;
const int kCount = 10;
for (int i = 1; i <= kCount; i++) {
set.emplace(MakeUniq(i));
}
int sum = 0;
for (const auto& p : set) {
sum += *p;
}
EXPECT_EQ(sum, (kCount * (kCount + 1)) / 2);
}
} // namespace
} // namespace gtl
} // namespace tensorflow

View File

@ -42,6 +42,24 @@ inline uint64 Hash64Combine(uint64 a, uint64 b) {
return a ^ (b + 0x9e3779b97f4a7800ULL + (a << 10) + (a >> 4));
}
// Convenience Hash functors
struct HashInt64 {
size_t operator()(int64 x) const { return static_cast<size_t>(x); }
};
struct HashStr {
size_t operator()(const string& s) const {
return static_cast<size_t>(Hash64(s));
}
};
template <typename PTR>
struct HashPtr {
size_t operator()(const PTR p) const {
// Hash pointers as integers, but bring more entropy to the lower bits.
size_t k = static_cast<size_t>(reinterpret_cast<uintptr_t>(p));
return k + (k >> 6);
}
};
} // namespace tensorflow
#endif // TENSORFLOW_LIB_HASH_HASH_H_