Add MakeAscendingLayout().
Also, add unit test for MakeAscendingLayout() and (while at it) MakeDescendingLayout(). PiperOrigin-RevId: 350767284 Change-Id: I348f5628c5a9458cd2f3d75df2e25c5395639586
This commit is contained in:
parent
cccd2f2d7e
commit
b74444efa8
@ -84,6 +84,12 @@ void SetDefaultLayoutToContainer(T* minor_to_major) {
|
||||
return MakeLayout(layout);
|
||||
}
|
||||
|
||||
/* static */ Layout LayoutUtil::MakeAscendingLayout(int64 rank) {
|
||||
std::vector<int64> layout(rank);
|
||||
std::iota(layout.begin(), layout.end(), static_cast<int64>(0));
|
||||
return MakeLayout(layout);
|
||||
}
|
||||
|
||||
/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
|
||||
absl::Span<const int64> major_to_minor) {
|
||||
Layout layout;
|
||||
|
@ -45,10 +45,14 @@ class LayoutUtil {
|
||||
static Layout MakeLayoutFromMajorToMinor(
|
||||
absl::Span<const int64> major_to_minor);
|
||||
|
||||
// Returns a layout with descending ((i.e. {n, n-1, ..., 0}) minor-to-major
|
||||
// Returns a layout with descending ((i.e. {n-1, n-2, ... 0}) minor-to-major
|
||||
// dimensions.
|
||||
static Layout MakeDescendingLayout(int64 rank);
|
||||
|
||||
// Returns a layout with ascending ((i.e. {0, 1, ... n-1}) minor-to-major
|
||||
// dimensions.
|
||||
static Layout MakeAscendingLayout(int64 rank);
|
||||
|
||||
// Returns default layout for the given shape.
|
||||
static Layout GetDefaultLayoutForShape(const Shape& shape);
|
||||
|
||||
|
@ -251,6 +251,24 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) {
|
||||
ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25}))));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, MakeDescending) {
|
||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(5),
|
||||
LayoutUtil::MakeLayout({4, 3, 2, 1, 0})));
|
||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(1),
|
||||
LayoutUtil::MakeLayout({0})));
|
||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(0),
|
||||
LayoutUtil::MakeLayout({})));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, MakeAscending) {
|
||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(5),
|
||||
LayoutUtil::MakeLayout({0, 1, 2, 3, 4})));
|
||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(1),
|
||||
LayoutUtil::MakeLayout({0})));
|
||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(0),
|
||||
LayoutUtil::MakeLayout({})));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, HumanStringWithTiling) {
|
||||
Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3, 4}, {0, 1, 2});
|
||||
Tile* tile;
|
||||
|
Loading…
Reference in New Issue
Block a user