[mlir] Enhance mlir::MemRefType -> xla::Shape conversion.

PiperOrigin-RevId: 323861437
Change-Id: If60b33c5b69a81b7f05843f42b602ce2945bed95
This commit is contained in:
Tim Shen 2020-07-29 13:54:33 -07:00 committed by TensorFlower Gardener
parent d6066885d7
commit 57fe05f57d
3 changed files with 49 additions and 4 deletions

View File

@ -209,6 +209,7 @@ tf_cc_test(
name = "type_to_shape_test",
srcs = ["type_to_shape_test.cc"],
deps = [
":hlo_utils",
":type_to_shape",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",

View File

@ -145,11 +145,37 @@ Shape TypeToShape(mlir::Type type) {
// For the primitive type case, the shape of the memref is similar to the
// vector type case (i.e., it is, modulo the layout, the same dimensions
// and primitive type).
// Currently we only return shapes for identity affine maps.
// TODO(andydavis) Map affine map layout function to XLA layout.
if (m.getAffineMaps().empty() ||
(m.getAffineMaps().size() == 1 && m.getAffineMaps()[0].isIdentity()))
if (m.getAffineMaps().empty())
return ShapeUtil::MakeShape(primitive_type, span);
if (m.getAffineMaps().size() == 1) {
llvm::SmallVector<int64_t, 4> strides;
int64_t offset;
if (failed(mlir::getStridesAndOffset(m, strides, offset))) return {};
llvm::SmallVector<std::pair<int64_t, int>, 4> strides_with_indices;
for (const auto& e : llvm::enumerate(strides)) {
strides_with_indices.push_back({e.value(), e.index()});
}
std::sort(strides_with_indices.begin(), strides_with_indices.end());
llvm::SmallVector<int64, 4> minor_to_major;
int64_t stride = 1;
for (const auto& pr : strides_with_indices) {
minor_to_major.push_back(pr.second);
// Either the affine map is not perfectly strided, or the dimensions
// recovered from strides don't match the actual dimensions in shapes.
if (stride != pr.first) return {};
stride *= m.getShape()[pr.second];
}
llvm::SmallVector<int64, 4> dimensions(m.getShape().begin(),
m.getShape().end());
return ::xla::ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions,
minor_to_major);
}
break;
}
case mlir::StandardTypes::RankedTensor: {

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -178,5 +179,22 @@ TEST(TypeToShapeTest, ConvertWithShapeRepresentationFn) {
EXPECT_EQ(captured_tensor_shape, tensorflow::TensorShape({1, 2, 3}));
}
TEST(TypeToShapeTest, ConvertMemRefToShape) {
Shape shape = ShapeUtil::MakeShapeWithLayout(PrimitiveType::F32, {10, 20, 30},
{2, 0, 1});
MLIRContext context;
mlir::Builder builder(&context);
StatusOr<mlir::Type> mlir_type =
ConvertShapeToType<MemRefType>(shape, builder);
ASSERT_TRUE(mlir_type.ok());
mlir::Type type = mlir_type.ConsumeValueOrDie();
Shape converted = TypeToShape(type);
EXPECT_TRUE(ShapeUtil::Equal(
converted, ShapeUtil::MakeShapeWithLayout(PrimitiveType::F32,
{10, 20, 30}, {2, 0, 1})));
EXPECT_TRUE(ShapeUtil::Equal(converted, shape));
}
} // namespace
} // namespace xla