From 57fe05f57dffc6d5ad356f8e1cc9d4f22cc12116 Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Wed, 29 Jul 2020 13:54:33 -0700 Subject: [PATCH] [mlir] Enhance mlir::MemRefType -> xla::Shape conversion. PiperOrigin-RevId: 323861437 Change-Id: If60b33c5b69a81b7f05843f42b602ce2945bed95 --- tensorflow/compiler/mlir/xla/BUILD | 1 + tensorflow/compiler/mlir/xla/type_to_shape.cc | 34 ++++++++++++++++--- .../compiler/mlir/xla/type_to_shape_test.cc | 18 ++++++++++ 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index a6eb9f2fe1c..0a7e44a275f 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -209,6 +209,7 @@ tf_cc_test( name = "type_to_shape_test", srcs = ["type_to_shape_test.cc"], deps = [ + ":hlo_utils", ":type_to_shape", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index b684abde7a5..afc36916348 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -145,11 +145,37 @@ Shape TypeToShape(mlir::Type type) { // For the primitive type case, the shape of the memref is similar to the // vector type case (i.e., it is, modulo the layout, the same dimensions // and primitive type). - // Currently we only return shapes for identity affine maps. - // TODO(andydavis) Map affine map layout function to XLA layout. - if (m.getAffineMaps().empty() || - (m.getAffineMaps().size() == 1 && m.getAffineMaps()[0].isIdentity())) + if (m.getAffineMaps().empty()) return ShapeUtil::MakeShape(primitive_type, span); + + if (m.getAffineMaps().size() == 1) { + llvm::SmallVector strides; + int64_t offset; + if (failed(mlir::getStridesAndOffset(m, strides, offset))) return {}; + + llvm::SmallVector, 4> strides_with_indices; + for (const auto& e : llvm::enumerate(strides)) { + strides_with_indices.push_back({e.value(), e.index()}); + } + std::sort(strides_with_indices.begin(), strides_with_indices.end()); + + llvm::SmallVector minor_to_major; + int64_t stride = 1; + for (const auto& pr : strides_with_indices) { + minor_to_major.push_back(pr.second); + + // Either the affine map is not perfectly strided, or the dimensions + // recovered from strides don't match the actual dimensions in shapes. + if (stride != pr.first) return {}; + + stride *= m.getShape()[pr.second]; + } + + llvm::SmallVector dimensions(m.getShape().begin(), + m.getShape().end()); + return ::xla::ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, + minor_to_major); + } break; } case mlir::StandardTypes::RankedTensor: { diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc index b2a7cb85686..a4a2bc42d99 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -178,5 +179,22 @@ TEST(TypeToShapeTest, ConvertWithShapeRepresentationFn) { EXPECT_EQ(captured_tensor_shape, tensorflow::TensorShape({1, 2, 3})); } +TEST(TypeToShapeTest, ConvertMemRefToShape) { + Shape shape = ShapeUtil::MakeShapeWithLayout(PrimitiveType::F32, {10, 20, 30}, + {2, 0, 1}); + MLIRContext context; + mlir::Builder builder(&context); + + StatusOr mlir_type = + ConvertShapeToType(shape, builder); + ASSERT_TRUE(mlir_type.ok()); + mlir::Type type = mlir_type.ConsumeValueOrDie(); + Shape converted = TypeToShape(type); + EXPECT_TRUE(ShapeUtil::Equal( + converted, ShapeUtil::MakeShapeWithLayout(PrimitiveType::F32, + {10, 20, 30}, {2, 0, 1}))); + EXPECT_TRUE(ShapeUtil::Equal(converted, shape)); +} + } // namespace } // namespace xla