Add MakeAscendingLayout().

Also, add unit test for MakeAscendingLayout() and (while at it) MakeDescendingLayout().

PiperOrigin-RevId: 350767284
Change-Id: I348f5628c5a9458cd2f3d75df2e25c5395639586
This commit is contained in:
A. Unique TensorFlower 2021-01-08 08:18:04 -08:00 committed by TensorFlower Gardener
parent cccd2f2d7e
commit b74444efa8
3 changed files with 29 additions and 1 deletions

View File

@ -84,6 +84,12 @@ void SetDefaultLayoutToContainer(T* minor_to_major) {
return MakeLayout(layout);
}
/* static */ Layout LayoutUtil::MakeAscendingLayout(int64 rank) {
std::vector<int64> layout(rank);
std::iota(layout.begin(), layout.end(), static_cast<int64>(0));
return MakeLayout(layout);
}
/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
absl::Span<const int64> major_to_minor) {
Layout layout;

View File

@ -45,10 +45,14 @@ class LayoutUtil {
static Layout MakeLayoutFromMajorToMinor(
absl::Span<const int64> major_to_minor);
// Returns a layout with descending ((i.e. {n, n-1, ..., 0}) minor-to-major
// Returns a layout with descending ((i.e. {n-1, n-2, ... 0}) minor-to-major
// dimensions.
static Layout MakeDescendingLayout(int64 rank);
// Returns a layout with ascending ((i.e. {0, 1, ... n-1}) minor-to-major
// dimensions.
static Layout MakeAscendingLayout(int64 rank);
// Returns default layout for the given shape.
static Layout GetDefaultLayoutForShape(const Shape& shape);

View File

@ -251,6 +251,24 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) {
ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25}))));
}
TEST_F(LayoutUtilTest, MakeDescending) {
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(5),
LayoutUtil::MakeLayout({4, 3, 2, 1, 0})));
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(1),
LayoutUtil::MakeLayout({0})));
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(0),
LayoutUtil::MakeLayout({})));
}
TEST_F(LayoutUtilTest, MakeAscending) {
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(5),
LayoutUtil::MakeLayout({0, 1, 2, 3, 4})));
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(1),
LayoutUtil::MakeLayout({0})));
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(0),
LayoutUtil::MakeLayout({})));
}
TEST_F(LayoutUtilTest, HumanStringWithTiling) {
Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3, 4}, {0, 1, 2});
Tile* tile;