Add syntactic sugar for strided memref parsing.

This CL implements the last remaining bit of the [strided memref proposal](https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/MaL8m2nXuio).

The syntax is a bit more explicit than what was originally proposed and resembles:
  `memref<?x?xf32, offset: 0 strides: [?, 1]>`

Nonnegative strides and offsets are currently supported. Future extensions will include negative strides.

This also gives a concrete example of syntactic sugar for the ([RFC] Proposed Changes to MemRef and Tensor MLIR Types)[https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/-wKHANzDNTg].

The underlying implementation still uses AffineMap layout.

PiperOrigin-RevId: 272717437
This commit is contained in:
Nicolas Vasilache 2019-10-03 12:33:47 -07:00 committed by TensorFlower Gardener
parent fbdc707d14
commit 8ce19dbfef
3 changed files with 104 additions and 13 deletions

View File

@ -492,7 +492,13 @@ static void extractStrides(AffineExpr e, MutableArrayRef<int64_t> strides,
return;
}
if (bin.getKind() == AffineExprKind::Mul) {
auto dim = bin.getLHS().cast<AffineDimExpr>();
// LHS may be more complex than just a single dim (e.g. multiple syms and
// dims). Bail out for now and revisit when we have evidence this is needed.
auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
if (!dim) {
failed = true;
return;
}
auto cst = bin.getRHS().dyn_cast<AffineConstantExpr>();
if (!cst) {
strides[dim.getPosition()] = MemRefType::kDynamicStrideOrOffset;

View File

@ -38,7 +38,6 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
@ -211,6 +210,14 @@ public:
bool allowDynamic = true);
ParseResult parseXInDimensionList();
/// Parse strided layout specification.
ParseResult parseStridedLayout(int64_t &offset,
SmallVectorImpl<int64_t> &strides);
// Parse a brace-delimiter list of comma-separated integers with `?` as an
// unknown marker.
ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions);
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
@ -634,6 +641,40 @@ Type Parser::parseFunctionType() {
return builder.getFunctionType(arguments, results);
}
/// Parse the offset and strides from a strided layout specification.
///
/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
///
ParseResult Parser::parseStridedLayout(int64_t &offset,
SmallVectorImpl<int64_t> &strides) {
// Parse offset.
consumeToken(Token::kw_offset);
if (!consumeIf(Token::colon))
return emitError("expected colon after `offset` keyword");
auto maybeOffset = getToken().getUnsignedIntegerValue();
bool question = getToken().is(Token::question);
if (!maybeOffset && !question)
return emitError("invalid offset");
offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue())
: MemRefType::kDynamicStrideOrOffset;
consumeToken();
if (!consumeIf(Token::comma))
return emitError("expected comma after offset value");
// Parse stride list.
if (!consumeIf(Token::kw_strides))
return emitError("expected `strides` keyword after offset specification");
if (!consumeIf(Token::colon))
return emitError("expected colon after `strides` keyword");
if (failed(parseStrideList(strides)))
return emitError("invalid braces-enclosed stride list");
if (llvm::any_of(strides, [](int64_t st) { return st == 0; }))
return emitError("invalid memref stride");
return success();
}
/// Parse a memref type.
///
/// memref-type ::= `memref` `<` dimension-list-ranked type
@ -675,18 +716,28 @@ Type Parser::parseMemRefType() {
consumeToken(Token::integer);
parsedMemorySpace = true;
} else {
// Parse affine map.
if (parsedMemorySpace)
return emitError("affine map after memory space in memref type");
auto affineMap = parseAttribute();
if (!affineMap)
return failure();
// Verify that the parsed attribute is an affine map.
if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>())
affineMapComposition.push_back(affineMapAttr.getValue());
else
return emitError("expected affine map in memref type");
return emitError("expected memory space to be last in memref type");
if (getToken().is(Token::kw_offset)) {
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(parseStridedLayout(offset, strides)))
return failure();
// Construct strided affine map.
auto map = makeStridedLinearLayoutMap(strides, offset,
elementType.getContext());
affineMapComposition.push_back(map);
} else {
// Parse affine map.
auto affineMap = parseAttribute();
if (!affineMap)
return failure();
// Verify that the parsed attribute is an affine map.
if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>())
affineMapComposition.push_back(affineMapAttr.getValue());
else
return emitError("expected affine map in memref type");
}
}
return success();
};
@ -935,6 +986,38 @@ ParseResult Parser::parseXInDimensionList() {
return success();
}
// Parse a comma-separated list of dimensions, possibly empty:
// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
if (!consumeIf(Token::l_square))
return failure();
// Empty list early exit.
if (consumeIf(Token::r_square))
return success();
while (true) {
if (consumeIf(Token::question)) {
dimensions.push_back(MemRefType::kDynamicStrideOrOffset);
} else {
// This must be an integer value.
int64_t val;
if (getToken().getSpelling().getAsInteger(10, val))
return emitError("invalid integer value: ") << getToken().getSpelling();
// Make sure it is not the one value for `?`.
if (ShapedType::isDynamic(val))
return emitError("invalid integer value: ")
<< getToken().getSpelling()
<< ", use `?` to specify a dynamic dimension";
dimensions.push_back(val);
consumeToken(Token::integer);
}
if (!consumeIf(Token::comma))
break;
}
if (!consumeIf(Token::r_square))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// Attribute parsing.
//===----------------------------------------------------------------------===//

View File

@ -110,10 +110,12 @@ TOK_KEYWORD(memref)
TOK_KEYWORD(min)
TOK_KEYWORD(mod)
TOK_KEYWORD(none)
TOK_KEYWORD(offset)
TOK_KEYWORD(opaque)
TOK_KEYWORD(size)
TOK_KEYWORD(sparse)
TOK_KEYWORD(step)
TOK_KEYWORD(strides)
TOK_KEYWORD(symbol)
TOK_KEYWORD(tensor)
TOK_KEYWORD(to)