diff --git a/third_party/mlir/lib/Parser/Parser.cpp b/third_party/mlir/lib/Parser/Parser.cpp index a5dd98138a2..8af99179f6f 100644 --- a/third_party/mlir/lib/Parser/Parser.cpp +++ b/third_party/mlir/lib/Parser/Parser.cpp @@ -1152,6 +1152,19 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) { return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); } +/// Construct a float attribute bitwise equivalent to the integer literal. +static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type, + uint64_t value) { + int width = type.getIntOrFloatBitWidth(); + APInt apInt(width, value); + if (apInt != value) { + p->emitError("hexadecimal float constant out of range for type"); + return nullptr; + } + APFloat apFloat(type.getFloatSemantics(), apInt); + return p->builder.getFloatAttr(type, apFloat); +} + /// Parse a decimal or a hexadecimal literal, which can be either an integer /// or a float attribute. Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { @@ -1188,14 +1201,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { } // Construct a float attribute bitwise equivalent to the integer literal. - int width = type.getIntOrFloatBitWidth(); - APInt apInt(width, *val, isNegative); - if (apInt != *val) { - emitError("hexadecimal float constant out of range for attribute"); - return nullptr; - } - APFloat apFloat(floatType.getFloatSemantics(), apInt); - return builder.getFloatAttr(type, apFloat); + return buildHexadecimalFloatLiteral(this, floatType, *val); } if (!type.isIntOrIndex()) @@ -1306,14 +1312,6 @@ private: /// parseElement([1]) -> Failure ParseResult parseElement(); - /// Parse an integer element value, returning failure if the value isn't - /// valid. - ParseResult parseIntegerElement(bool isSigned); - - /// Parse a floating-point element value, returning failure if the value isn't - /// valid. - ParseResult parseFloatElement(bool isNegative); - /// Parse a list of either lists or elements, returning the dimensions of the /// parsed sub-tensors in dims. For example: /// parseList([1, 2, 3]) -> Success, [3] @@ -1327,12 +1325,8 @@ private: /// The shape inferred from the parsed elements. SmallVector shape; - /// Storage used when parsing integer elements, this is a pair of . - std::vector> intStorage; - - /// Storage used when parsing float elements. - std::vector floatStorage; + /// Storage used when parsing elements, this is a pair of . + std::vector> storage; /// A flag that indicates the type of elements that have been parsed. llvm::Optional knownEltKind; @@ -1370,21 +1364,43 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, ShapedType type, IntegerType eltTy) { - // Check to see if floating point values were parsed. - if (!floatStorage.empty()) { - p.emitError() << "expected integer elements, but parsed floating-point"; - return nullptr; + std::vector intElements; + intElements.reserve(storage.size()); + for (const auto &signAndToken : storage) { + bool isNegative = signAndToken.first; + const Token &token = signAndToken.second; + + // Check to see if floating point values were parsed. + if (token.is(Token::floatliteral)) { + p.emitError() << "expected integer elements, but parsed floating-point"; + return nullptr; + } + + assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && + "unexpected token type"); + if (token.isAny(Token::kw_true, Token::kw_false)) { + if (!eltTy.isInteger(1)) + p.emitError() << "expected i1 type for 'true' or 'false' values"; + APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), + /*isSigned=*/false); + intElements.push_back(apInt); + continue; + } + + // Create APInt values for each element with the correct bitwidth. + auto val = token.getUInt64IntegerValue(); + if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0 + : (int64_t)val.getValue() < 0)) { + p.emitError(token.getLoc(), + "integer constant out of range for attribute"); + return nullptr; + } + APInt apInt(eltTy.getWidth(), val.getValue(), isNegative); + if (apInt != val.getValue()) + return (p.emitError("integer constant out of range for type"), nullptr); + intElements.push_back(isNegative ? -apInt : apInt); } - // Create APInt values for each element with the correct bitwidth. - std::vector intElements; - intElements.reserve(intStorage.size()); - for (auto &signAndValue : intStorage) { - APInt apInt(eltTy.getWidth(), signAndValue.second, signAndValue.first); - if (apInt != signAndValue.second) - return (p.emitError("integer constant out of range for type"), nullptr); - intElements.push_back(signAndValue.first ? -apInt : apInt); - } return DenseElementsAttr::get(type, intElements); } @@ -1392,109 +1408,73 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, ShapedType type, FloatType eltTy) { - // Check to see if integer values were parsed. - if (!intStorage.empty()) { - p.emitError() << "expected floating-point elements, but parsed integer"; - return nullptr; + std::vector floatValues; + floatValues.reserve(storage.size()); + for (const auto &signAndToken : storage) { + bool isNegative = signAndToken.first; + const Token &token = signAndToken.second; + + // Handle hexadecimal float literals. + if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { + if (isNegative) { + p.emitError(token.getLoc()) + << "hexadecimal float literal should not have a leading minus"; + return nullptr; + } + auto val = token.getUInt64IntegerValue(); + if (!val.hasValue()) { + p.emitError("hexadecimal float constant out of range for attribute"); + return nullptr; + } + FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val); + if (!attr) + return nullptr; + floatValues.push_back(attr); + continue; + } + + // Check to see if any decimal integers or booleans were parsed. + if (!token.is(Token::floatliteral)) { + p.emitError() << "expected floating-point elements, but parsed integer"; + return nullptr; + } + + // Build the float values from tokens. + auto val = token.getFloatingPointValue(); + if (!val.hasValue()) { + p.emitError("floating point value too large for attribute"); + return nullptr; + } + floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val)); } - // Build the float values from the raw integer storage. - std::vector floatValues; - floatValues.reserve(floatStorage.size()); - for (auto &elt : floatStorage) - floatValues.push_back(FloatAttr::get(eltTy, elt)); return DenseElementsAttr::get(type, floatValues); } ParseResult TensorLiteralParser::parseElement() { - auto loc = p.getToken().getLoc(); - - ElementKind newEltKind; switch (p.getToken().getKind()) { // Parse a boolean element. case Token::kw_true: case Token::kw_false: - intStorage.emplace_back(false, p.getToken().is(Token::kw_true)); + case Token::floatliteral: + case Token::integer: + storage.emplace_back(/*isNegative=*/false, p.getToken()); p.consumeToken(); - newEltKind = ElementKind::Boolean; break; // Parse a signed integer or a negative floating-point element. case Token::minus: p.consumeToken(Token::minus); - - // Otherwise, check for an integer value. - if (p.getToken().is(Token::integer)) { - if (parseIntegerElement(/*isSigned=*/true)) - return failure(); - newEltKind = ElementKind::Integer; - - // Otherwise, check for a floating point value. - } else if (p.getToken().is(Token::floatliteral)) { - if (parseFloatElement(/*isNegative=*/true)) - return failure(); - newEltKind = ElementKind::Float; - } else { + if (!p.getToken().isAny(Token::floatliteral, Token::integer)) return p.emitError("expected integer or floating point literal"); - } + storage.emplace_back(/*isNegative=*/true, p.getToken()); + p.consumeToken(); break; - // Parse a floating-point element. - case Token::floatliteral: - if (parseFloatElement(/*isNegative=*/false)) - return failure(); - newEltKind = ElementKind::Float; - break; - - // Parse an integer element. - case Token::integer: - if (parseIntegerElement(/*isSigned=*/false)) - return failure(); - newEltKind = ElementKind::Integer; - break; default: return p.emitError("expected element literal of primitive type"); } - // Check to see if the element kind has changed from the previously inferred - // type. - if (!knownEltKind) - knownEltKind = newEltKind; - else if (knownEltKind != newEltKind) - return p.emitError(loc) - << "tensor element type differs from previously inferred type, with " - "old type of " - << getElementKindStr(*knownEltKind) << ", and new type of " - << getElementKindStr(newEltKind); - return success(); -} - -/// Parse an integer element value, returning failure if the value isn't -/// valid. -ParseResult TensorLiteralParser::parseIntegerElement(bool isSigned) { - // Check that the integer value is valid. - auto val = p.getToken().getUInt64IntegerValue(); - if (!val.hasValue() || - (isSigned ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0)) - return p.emitError("integer constant out of range for attribute"); - - // Add it to the storage. - p.consumeToken(Token::integer); - intStorage.emplace_back(isSigned, *val); - return success(); -} - -/// Parse a floating-point element value, returning failure if the value isn't -/// valid. -ParseResult TensorLiteralParser::parseFloatElement(bool isNegative) { - // Check that the float value is valid. - auto val = p.getToken().getFloatingPointValue(); - if (!val.hasValue()) - return p.emitError("floating point value too large for attribute"); - - // Add it to the storage. - p.consumeToken(Token::floatliteral); - floatStorage.push_back(isNegative ? -val.getValue() : val.getValue()); return success(); }