[mlir] Enhance mlir::MemRefType -> xla::Shape conversion.
PiperOrigin-RevId: 323861437 Change-Id: If60b33c5b69a81b7f05843f42b602ce2945bed95
This commit is contained in:
parent
d6066885d7
commit
57fe05f57d
@ -209,6 +209,7 @@ tf_cc_test(
|
||||
name = "type_to_shape_test",
|
||||
srcs = ["type_to_shape_test.cc"],
|
||||
deps = [
|
||||
":hlo_utils",
|
||||
":type_to_shape",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
|
@ -145,11 +145,37 @@ Shape TypeToShape(mlir::Type type) {
|
||||
// For the primitive type case, the shape of the memref is similar to the
|
||||
// vector type case (i.e., it is, modulo the layout, the same dimensions
|
||||
// and primitive type).
|
||||
// Currently we only return shapes for identity affine maps.
|
||||
// TODO(andydavis) Map affine map layout function to XLA layout.
|
||||
if (m.getAffineMaps().empty() ||
|
||||
(m.getAffineMaps().size() == 1 && m.getAffineMaps()[0].isIdentity()))
|
||||
if (m.getAffineMaps().empty())
|
||||
return ShapeUtil::MakeShape(primitive_type, span);
|
||||
|
||||
if (m.getAffineMaps().size() == 1) {
|
||||
llvm::SmallVector<int64_t, 4> strides;
|
||||
int64_t offset;
|
||||
if (failed(mlir::getStridesAndOffset(m, strides, offset))) return {};
|
||||
|
||||
llvm::SmallVector<std::pair<int64_t, int>, 4> strides_with_indices;
|
||||
for (const auto& e : llvm::enumerate(strides)) {
|
||||
strides_with_indices.push_back({e.value(), e.index()});
|
||||
}
|
||||
std::sort(strides_with_indices.begin(), strides_with_indices.end());
|
||||
|
||||
llvm::SmallVector<int64, 4> minor_to_major;
|
||||
int64_t stride = 1;
|
||||
for (const auto& pr : strides_with_indices) {
|
||||
minor_to_major.push_back(pr.second);
|
||||
|
||||
// Either the affine map is not perfectly strided, or the dimensions
|
||||
// recovered from strides don't match the actual dimensions in shapes.
|
||||
if (stride != pr.first) return {};
|
||||
|
||||
stride *= m.getShape()[pr.second];
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64, 4> dimensions(m.getShape().begin(),
|
||||
m.getShape().end());
|
||||
return ::xla::ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions,
|
||||
minor_to_major);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::RankedTensor: {
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -178,5 +179,22 @@ TEST(TypeToShapeTest, ConvertWithShapeRepresentationFn) {
|
||||
EXPECT_EQ(captured_tensor_shape, tensorflow::TensorShape({1, 2, 3}));
|
||||
}
|
||||
|
||||
TEST(TypeToShapeTest, ConvertMemRefToShape) {
|
||||
Shape shape = ShapeUtil::MakeShapeWithLayout(PrimitiveType::F32, {10, 20, 30},
|
||||
{2, 0, 1});
|
||||
MLIRContext context;
|
||||
mlir::Builder builder(&context);
|
||||
|
||||
StatusOr<mlir::Type> mlir_type =
|
||||
ConvertShapeToType<MemRefType>(shape, builder);
|
||||
ASSERT_TRUE(mlir_type.ok());
|
||||
mlir::Type type = mlir_type.ConsumeValueOrDie();
|
||||
Shape converted = TypeToShape(type);
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
converted, ShapeUtil::MakeShapeWithLayout(PrimitiveType::F32,
|
||||
{10, 20, 30}, {2, 0, 1})));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(converted, shape));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user