diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index d2d27c7f814..b1c96e9becf 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -70,6 +70,8 @@ struct IndexTableEntry { template class ShapeTreeIterator; +template +class ShapeTreeLeafIterator; // A ShapeTree is a recursive data structure which mirrors the structure of a // XLA shape and holds a value of type T for each subshape (i.e. tuple or array) @@ -158,23 +160,25 @@ class ShapeTree { using reverse_iterator = std::reverse_iterator; using const_reverse_iterator = std::reverse_iterator; + using leaf_iterator = + ShapeTreeLeafIterator, + typename std::vector::iterator, + std::pair>; + using const_leaf_iterator = + ShapeTreeLeafIterator, + typename std::vector::const_iterator, + const std::pair>; + using reverse_leaf_iterator = std::reverse_iterator; + using const_reverse_leaf_iterator = + std::reverse_iterator; + // begin/end for iterating over all nodes. - iterator begin() { - return iterator(&nodes_, nodes_.begin(), - /*iterate_leaves_only=*/false); - } - iterator end() { - return iterator(&nodes_, nodes_.end(), - /*iterate_leaves_only=*/false); - } + iterator begin() { return iterator(&nodes_, nodes_.begin()); } + iterator end() { return iterator(&nodes_, nodes_.end()); } const_iterator begin() const { - return const_iterator(&nodes_, nodes_.begin(), - /*iterate_leaves_only=*/false); - } - const_iterator end() const { - return const_iterator(&nodes_, nodes_.end(), - /*iterate_leaves_only=*/false); + return const_iterator(&nodes_, nodes_.begin()); } + const_iterator end() const { return const_iterator(&nodes_, nodes_.end()); } // rbegin/rend for iterating over all nodes in reverse. reverse_iterator rbegin() { return reverse_iterator(end()); } @@ -188,37 +192,33 @@ class ShapeTree { // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no // children). - iterator leaf_begin() { - return iterator(&nodes_, nodes_.begin(), - /*iterate_leaves_only=*/true); + leaf_iterator leaf_begin() { return leaf_iterator(&nodes_, nodes_.begin()); } + leaf_iterator leaf_end() { return leaf_iterator(&nodes_, nodes_.end()); } + const_leaf_iterator leaf_begin() const { + return const_leaf_iterator(&nodes_, nodes_.begin()); } - iterator leaf_end() { - return iterator(&nodes_, nodes_.end(), - /*iterate_leaves_only=*/true); - } - const_iterator leaf_begin() const { - return const_iterator(&nodes_, nodes_.begin(), - /*iterate_leaves_only=*/true); - } - const_iterator leaf_end() const { - return const_iterator(&nodes_, nodes_.end(), - /*iterate_leaves_only=*/true); + const_leaf_iterator leaf_end() const { + return const_leaf_iterator(&nodes_, nodes_.end()); } // range-based iterator for leaf_begin()/leaf_end(). - tensorflow::gtl::iterator_range leaves() { + tensorflow::gtl::iterator_range leaves() { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } - tensorflow::gtl::iterator_range leaves() const { + tensorflow::gtl::iterator_range leaves() const { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } - reverse_iterator leaf_rbegin() { return reverse_iterator(leaf_end()); } - reverse_iterator leaf_rend() { return reverse_iterator(leaf_begin()); } - const_reverse_iterator leaf_rbegin() const { - return const_reverse_iterator(leaf_end()); + reverse_leaf_iterator leaf_rbegin() { + return reverse_leaf_iterator(leaf_end()); } - const_reverse_iterator leaf_rend() const { - return const_reverse_iterator(leaf_begin()); + reverse_leaf_iterator leaf_rend() { + return reverse_leaf_iterator(leaf_begin()); + } + const_reverse_leaf_iterator leaf_rbegin() const { + return const_reverse_leaf_iterator(leaf_end()); + } + const_reverse_leaf_iterator leaf_rend() const { + return const_reverse_leaf_iterator(leaf_begin()); } // Returns an iterator pointing to the given ShapeIndex. @@ -226,12 +226,12 @@ class ShapeTree { iterator find(ShapeIndexView index) { Node* element = Lookup(index); auto element_iter = nodes_.begin() + (element - &nodes_[0]); - return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); + return iterator(&nodes_, element_iter); } const_iterator find(ShapeIndexView index) const { - Node* element = Lookup(index); + const Node* element = Lookup(index); auto element_iter = nodes_.cbegin() + (element - &nodes_[0]); - return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); + return const_iterator(&nodes_, element_iter); } // Returns the number of leaf nodes in the tree. @@ -343,21 +343,11 @@ template class ShapeTreeIterator : public std::iterator { public: - ShapeTreeIterator(ContainerType* nodes, IteratorType node, - bool iterate_leaves_only) - : nodes_(nodes), - node_(std::move(node)), - iterate_leaves_only_(iterate_leaves_only) { - while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) { - ++node_; - } - } + ShapeTreeIterator(ContainerType* nodes, IteratorType node) + : nodes_(nodes), node_(std::move(node)) {} ShapeTreeIterator& operator++() { ++node_; - while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) { - ++node_; - } return *this; } ShapeTreeIterator operator++(int) { @@ -368,9 +358,6 @@ class ShapeTreeIterator ShapeTreeIterator& operator--() { --node_; - while (iterate_leaves_only_ && node_ > nodes_->begin() && !node_->is_leaf) { - --node_; - } return *this; } ShapeTreeIterator operator--(int) { @@ -391,8 +378,60 @@ class ShapeTreeIterator private: ContainerType* nodes_; IteratorType node_; - // True if we should not include interior nodes in our walk. - const bool iterate_leaves_only_; +}; + +// Internal iterator that performs a pre-order walk of the leaves. This is cheap +// to copy. The iterator value_type is equivalent to a std::pair&, +// similar to std::map. +template +class ShapeTreeLeafIterator + : public std::iterator { + public: + ShapeTreeLeafIterator(ContainerType* nodes, IteratorType node) + : nodes_(nodes), node_(std::move(node)) { + while (node_ != nodes_->end() && !node_->is_leaf) { + ++node_; + } + } + + ShapeTreeLeafIterator& operator++() { + ++node_; + while (node_ != nodes_->end() && !node_->is_leaf) { + ++node_; + } + return *this; + } + ShapeTreeLeafIterator operator++(int) { + auto i = *this; + ++(*this); + return i; + } + + ShapeTreeLeafIterator& operator--() { + --node_; + while (node_ > nodes_->begin() && !node_->is_leaf) { + --node_; + } + return *this; + } + ShapeTreeLeafIterator operator--(int) { + auto i = *this; + --(*this); + return i; + } + + bool operator==(const ShapeTreeLeafIterator& other) const { + return node_ == other.node_; + } + bool operator!=(const ShapeTreeLeafIterator& other) const { + return node_ != other.node_; + } + ValueType& operator*() const { return node_->data; } + ValueType* operator->() const { return &node_->data; } + + private: + ContainerType* nodes_; + IteratorType node_; }; template diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 2b6c484bc4f..c294355e269 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -485,6 +485,30 @@ TEST_F(ShapeTreeTest, ReverseIterateOrder) { })); } +// Ensures that we can find an element at an index that we know ahead of time to +// be occupied in a 'ShapeTree' via the 'find' API. +TEST_F(ShapeTreeTest, Find) { + ShapeTree t(nested_tuple_shape_, 42); + auto found = t.find({1, 0}); + EXPECT_NE(found, t.end()); + // The found key must be the same key we searched for. + EXPECT_EQ(found->first, ShapeIndex({1, 0})); + // The 'ShapeTree' has 42 at every position. + EXPECT_EQ(found->second, 42); +} + +// Ensures that we can find an element at an index that we know ahead of time to +// be occupied in a 'const ShapeTree' via the 'find' API. +TEST_F(ShapeTreeTest, ConstFind) { + const ShapeTree t(nested_tuple_shape_, 42); + auto found = t.find({1, 0}); + EXPECT_NE(found, t.end()); + // The found key must be the same key we searched for. + EXPECT_EQ(found->first, ShapeIndex({1, 0})); + // The 'ShapeTree' has 42 at every position. + EXPECT_EQ(found->second, 42); +} + TEST_F(ShapeTreeTest, IterateOrderLeaves) { ShapeTree t(nested_tuple_shape_, 42); std::vector v;