[XLA] Various improvements to ShapeTree.
Add support for holding non-copyable types, operator==, and a CopySubtreeFrom method for copying a subtree from one ShapeTree to another. PiperOrigin-RevId: 157777636
This commit is contained in:
parent
4f3ae76996
commit
2ee09b873a
@ -44,6 +44,7 @@ struct ShapeTreeNode {
|
|||||||
// Children of this node.
|
// Children of this node.
|
||||||
std::vector<std::unique_ptr<ShapeTreeNode>> children;
|
std::vector<std::unique_ptr<ShapeTreeNode>> children;
|
||||||
|
|
||||||
|
ShapeTreeNode() = default;
|
||||||
explicit ShapeTreeNode(const T& data) : data(data) {}
|
explicit ShapeTreeNode(const T& data) : data(data) {}
|
||||||
|
|
||||||
ShapeTreeNode(const ShapeTreeNode& other)
|
ShapeTreeNode(const ShapeTreeNode& other)
|
||||||
@ -85,8 +86,9 @@ class ShapeTree {
|
|||||||
public:
|
public:
|
||||||
// Default constructor creates a tree with a nil shape (i.e. an empty tuple).
|
// Default constructor creates a tree with a nil shape (i.e. an empty tuple).
|
||||||
ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
|
ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
|
||||||
// Create ShapeTree with the given shape, and default T values for all nodes.
|
// Create ShapeTree with the given shape, and default-constructed T values for
|
||||||
explicit ShapeTree(const Shape& shape) : ShapeTree(shape, T()) {}
|
// all nodes.
|
||||||
|
explicit ShapeTree(const Shape& shape);
|
||||||
// Create ShapeTree with the given shape, and init_value for all nodes.
|
// Create ShapeTree with the given shape, and init_value for all nodes.
|
||||||
ShapeTree(const Shape& shape, const T& init_value);
|
ShapeTree(const Shape& shape, const T& init_value);
|
||||||
|
|
||||||
@ -127,6 +129,19 @@ class ShapeTree {
|
|||||||
const ShapeIndex& /*index*/, bool /*is_leaf*/, T* /*data*/)>;
|
const ShapeIndex& /*index*/, bool /*is_leaf*/, T* /*data*/)>;
|
||||||
Status ForEachMutableElement(const MutableVisitorFunction& func);
|
Status ForEachMutableElement(const MutableVisitorFunction& func);
|
||||||
|
|
||||||
|
// Copy the subtree of values from 'other' rooted at ShapeIndex
|
||||||
|
// 'source_base_index' into the subtree of value in this ShapeTree rooted at
|
||||||
|
// 'target_base_index'.
|
||||||
|
//
|
||||||
|
// Precondition: The subshape of other.shape() at index source_base_index must
|
||||||
|
// be compatible with the subshape of shape() at index target_base_index.
|
||||||
|
void CopySubtreeFrom(const ShapeTree<T>& other,
|
||||||
|
const ShapeIndex& source_base_index,
|
||||||
|
const ShapeIndex& target_base_index);
|
||||||
|
|
||||||
|
bool operator==(const ShapeTree<T>& other) const;
|
||||||
|
bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
using Node = internal::ShapeTreeNode<T>;
|
using Node = internal::ShapeTreeNode<T>;
|
||||||
|
|
||||||
@ -134,6 +149,10 @@ class ShapeTree {
|
|||||||
// the given 'init_value'.
|
// the given 'init_value'.
|
||||||
void InitChildren(const Shape& shape, const T& init_value, Node* node);
|
void InitChildren(const Shape& shape, const T& init_value, Node* node);
|
||||||
|
|
||||||
|
// Initialize node->children based on 'shape'. All children have
|
||||||
|
// default-constructed data values.
|
||||||
|
void InitChildren(const Shape& shape, Node* node);
|
||||||
|
|
||||||
// Helpers for traversing the shape via ForEachElement. The helpers
|
// Helpers for traversing the shape via ForEachElement. The helpers
|
||||||
// recursively traverse the subtree rooted at "index" (defined as in
|
// recursively traverse the subtree rooted at "index" (defined as in
|
||||||
// ShapeUtil::GetSubshape).
|
// ShapeUtil::GetSubshape).
|
||||||
@ -165,6 +184,24 @@ void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) {
|
||||||
|
if (ShapeUtil::IsTuple(shape)) {
|
||||||
|
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
|
||||||
|
node->children.emplace_back(new Node());
|
||||||
|
InitChildren(shape.tuple_shapes(i), node->children.back().get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
ShapeTree<T>::ShapeTree(const Shape& shape) : root_(), shape_(shape) {
|
||||||
|
// The shape_ field is just used to hold the structure of the shape.
|
||||||
|
// It should not be relied upon to store layout information.
|
||||||
|
LayoutUtil::ClearLayout(&shape_);
|
||||||
|
InitChildren(shape_, &root_);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ShapeTree<T>::ShapeTree(const Shape& shape, const T& init_value)
|
ShapeTree<T>::ShapeTree(const Shape& shape, const T& init_value)
|
||||||
: root_(init_value), shape_(shape) {
|
: root_(init_value), shape_(shape) {
|
||||||
@ -240,6 +277,48 @@ Status ShapeTree<T>::ForEachMutableElement(const MutableVisitorFunction& func) {
|
|||||||
return ForEachMutableHelper(func, &root_, &index);
|
return ForEachMutableHelper(func, &root_, &index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
|
||||||
|
const ShapeIndex& source_base_index,
|
||||||
|
const ShapeIndex& target_base_index) {
|
||||||
|
CHECK(ShapeUtil::Compatible(
|
||||||
|
ShapeUtil::GetSubshape(shape(), target_base_index),
|
||||||
|
ShapeUtil::GetSubshape(other.shape(), source_base_index)));
|
||||||
|
ForEachMutableElement(
|
||||||
|
[this, &other, &source_base_index, &target_base_index](
|
||||||
|
const ShapeIndex& index, bool /*is_leaf*/, T* data) {
|
||||||
|
// Copy the data element only if index is in the
|
||||||
|
// subtree rooted at target_base_index.
|
||||||
|
for (int i = 0; i < target_base_index.size(); ++i) {
|
||||||
|
if (i >= index.size() || index[i] != target_base_index[i]) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Construct source element index to copy from.
|
||||||
|
ShapeIndex source_index = source_base_index;
|
||||||
|
for (int i = target_base_index.size(); i < index.size(); ++i) {
|
||||||
|
source_index.push_back(index[i]);
|
||||||
|
}
|
||||||
|
*data = other.element(source_index);
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
|
.IgnoreError();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
|
||||||
|
bool equal = true;
|
||||||
|
ForEachElement([this, &other, &equal](const ShapeIndex& index,
|
||||||
|
bool /*is_leaf*/, const T& data) {
|
||||||
|
if (data != other.element(index)) {
|
||||||
|
equal = false;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
|
.IgnoreError();
|
||||||
|
return equal;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
|
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
|
||||||
|
@ -245,5 +245,139 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) {
|
|||||||
EXPECT_DEATH(shape_tree.element({0, 0}), "");
|
EXPECT_DEATH(shape_tree.element({0, 0}), "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) {
|
||||||
|
ShapeTree<std::unique_ptr<int>> shape_tree{tuple_shape_};
|
||||||
|
EXPECT_EQ(shape_tree.element({2}).get(), nullptr);
|
||||||
|
*shape_tree.mutable_element({2}) = MakeUnique<int>(42);
|
||||||
|
EXPECT_EQ(*shape_tree.element({2}), 42);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTreeTest, CopySubtreeFromArrayShape) {
|
||||||
|
// Test CopySubtreeFrom method for a single value copied between array-shaped
|
||||||
|
// ShapeTrees.
|
||||||
|
ShapeTree<int> source(array_shape_);
|
||||||
|
*source.mutable_element(/*index=*/{}) = 42;
|
||||||
|
ShapeTree<int> destination(array_shape_, 123);
|
||||||
|
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{}), 123);
|
||||||
|
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
|
||||||
|
/*target_base_index=*/{});
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{}), 42);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTreeTest, FullCopySubtreeFromTupleShape) {
|
||||||
|
// Test CopySubtreeFrom method for a copy of all elements from one
|
||||||
|
// tuple-shaped ShapeTree to another.
|
||||||
|
ShapeTree<int> source(tuple_shape_);
|
||||||
|
*source.mutable_element(/*index=*/{}) = 10;
|
||||||
|
*source.mutable_element(/*index=*/{0}) = 11;
|
||||||
|
*source.mutable_element(/*index=*/{1}) = 12;
|
||||||
|
*source.mutable_element(/*index=*/{2}) = 13;
|
||||||
|
|
||||||
|
ShapeTree<int> destination(tuple_shape_, 0);
|
||||||
|
|
||||||
|
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
|
||||||
|
/*target_base_index=*/{});
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{}), 10);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{0}), 11);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{1}), 12);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{2}), 13);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTreeTest, SingleElementCopySubtreeFromTupleShape) {
|
||||||
|
// Test CopySubtreeFrom method for a copy of a single element from one
|
||||||
|
// tuple-shaped ShapeTree to another.
|
||||||
|
ShapeTree<int> source(tuple_shape_);
|
||||||
|
*source.mutable_element(/*index=*/{}) = 10;
|
||||||
|
*source.mutable_element(/*index=*/{0}) = 11;
|
||||||
|
*source.mutable_element(/*index=*/{1}) = 12;
|
||||||
|
*source.mutable_element(/*index=*/{2}) = 13;
|
||||||
|
|
||||||
|
ShapeTree<int> destination(tuple_shape_, 0);
|
||||||
|
|
||||||
|
destination.CopySubtreeFrom(source, /*source_base_index=*/{0},
|
||||||
|
/*target_base_index=*/{1});
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{0}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{1}), 11);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{2}), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTreeTest, CopySubtreeIntoNestedShape) {
|
||||||
|
// Test CopySubtreeFrom method for a copy of a tuple-shaped ShapeTree into a
|
||||||
|
// nested-tuple-shaped ShapeTree.
|
||||||
|
ShapeTree<int> source(
|
||||||
|
ShapeUtil::MakeTupleShape({array_shape_, array_shape_}));
|
||||||
|
*source.mutable_element(/*index=*/{}) = 10;
|
||||||
|
*source.mutable_element(/*index=*/{0}) = 11;
|
||||||
|
*source.mutable_element(/*index=*/{1}) = 12;
|
||||||
|
|
||||||
|
ShapeTree<int> destination(nested_tuple_shape_, 0);
|
||||||
|
|
||||||
|
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
|
||||||
|
/*target_base_index=*/{2, 0});
|
||||||
|
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{0}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{1}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{1, 0}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{1, 1}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{2}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{2, 0}), 10);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{2, 0, 0}), 11);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{2, 0, 1}), 12);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{2, 1}), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTreeTest, CopySubtreeFromNestedShape) {
|
||||||
|
// Test CopySubtreeFrom method for a copy from a nested-tuple-shape.
|
||||||
|
ShapeTree<int> source(nested_tuple_shape_, 42);
|
||||||
|
*source.mutable_element(/*index=*/{1}) = 10;
|
||||||
|
*source.mutable_element(/*index=*/{1, 0}) = 11;
|
||||||
|
*source.mutable_element(/*index=*/{1, 1}) = 12;
|
||||||
|
|
||||||
|
ShapeTree<int> destination(
|
||||||
|
ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), 0);
|
||||||
|
|
||||||
|
destination.CopySubtreeFrom(source, /*source_base_index=*/{1},
|
||||||
|
/*target_base_index=*/{});
|
||||||
|
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{}), 10);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{0}), 11);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{1}), 12);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTreeTest, OperatorEquals) {
|
||||||
|
{
|
||||||
|
ShapeTree<int> a(array_shape_, 123);
|
||||||
|
ShapeTree<int> b(array_shape_, 42);
|
||||||
|
ShapeTree<int> c(array_shape_, 42);
|
||||||
|
EXPECT_FALSE(a == b);
|
||||||
|
EXPECT_TRUE(a != b);
|
||||||
|
EXPECT_TRUE(b == c);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ShapeTree<int> a(tuple_shape_);
|
||||||
|
*a.mutable_element(/*index=*/{}) = 10;
|
||||||
|
*a.mutable_element(/*index=*/{0}) = 11;
|
||||||
|
*a.mutable_element(/*index=*/{1}) = 12;
|
||||||
|
|
||||||
|
ShapeTree<int> b(tuple_shape_);
|
||||||
|
*b.mutable_element(/*index=*/{}) = 10;
|
||||||
|
*b.mutable_element(/*index=*/{0}) = 42;
|
||||||
|
*b.mutable_element(/*index=*/{1}) = 11;
|
||||||
|
|
||||||
|
ShapeTree<int> c(tuple_shape_);
|
||||||
|
*c.mutable_element(/*index=*/{}) = 10;
|
||||||
|
*c.mutable_element(/*index=*/{0}) = 42;
|
||||||
|
*c.mutable_element(/*index=*/{1}) = 11;
|
||||||
|
|
||||||
|
EXPECT_FALSE(a == b);
|
||||||
|
EXPECT_TRUE(a != b);
|
||||||
|
EXPECT_TRUE(b == c);
|
||||||
|
EXPECT_FALSE(b != c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
x
Reference in New Issue
Block a user