Split ShapeTreeIterator into two classes: ShapeTreeIterator and ShapeTreeLeafIterator

The const bool "iterate leaves only" member meant that one of these iterators was either a tree-traversing or leaf-traversing iterator "for life". This seems better served by just having two separate classes. It was causing concrete problems too, e.g. the const bool member prevented the type from being copy-assignable.

The only way this would cause problems is if there is some need for dynamic switching between leaf- and tree- traversal at runtime, which seems unlikely.

PiperOrigin-RevId: 331854790
Change-Id: Ic833fe8445ed82ee274e1a6c2856d5933f10f414
This commit is contained in:
A. Unique TensorFlower 2020-09-15 14:23:34 -07:00 committed by TensorFlower Gardener
parent 5cbcf59ba2
commit 139ba9c528
2 changed files with 119 additions and 56 deletions

View File

@ -70,6 +70,8 @@ struct IndexTableEntry {
template <typename ContainerType, typename IteratorType, typename ValueType>
class ShapeTreeIterator;
template <typename ContainerType, typename IteratorType, typename ValueType>
class ShapeTreeLeafIterator;
// A ShapeTree<T> is a recursive data structure which mirrors the structure of a
// XLA shape and holds a value of type T for each subshape (i.e. tuple or array)
@ -158,23 +160,25 @@ class ShapeTree {
using reverse_iterator = std::reverse_iterator<iterator>;
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
using leaf_iterator =
ShapeTreeLeafIterator<std::vector<Node>,
typename std::vector<Node>::iterator,
std::pair<ShapeIndex, T>>;
using const_leaf_iterator =
ShapeTreeLeafIterator<const std::vector<Node>,
typename std::vector<Node>::const_iterator,
const std::pair<ShapeIndex, T>>;
using reverse_leaf_iterator = std::reverse_iterator<leaf_iterator>;
using const_reverse_leaf_iterator =
std::reverse_iterator<const_leaf_iterator>;
// begin/end for iterating over all nodes.
iterator begin() {
return iterator(&nodes_, nodes_.begin(),
/*iterate_leaves_only=*/false);
}
iterator end() {
return iterator(&nodes_, nodes_.end(),
/*iterate_leaves_only=*/false);
}
iterator begin() { return iterator(&nodes_, nodes_.begin()); }
iterator end() { return iterator(&nodes_, nodes_.end()); }
const_iterator begin() const {
return const_iterator(&nodes_, nodes_.begin(),
/*iterate_leaves_only=*/false);
}
const_iterator end() const {
return const_iterator(&nodes_, nodes_.end(),
/*iterate_leaves_only=*/false);
return const_iterator(&nodes_, nodes_.begin());
}
const_iterator end() const { return const_iterator(&nodes_, nodes_.end()); }
// rbegin/rend for iterating over all nodes in reverse.
reverse_iterator rbegin() { return reverse_iterator(end()); }
@ -188,37 +192,33 @@ class ShapeTree {
// leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no
// children).
iterator leaf_begin() {
return iterator(&nodes_, nodes_.begin(),
/*iterate_leaves_only=*/true);
leaf_iterator leaf_begin() { return leaf_iterator(&nodes_, nodes_.begin()); }
leaf_iterator leaf_end() { return leaf_iterator(&nodes_, nodes_.end()); }
const_leaf_iterator leaf_begin() const {
return const_leaf_iterator(&nodes_, nodes_.begin());
}
iterator leaf_end() {
return iterator(&nodes_, nodes_.end(),
/*iterate_leaves_only=*/true);
}
const_iterator leaf_begin() const {
return const_iterator(&nodes_, nodes_.begin(),
/*iterate_leaves_only=*/true);
}
const_iterator leaf_end() const {
return const_iterator(&nodes_, nodes_.end(),
/*iterate_leaves_only=*/true);
const_leaf_iterator leaf_end() const {
return const_leaf_iterator(&nodes_, nodes_.end());
}
// range-based iterator for leaf_begin()/leaf_end().
tensorflow::gtl::iterator_range<iterator> leaves() {
tensorflow::gtl::iterator_range<leaf_iterator> leaves() {
return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
}
tensorflow::gtl::iterator_range<const_iterator> leaves() const {
tensorflow::gtl::iterator_range<const_leaf_iterator> leaves() const {
return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
}
reverse_iterator leaf_rbegin() { return reverse_iterator(leaf_end()); }
reverse_iterator leaf_rend() { return reverse_iterator(leaf_begin()); }
const_reverse_iterator leaf_rbegin() const {
return const_reverse_iterator(leaf_end());
reverse_leaf_iterator leaf_rbegin() {
return reverse_leaf_iterator(leaf_end());
}
const_reverse_iterator leaf_rend() const {
return const_reverse_iterator(leaf_begin());
reverse_leaf_iterator leaf_rend() {
return reverse_leaf_iterator(leaf_begin());
}
const_reverse_leaf_iterator leaf_rbegin() const {
return const_reverse_leaf_iterator(leaf_end());
}
const_reverse_leaf_iterator leaf_rend() const {
return const_reverse_leaf_iterator(leaf_begin());
}
// Returns an iterator pointing to the given ShapeIndex.
@ -226,12 +226,12 @@ class ShapeTree {
iterator find(ShapeIndexView index) {
Node* element = Lookup(index);
auto element_iter = nodes_.begin() + (element - &nodes_[0]);
return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
return iterator(&nodes_, element_iter);
}
const_iterator find(ShapeIndexView index) const {
Node* element = Lookup(index);
const Node* element = Lookup(index);
auto element_iter = nodes_.cbegin() + (element - &nodes_[0]);
return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
return const_iterator(&nodes_, element_iter);
}
// Returns the number of leaf nodes in the tree.
@ -343,21 +343,11 @@ template <typename ContainerType, typename IteratorType, typename ValueType>
class ShapeTreeIterator
: public std::iterator<std::bidirectional_iterator_tag, ValueType> {
public:
ShapeTreeIterator(ContainerType* nodes, IteratorType node,
bool iterate_leaves_only)
: nodes_(nodes),
node_(std::move(node)),
iterate_leaves_only_(iterate_leaves_only) {
while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) {
++node_;
}
}
ShapeTreeIterator(ContainerType* nodes, IteratorType node)
: nodes_(nodes), node_(std::move(node)) {}
ShapeTreeIterator& operator++() {
++node_;
while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) {
++node_;
}
return *this;
}
ShapeTreeIterator operator++(int) {
@ -368,9 +358,6 @@ class ShapeTreeIterator
ShapeTreeIterator& operator--() {
--node_;
while (iterate_leaves_only_ && node_ > nodes_->begin() && !node_->is_leaf) {
--node_;
}
return *this;
}
ShapeTreeIterator operator--(int) {
@ -391,8 +378,60 @@ class ShapeTreeIterator
private:
ContainerType* nodes_;
IteratorType node_;
// True if we should not include interior nodes in our walk.
const bool iterate_leaves_only_;
};
// Internal iterator that performs a pre-order walk of the leaves. This is cheap
// to copy. The iterator value_type is equivalent to a std::pair<ShapeIndex,T>&,
// similar to std::map.
template <typename ContainerType, typename IteratorType, typename ValueType>
class ShapeTreeLeafIterator
: public std::iterator<std::bidirectional_iterator_tag, ValueType> {
public:
ShapeTreeLeafIterator(ContainerType* nodes, IteratorType node)
: nodes_(nodes), node_(std::move(node)) {
while (node_ != nodes_->end() && !node_->is_leaf) {
++node_;
}
}
ShapeTreeLeafIterator& operator++() {
++node_;
while (node_ != nodes_->end() && !node_->is_leaf) {
++node_;
}
return *this;
}
ShapeTreeLeafIterator operator++(int) {
auto i = *this;
++(*this);
return i;
}
ShapeTreeLeafIterator& operator--() {
--node_;
while (node_ > nodes_->begin() && !node_->is_leaf) {
--node_;
}
return *this;
}
ShapeTreeLeafIterator operator--(int) {
auto i = *this;
--(*this);
return i;
}
bool operator==(const ShapeTreeLeafIterator& other) const {
return node_ == other.node_;
}
bool operator!=(const ShapeTreeLeafIterator& other) const {
return node_ != other.node_;
}
ValueType& operator*() const { return node_->data; }
ValueType* operator->() const { return &node_->data; }
private:
ContainerType* nodes_;
IteratorType node_;
};
template <typename T>

View File

@ -485,6 +485,30 @@ TEST_F(ShapeTreeTest, ReverseIterateOrder) {
}));
}
// Ensures that we can find an element at an index that we know ahead of time to
// be occupied in a 'ShapeTree' via the 'find' API.
TEST_F(ShapeTreeTest, Find) {
ShapeTree<int> t(nested_tuple_shape_, 42);
auto found = t.find({1, 0});
EXPECT_NE(found, t.end());
// The found key must be the same key we searched for.
EXPECT_EQ(found->first, ShapeIndex({1, 0}));
// The 'ShapeTree' has 42 at every position.
EXPECT_EQ(found->second, 42);
}
// Ensures that we can find an element at an index that we know ahead of time to
// be occupied in a 'const ShapeTree' via the 'find' API.
TEST_F(ShapeTreeTest, ConstFind) {
const ShapeTree<int> t(nested_tuple_shape_, 42);
auto found = t.find({1, 0});
EXPECT_NE(found, t.end());
// The found key must be the same key we searched for.
EXPECT_EQ(found->first, ShapeIndex({1, 0}));
// The 'ShapeTree' has 42 at every position.
EXPECT_EQ(found->second, 42);
}
TEST_F(ShapeTreeTest, IterateOrderLeaves) {
ShapeTree<int> t(nested_tuple_shape_, 42);
std::vector<ShapeIndex> v;