[XLA] Various improvements to ShapeTree.
Add support for holding non-copyable types, operator==, and a CopySubtreeFrom method for copying a subtree from one ShapeTree to another. PiperOrigin-RevId: 157777636
This commit is contained in:
parent
4f3ae76996
commit
2ee09b873a
tensorflow/compiler/xla
@ -44,6 +44,7 @@ struct ShapeTreeNode {
|
||||
// Children of this node.
|
||||
std::vector<std::unique_ptr<ShapeTreeNode>> children;
|
||||
|
||||
ShapeTreeNode() = default;
|
||||
explicit ShapeTreeNode(const T& data) : data(data) {}
|
||||
|
||||
ShapeTreeNode(const ShapeTreeNode& other)
|
||||
@ -85,8 +86,9 @@ class ShapeTree {
|
||||
public:
|
||||
// Default constructor creates a tree with a nil shape (i.e. an empty tuple).
|
||||
ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
|
||||
// Create ShapeTree with the given shape, and default T values for all nodes.
|
||||
explicit ShapeTree(const Shape& shape) : ShapeTree(shape, T()) {}
|
||||
// Create ShapeTree with the given shape, and default-constructed T values for
|
||||
// all nodes.
|
||||
explicit ShapeTree(const Shape& shape);
|
||||
// Create ShapeTree with the given shape, and init_value for all nodes.
|
||||
ShapeTree(const Shape& shape, const T& init_value);
|
||||
|
||||
@ -127,6 +129,19 @@ class ShapeTree {
|
||||
const ShapeIndex& /*index*/, bool /*is_leaf*/, T* /*data*/)>;
|
||||
Status ForEachMutableElement(const MutableVisitorFunction& func);
|
||||
|
||||
// Copy the subtree of values from 'other' rooted at ShapeIndex
|
||||
// 'source_base_index' into the subtree of value in this ShapeTree rooted at
|
||||
// 'target_base_index'.
|
||||
//
|
||||
// Precondition: The subshape of other.shape() at index source_base_index must
|
||||
// be compatible with the subshape of shape() at index target_base_index.
|
||||
void CopySubtreeFrom(const ShapeTree<T>& other,
|
||||
const ShapeIndex& source_base_index,
|
||||
const ShapeIndex& target_base_index);
|
||||
|
||||
bool operator==(const ShapeTree<T>& other) const;
|
||||
bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
|
||||
|
||||
private:
|
||||
using Node = internal::ShapeTreeNode<T>;
|
||||
|
||||
@ -134,6 +149,10 @@ class ShapeTree {
|
||||
// the given 'init_value'.
|
||||
void InitChildren(const Shape& shape, const T& init_value, Node* node);
|
||||
|
||||
// Initialize node->children based on 'shape'. All children have
|
||||
// default-constructed data values.
|
||||
void InitChildren(const Shape& shape, Node* node);
|
||||
|
||||
// Helpers for traversing the shape via ForEachElement. The helpers
|
||||
// recursively traverse the subtree rooted at "index" (defined as in
|
||||
// ShapeUtil::GetSubshape).
|
||||
@ -165,6 +184,24 @@ void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) {
|
||||
if (ShapeUtil::IsTuple(shape)) {
|
||||
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
|
||||
node->children.emplace_back(new Node());
|
||||
InitChildren(shape.tuple_shapes(i), node->children.back().get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ShapeTree<T>::ShapeTree(const Shape& shape) : root_(), shape_(shape) {
|
||||
// The shape_ field is just used to hold the structure of the shape.
|
||||
// It should not be relied upon to store layout information.
|
||||
LayoutUtil::ClearLayout(&shape_);
|
||||
InitChildren(shape_, &root_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ShapeTree<T>::ShapeTree(const Shape& shape, const T& init_value)
|
||||
: root_(init_value), shape_(shape) {
|
||||
@ -240,6 +277,48 @@ Status ShapeTree<T>::ForEachMutableElement(const MutableVisitorFunction& func) {
|
||||
return ForEachMutableHelper(func, &root_, &index);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
|
||||
const ShapeIndex& source_base_index,
|
||||
const ShapeIndex& target_base_index) {
|
||||
CHECK(ShapeUtil::Compatible(
|
||||
ShapeUtil::GetSubshape(shape(), target_base_index),
|
||||
ShapeUtil::GetSubshape(other.shape(), source_base_index)));
|
||||
ForEachMutableElement(
|
||||
[this, &other, &source_base_index, &target_base_index](
|
||||
const ShapeIndex& index, bool /*is_leaf*/, T* data) {
|
||||
// Copy the data element only if index is in the
|
||||
// subtree rooted at target_base_index.
|
||||
for (int i = 0; i < target_base_index.size(); ++i) {
|
||||
if (i >= index.size() || index[i] != target_base_index[i]) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
// Construct source element index to copy from.
|
||||
ShapeIndex source_index = source_base_index;
|
||||
for (int i = target_base_index.size(); i < index.size(); ++i) {
|
||||
source_index.push_back(index[i]);
|
||||
}
|
||||
*data = other.element(source_index);
|
||||
return Status::OK();
|
||||
})
|
||||
.IgnoreError();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
|
||||
bool equal = true;
|
||||
ForEachElement([this, &other, &equal](const ShapeIndex& index,
|
||||
bool /*is_leaf*/, const T& data) {
|
||||
if (data != other.element(index)) {
|
||||
equal = false;
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.IgnoreError();
|
||||
return equal;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
|
||||
|
@ -245,5 +245,139 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) {
|
||||
EXPECT_DEATH(shape_tree.element({0, 0}), "");
|
||||
}
|
||||
|
||||
TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) {
|
||||
ShapeTree<std::unique_ptr<int>> shape_tree{tuple_shape_};
|
||||
EXPECT_EQ(shape_tree.element({2}).get(), nullptr);
|
||||
*shape_tree.mutable_element({2}) = MakeUnique<int>(42);
|
||||
EXPECT_EQ(*shape_tree.element({2}), 42);
|
||||
}
|
||||
|
||||
TEST_F(ShapeTreeTest, CopySubtreeFromArrayShape) {
|
||||
// Test CopySubtreeFrom method for a single value copied between array-shaped
|
||||
// ShapeTrees.
|
||||
ShapeTree<int> source(array_shape_);
|
||||
*source.mutable_element(/*index=*/{}) = 42;
|
||||
ShapeTree<int> destination(array_shape_, 123);
|
||||
|
||||
EXPECT_EQ(destination.element(/*index=*/{}), 123);
|
||||
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
|
||||
/*target_base_index=*/{});
|
||||
EXPECT_EQ(destination.element(/*index=*/{}), 42);
|
||||
}
|
||||
|
||||
TEST_F(ShapeTreeTest, FullCopySubtreeFromTupleShape) {
|
||||
// Test CopySubtreeFrom method for a copy of all elements from one
|
||||
// tuple-shaped ShapeTree to another.
|
||||
ShapeTree<int> source(tuple_shape_);
|
||||
*source.mutable_element(/*index=*/{}) = 10;
|
||||
*source.mutable_element(/*index=*/{0}) = 11;
|
||||
*source.mutable_element(/*index=*/{1}) = 12;
|
||||
*source.mutable_element(/*index=*/{2}) = 13;
|
||||
|
||||
ShapeTree<int> destination(tuple_shape_, 0);
|
||||
|
||||
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
|
||||
/*target_base_index=*/{});
|
||||
EXPECT_EQ(destination.element(/*index=*/{}), 10);
|
||||
EXPECT_EQ(destination.element(/*index=*/{0}), 11);
|
||||
EXPECT_EQ(destination.element(/*index=*/{1}), 12);
|
||||
EXPECT_EQ(destination.element(/*index=*/{2}), 13);
|
||||
}
|
||||
|
||||
TEST_F(ShapeTreeTest, SingleElementCopySubtreeFromTupleShape) {
|
||||
// Test CopySubtreeFrom method for a copy of a single element from one
|
||||
// tuple-shaped ShapeTree to another.
|
||||
ShapeTree<int> source(tuple_shape_);
|
||||
*source.mutable_element(/*index=*/{}) = 10;
|
||||
*source.mutable_element(/*index=*/{0}) = 11;
|
||||
*source.mutable_element(/*index=*/{1}) = 12;
|
||||
*source.mutable_element(/*index=*/{2}) = 13;
|
||||
|
||||
ShapeTree<int> destination(tuple_shape_, 0);
|
||||
|
||||
destination.CopySubtreeFrom(source, /*source_base_index=*/{0},
|
||||
/*target_base_index=*/{1});
|
||||
EXPECT_EQ(destination.element(/*index=*/{}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{0}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{1}), 11);
|
||||
EXPECT_EQ(destination.element(/*index=*/{2}), 0);
|
||||
}
|
||||
|
||||
TEST_F(ShapeTreeTest, CopySubtreeIntoNestedShape) {
|
||||
// Test CopySubtreeFrom method for a copy of a tuple-shaped ShapeTree into a
|
||||
// nested-tuple-shaped ShapeTree.
|
||||
ShapeTree<int> source(
|
||||
ShapeUtil::MakeTupleShape({array_shape_, array_shape_}));
|
||||
*source.mutable_element(/*index=*/{}) = 10;
|
||||
*source.mutable_element(/*index=*/{0}) = 11;
|
||||
*source.mutable_element(/*index=*/{1}) = 12;
|
||||
|
||||
ShapeTree<int> destination(nested_tuple_shape_, 0);
|
||||
|
||||
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
|
||||
/*target_base_index=*/{2, 0});
|
||||
|
||||
EXPECT_EQ(destination.element(/*index=*/{}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{0}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{1}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{1, 0}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{1, 1}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{2}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{2, 0}), 10);
|
||||
EXPECT_EQ(destination.element(/*index=*/{2, 0, 0}), 11);
|
||||
EXPECT_EQ(destination.element(/*index=*/{2, 0, 1}), 12);
|
||||
EXPECT_EQ(destination.element(/*index=*/{2, 1}), 0);
|
||||
}
|
||||
|
||||
TEST_F(ShapeTreeTest, CopySubtreeFromNestedShape) {
|
||||
// Test CopySubtreeFrom method for a copy from a nested-tuple-shape.
|
||||
ShapeTree<int> source(nested_tuple_shape_, 42);
|
||||
*source.mutable_element(/*index=*/{1}) = 10;
|
||||
*source.mutable_element(/*index=*/{1, 0}) = 11;
|
||||
*source.mutable_element(/*index=*/{1, 1}) = 12;
|
||||
|
||||
ShapeTree<int> destination(
|
||||
ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), 0);
|
||||
|
||||
destination.CopySubtreeFrom(source, /*source_base_index=*/{1},
|
||||
/*target_base_index=*/{});
|
||||
|
||||
EXPECT_EQ(destination.element(/*index=*/{}), 10);
|
||||
EXPECT_EQ(destination.element(/*index=*/{0}), 11);
|
||||
EXPECT_EQ(destination.element(/*index=*/{1}), 12);
|
||||
}
|
||||
|
||||
TEST_F(ShapeTreeTest, OperatorEquals) {
|
||||
{
|
||||
ShapeTree<int> a(array_shape_, 123);
|
||||
ShapeTree<int> b(array_shape_, 42);
|
||||
ShapeTree<int> c(array_shape_, 42);
|
||||
EXPECT_FALSE(a == b);
|
||||
EXPECT_TRUE(a != b);
|
||||
EXPECT_TRUE(b == c);
|
||||
}
|
||||
{
|
||||
ShapeTree<int> a(tuple_shape_);
|
||||
*a.mutable_element(/*index=*/{}) = 10;
|
||||
*a.mutable_element(/*index=*/{0}) = 11;
|
||||
*a.mutable_element(/*index=*/{1}) = 12;
|
||||
|
||||
ShapeTree<int> b(tuple_shape_);
|
||||
*b.mutable_element(/*index=*/{}) = 10;
|
||||
*b.mutable_element(/*index=*/{0}) = 42;
|
||||
*b.mutable_element(/*index=*/{1}) = 11;
|
||||
|
||||
ShapeTree<int> c(tuple_shape_);
|
||||
*c.mutable_element(/*index=*/{}) = 10;
|
||||
*c.mutable_element(/*index=*/{0}) = 42;
|
||||
*c.mutable_element(/*index=*/{1}) = 11;
|
||||
|
||||
EXPECT_FALSE(a == b);
|
||||
EXPECT_TRUE(a != b);
|
||||
EXPECT_TRUE(b == c);
|
||||
EXPECT_FALSE(b != c);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user