diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index afd7141477f..8d38de0e5c7 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -84,6 +84,12 @@ void SetDefaultLayoutToContainer(T* minor_to_major) { return MakeLayout(layout); } +/* static */ Layout LayoutUtil::MakeAscendingLayout(int64 rank) { + std::vector layout(rank); + std::iota(layout.begin(), layout.end(), static_cast(0)); + return MakeLayout(layout); +} + /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( absl::Span major_to_minor) { Layout layout; diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 60e135de354..7abedbd2a64 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -45,10 +45,14 @@ class LayoutUtil { static Layout MakeLayoutFromMajorToMinor( absl::Span major_to_minor); - // Returns a layout with descending ((i.e. {n, n-1, ..., 0}) minor-to-major + // Returns a layout with descending ((i.e. {n-1, n-2, ... 0}) minor-to-major // dimensions. static Layout MakeDescendingLayout(int64 rank); + // Returns a layout with ascending ((i.e. {0, 1, ... n-1}) minor-to-major + // dimensions. + static Layout MakeAscendingLayout(int64 rank); + // Returns default layout for the given shape. static Layout GetDefaultLayoutForShape(const Shape& shape); diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 398baa13fca..07fa17cdc81 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -251,6 +251,24 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); } +TEST_F(LayoutUtilTest, MakeDescending) { + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(5), + LayoutUtil::MakeLayout({4, 3, 2, 1, 0}))); + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(1), + LayoutUtil::MakeLayout({0}))); + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeDescendingLayout(0), + LayoutUtil::MakeLayout({}))); +} + +TEST_F(LayoutUtilTest, MakeAscending) { + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(5), + LayoutUtil::MakeLayout({0, 1, 2, 3, 4}))); + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(1), + LayoutUtil::MakeLayout({0}))); + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeAscendingLayout(0), + LayoutUtil::MakeLayout({}))); +} + TEST_F(LayoutUtilTest, HumanStringWithTiling) { Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3, 4}, {0, 1, 2}); Tile* tile;