Add syntactic sugar for strided memref parsing.
This CL implements the last remaining bit of the [strided memref proposal](https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/MaL8m2nXuio). The syntax is a bit more explicit than what was originally proposed and resembles: `memref<?x?xf32, offset: 0 strides: [?, 1]>` Nonnegative strides and offsets are currently supported. Future extensions will include negative strides. This also gives a concrete example of syntactic sugar for the ([RFC] Proposed Changes to MemRef and Tensor MLIR Types)[https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/-wKHANzDNTg]. The underlying implementation still uses AffineMap layout. PiperOrigin-RevId: 272717437
This commit is contained in:
parent
fbdc707d14
commit
8ce19dbfef
8
third_party/mlir/lib/IR/StandardTypes.cpp
vendored
8
third_party/mlir/lib/IR/StandardTypes.cpp
vendored
@ -492,7 +492,13 @@ static void extractStrides(AffineExpr e, MutableArrayRef<int64_t> strides,
|
||||
return;
|
||||
}
|
||||
if (bin.getKind() == AffineExprKind::Mul) {
|
||||
auto dim = bin.getLHS().cast<AffineDimExpr>();
|
||||
// LHS may be more complex than just a single dim (e.g. multiple syms and
|
||||
// dims). Bail out for now and revisit when we have evidence this is needed.
|
||||
auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
|
||||
if (!dim) {
|
||||
failed = true;
|
||||
return;
|
||||
}
|
||||
auto cst = bin.getRHS().dyn_cast<AffineConstantExpr>();
|
||||
if (!cst) {
|
||||
strides[dim.getPosition()] = MemRefType::kDynamicStrideOrOffset;
|
||||
|
107
third_party/mlir/lib/Parser/Parser.cpp
vendored
107
third_party/mlir/lib/Parser/Parser.cpp
vendored
@ -38,7 +38,6 @@
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/ADT/bit.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
@ -211,6 +210,14 @@ public:
|
||||
bool allowDynamic = true);
|
||||
ParseResult parseXInDimensionList();
|
||||
|
||||
/// Parse strided layout specification.
|
||||
ParseResult parseStridedLayout(int64_t &offset,
|
||||
SmallVectorImpl<int64_t> &strides);
|
||||
|
||||
// Parse a brace-delimiter list of comma-separated integers with `?` as an
|
||||
// unknown marker.
|
||||
ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Attribute Parsing
|
||||
//===--------------------------------------------------------------------===//
|
||||
@ -634,6 +641,40 @@ Type Parser::parseFunctionType() {
|
||||
return builder.getFunctionType(arguments, results);
|
||||
}
|
||||
|
||||
/// Parse the offset and strides from a strided layout specification.
|
||||
///
|
||||
/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
|
||||
///
|
||||
ParseResult Parser::parseStridedLayout(int64_t &offset,
|
||||
SmallVectorImpl<int64_t> &strides) {
|
||||
// Parse offset.
|
||||
consumeToken(Token::kw_offset);
|
||||
if (!consumeIf(Token::colon))
|
||||
return emitError("expected colon after `offset` keyword");
|
||||
auto maybeOffset = getToken().getUnsignedIntegerValue();
|
||||
bool question = getToken().is(Token::question);
|
||||
if (!maybeOffset && !question)
|
||||
return emitError("invalid offset");
|
||||
offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue())
|
||||
: MemRefType::kDynamicStrideOrOffset;
|
||||
consumeToken();
|
||||
|
||||
if (!consumeIf(Token::comma))
|
||||
return emitError("expected comma after offset value");
|
||||
|
||||
// Parse stride list.
|
||||
if (!consumeIf(Token::kw_strides))
|
||||
return emitError("expected `strides` keyword after offset specification");
|
||||
if (!consumeIf(Token::colon))
|
||||
return emitError("expected colon after `strides` keyword");
|
||||
if (failed(parseStrideList(strides)))
|
||||
return emitError("invalid braces-enclosed stride list");
|
||||
if (llvm::any_of(strides, [](int64_t st) { return st == 0; }))
|
||||
return emitError("invalid memref stride");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Parse a memref type.
|
||||
///
|
||||
/// memref-type ::= `memref` `<` dimension-list-ranked type
|
||||
@ -675,18 +716,28 @@ Type Parser::parseMemRefType() {
|
||||
consumeToken(Token::integer);
|
||||
parsedMemorySpace = true;
|
||||
} else {
|
||||
// Parse affine map.
|
||||
if (parsedMemorySpace)
|
||||
return emitError("affine map after memory space in memref type");
|
||||
auto affineMap = parseAttribute();
|
||||
if (!affineMap)
|
||||
return failure();
|
||||
|
||||
// Verify that the parsed attribute is an affine map.
|
||||
if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>())
|
||||
affineMapComposition.push_back(affineMapAttr.getValue());
|
||||
else
|
||||
return emitError("expected affine map in memref type");
|
||||
return emitError("expected memory space to be last in memref type");
|
||||
if (getToken().is(Token::kw_offset)) {
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
if (failed(parseStridedLayout(offset, strides)))
|
||||
return failure();
|
||||
// Construct strided affine map.
|
||||
auto map = makeStridedLinearLayoutMap(strides, offset,
|
||||
elementType.getContext());
|
||||
affineMapComposition.push_back(map);
|
||||
} else {
|
||||
// Parse affine map.
|
||||
auto affineMap = parseAttribute();
|
||||
if (!affineMap)
|
||||
return failure();
|
||||
// Verify that the parsed attribute is an affine map.
|
||||
if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>())
|
||||
affineMapComposition.push_back(affineMapAttr.getValue());
|
||||
else
|
||||
return emitError("expected affine map in memref type");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
};
|
||||
@ -935,6 +986,38 @@ ParseResult Parser::parseXInDimensionList() {
|
||||
return success();
|
||||
}
|
||||
|
||||
// Parse a comma-separated list of dimensions, possibly empty:
|
||||
// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
|
||||
ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
|
||||
if (!consumeIf(Token::l_square))
|
||||
return failure();
|
||||
// Empty list early exit.
|
||||
if (consumeIf(Token::r_square))
|
||||
return success();
|
||||
while (true) {
|
||||
if (consumeIf(Token::question)) {
|
||||
dimensions.push_back(MemRefType::kDynamicStrideOrOffset);
|
||||
} else {
|
||||
// This must be an integer value.
|
||||
int64_t val;
|
||||
if (getToken().getSpelling().getAsInteger(10, val))
|
||||
return emitError("invalid integer value: ") << getToken().getSpelling();
|
||||
// Make sure it is not the one value for `?`.
|
||||
if (ShapedType::isDynamic(val))
|
||||
return emitError("invalid integer value: ")
|
||||
<< getToken().getSpelling()
|
||||
<< ", use `?` to specify a dynamic dimension";
|
||||
dimensions.push_back(val);
|
||||
consumeToken(Token::integer);
|
||||
}
|
||||
if (!consumeIf(Token::comma))
|
||||
break;
|
||||
}
|
||||
if (!consumeIf(Token::r_square))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attribute parsing.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
2
third_party/mlir/lib/Parser/TokenKinds.def
vendored
2
third_party/mlir/lib/Parser/TokenKinds.def
vendored
@ -110,10 +110,12 @@ TOK_KEYWORD(memref)
|
||||
TOK_KEYWORD(min)
|
||||
TOK_KEYWORD(mod)
|
||||
TOK_KEYWORD(none)
|
||||
TOK_KEYWORD(offset)
|
||||
TOK_KEYWORD(opaque)
|
||||
TOK_KEYWORD(size)
|
||||
TOK_KEYWORD(sparse)
|
||||
TOK_KEYWORD(step)
|
||||
TOK_KEYWORD(strides)
|
||||
TOK_KEYWORD(symbol)
|
||||
TOK_KEYWORD(tensor)
|
||||
TOK_KEYWORD(to)
|
||||
|
Loading…
Reference in New Issue
Block a user