Add syntactic sugar for strided memref parsing.
This CL implements the last remaining bit of the [strided memref proposal](https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/MaL8m2nXuio). The syntax is a bit more explicit than what was originally proposed and resembles: `memref<?x?xf32, offset: 0 strides: [?, 1]>` Nonnegative strides and offsets are currently supported. Future extensions will include negative strides. This also gives a concrete example of syntactic sugar for the ([RFC] Proposed Changes to MemRef and Tensor MLIR Types)[https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/-wKHANzDNTg]. The underlying implementation still uses AffineMap layout. PiperOrigin-RevId: 272717437
This commit is contained in:
parent
fbdc707d14
commit
8ce19dbfef
8
third_party/mlir/lib/IR/StandardTypes.cpp
vendored
8
third_party/mlir/lib/IR/StandardTypes.cpp
vendored
@ -492,7 +492,13 @@ static void extractStrides(AffineExpr e, MutableArrayRef<int64_t> strides,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (bin.getKind() == AffineExprKind::Mul) {
|
if (bin.getKind() == AffineExprKind::Mul) {
|
||||||
auto dim = bin.getLHS().cast<AffineDimExpr>();
|
// LHS may be more complex than just a single dim (e.g. multiple syms and
|
||||||
|
// dims). Bail out for now and revisit when we have evidence this is needed.
|
||||||
|
auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
|
||||||
|
if (!dim) {
|
||||||
|
failed = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
auto cst = bin.getRHS().dyn_cast<AffineConstantExpr>();
|
auto cst = bin.getRHS().dyn_cast<AffineConstantExpr>();
|
||||||
if (!cst) {
|
if (!cst) {
|
||||||
strides[dim.getPosition()] = MemRefType::kDynamicStrideOrOffset;
|
strides[dim.getPosition()] = MemRefType::kDynamicStrideOrOffset;
|
||||||
|
91
third_party/mlir/lib/Parser/Parser.cpp
vendored
91
third_party/mlir/lib/Parser/Parser.cpp
vendored
@ -38,7 +38,6 @@
|
|||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/StringSet.h"
|
#include "llvm/ADT/StringSet.h"
|
||||||
#include "llvm/ADT/bit.h"
|
#include "llvm/ADT/bit.h"
|
||||||
#include "llvm/Support/MemoryBuffer.h"
|
|
||||||
#include "llvm/Support/PrettyStackTrace.h"
|
#include "llvm/Support/PrettyStackTrace.h"
|
||||||
#include "llvm/Support/SMLoc.h"
|
#include "llvm/Support/SMLoc.h"
|
||||||
#include "llvm/Support/SourceMgr.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
@ -211,6 +210,14 @@ public:
|
|||||||
bool allowDynamic = true);
|
bool allowDynamic = true);
|
||||||
ParseResult parseXInDimensionList();
|
ParseResult parseXInDimensionList();
|
||||||
|
|
||||||
|
/// Parse strided layout specification.
|
||||||
|
ParseResult parseStridedLayout(int64_t &offset,
|
||||||
|
SmallVectorImpl<int64_t> &strides);
|
||||||
|
|
||||||
|
// Parse a brace-delimiter list of comma-separated integers with `?` as an
|
||||||
|
// unknown marker.
|
||||||
|
ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions);
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Attribute Parsing
|
// Attribute Parsing
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
@ -634,6 +641,40 @@ Type Parser::parseFunctionType() {
|
|||||||
return builder.getFunctionType(arguments, results);
|
return builder.getFunctionType(arguments, results);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parse the offset and strides from a strided layout specification.
|
||||||
|
///
|
||||||
|
/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
|
||||||
|
///
|
||||||
|
ParseResult Parser::parseStridedLayout(int64_t &offset,
|
||||||
|
SmallVectorImpl<int64_t> &strides) {
|
||||||
|
// Parse offset.
|
||||||
|
consumeToken(Token::kw_offset);
|
||||||
|
if (!consumeIf(Token::colon))
|
||||||
|
return emitError("expected colon after `offset` keyword");
|
||||||
|
auto maybeOffset = getToken().getUnsignedIntegerValue();
|
||||||
|
bool question = getToken().is(Token::question);
|
||||||
|
if (!maybeOffset && !question)
|
||||||
|
return emitError("invalid offset");
|
||||||
|
offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue())
|
||||||
|
: MemRefType::kDynamicStrideOrOffset;
|
||||||
|
consumeToken();
|
||||||
|
|
||||||
|
if (!consumeIf(Token::comma))
|
||||||
|
return emitError("expected comma after offset value");
|
||||||
|
|
||||||
|
// Parse stride list.
|
||||||
|
if (!consumeIf(Token::kw_strides))
|
||||||
|
return emitError("expected `strides` keyword after offset specification");
|
||||||
|
if (!consumeIf(Token::colon))
|
||||||
|
return emitError("expected colon after `strides` keyword");
|
||||||
|
if (failed(parseStrideList(strides)))
|
||||||
|
return emitError("invalid braces-enclosed stride list");
|
||||||
|
if (llvm::any_of(strides, [](int64_t st) { return st == 0; }))
|
||||||
|
return emitError("invalid memref stride");
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
/// Parse a memref type.
|
/// Parse a memref type.
|
||||||
///
|
///
|
||||||
/// memref-type ::= `memref` `<` dimension-list-ranked type
|
/// memref-type ::= `memref` `<` dimension-list-ranked type
|
||||||
@ -675,19 +716,29 @@ Type Parser::parseMemRefType() {
|
|||||||
consumeToken(Token::integer);
|
consumeToken(Token::integer);
|
||||||
parsedMemorySpace = true;
|
parsedMemorySpace = true;
|
||||||
} else {
|
} else {
|
||||||
// Parse affine map.
|
|
||||||
if (parsedMemorySpace)
|
if (parsedMemorySpace)
|
||||||
return emitError("affine map after memory space in memref type");
|
return emitError("expected memory space to be last in memref type");
|
||||||
|
if (getToken().is(Token::kw_offset)) {
|
||||||
|
int64_t offset;
|
||||||
|
SmallVector<int64_t, 4> strides;
|
||||||
|
if (failed(parseStridedLayout(offset, strides)))
|
||||||
|
return failure();
|
||||||
|
// Construct strided affine map.
|
||||||
|
auto map = makeStridedLinearLayoutMap(strides, offset,
|
||||||
|
elementType.getContext());
|
||||||
|
affineMapComposition.push_back(map);
|
||||||
|
} else {
|
||||||
|
// Parse affine map.
|
||||||
auto affineMap = parseAttribute();
|
auto affineMap = parseAttribute();
|
||||||
if (!affineMap)
|
if (!affineMap)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Verify that the parsed attribute is an affine map.
|
// Verify that the parsed attribute is an affine map.
|
||||||
if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>())
|
if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>())
|
||||||
affineMapComposition.push_back(affineMapAttr.getValue());
|
affineMapComposition.push_back(affineMapAttr.getValue());
|
||||||
else
|
else
|
||||||
return emitError("expected affine map in memref type");
|
return emitError("expected affine map in memref type");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return success();
|
return success();
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -935,6 +986,38 @@ ParseResult Parser::parseXInDimensionList() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse a comma-separated list of dimensions, possibly empty:
|
||||||
|
// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
|
||||||
|
ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
|
||||||
|
if (!consumeIf(Token::l_square))
|
||||||
|
return failure();
|
||||||
|
// Empty list early exit.
|
||||||
|
if (consumeIf(Token::r_square))
|
||||||
|
return success();
|
||||||
|
while (true) {
|
||||||
|
if (consumeIf(Token::question)) {
|
||||||
|
dimensions.push_back(MemRefType::kDynamicStrideOrOffset);
|
||||||
|
} else {
|
||||||
|
// This must be an integer value.
|
||||||
|
int64_t val;
|
||||||
|
if (getToken().getSpelling().getAsInteger(10, val))
|
||||||
|
return emitError("invalid integer value: ") << getToken().getSpelling();
|
||||||
|
// Make sure it is not the one value for `?`.
|
||||||
|
if (ShapedType::isDynamic(val))
|
||||||
|
return emitError("invalid integer value: ")
|
||||||
|
<< getToken().getSpelling()
|
||||||
|
<< ", use `?` to specify a dynamic dimension";
|
||||||
|
dimensions.push_back(val);
|
||||||
|
consumeToken(Token::integer);
|
||||||
|
}
|
||||||
|
if (!consumeIf(Token::comma))
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (!consumeIf(Token::r_square))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Attribute parsing.
|
// Attribute parsing.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
2
third_party/mlir/lib/Parser/TokenKinds.def
vendored
2
third_party/mlir/lib/Parser/TokenKinds.def
vendored
@ -110,10 +110,12 @@ TOK_KEYWORD(memref)
|
|||||||
TOK_KEYWORD(min)
|
TOK_KEYWORD(min)
|
||||||
TOK_KEYWORD(mod)
|
TOK_KEYWORD(mod)
|
||||||
TOK_KEYWORD(none)
|
TOK_KEYWORD(none)
|
||||||
|
TOK_KEYWORD(offset)
|
||||||
TOK_KEYWORD(opaque)
|
TOK_KEYWORD(opaque)
|
||||||
TOK_KEYWORD(size)
|
TOK_KEYWORD(size)
|
||||||
TOK_KEYWORD(sparse)
|
TOK_KEYWORD(sparse)
|
||||||
TOK_KEYWORD(step)
|
TOK_KEYWORD(step)
|
||||||
|
TOK_KEYWORD(strides)
|
||||||
TOK_KEYWORD(symbol)
|
TOK_KEYWORD(symbol)
|
||||||
TOK_KEYWORD(tensor)
|
TOK_KEYWORD(tensor)
|
||||||
TOK_KEYWORD(to)
|
TOK_KEYWORD(to)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user