Added new tensorflow::gtl::FlatMap and tensorflow::gtl::FlatSet classes.
Mostly drop-in replacements for std::unordered_map and std::unordered_set, but much faster (does not do an allocation per entry, and represents entries in groups of 8 in a flat array, which is much more cache efficient). Benchmarks not included in this cl show about 3X to 5X performance improvements over the std::unordered_{set,map} for many kinds of common maps e.g. std::unordered_mapmap<int64, int64> or std::unordered_map<string, int64>. Change: 137401863
This commit is contained in:
parent
e43eaf662d
commit
80aec93166
tensorflow/core
@ -164,6 +164,8 @@ cc_library(
|
||||
"lib/core/threadpool.h",
|
||||
"lib/gtl/array_slice.h",
|
||||
"lib/gtl/cleanup.h",
|
||||
"lib/gtl/flatmap.h",
|
||||
"lib/gtl/flatset.h",
|
||||
"lib/gtl/inlined_vector.h",
|
||||
"lib/gtl/priority_queue_util.h",
|
||||
"lib/hash/crc32c.h",
|
||||
@ -1447,6 +1449,8 @@ tf_cc_tests(
|
||||
"lib/gtl/array_slice_test.cc",
|
||||
"lib/gtl/cleanup_test.cc",
|
||||
"lib/gtl/edit_distance_test.cc",
|
||||
"lib/gtl/flatmap_test.cc",
|
||||
"lib/gtl/flatset_test.cc",
|
||||
"lib/gtl/inlined_vector_test.cc",
|
||||
"lib/gtl/int_type_test.cc",
|
||||
"lib/gtl/iterator_range_test.cc",
|
||||
|
349
tensorflow/core/lib/gtl/flatmap.h
Normal file
349
tensorflow/core/lib/gtl/flatmap.h
Normal file
@ -0,0 +1,349 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <utility>
|
||||
#include "tensorflow/core/lib/gtl/flatrep.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gtl {
|
||||
|
||||
// FlatMap<K,V,...> provides a map from K to V.
|
||||
//
|
||||
// The map is implemented using an open-addressed hash table. A
|
||||
// single array holds entire map contents and collisions are resolved
|
||||
// by probing at a sequence of locations in the array.
|
||||
template <typename Key, typename Val, class Hash, class Eq = std::equal_to<Key>>
|
||||
class FlatMap {
|
||||
private:
|
||||
// Forward declare some internal types needed in public section.
|
||||
struct Bucket;
|
||||
|
||||
public:
|
||||
typedef Key key_type;
|
||||
typedef Val mapped_type;
|
||||
typedef Hash hasher;
|
||||
typedef Eq key_equal;
|
||||
typedef size_t size_type;
|
||||
typedef ptrdiff_t difference_type;
|
||||
|
||||
// We cannot use std::pair<> since internal representation stores
|
||||
// keys and values in separate arrays, so we make a custom struct
|
||||
// that holds references to the internal key, value elements.
|
||||
struct value_type {
|
||||
typedef Key first_type;
|
||||
typedef Val second_type;
|
||||
|
||||
const Key& first;
|
||||
Val& second;
|
||||
value_type(const Key& k, Val& v) : first(k), second(v) {}
|
||||
};
|
||||
typedef value_type* pointer;
|
||||
typedef const value_type* const_pointer;
|
||||
typedef value_type& reference;
|
||||
typedef const value_type& const_reference;
|
||||
|
||||
FlatMap() : FlatMap(1) {}
|
||||
|
||||
explicit FlatMap(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq())
|
||||
: rep_(N, hf, eq) {}
|
||||
|
||||
FlatMap(const FlatMap& src) : rep_(src.rep_) {}
|
||||
|
||||
template <typename InputIter>
|
||||
FlatMap(InputIter first, InputIter last, size_t N = 1,
|
||||
const Hash& hf = Hash(), const Eq& eq = Eq())
|
||||
: FlatMap(N, hf, eq) {
|
||||
insert(first, last);
|
||||
}
|
||||
|
||||
FlatMap& operator=(const FlatMap& src) {
|
||||
rep_.CopyFrom(src.rep_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
~FlatMap() {}
|
||||
|
||||
void swap(FlatMap& x) { rep_.swap(x.rep_); }
|
||||
void clear_no_resize() { rep_.clear_no_resize(); }
|
||||
void clear() { rep_.clear(); }
|
||||
void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
|
||||
void rehash(size_t N) { rep_.Resize(std::max(N, size())); }
|
||||
void resize(size_t N) { rep_.Resize(std::max(N, size())); }
|
||||
size_t size() const { return rep_.size(); }
|
||||
bool empty() const { return size() == 0; }
|
||||
size_t bucket_count() const { return rep_.bucket_count(); }
|
||||
hasher hash_function() const { return rep_.hash_function(); }
|
||||
key_equal key_eq() const { return rep_.key_eq(); }
|
||||
|
||||
class iterator {
|
||||
public:
|
||||
iterator() : b_(nullptr), end_(nullptr), i_(0) {}
|
||||
|
||||
// Make iterator pointing at first element at or after b.
|
||||
explicit iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) {
|
||||
SkipUnused();
|
||||
}
|
||||
|
||||
// Make iterator pointing exactly at ith element in b, which must exist.
|
||||
iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) {
|
||||
FillValue();
|
||||
}
|
||||
|
||||
value_type& operator*() { return *val(); }
|
||||
value_type* operator->() { return val(); }
|
||||
bool operator==(const iterator& x) const {
|
||||
return b_ == x.b_ && i_ == x.i_;
|
||||
}
|
||||
bool operator!=(const iterator& x) const { return !(*this == x); }
|
||||
iterator& operator++() {
|
||||
DCHECK(b_ != end_);
|
||||
i_++;
|
||||
SkipUnused();
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class FlatMap;
|
||||
Bucket* b_;
|
||||
Bucket* end_;
|
||||
uint32 i_;
|
||||
char space_[sizeof(value_type)];
|
||||
|
||||
value_type* val() { return reinterpret_cast<value_type*>(space_); }
|
||||
void FillValue() { new (space_) value_type(b_->key(i_), b_->val(i_)); }
|
||||
void SkipUnused() {
|
||||
while (b_ < end_) {
|
||||
if (i_ >= Rep::kWidth) {
|
||||
i_ = 0;
|
||||
b_++;
|
||||
} else if (b_->marker[i_] < 2) {
|
||||
i_++;
|
||||
} else {
|
||||
FillValue();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class const_iterator {
|
||||
private:
|
||||
mutable iterator rep_; // Share state and logic with non-const iterator.
|
||||
public:
|
||||
const_iterator() : rep_() {}
|
||||
explicit const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {}
|
||||
const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {}
|
||||
|
||||
const value_type& operator*() const { return *rep_.val(); }
|
||||
const value_type* operator->() const { return rep_.val(); }
|
||||
bool operator==(const const_iterator& x) const { return rep_ == x.rep_; }
|
||||
bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; }
|
||||
const_iterator& operator++() {
|
||||
++rep_;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
iterator begin() { return iterator(rep_.start(), rep_.limit()); }
|
||||
iterator end() { return iterator(rep_.limit(), rep_.limit()); }
|
||||
const_iterator begin() const {
|
||||
return const_iterator(rep_.start(), rep_.limit());
|
||||
}
|
||||
const_iterator end() const {
|
||||
return const_iterator(rep_.limit(), rep_.limit());
|
||||
}
|
||||
|
||||
size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; }
|
||||
iterator find(const Key& k) {
|
||||
auto r = rep_.Find(k);
|
||||
return r.found ? iterator(r.b, rep_.limit(), r.index) : end();
|
||||
}
|
||||
const_iterator find(const Key& k) const {
|
||||
auto r = rep_.Find(k);
|
||||
return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end();
|
||||
}
|
||||
|
||||
Val& at(const Key& k) {
|
||||
auto r = rep_.Find(k);
|
||||
DCHECK(r.found);
|
||||
return r.b->val(r.index);
|
||||
}
|
||||
const Val& at(const Key& k) const {
|
||||
auto r = rep_.Find(k);
|
||||
DCHECK(r.found);
|
||||
return r.b->val(r.index);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
std::pair<iterator, bool> insert(const P& p) {
|
||||
return Insert(p.first, p.second);
|
||||
}
|
||||
std::pair<iterator, bool> insert(const std::pair<const Key, Val>& p) {
|
||||
return Insert(p.first, p.second);
|
||||
}
|
||||
template <typename InputIter>
|
||||
void insert(InputIter first, InputIter last) {
|
||||
for (; first != last; ++first) {
|
||||
insert(*first);
|
||||
}
|
||||
}
|
||||
|
||||
Val& operator[](const Key& k) { return IndexOp(k); }
|
||||
Val& operator[](Key&& k) { return IndexOp(std::forward<Key>(k)); }
|
||||
|
||||
template <typename... Args>
|
||||
std::pair<iterator, bool> emplace(Args&&... args) {
|
||||
return InsertPair(std::make_pair(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
size_t erase(const Key& k) {
|
||||
auto r = rep_.Find(k);
|
||||
if (!r.found) return 0;
|
||||
rep_.Erase(r.b, r.index);
|
||||
return 1;
|
||||
}
|
||||
iterator erase(iterator pos) {
|
||||
rep_.Erase(pos.b_, pos.i_);
|
||||
++pos;
|
||||
return pos;
|
||||
}
|
||||
iterator erase(iterator pos, iterator last) {
|
||||
for (; pos != last; ++pos) {
|
||||
rep_.Erase(pos.b_, pos.i_);
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
std::pair<iterator, iterator> equal_range(const Key& k) {
|
||||
auto pos = find(k);
|
||||
if (pos == end()) {
|
||||
return std::make_pair(pos, pos);
|
||||
} else {
|
||||
auto next = pos;
|
||||
++next;
|
||||
return std::make_pair(pos, next);
|
||||
}
|
||||
}
|
||||
std::pair<const_iterator, const_iterator> equal_range(const Key& k) const {
|
||||
auto pos = find(k);
|
||||
if (pos == end()) {
|
||||
return std::make_pair(pos, pos);
|
||||
} else {
|
||||
auto next = pos;
|
||||
++next;
|
||||
return std::make_pair(pos, next);
|
||||
}
|
||||
}
|
||||
|
||||
bool operator==(const FlatMap& x) const {
|
||||
if (size() != x.size()) return false;
|
||||
for (auto& p : x) {
|
||||
auto i = find(p.first);
|
||||
if (i == end()) return false;
|
||||
if (i->second != p.second) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool operator!=(const FlatMap& x) const { return !(*this == x); }
|
||||
|
||||
// If key exists in the table, prefetch the associated value. This
|
||||
// is a hint, and may have no effect.
|
||||
void prefetch_value(const Key& key) const { rep_.Prefetch(key); }
|
||||
|
||||
private:
|
||||
using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>;
|
||||
|
||||
// Bucket stores kWidth <marker, key, value> triples.
|
||||
// The data is organized as three parallel arrays to reduce padding.
|
||||
struct Bucket {
|
||||
uint8 marker[Rep::kWidth];
|
||||
|
||||
// Wrap keys and values in union to control construction and destruction.
|
||||
union Storage {
|
||||
struct {
|
||||
Key key[Rep::kWidth];
|
||||
Val val[Rep::kWidth];
|
||||
};
|
||||
Storage() {}
|
||||
~Storage() {}
|
||||
} storage;
|
||||
|
||||
Key& key(uint32 i) {
|
||||
DCHECK_GE(marker[i], 2);
|
||||
return storage.key[i];
|
||||
}
|
||||
Val& val(uint32 i) {
|
||||
DCHECK_GE(marker[i], 2);
|
||||
return storage.val[i];
|
||||
}
|
||||
template <typename V>
|
||||
void InitVal(uint32 i, V&& v) {
|
||||
new (&storage.val[i]) Val(std::forward<V>(v));
|
||||
}
|
||||
void Destroy(uint32 i) {
|
||||
storage.key[i].Key::~Key();
|
||||
storage.val[i].Val::~Val();
|
||||
}
|
||||
void MoveFrom(uint32 i, Bucket* src, uint32 src_index) {
|
||||
new (&storage.key[i]) Key(std::move(src->storage.key[src_index]));
|
||||
new (&storage.val[i]) Val(std::move(src->storage.val[src_index]));
|
||||
}
|
||||
void CopyFrom(uint32 i, Bucket* src, uint32 src_index) {
|
||||
new (&storage.key[i]) Key(src->storage.key[src_index]);
|
||||
new (&storage.val[i]) Val(src->storage.val[src_index]);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Pair>
|
||||
std::pair<iterator, bool> InsertPair(Pair&& p) {
|
||||
return Insert(std::forward<decltype(p.first)>(p.first),
|
||||
std::forward<decltype(p.second)>(p.second));
|
||||
}
|
||||
|
||||
template <typename K, typename V>
|
||||
std::pair<iterator, bool> Insert(K&& k, V&& v) {
|
||||
rep_.MaybeResize();
|
||||
auto r = rep_.FindOrInsert(std::forward<K>(k));
|
||||
const bool inserted = !r.found;
|
||||
if (inserted) {
|
||||
r.b->InitVal(r.index, std::forward<V>(v));
|
||||
}
|
||||
return {iterator(r.b, rep_.limit(), r.index), inserted};
|
||||
}
|
||||
|
||||
template <typename K>
|
||||
Val& IndexOp(K&& k) {
|
||||
rep_.MaybeResize();
|
||||
auto r = rep_.FindOrInsert(std::forward<K>(k));
|
||||
Val* vptr = &r.b->val(r.index);
|
||||
if (!r.found) {
|
||||
new (vptr) Val(); // Initialize value in new slot.
|
||||
}
|
||||
return *vptr;
|
||||
}
|
||||
|
||||
Rep rep_;
|
||||
};
|
||||
|
||||
} // namespace gtl
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
|
576
tensorflow/core/lib/gtl/flatmap_test.cc
Normal file
576
tensorflow/core/lib/gtl/flatmap_test.cc
Normal file
@ -0,0 +1,576 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gtl {
|
||||
namespace {
|
||||
|
||||
typedef FlatMap<int64, int32, HashInt64> NumMap;
|
||||
|
||||
// If map has an entry for k, return the corresponding value, else return def.
|
||||
int32 Get(const NumMap& map, int64 k, int32 def = -1) {
|
||||
auto iter = map.find(k);
|
||||
if (iter == map.end()) {
|
||||
EXPECT_EQ(map.count(k), 0);
|
||||
return def;
|
||||
} else {
|
||||
EXPECT_EQ(map.count(k), 1);
|
||||
EXPECT_EQ(&map.at(k), &iter->second);
|
||||
EXPECT_EQ(iter->first, k);
|
||||
return iter->second;
|
||||
}
|
||||
}
|
||||
|
||||
// Return contents of map as a sorted list of pairs.
|
||||
typedef std::vector<std::pair<int64, int32>> NumMapContents;
|
||||
NumMapContents Contents(const NumMap& map) {
|
||||
NumMapContents result;
|
||||
for (const auto& p : map) {
|
||||
result.push_back({p.first, p.second});
|
||||
}
|
||||
std::sort(result.begin(), result.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
// Fill entries with keys [start,limit).
|
||||
void Fill(NumMap* map, int64 start, int64 limit) {
|
||||
for (int64 i = start; i < limit; i++) {
|
||||
map->insert({i, i * 100});
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatMapTest, Find) {
|
||||
NumMap map;
|
||||
EXPECT_EQ(Get(map, 1), -1);
|
||||
map.insert({1, 100});
|
||||
map.insert({2, 200});
|
||||
EXPECT_EQ(Get(map, 1), 100);
|
||||
EXPECT_EQ(Get(map, 2), 200);
|
||||
EXPECT_EQ(Get(map, 3), -1);
|
||||
}
|
||||
|
||||
TEST(FlatMapTest, Insert) {
|
||||
NumMap map;
|
||||
EXPECT_EQ(Get(map, 1), -1);
|
||||
|
||||
// New entry.
|
||||
auto result = map.insert({1, 100});
|
||||
EXPECT_TRUE(result.second);
|
||||
EXPECT_EQ(result.first->first, 1);
|
||||
EXPECT_EQ(result.first->second, 100);
|
||||
EXPECT_EQ(Get(map, 1), 100);
|
||||
|
||||
// Attempt to insert over existing entry.
|
||||
result = map.insert({1, 200});
|
||||
EXPECT_FALSE(result.second);
|
||||
EXPECT_EQ(result.first->first, 1);
|
||||
EXPECT_EQ(result.first->second, 100);
|
||||
EXPECT_EQ(Get(map, 1), 100);
|
||||
|
||||
// Overwrite through iterator.
|
||||
result.first->second = 300;
|
||||
EXPECT_EQ(result.first->second, 300);
|
||||
EXPECT_EQ(Get(map, 1), 300);
|
||||
|
||||
// Should get updated value.
|
||||
result = map.insert({1, 400});
|
||||
EXPECT_FALSE(result.second);
|
||||
EXPECT_EQ(result.first->first, 1);
|
||||
EXPECT_EQ(result.first->second, 300);
|
||||
EXPECT_EQ(Get(map, 1), 300);
|
||||
}
|
||||
|
||||
TEST(FlatMapTest, InsertGrowth) {
|
||||
NumMap map;
|
||||
const int n = 100;
|
||||
Fill(&map, 0, 100);
|
||||
EXPECT_EQ(map.size(), n);
|
||||
for (int i = 0; i < n; i++) {
|
||||
EXPECT_EQ(Get(map, i), i * 100) << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatMapTest, Emplace) {
|
||||
NumMap map;
|
||||
|
||||
// New entry.
|
||||
auto result = map.emplace(1, 100);
|
||||
EXPECT_TRUE(result.second);
|
||||
EXPECT_EQ(result.first->first, 1);
|
||||
EXPECT_EQ(result.first->second, 100);
|
||||
EXPECT_EQ(Get(map, 1), 100);
|
||||
|
||||
// Attempt to insert over existing entry.
|
||||
result = map.emplace(1, 200);
|
||||
EXPECT_FALSE(result.second);
|
||||
EXPECT_EQ(result.first->first, 1);
|
||||
EXPECT_EQ(result.first->second, 100);
|
||||
EXPECT_EQ(Get(map, 1), 100);
|
||||
|
||||
// Overwrite through iterator.
|
||||
result.first->second = 300;
|
||||
EXPECT_EQ(result.first->second, 300);
|
||||
EXPECT_EQ(Get(map, 1), 300);
|
||||
|
||||
// Update a second value
|
||||
result = map.emplace(2, 400);
|
||||
EXPECT_TRUE(result.second);
|
||||
EXPECT_EQ(result.first->first, 2);
|
||||
EXPECT_EQ(result.first->second, 400);
|
||||
EXPECT_EQ(Get(map, 2), 400);
|
||||
}
|
||||
|
||||
TEST(FlatMapTest, EmplaceUniquePtr) {
|
||||
FlatMap<int64, std::unique_ptr<string>, HashInt64> smap;
|
||||
smap.emplace(1, std::unique_ptr<string>(new string("hello")));
|
||||
}
|
||||
|
||||
TEST(FlatMapTest, Size) {
|
||||
NumMap map;
|
||||
EXPECT_EQ(map.size(), 0);
|
||||
|
||||
map.insert({1, 100});
|
||||
map.insert({2, 200});
|
||||
EXPECT_EQ(map.size(), 2);
|
||||
}
|
||||
|
||||
TEST(FlatMapTest, Empty) {
|
||||
NumMap map;
|
||||
EXPECT_TRUE(map.empty());
|
||||
|
||||
map.insert({1, 100});
|
||||
map.insert({2, 200});
|
||||
EXPECT_FALSE(map.empty());
|
||||
}
|
||||
|
||||
TEST(FlatMapTest, ArrayOperator) {
|
||||
NumMap map;
|
||||
|
||||
// Create new element if not found.
|
||||
auto v1 = &map[1];
|
||||
EXPECT_EQ(*v1, 0);
|
||||
EXPECT_EQ(Get(map, 1), 0);
|
||||
|
||||
// Write through returned reference.
|
||||
*v1 = 100;
|
||||
EXPECT_EQ(map[1], 100);
|
||||
EXPECT_EQ(Get(map, 1), 100);
|
||||
|
||||
// Reuse existing element if found.
|
||||
auto v1a = &map[1];
|
||||
EXPECT_EQ(v1, v1a);
|
||||
EXPECT_EQ(*v1, 100);
|
||||
|
||||
// Create another element.
|
||||
map[2] = 200;
|
||||
EXPECT_EQ(Get(map, 1), 100);
|
||||
EXPECT_EQ(Get(map, 2), 200);
|
||||
}
|
||||
|
||||
TEST(FlatMapTest, Count) {
|
||||
NumMap map;
|
||||
EXPECT_EQ(map.count(1), 0);
|
||||
EXPECT_EQ(map.count(2), 0);
|
||||
|
||||
map.insert({1, 100});
|
||||
EXPECT_EQ(map.count(1), 1);
|
||||
EXPECT_EQ(map.count(2), 0);
|
||||
|
||||
map.insert({2, 200});
|
||||
EXPECT_EQ(map.count(1), 1);
|
||||
EXPECT_EQ(map.count(2), 1);
|
||||
}
|
||||
|
||||
TEST(FlatMapTest, Iter) {
|
||||
NumMap map;
|
||||
EXPECT_EQ(Contents(map), NumMapContents());
|
||||
|
||||
map.insert({1, 100});
|
||||
map.insert({2, 200});
|
||||
EXPECT_EQ(Contents(map), NumMapContents({{1, 100}, {2, 200}}));
|
||||
}
|
||||
|
||||
TEST(FlatMapTest, Erase) {
|
||||
NumMap map;
|
||||
EXPECT_EQ(map.erase(1), 0);
|
||||
map[1] = 100;
|
||||
map[2] = 200;
|
||||
EXPECT_EQ(map.erase(3), 0);
|
||||
EXPECT_EQ(map.erase(1), 1);
|
||||
EXPECT_EQ(map.size(), 1);
|
||||
EXPECT_EQ(Get(map, 2), 200);
|
||||
EXPECT_EQ(Contents(map), NumMapContents({{2, 200}}));
|
||||
EXPECT_EQ(map.erase(2), 1);
|
||||
EXPECT_EQ(Contents(map), NumMapContents());
|
||||
}
|
||||
|
||||
TEST(FlatMapTest, EraseIter) {
|
||||
NumMap map;
|
||||
Fill(&map, 1, 11);
|
||||
size_t size = 10;
|
||||
for (auto iter = map.begin(); iter != map.end();) {
|
||||
iter = map.erase(iter);
|
||||
size--;
|
||||
EXPECT_EQ(map.size(), size);
|
||||
}
|
||||
EXPECT_EQ(Contents(map), NumMapContents());
|
||||
}
|
||||
|
||||
TEST(FlatMapTest, EraseIterPair) {
|
||||
NumMap map;
|
||||
Fill(&map, 1, 11);
|
||||
NumMap expected;
|
||||
auto p1 = map.begin();
|
||||
expected.insert(*p1);
|
||||
++p1;
|
||||
expected.insert(*p1);
|
||||
++p1;
|
||||
auto p2 = map.end();
|
||||
EXPECT_EQ(map.erase(p1, p2), map.end());
|
||||
EXPECT_EQ(map.size(), 2);
|
||||
EXPECT_EQ(Contents(map), Contents(expected));
|
||||
}
|
||||
|
||||
TEST(FlatMapTest, EraseLongChains) {
|
||||
// Make a map with lots of elements and erase a bunch of them to ensure
|
||||
// that we are likely to hit them on future lookups.
|
||||
NumMap map;
|
||||
const int num = 128;
|
||||
Fill(&map, 0, num);
|
||||
for (int i = 0; i < num; i += 3) {
|
||||
EXPECT_EQ(map.erase(i), 1);
|
||||
}
|
||||
for (int i = 0; i < num; i++) {
|
||||
if ((i % 3) != 0) {
|
||||
EXPECT_EQ(Get(map, i), i * 100);
|
||||
} else {
|
||||
EXPECT_EQ(map.count(i), 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Erase remainder to trigger table shrinking.
|
||||
const size_t orig_buckets = map.bucket_count();
|
||||
for (int i = 0; i < num; i++) {
|
||||
map.erase(i);
|
||||
}
|
||||
EXPECT_TRUE(map.empty());
|
||||
EXPECT_EQ(map.bucket_count(), orig_buckets);
|
||||
map[1] = 100; // Actual shrinking is triggered by an insert.
|
||||
EXPECT_LT(map.bucket_count(), orig_buckets);
|
||||
}
|
||||
|
||||
TEST(FlatMap, AlternatingInsertRemove) {
|
||||
NumMap map;
|
||||
map.insert({1000, 1000});
|
||||
map.insert({2000, 1000});
|
||||
map.insert({3000, 1000});
|
||||
for (int i = 0; i < 10000; i++) {
|
||||
map.insert({i, i});
|
||||
map.erase(i);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatMap, ClearNoResize) {
|
||||
NumMap map;
|
||||
Fill(&map, 0, 100);
|
||||
const size_t orig = map.bucket_count();
|
||||
map.clear_no_resize();
|
||||
EXPECT_EQ(map.size(), 0);
|
||||
EXPECT_EQ(Contents(map), NumMapContents());
|
||||
EXPECT_EQ(map.bucket_count(), orig);
|
||||
}
|
||||
|
||||
TEST(FlatMap, Clear) {
|
||||
NumMap map;
|
||||
Fill(&map, 0, 100);
|
||||
const size_t orig = map.bucket_count();
|
||||
map.clear();
|
||||
EXPECT_EQ(map.size(), 0);
|
||||
EXPECT_EQ(Contents(map), NumMapContents());
|
||||
EXPECT_LT(map.bucket_count(), orig);
|
||||
}
|
||||
|
||||
TEST(FlatMap, Copy) {
|
||||
for (int n = 0; n < 10; n++) {
|
||||
NumMap src;
|
||||
Fill(&src, 0, n);
|
||||
NumMap copy = src;
|
||||
EXPECT_EQ(Contents(src), Contents(copy));
|
||||
NumMap copy2;
|
||||
copy2 = src;
|
||||
EXPECT_EQ(Contents(src), Contents(copy2));
|
||||
copy2 = copy2; // Self-assignment
|
||||
EXPECT_EQ(Contents(src), Contents(copy2));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatMap, InitFromIter) {
|
||||
for (int n = 0; n < 10; n++) {
|
||||
NumMap src;
|
||||
Fill(&src, 0, n);
|
||||
auto vec = Contents(src);
|
||||
NumMap dst(vec.begin(), vec.end());
|
||||
EXPECT_EQ(Contents(dst), vec);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatMap, InsertIter) {
|
||||
NumMap a, b;
|
||||
Fill(&a, 1, 10);
|
||||
Fill(&b, 8, 20);
|
||||
b[9] = 10000; // Should not get inserted into a since a already has 9
|
||||
a.insert(b.begin(), b.end());
|
||||
NumMap expected;
|
||||
Fill(&expected, 1, 20);
|
||||
EXPECT_EQ(Contents(a), Contents(expected));
|
||||
}
|
||||
|
||||
TEST(FlatMap, Eq) {
|
||||
NumMap empty;
|
||||
|
||||
NumMap elems;
|
||||
Fill(&elems, 0, 5);
|
||||
EXPECT_FALSE(empty == elems);
|
||||
EXPECT_TRUE(empty != elems);
|
||||
|
||||
NumMap copy = elems;
|
||||
EXPECT_TRUE(copy == elems);
|
||||
EXPECT_FALSE(copy != elems);
|
||||
|
||||
NumMap changed = elems;
|
||||
changed[3] = 1;
|
||||
EXPECT_FALSE(changed == elems);
|
||||
EXPECT_TRUE(changed != elems);
|
||||
|
||||
NumMap changed2 = elems;
|
||||
changed2.erase(3);
|
||||
EXPECT_FALSE(changed2 == elems);
|
||||
EXPECT_TRUE(changed2 != elems);
|
||||
}
|
||||
|
||||
TEST(FlatMap, Swap) {
|
||||
NumMap a, b;
|
||||
Fill(&a, 1, 5);
|
||||
Fill(&b, 100, 200);
|
||||
NumMap c = a;
|
||||
NumMap d = b;
|
||||
EXPECT_EQ(c, a);
|
||||
EXPECT_EQ(d, b);
|
||||
c.swap(d);
|
||||
EXPECT_EQ(c, b);
|
||||
EXPECT_EQ(d, a);
|
||||
}
|
||||
|
||||
TEST(FlatMap, Reserve) {
|
||||
NumMap src;
|
||||
Fill(&src, 1, 100);
|
||||
NumMap a = src;
|
||||
a.reserve(10);
|
||||
EXPECT_EQ(a, src);
|
||||
NumMap b = src;
|
||||
b.rehash(1000);
|
||||
EXPECT_EQ(b, src);
|
||||
}
|
||||
|
||||
TEST(FlatMap, EqualRangeMutable) {
|
||||
NumMap map;
|
||||
Fill(&map, 1, 10);
|
||||
|
||||
// Existing element
|
||||
auto p1 = map.equal_range(3);
|
||||
EXPECT_TRUE(p1.first != p1.second);
|
||||
EXPECT_EQ(p1.first->first, 3);
|
||||
EXPECT_EQ(p1.first->second, 300);
|
||||
++p1.first;
|
||||
EXPECT_TRUE(p1.first == p1.second);
|
||||
|
||||
// Missing element
|
||||
auto p2 = map.equal_range(100);
|
||||
EXPECT_TRUE(p2.first == p2.second);
|
||||
}
|
||||
|
||||
TEST(FlatMap, EqualRangeConst) {
|
||||
NumMap tmp;
|
||||
Fill(&tmp, 1, 10);
|
||||
|
||||
const NumMap map = tmp;
|
||||
|
||||
// Existing element
|
||||
auto p1 = map.equal_range(3);
|
||||
EXPECT_TRUE(p1.first != p1.second);
|
||||
EXPECT_EQ(p1.first->first, 3);
|
||||
EXPECT_EQ(p1.first->second, 300);
|
||||
++p1.first;
|
||||
EXPECT_TRUE(p1.first == p1.second);
|
||||
|
||||
// Missing element
|
||||
auto p2 = map.equal_range(100);
|
||||
EXPECT_TRUE(p2.first == p2.second);
|
||||
}
|
||||
|
||||
TEST(FlatMap, Prefetch) {
|
||||
NumMap map;
|
||||
Fill(&map, 0, 1000);
|
||||
// Prefetch present and missing keys.
|
||||
for (int i = 0; i < 2000; i++) {
|
||||
map.prefetch_value(i);
|
||||
}
|
||||
}
|
||||
|
||||
// Non-copyable values should work.
|
||||
struct NC {
|
||||
int64 value;
|
||||
NC() : value(-1) {}
|
||||
NC(int64 v) : value(v) {}
|
||||
NC(const NC& x) : value(x.value) {}
|
||||
bool operator==(const NC& x) const { return value == x.value; }
|
||||
};
|
||||
struct HashNC {
|
||||
size_t operator()(NC x) const { return x.value; }
|
||||
};
|
||||
|
||||
TEST(FlatMap, NonCopyable) {
|
||||
FlatMap<NC, NC, HashNC> map;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
map[NC(i)] = NC(i * 100);
|
||||
}
|
||||
for (int i = 0; i < 100; i++) {
|
||||
EXPECT_EQ(map.count(NC(i)), 1);
|
||||
auto iter = map.find(NC(i));
|
||||
EXPECT_NE(iter, map.end());
|
||||
EXPECT_EQ(iter->first, NC(i));
|
||||
EXPECT_EQ(iter->second, NC(i * 100));
|
||||
EXPECT_EQ(map[NC(i)], NC(i * 100));
|
||||
}
|
||||
map.erase(NC(10));
|
||||
EXPECT_EQ(map.count(NC(10)), 0);
|
||||
}
|
||||
|
||||
// Test with heap-allocated objects so that mismanaged constructions
|
||||
// or destructions will show up as errors under a sanitizer or
|
||||
// heap checker.
|
||||
TEST(FlatMap, ConstructDestruct) {
|
||||
FlatMap<string, string, HashStr> map;
|
||||
string k1 = "the quick brown fox jumped over the lazy dog";
|
||||
string k2 = k1 + k1;
|
||||
string k3 = k1 + k2;
|
||||
map[k1] = k2;
|
||||
map[k3] = k1;
|
||||
EXPECT_EQ(k1, map.find(k1)->first);
|
||||
EXPECT_EQ(k2, map.find(k1)->second);
|
||||
EXPECT_EQ(k1, map[k3]);
|
||||
map.erase(k3);
|
||||
EXPECT_EQ(string(), map[k3]);
|
||||
|
||||
map.clear();
|
||||
map[k1] = k2;
|
||||
EXPECT_EQ(k2, map[k1]);
|
||||
|
||||
map.reserve(100);
|
||||
EXPECT_EQ(k2, map[k1]);
|
||||
}
|
||||
|
||||
// Type to use to ensure that custom equality operator is used
|
||||
// that ignores extra value.
|
||||
struct CustomCmpKey {
|
||||
int64 a;
|
||||
int64 b;
|
||||
CustomCmpKey(int64 v1, int64 v2) : a(v1), b(v2) {}
|
||||
bool operator==(const CustomCmpKey& x) const { return a == x.a && b == x.b; }
|
||||
};
|
||||
struct HashA {
|
||||
size_t operator()(CustomCmpKey x) const { return x.a; }
|
||||
};
|
||||
struct EqA {
|
||||
// Ignore b fields.
|
||||
bool operator()(CustomCmpKey x, CustomCmpKey y) const { return x.a == y.a; }
|
||||
};
|
||||
TEST(FlatMap, CustomCmp) {
|
||||
FlatMap<CustomCmpKey, int, HashA, EqA> map;
|
||||
map[CustomCmpKey(100, 200)] = 300;
|
||||
EXPECT_EQ(300, map[CustomCmpKey(100, 200)]);
|
||||
EXPECT_EQ(300, map[CustomCmpKey(100, 500)]); // Differences in key.b ignored
|
||||
}
|
||||
|
||||
// Test unique_ptr handling.
|
||||
typedef std::unique_ptr<int> UniqInt;
|
||||
static UniqInt MakeUniq(int i) { return UniqInt(new int(i)); }
|
||||
|
||||
struct HashUniq {
|
||||
size_t operator()(const UniqInt& p) const { return *p; }
|
||||
};
|
||||
struct EqUniq {
|
||||
bool operator()(const UniqInt& a, const UniqInt& b) const { return *a == *b; }
|
||||
};
|
||||
typedef FlatMap<UniqInt, UniqInt, HashUniq, EqUniq> UniqMap;
|
||||
|
||||
TEST(FlatMap, UniqueMap) {
|
||||
UniqMap map;
|
||||
|
||||
// Fill map
|
||||
const int N = 10;
|
||||
for (int i = 0; i < N; i++) {
|
||||
if ((i % 2) == 0) {
|
||||
map[MakeUniq(i)] = MakeUniq(i + 100);
|
||||
} else {
|
||||
map.emplace(MakeUniq(i), MakeUniq(i + 100));
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(map.size(), N);
|
||||
|
||||
// Lookups
|
||||
for (int i = 0; i < N; i++) {
|
||||
EXPECT_EQ(*map.at(MakeUniq(i)), i + 100);
|
||||
}
|
||||
|
||||
// find+erase
|
||||
EXPECT_EQ(map.count(MakeUniq(2)), 1);
|
||||
map.erase(MakeUniq(2));
|
||||
EXPECT_EQ(map.count(MakeUniq(2)), 0);
|
||||
|
||||
// clear
|
||||
map.clear();
|
||||
EXPECT_EQ(map.size(), 0);
|
||||
}
|
||||
|
||||
TEST(FlatMap, UniqueMapIter) {
|
||||
UniqMap map;
|
||||
const int kCount = 10;
|
||||
const int kValueDelta = 100;
|
||||
for (int i = 1; i <= kCount; i++) {
|
||||
map[MakeUniq(i)] = MakeUniq(i + kValueDelta);
|
||||
}
|
||||
int key_sum = 0;
|
||||
int val_sum = 0;
|
||||
for (const auto& p : map) {
|
||||
key_sum += *p.first;
|
||||
val_sum += *p.second;
|
||||
}
|
||||
EXPECT_EQ(key_sum, (kCount * (kCount + 1)) / 2);
|
||||
EXPECT_EQ(val_sum, key_sum + (kCount * kValueDelta));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gtl
|
||||
} // namespace tensorflow
|
332
tensorflow/core/lib/gtl/flatrep.h
Normal file
332
tensorflow/core/lib/gtl/flatrep.h
Normal file
@ -0,0 +1,332 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
|
||||
|
||||
#include <string.h>
|
||||
#include <utility>
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gtl {
|
||||
namespace internal {
|
||||
|
||||
// Internal representation for FlatMap and FlatSet.
|
||||
//
|
||||
// The representation is an open-addressed hash table. Conceptually,
|
||||
// the representation is a flat array of entries. However we
|
||||
// structure it as an array of of buckets where each bucket holds
|
||||
// kWidth entries along with metadata for the kWidth entries. The
|
||||
// metadata marker is
|
||||
//
|
||||
// (a) kEmpty: the entry is empty
|
||||
// (b) kDeleted: the entry has been deleted
|
||||
// (c) other: the entry is occupied and has low-8 bits of its hash.
|
||||
// These hash bits can be used to avoid potentially expensive
|
||||
// key comparisons.
|
||||
//
|
||||
// FlatMap passes in a bucket that contains keys and values, FlatSet
|
||||
// passes in a bucket that does not contain values.
|
||||
template <typename Key, typename Bucket, class Hash, class Eq>
|
||||
class FlatRep {
|
||||
public:
|
||||
// kWidth is the number of entries stored in a bucket.
|
||||
static const uint32 kBase = 3;
|
||||
static const uint32 kWidth = (1 << kBase);
|
||||
|
||||
FlatRep(size_t N, const Hash& hf, const Eq& eq) : hash_(hf), equal_(eq) {
|
||||
Init(N);
|
||||
}
|
||||
explicit FlatRep(const FlatRep& src) : hash_(src.hash_), equal_(src.equal_) {
|
||||
Init(src.size());
|
||||
CopyEntries(src.array_, src.end_, CopyEntry());
|
||||
}
|
||||
~FlatRep() {
|
||||
clear_no_resize();
|
||||
delete[] array_;
|
||||
}
|
||||
|
||||
// Simple accessors.
|
||||
size_t size() const { return not_empty_ - deleted_; }
|
||||
size_t bucket_count() const { return mask_ + 1; }
|
||||
Bucket* start() const { return array_; }
|
||||
Bucket* limit() const { return end_; }
|
||||
const Hash& hash_function() const { return hash_; }
|
||||
const Eq& key_eq() const { return equal_; }
|
||||
|
||||
// Overwrite contents of *this with contents of src.
|
||||
void CopyFrom(const FlatRep& src) {
|
||||
if (this != &src) {
|
||||
clear_no_resize();
|
||||
delete[] array_;
|
||||
Init(src.size());
|
||||
CopyEntries(src.array_, src.end_, CopyEntry());
|
||||
}
|
||||
}
|
||||
|
||||
void clear_no_resize() {
|
||||
for (Bucket* b = array_; b != end_; b++) {
|
||||
for (uint32 i = 0; i < kWidth; i++) {
|
||||
if (b->marker[i] >= 2) {
|
||||
b->Destroy(i);
|
||||
b->marker[i] = kEmpty;
|
||||
}
|
||||
}
|
||||
}
|
||||
not_empty_ = 0;
|
||||
deleted_ = 0;
|
||||
}
|
||||
|
||||
void clear() {
|
||||
clear_no_resize();
|
||||
grow_ = 0; // Consider shrinking in MaybeResize()
|
||||
MaybeResize();
|
||||
}
|
||||
|
||||
void swap(FlatRep& x) {
|
||||
using std::swap;
|
||||
swap(array_, x.array_);
|
||||
swap(end_, x.end_);
|
||||
swap(lglen_, x.lglen_);
|
||||
swap(mask_, x.mask_);
|
||||
swap(not_empty_, x.not_empty_);
|
||||
swap(deleted_, x.deleted_);
|
||||
swap(grow_, x.grow_);
|
||||
swap(shrink_, x.shrink_);
|
||||
}
|
||||
|
||||
struct SearchResult {
|
||||
bool found;
|
||||
Bucket* b;
|
||||
uint32 index;
|
||||
};
|
||||
|
||||
// Hash value is partitioned as follows:
|
||||
// 1. Bottom 8 bits are stored in bucket to help speed up comparisons.
|
||||
// 2. Next 3 bits give index inside bucket.
|
||||
// 3. Remaining bits give bucket number.
|
||||
|
||||
// Find bucket/index for key k.
|
||||
SearchResult Find(const Key& k) const {
|
||||
size_t h = hash_(k);
|
||||
const uint32 marker = Marker(h & 0xff);
|
||||
size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket
|
||||
uint32 num_probes = 1; // Needed for quadratic probing
|
||||
while (true) {
|
||||
uint32 bi = index & (kWidth - 1);
|
||||
Bucket* b = &array_[index >> kBase];
|
||||
const uint32 x = b->marker[bi];
|
||||
if (x == marker && equal_(b->key(bi), k)) {
|
||||
return {true, b, bi};
|
||||
} else if (x == kEmpty) {
|
||||
return {false, nullptr, 0};
|
||||
}
|
||||
// Quadratic probing.
|
||||
index = (index + num_probes) & mask_;
|
||||
num_probes++;
|
||||
}
|
||||
}
|
||||
|
||||
// Find bucket/index for key k, creating a new one if necessary.
|
||||
//
|
||||
// KeyType is a template parameter so that k's type is deduced and it
|
||||
// becomes a universal reference which allows the key initialization
|
||||
// below to use an rvalue constructor if available.
|
||||
template <typename KeyType>
|
||||
SearchResult FindOrInsert(KeyType&& k) {
|
||||
size_t h = hash_(k);
|
||||
const uint32 marker = Marker(h & 0xff);
|
||||
size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket
|
||||
uint32 num_probes = 1; // Needed for quadratic probing
|
||||
Bucket* del = nullptr; // First encountered deletion for kInsert
|
||||
uint32 di = 0;
|
||||
while (true) {
|
||||
uint32 bi = index & (kWidth - 1);
|
||||
Bucket* b = &array_[index >> kBase];
|
||||
const uint32 x = b->marker[bi];
|
||||
if (x == marker && equal_(b->key(bi), k)) {
|
||||
return {true, b, bi};
|
||||
} else if (!del && x == kDeleted) {
|
||||
// Remember deleted index to use for insertion.
|
||||
del = b;
|
||||
di = bi;
|
||||
} else if (x == kEmpty) {
|
||||
if (del) {
|
||||
// Store in the first deleted slot we encountered
|
||||
b = del;
|
||||
bi = di;
|
||||
deleted_--; // not_empty_ does not change
|
||||
} else {
|
||||
not_empty_++;
|
||||
}
|
||||
b->marker[bi] = marker;
|
||||
new (&b->key(bi)) Key(std::forward<KeyType>(k));
|
||||
return {false, b, bi};
|
||||
}
|
||||
// Quadratic probing.
|
||||
index = (index + num_probes) & mask_;
|
||||
num_probes++;
|
||||
}
|
||||
}
|
||||
|
||||
void Erase(Bucket* b, uint32 i) {
|
||||
b->Destroy(i);
|
||||
b->marker[i] = kDeleted;
|
||||
deleted_++;
|
||||
grow_ = 0; // Consider shrinking on next insert
|
||||
}
|
||||
|
||||
void Prefetch(const Key& k) const {
|
||||
size_t h = hash_(k);
|
||||
size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket
|
||||
uint32 bi = index & (kWidth - 1);
|
||||
Bucket* b = &array_[index >> kBase];
|
||||
prefetch(&b->storage.key[bi]);
|
||||
}
|
||||
void prefetch(const void* ptr) const {
|
||||
// TODO(jeff,sanjay): Remove this routine when we add a
|
||||
// prefetch(...) call to platform so that the Prefetch routine
|
||||
// actually does something
|
||||
}
|
||||
|
||||
inline void MaybeResize() {
|
||||
if (not_empty_ < grow_) {
|
||||
return; // Nothing to do
|
||||
}
|
||||
if (grow_ == 0) {
|
||||
// Special value set by erase to cause shrink on next insert.
|
||||
if (size() >= shrink_) {
|
||||
// Not small enough to shrink.
|
||||
grow_ = static_cast<size_t>(bucket_count() * 0.8);
|
||||
if (not_empty_ < grow_) return;
|
||||
}
|
||||
}
|
||||
Resize(size() + 1);
|
||||
}
|
||||
|
||||
void Resize(size_t N) {
|
||||
Bucket* old = array_;
|
||||
Bucket* old_end = end_;
|
||||
Init(N);
|
||||
CopyEntries(old, old_end, MoveEntry());
|
||||
delete[] old;
|
||||
}
|
||||
|
||||
private:
|
||||
enum { kEmpty = 0, kDeleted = 1 }; // Special markers for an entry.
|
||||
|
||||
Hash hash_; // User-supplied hasher
|
||||
Eq equal_; // User-supplied comparator
|
||||
uint8 lglen_; // lg(#buckets)
|
||||
Bucket* array_; // array of length (1 << lglen_)
|
||||
Bucket* end_; // Points just past last bucket in array_
|
||||
size_t mask_; // (# of entries in table) - 1
|
||||
size_t not_empty_; // Count of entries with marker != kEmpty
|
||||
size_t deleted_; // Count of entries with marker == kDeleted
|
||||
size_t grow_; // Grow array when not_empty_ >= grow_
|
||||
size_t shrink_; // Shrink array when size() < shrink_
|
||||
|
||||
// Avoid kEmpty and kDeleted markers when computing hash values to
|
||||
// store in Bucket::marker[].
|
||||
static uint32 Marker(uint32 hb) { return hb + (hb < 2 ? 2 : 0); }
|
||||
|
||||
void Init(size_t N) {
|
||||
// Make enough room for N elements.
|
||||
size_t lg = 0; // Smallest table is just one bucket.
|
||||
while (N >= 0.8 * ((1 << lg) * kWidth)) {
|
||||
lg++;
|
||||
}
|
||||
const size_t n = (1 << lg);
|
||||
Bucket* array = new Bucket[n];
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
Bucket* b = &array[i];
|
||||
memset(b->marker, kEmpty, kWidth);
|
||||
}
|
||||
const size_t capacity = (1 << lg) * kWidth;
|
||||
lglen_ = lg;
|
||||
mask_ = capacity - 1;
|
||||
array_ = array;
|
||||
end_ = array + n;
|
||||
not_empty_ = 0;
|
||||
deleted_ = 0;
|
||||
grow_ = static_cast<size_t>(capacity * 0.8);
|
||||
if (lg == 0) {
|
||||
// Already down to one bucket; no more shrinking.
|
||||
shrink_ = 0;
|
||||
} else {
|
||||
shrink_ = static_cast<size_t>(grow_ * 0.4); // Must be less than 0.5
|
||||
}
|
||||
}
|
||||
|
||||
// Used by FreshInsert when we should copy from source.
|
||||
struct CopyEntry {
|
||||
inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) {
|
||||
dst->CopyFrom(dsti, src, srci);
|
||||
}
|
||||
};
|
||||
|
||||
// Used by FreshInsert when we should move from source.
|
||||
struct MoveEntry {
|
||||
inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) {
|
||||
dst->MoveFrom(dsti, src, srci);
|
||||
src->Destroy(srci);
|
||||
src->marker[srci] = kDeleted;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Copier>
|
||||
void CopyEntries(Bucket* start, Bucket* end, Copier copier) {
|
||||
for (Bucket* b = start; b != end; b++) {
|
||||
for (uint32 i = 0; i < kWidth; i++) {
|
||||
if (b->marker[i] >= 2) {
|
||||
FreshInsert(b, i, copier);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create an entry for the key numbered src_index in *src and return
|
||||
// its bucket/index. Used for insertion into a fresh table. We
|
||||
// assume that there are no deletions, and k does not already exist
|
||||
// in the table.
|
||||
template <typename Copier>
|
||||
void FreshInsert(Bucket* src, uint32 src_index, Copier copier) {
|
||||
size_t h = hash_(src->key(src_index));
|
||||
const uint32 marker = Marker(h & 0xff);
|
||||
size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket
|
||||
uint32 num_probes = 1; // Needed for quadratic probing
|
||||
while (true) {
|
||||
uint32 bi = index & (kWidth - 1);
|
||||
Bucket* b = &array_[index >> kBase];
|
||||
const uint32 x = b->marker[bi];
|
||||
if (x == 0) {
|
||||
b->marker[bi] = marker;
|
||||
not_empty_++;
|
||||
copier(b, bi, src, src_index);
|
||||
return;
|
||||
}
|
||||
// Quadratic probing.
|
||||
index = (index + num_probes) & mask_;
|
||||
num_probes++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gtl
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
|
277
tensorflow/core/lib/gtl/flatset.h
Normal file
277
tensorflow/core/lib/gtl/flatset.h
Normal file
@ -0,0 +1,277 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <utility>
|
||||
#include "tensorflow/core/lib/gtl/flatrep.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gtl {
|
||||
|
||||
// FlatSet<K,...> provides a set of K.
|
||||
//
|
||||
// The map is implemented using an open-addressed hash table. A
|
||||
// single array holds entire map contents and collisions are resolved
|
||||
// by probing at a sequence of locations in the array.
|
||||
template <typename Key, class Hash, class Eq = std::equal_to<Key>>
|
||||
class FlatSet {
|
||||
private:
|
||||
// Forward declare some internal types needed in public section.
|
||||
struct Bucket;
|
||||
|
||||
public:
|
||||
typedef Key key_type;
|
||||
typedef Key value_type;
|
||||
typedef Hash hasher;
|
||||
typedef Eq key_equal;
|
||||
typedef size_t size_type;
|
||||
typedef ptrdiff_t difference_type;
|
||||
typedef value_type* pointer;
|
||||
typedef const value_type* const_pointer;
|
||||
typedef value_type& reference;
|
||||
typedef const value_type& const_reference;
|
||||
|
||||
FlatSet() : FlatSet(1) {}
|
||||
|
||||
explicit FlatSet(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq())
|
||||
: rep_(N, hf, eq) {}
|
||||
|
||||
FlatSet(const FlatSet& src) : rep_(src.rep_) {}
|
||||
|
||||
template <typename InputIter>
|
||||
FlatSet(InputIter first, InputIter last, size_t N = 1,
|
||||
const Hash& hf = Hash(), const Eq& eq = Eq())
|
||||
: FlatSet(N, hf, eq) {
|
||||
insert(first, last);
|
||||
}
|
||||
|
||||
FlatSet& operator=(const FlatSet& src) {
|
||||
rep_.CopyFrom(src.rep_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
~FlatSet() {}
|
||||
|
||||
void swap(FlatSet& x) { rep_.swap(x.rep_); }
|
||||
void clear_no_resize() { rep_.clear_no_resize(); }
|
||||
void clear() { rep_.clear(); }
|
||||
void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
|
||||
void rehash(size_t N) { rep_.Resize(std::max(N, size())); }
|
||||
void resize(size_t N) { rep_.Resize(std::max(N, size())); }
|
||||
size_t size() const { return rep_.size(); }
|
||||
bool empty() const { return size() == 0; }
|
||||
size_t bucket_count() const { return rep_.bucket_count(); }
|
||||
hasher hash_function() const { return rep_.hash_function(); }
|
||||
key_equal key_eq() const { return rep_.key_eq(); }
|
||||
|
||||
class iterator {
|
||||
public:
|
||||
iterator() : b_(nullptr), end_(nullptr), i_(0) {}
|
||||
|
||||
// Make iterator pointing at first element at or after b.
|
||||
explicit iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) {
|
||||
SkipUnused();
|
||||
}
|
||||
|
||||
// Make iterator pointing exactly at ith element in b, which must exist.
|
||||
iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) {}
|
||||
|
||||
Key& operator*() { return key(); }
|
||||
Key* operator->() { return &key(); }
|
||||
bool operator==(const iterator& x) const {
|
||||
return b_ == x.b_ && i_ == x.i_;
|
||||
}
|
||||
bool operator!=(const iterator& x) const { return !(*this == x); }
|
||||
iterator& operator++() {
|
||||
DCHECK(b_ != end_);
|
||||
i_++;
|
||||
SkipUnused();
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class FlatSet;
|
||||
Bucket* b_;
|
||||
Bucket* end_;
|
||||
uint32 i_;
|
||||
|
||||
Key& key() const { return b_->key(i_); }
|
||||
void SkipUnused() {
|
||||
while (b_ < end_) {
|
||||
if (i_ >= Rep::kWidth) {
|
||||
i_ = 0;
|
||||
b_++;
|
||||
} else if (b_->marker[i_] < 2) {
|
||||
i_++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class const_iterator {
|
||||
private:
|
||||
mutable iterator rep_; // Share state and logic with non-const iterator.
|
||||
public:
|
||||
const_iterator() : rep_() {}
|
||||
explicit const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {}
|
||||
const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {}
|
||||
|
||||
const Key& operator*() const { return rep_.key(); }
|
||||
const Key* operator->() const { return &rep_.key(); }
|
||||
bool operator==(const const_iterator& x) const { return rep_ == x.rep_; }
|
||||
bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; }
|
||||
const_iterator& operator++() {
|
||||
++rep_;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
iterator begin() { return iterator(rep_.start(), rep_.limit()); }
|
||||
iterator end() { return iterator(rep_.limit(), rep_.limit()); }
|
||||
const_iterator begin() const {
|
||||
return const_iterator(rep_.start(), rep_.limit());
|
||||
}
|
||||
const_iterator end() const {
|
||||
return const_iterator(rep_.limit(), rep_.limit());
|
||||
}
|
||||
|
||||
size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; }
|
||||
iterator find(const Key& k) {
|
||||
auto r = rep_.Find(k);
|
||||
return r.found ? iterator(r.b, rep_.limit(), r.index) : end();
|
||||
}
|
||||
const_iterator find(const Key& k) const {
|
||||
auto r = rep_.Find(k);
|
||||
return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end();
|
||||
}
|
||||
|
||||
std::pair<iterator, bool> insert(const Key& k) { return Insert(k); }
|
||||
template <typename InputIter>
|
||||
void insert(InputIter first, InputIter last) {
|
||||
for (; first != last; ++first) {
|
||||
insert(*first);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
std::pair<iterator, bool> emplace(Args&&... args) {
|
||||
rep_.MaybeResize();
|
||||
auto r = rep_.FindOrInsert(std::forward<Args>(args)...);
|
||||
const bool inserted = !r.found;
|
||||
return {iterator(r.b, rep_.limit(), r.index), inserted};
|
||||
}
|
||||
|
||||
size_t erase(const Key& k) {
|
||||
auto r = rep_.Find(k);
|
||||
if (!r.found) return 0;
|
||||
rep_.Erase(r.b, r.index);
|
||||
return 1;
|
||||
}
|
||||
iterator erase(iterator pos) {
|
||||
rep_.Erase(pos.b_, pos.i_);
|
||||
++pos;
|
||||
return pos;
|
||||
}
|
||||
iterator erase(iterator pos, iterator last) {
|
||||
for (; pos != last; ++pos) {
|
||||
rep_.Erase(pos.b_, pos.i_);
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
std::pair<iterator, iterator> equal_range(const Key& k) {
|
||||
auto pos = find(k);
|
||||
if (pos == end()) {
|
||||
return std::make_pair(pos, pos);
|
||||
} else {
|
||||
auto next = pos;
|
||||
++next;
|
||||
return std::make_pair(pos, next);
|
||||
}
|
||||
}
|
||||
std::pair<const_iterator, const_iterator> equal_range(const Key& k) const {
|
||||
auto pos = find(k);
|
||||
if (pos == end()) {
|
||||
return std::make_pair(pos, pos);
|
||||
} else {
|
||||
auto next = pos;
|
||||
++next;
|
||||
return std::make_pair(pos, next);
|
||||
}
|
||||
}
|
||||
|
||||
bool operator==(const FlatSet& x) const {
|
||||
if (size() != x.size()) return false;
|
||||
for (const auto& elem : x) {
|
||||
auto i = find(elem);
|
||||
if (i == end()) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool operator!=(const FlatSet& x) const { return !(*this == x); }
|
||||
|
||||
// If key exists in the table, prefetch it. This is a hint, and may
|
||||
// have no effect.
|
||||
void prefetch_value(const Key& key) const { rep_.Prefetch(key); }
|
||||
|
||||
private:
|
||||
using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>;
|
||||
|
||||
// Bucket stores kWidth <marker, key, value> triples.
|
||||
// The data is organized as three parallel arrays to reduce padding.
|
||||
struct Bucket {
|
||||
uint8 marker[Rep::kWidth];
|
||||
|
||||
// Wrap keys in union to control construction and destruction.
|
||||
union Storage {
|
||||
Key key[Rep::kWidth];
|
||||
Storage() {}
|
||||
~Storage() {}
|
||||
} storage;
|
||||
|
||||
Key& key(uint32 i) {
|
||||
DCHECK_GE(marker[i], 2);
|
||||
return storage.key[i];
|
||||
}
|
||||
void Destroy(uint32 i) { storage.key[i].Key::~Key(); }
|
||||
void MoveFrom(uint32 i, Bucket* src, uint32 src_index) {
|
||||
new (&storage.key[i]) Key(std::move(src->storage.key[src_index]));
|
||||
}
|
||||
void CopyFrom(uint32 i, Bucket* src, uint32 src_index) {
|
||||
new (&storage.key[i]) Key(src->storage.key[src_index]);
|
||||
}
|
||||
};
|
||||
|
||||
std::pair<iterator, bool> Insert(const Key& k) {
|
||||
rep_.MaybeResize();
|
||||
auto r = rep_.FindOrInsert(k);
|
||||
const bool inserted = !r.found;
|
||||
return {iterator(r.b, rep_.limit(), r.index), inserted};
|
||||
}
|
||||
|
||||
Rep rep_;
|
||||
};
|
||||
|
||||
} // namespace gtl
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
|
501
tensorflow/core/lib/gtl/flatset_test.cc
Normal file
501
tensorflow/core/lib/gtl/flatset_test.cc
Normal file
@ -0,0 +1,501 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gtl {
|
||||
namespace {
|
||||
|
||||
typedef FlatSet<int64, HashInt64> NumSet;
|
||||
|
||||
// Returns true iff set has an entry for k.
|
||||
// Also verifies that find and count give consistent results.
|
||||
bool Has(const NumSet& set, int64 k) {
|
||||
auto iter = set.find(k);
|
||||
if (iter == set.end()) {
|
||||
EXPECT_EQ(set.count(k), 0);
|
||||
return false;
|
||||
} else {
|
||||
EXPECT_EQ(set.count(k), 1);
|
||||
EXPECT_EQ(*iter, k);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Return contents of set as a sorted list of numbers.
|
||||
typedef std::vector<int64> NumSetContents;
|
||||
NumSetContents Contents(const NumSet& set) {
|
||||
NumSetContents result;
|
||||
for (int64 n : set) {
|
||||
result.push_back(n);
|
||||
}
|
||||
std::sort(result.begin(), result.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
// Fill entries with keys [start,limit).
|
||||
void Fill(NumSet* set, int64 start, int64 limit) {
|
||||
for (int64 i = start; i < limit; i++) {
|
||||
set->insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatSetTest, Find) {
|
||||
NumSet set;
|
||||
EXPECT_FALSE(Has(set, 1));
|
||||
set.insert(1);
|
||||
set.insert(2);
|
||||
EXPECT_TRUE(Has(set, 1));
|
||||
EXPECT_TRUE(Has(set, 2));
|
||||
EXPECT_FALSE(Has(set, 3));
|
||||
}
|
||||
|
||||
TEST(FlatSetTest, Insert) {
|
||||
NumSet set;
|
||||
EXPECT_FALSE(Has(set, 1));
|
||||
|
||||
// New entry.
|
||||
auto result = set.insert(1);
|
||||
EXPECT_TRUE(result.second);
|
||||
EXPECT_EQ(*result.first, 1);
|
||||
EXPECT_TRUE(Has(set, 1));
|
||||
|
||||
// Attempt to insert over existing entry.
|
||||
result = set.insert(1);
|
||||
EXPECT_FALSE(result.second);
|
||||
EXPECT_EQ(*result.first, 1);
|
||||
EXPECT_TRUE(Has(set, 1));
|
||||
}
|
||||
|
||||
TEST(FlatSetTest, InsertGrowth) {
|
||||
NumSet set;
|
||||
const int n = 100;
|
||||
Fill(&set, 0, 100);
|
||||
EXPECT_EQ(set.size(), n);
|
||||
for (int i = 0; i < n; i++) {
|
||||
EXPECT_TRUE(Has(set, i)) << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatSetTest, Emplace) {
|
||||
NumSet set;
|
||||
|
||||
// New entry.
|
||||
auto result = set.emplace(73);
|
||||
EXPECT_TRUE(result.second);
|
||||
EXPECT_EQ(*result.first, 73);
|
||||
EXPECT_TRUE(Has(set, 73));
|
||||
|
||||
// Attempt to insert an existing entry.
|
||||
result = set.emplace(73);
|
||||
EXPECT_FALSE(result.second);
|
||||
EXPECT_EQ(*result.first, 73);
|
||||
EXPECT_TRUE(Has(set, 73));
|
||||
|
||||
// Add a second value
|
||||
result = set.emplace(103);
|
||||
EXPECT_TRUE(result.second);
|
||||
EXPECT_EQ(*result.first, 103);
|
||||
EXPECT_TRUE(Has(set, 103));
|
||||
}
|
||||
|
||||
TEST(FlatSetTest, Size) {
|
||||
NumSet set;
|
||||
EXPECT_EQ(set.size(), 0);
|
||||
|
||||
set.insert(1);
|
||||
set.insert(2);
|
||||
EXPECT_EQ(set.size(), 2);
|
||||
}
|
||||
|
||||
TEST(FlatSetTest, Empty) {
|
||||
NumSet set;
|
||||
EXPECT_TRUE(set.empty());
|
||||
|
||||
set.insert(1);
|
||||
set.insert(2);
|
||||
EXPECT_FALSE(set.empty());
|
||||
}
|
||||
|
||||
TEST(FlatSetTest, Count) {
|
||||
NumSet set;
|
||||
EXPECT_EQ(set.count(1), 0);
|
||||
EXPECT_EQ(set.count(2), 0);
|
||||
|
||||
set.insert(1);
|
||||
EXPECT_EQ(set.count(1), 1);
|
||||
EXPECT_EQ(set.count(2), 0);
|
||||
|
||||
set.insert(2);
|
||||
EXPECT_EQ(set.count(1), 1);
|
||||
EXPECT_EQ(set.count(2), 1);
|
||||
}
|
||||
|
||||
TEST(FlatSetTest, Iter) {
|
||||
NumSet set;
|
||||
EXPECT_EQ(Contents(set), NumSetContents());
|
||||
|
||||
set.insert(1);
|
||||
set.insert(2);
|
||||
EXPECT_EQ(Contents(set), NumSetContents({1, 2}));
|
||||
}
|
||||
|
||||
TEST(FlatSetTest, Erase) {
|
||||
NumSet set;
|
||||
EXPECT_EQ(set.erase(1), 0);
|
||||
set.insert(1);
|
||||
set.insert(2);
|
||||
EXPECT_EQ(set.erase(3), 0);
|
||||
EXPECT_EQ(set.erase(1), 1);
|
||||
EXPECT_EQ(set.size(), 1);
|
||||
EXPECT_TRUE(Has(set, 2));
|
||||
EXPECT_EQ(Contents(set), NumSetContents({2}));
|
||||
EXPECT_EQ(set.erase(2), 1);
|
||||
EXPECT_EQ(Contents(set), NumSetContents());
|
||||
}
|
||||
|
||||
TEST(FlatSetTest, EraseIter) {
|
||||
NumSet set;
|
||||
Fill(&set, 1, 11);
|
||||
size_t size = 10;
|
||||
for (auto iter = set.begin(); iter != set.end();) {
|
||||
iter = set.erase(iter);
|
||||
size--;
|
||||
EXPECT_EQ(set.size(), size);
|
||||
}
|
||||
EXPECT_EQ(Contents(set), NumSetContents());
|
||||
}
|
||||
|
||||
TEST(FlatSetTest, EraseIterPair) {
|
||||
NumSet set;
|
||||
Fill(&set, 1, 11);
|
||||
NumSet expected;
|
||||
auto p1 = set.begin();
|
||||
expected.insert(*p1);
|
||||
++p1;
|
||||
expected.insert(*p1);
|
||||
++p1;
|
||||
auto p2 = set.end();
|
||||
EXPECT_EQ(set.erase(p1, p2), set.end());
|
||||
EXPECT_EQ(set.size(), 2);
|
||||
EXPECT_EQ(Contents(set), Contents(expected));
|
||||
}
|
||||
|
||||
TEST(FlatSetTest, EraseLongChains) {
|
||||
// Make a set with lots of elements and erase a bunch of them to ensure
|
||||
// that we are likely to hit them on future lookups.
|
||||
NumSet set;
|
||||
const int num = 128;
|
||||
Fill(&set, 0, num);
|
||||
for (int i = 0; i < num; i += 3) {
|
||||
EXPECT_EQ(set.erase(i), 1);
|
||||
}
|
||||
for (int i = 0; i < num; i++) {
|
||||
// Multiples of 3 should be not present.
|
||||
EXPECT_EQ(Has(set, i), ((i % 3) != 0)) << i;
|
||||
}
|
||||
|
||||
// Erase remainder to trigger table shrinking.
|
||||
const size_t orig_buckets = set.bucket_count();
|
||||
for (int i = 0; i < num; i++) {
|
||||
set.erase(i);
|
||||
}
|
||||
EXPECT_TRUE(set.empty());
|
||||
EXPECT_EQ(set.bucket_count(), orig_buckets);
|
||||
set.insert(1); // Actual shrinking is triggered by an insert.
|
||||
EXPECT_LT(set.bucket_count(), orig_buckets);
|
||||
}
|
||||
|
||||
TEST(FlatSet, ClearNoResize) {
|
||||
NumSet set;
|
||||
Fill(&set, 0, 100);
|
||||
const size_t orig = set.bucket_count();
|
||||
set.clear_no_resize();
|
||||
EXPECT_EQ(set.size(), 0);
|
||||
EXPECT_EQ(Contents(set), NumSetContents());
|
||||
EXPECT_EQ(set.bucket_count(), orig);
|
||||
}
|
||||
|
||||
TEST(FlatSet, Clear) {
|
||||
NumSet set;
|
||||
Fill(&set, 0, 100);
|
||||
const size_t orig = set.bucket_count();
|
||||
set.clear();
|
||||
EXPECT_EQ(set.size(), 0);
|
||||
EXPECT_EQ(Contents(set), NumSetContents());
|
||||
EXPECT_LT(set.bucket_count(), orig);
|
||||
}
|
||||
|
||||
TEST(FlatSet, Copy) {
|
||||
for (int n = 0; n < 10; n++) {
|
||||
NumSet src;
|
||||
Fill(&src, 0, n);
|
||||
NumSet copy = src;
|
||||
EXPECT_EQ(Contents(src), Contents(copy));
|
||||
NumSet copy2;
|
||||
copy2 = src;
|
||||
EXPECT_EQ(Contents(src), Contents(copy2));
|
||||
copy2 = copy2; // Self-assignment
|
||||
EXPECT_EQ(Contents(src), Contents(copy2));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatSet, InitFromIter) {
|
||||
for (int n = 0; n < 10; n++) {
|
||||
NumSet src;
|
||||
Fill(&src, 0, n);
|
||||
auto vec = Contents(src);
|
||||
NumSet dst(vec.begin(), vec.end());
|
||||
EXPECT_EQ(Contents(dst), vec);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatSet, InsertIter) {
|
||||
NumSet a, b;
|
||||
Fill(&a, 1, 10);
|
||||
Fill(&b, 8, 20);
|
||||
b.insert(9); // Should not get inserted into a since a already has 9
|
||||
a.insert(b.begin(), b.end());
|
||||
NumSet expected;
|
||||
Fill(&expected, 1, 20);
|
||||
EXPECT_EQ(Contents(a), Contents(expected));
|
||||
}
|
||||
|
||||
TEST(FlatSet, Eq) {
|
||||
NumSet empty;
|
||||
|
||||
NumSet elems;
|
||||
Fill(&elems, 0, 5);
|
||||
EXPECT_FALSE(empty == elems);
|
||||
EXPECT_TRUE(empty != elems);
|
||||
|
||||
NumSet copy = elems;
|
||||
EXPECT_TRUE(copy == elems);
|
||||
EXPECT_FALSE(copy != elems);
|
||||
|
||||
NumSet changed = elems;
|
||||
changed.insert(7);
|
||||
EXPECT_FALSE(changed == elems);
|
||||
EXPECT_TRUE(changed != elems);
|
||||
|
||||
NumSet changed2 = elems;
|
||||
changed2.erase(3);
|
||||
EXPECT_FALSE(changed2 == elems);
|
||||
EXPECT_TRUE(changed2 != elems);
|
||||
}
|
||||
|
||||
TEST(FlatSet, Swap) {
|
||||
NumSet a, b;
|
||||
Fill(&a, 1, 5);
|
||||
Fill(&b, 100, 200);
|
||||
NumSet c = a;
|
||||
NumSet d = b;
|
||||
EXPECT_EQ(c, a);
|
||||
EXPECT_EQ(d, b);
|
||||
c.swap(d);
|
||||
EXPECT_EQ(c, b);
|
||||
EXPECT_EQ(d, a);
|
||||
}
|
||||
|
||||
TEST(FlatSet, Reserve) {
|
||||
NumSet src;
|
||||
Fill(&src, 1, 100);
|
||||
NumSet a = src;
|
||||
a.reserve(10);
|
||||
EXPECT_EQ(a, src);
|
||||
NumSet b = src;
|
||||
b.rehash(1000);
|
||||
EXPECT_EQ(b, src);
|
||||
}
|
||||
|
||||
TEST(FlatSet, EqualRangeMutable) {
|
||||
NumSet set;
|
||||
Fill(&set, 1, 10);
|
||||
|
||||
// Existing element
|
||||
auto p1 = set.equal_range(3);
|
||||
EXPECT_TRUE(p1.first != p1.second);
|
||||
EXPECT_EQ(*p1.first, 3);
|
||||
++p1.first;
|
||||
EXPECT_TRUE(p1.first == p1.second);
|
||||
|
||||
// Missing element
|
||||
auto p2 = set.equal_range(100);
|
||||
EXPECT_TRUE(p2.first == p2.second);
|
||||
}
|
||||
|
||||
TEST(FlatSet, EqualRangeConst) {
|
||||
NumSet tmp;
|
||||
Fill(&tmp, 1, 10);
|
||||
|
||||
const NumSet set = tmp;
|
||||
|
||||
// Existing element
|
||||
auto p1 = set.equal_range(3);
|
||||
EXPECT_TRUE(p1.first != p1.second);
|
||||
EXPECT_EQ(*p1.first, 3);
|
||||
++p1.first;
|
||||
EXPECT_TRUE(p1.first == p1.second);
|
||||
|
||||
// Missing element
|
||||
auto p2 = set.equal_range(100);
|
||||
EXPECT_TRUE(p2.first == p2.second);
|
||||
}
|
||||
|
||||
TEST(FlatSet, Prefetch) {
|
||||
NumSet set;
|
||||
Fill(&set, 0, 1000);
|
||||
// Prefetch present and missing keys.
|
||||
for (int i = 0; i < 2000; i++) {
|
||||
set.prefetch_value(i);
|
||||
}
|
||||
}
|
||||
|
||||
// Non-copyable values should work.
|
||||
struct NC {
|
||||
int64 value;
|
||||
NC() : value(-1) {}
|
||||
NC(int64 v) : value(v) {}
|
||||
NC(const NC& x) : value(x.value) {}
|
||||
bool operator==(const NC& x) const { return value == x.value; }
|
||||
};
|
||||
struct HashNC {
|
||||
size_t operator()(NC x) const { return x.value; }
|
||||
};
|
||||
|
||||
TEST(FlatSet, NonCopyable) {
|
||||
FlatSet<NC, HashNC> set;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
set.insert(NC(i));
|
||||
}
|
||||
for (int i = 0; i < 100; i++) {
|
||||
EXPECT_EQ(set.count(NC(i)), 1);
|
||||
auto iter = set.find(NC(i));
|
||||
EXPECT_NE(iter, set.end());
|
||||
EXPECT_EQ(*iter, NC(i));
|
||||
}
|
||||
set.erase(NC(10));
|
||||
EXPECT_EQ(set.count(NC(10)), 0);
|
||||
}
|
||||
|
||||
// Test with heap-allocated objects so that mismanaged constructions
|
||||
// or destructions will show up as errors under a sanitizer or
|
||||
// heap checker.
|
||||
TEST(FlatSet, ConstructDestruct) {
|
||||
FlatSet<string, HashStr> set;
|
||||
string k1 = "the quick brown fox jumped over the lazy dog";
|
||||
string k2 = k1 + k1;
|
||||
string k3 = k1 + k2;
|
||||
set.insert(k1);
|
||||
set.insert(k3);
|
||||
EXPECT_EQ(set.count(k1), 1);
|
||||
EXPECT_EQ(set.count(k2), 0);
|
||||
EXPECT_EQ(set.count(k3), 1);
|
||||
|
||||
set.erase(k3);
|
||||
EXPECT_EQ(set.count(k3), 0);
|
||||
|
||||
set.clear();
|
||||
set.insert(k1);
|
||||
EXPECT_EQ(set.count(k1), 1);
|
||||
EXPECT_EQ(set.count(k3), 0);
|
||||
|
||||
set.reserve(100);
|
||||
EXPECT_EQ(set.count(k1), 1);
|
||||
EXPECT_EQ(set.count(k3), 0);
|
||||
}
|
||||
|
||||
// Type to use to ensure that custom equality operator is used
|
||||
// that ignores extra value.
|
||||
struct CustomCmpKey {
|
||||
int64 a;
|
||||
int64 b;
|
||||
CustomCmpKey(int64 v1, int64 v2) : a(v1), b(v2) {}
|
||||
bool operator==(const CustomCmpKey& x) const { return a == x.a && b == x.b; }
|
||||
};
|
||||
struct HashA {
|
||||
size_t operator()(CustomCmpKey x) const { return x.a; }
|
||||
};
|
||||
struct EqA {
|
||||
// Ignore b fields.
|
||||
bool operator()(CustomCmpKey x, CustomCmpKey y) const { return x.a == y.a; }
|
||||
};
|
||||
TEST(FlatSet, CustomCmp) {
|
||||
FlatSet<CustomCmpKey, HashA, EqA> set;
|
||||
set.insert(CustomCmpKey(100, 200));
|
||||
EXPECT_EQ(set.count(CustomCmpKey(100, 200)), 1);
|
||||
EXPECT_EQ(set.count(CustomCmpKey(100, 500)), 1); // key.b ignored
|
||||
}
|
||||
|
||||
// Test unique_ptr handling.
|
||||
typedef std::unique_ptr<int> UniqInt;
|
||||
static UniqInt MakeUniq(int i) { return UniqInt(new int(i)); }
|
||||
|
||||
struct HashUniq {
|
||||
size_t operator()(const UniqInt& p) const { return *p; }
|
||||
};
|
||||
struct EqUniq {
|
||||
bool operator()(const UniqInt& a, const UniqInt& b) const { return *a == *b; }
|
||||
};
|
||||
typedef FlatSet<UniqInt, HashUniq, EqUniq> UniqSet;
|
||||
|
||||
TEST(FlatSet, UniqueSet) {
|
||||
UniqSet set;
|
||||
|
||||
// Fill set
|
||||
const int N = 10;
|
||||
for (int i = 0; i < N; i++) {
|
||||
set.emplace(MakeUniq(i));
|
||||
}
|
||||
EXPECT_EQ(set.size(), N);
|
||||
|
||||
// Lookups
|
||||
for (int i = 0; i < N; i++) {
|
||||
EXPECT_EQ(set.count(MakeUniq(i)), 1);
|
||||
}
|
||||
|
||||
// erase
|
||||
set.erase(MakeUniq(2));
|
||||
EXPECT_EQ(set.count(MakeUniq(2)), 0);
|
||||
|
||||
// clear
|
||||
set.clear();
|
||||
EXPECT_EQ(set.size(), 0);
|
||||
}
|
||||
|
||||
TEST(FlatSet, UniqueSetIter) {
|
||||
UniqSet set;
|
||||
const int kCount = 10;
|
||||
for (int i = 1; i <= kCount; i++) {
|
||||
set.emplace(MakeUniq(i));
|
||||
}
|
||||
int sum = 0;
|
||||
for (const auto& p : set) {
|
||||
sum += *p;
|
||||
}
|
||||
EXPECT_EQ(sum, (kCount * (kCount + 1)) / 2);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gtl
|
||||
} // namespace tensorflow
|
@ -42,6 +42,24 @@ inline uint64 Hash64Combine(uint64 a, uint64 b) {
|
||||
return a ^ (b + 0x9e3779b97f4a7800ULL + (a << 10) + (a >> 4));
|
||||
}
|
||||
|
||||
// Convenience Hash functors
|
||||
struct HashInt64 {
|
||||
size_t operator()(int64 x) const { return static_cast<size_t>(x); }
|
||||
};
|
||||
struct HashStr {
|
||||
size_t operator()(const string& s) const {
|
||||
return static_cast<size_t>(Hash64(s));
|
||||
}
|
||||
};
|
||||
template <typename PTR>
|
||||
struct HashPtr {
|
||||
size_t operator()(const PTR p) const {
|
||||
// Hash pointers as integers, but bring more entropy to the lower bits.
|
||||
size_t k = static_cast<size_t>(reinterpret_cast<uintptr_t>(p));
|
||||
return k + (k >> 6);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_LIB_HASH_HASH_H_
|
||||
|
Loading…
Reference in New Issue
Block a user