From 5ff83f98c8296690907c29a50000032ccc72b55a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2019 14:24:30 -0700 Subject: [PATCH] Support hexadecimal floats in tensor literals Extend the recently introduced support for hexadecimal float literals to tensor literals, which may also contain special floating point values such as infinities and NaNs. Modify TensorLiteralParser to store the list of tokens representing values until the type is parsed instead of trying to guess the tensor element type from the token kinds (hexadecimal values can be either integers or floats, and can be mixed with both). Maintain the error reports as close as possible to the existing implementation to avoid disturbing the tests. They can be improved in a separate clean-up if deemed necessary. PiperOrigin-RevId: 260794716 --- third_party/mlir/lib/Parser/Parser.cpp | 210 +++++++++++-------------- 1 file changed, 95 insertions(+), 115 deletions(-) diff --git a/third_party/mlir/lib/Parser/Parser.cpp b/third_party/mlir/lib/Parser/Parser.cpp index a5dd98138a2..8af99179f6f 100644 --- a/third_party/mlir/lib/Parser/Parser.cpp +++ b/third_party/mlir/lib/Parser/Parser.cpp @@ -1152,6 +1152,19 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) { return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); } +/// Construct a float attribute bitwise equivalent to the integer literal. +static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type, + uint64_t value) { + int width = type.getIntOrFloatBitWidth(); + APInt apInt(width, value); + if (apInt != value) { + p->emitError("hexadecimal float constant out of range for type"); + return nullptr; + } + APFloat apFloat(type.getFloatSemantics(), apInt); + return p->builder.getFloatAttr(type, apFloat); +} + /// Parse a decimal or a hexadecimal literal, which can be either an integer /// or a float attribute. Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { @@ -1188,14 +1201,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { } // Construct a float attribute bitwise equivalent to the integer literal. - int width = type.getIntOrFloatBitWidth(); - APInt apInt(width, *val, isNegative); - if (apInt != *val) { - emitError("hexadecimal float constant out of range for attribute"); - return nullptr; - } - APFloat apFloat(floatType.getFloatSemantics(), apInt); - return builder.getFloatAttr(type, apFloat); + return buildHexadecimalFloatLiteral(this, floatType, *val); } if (!type.isIntOrIndex()) @@ -1306,14 +1312,6 @@ private: /// parseElement([1]) -> Failure ParseResult parseElement(); - /// Parse an integer element value, returning failure if the value isn't - /// valid. - ParseResult parseIntegerElement(bool isSigned); - - /// Parse a floating-point element value, returning failure if the value isn't - /// valid. - ParseResult parseFloatElement(bool isNegative); - /// Parse a list of either lists or elements, returning the dimensions of the /// parsed sub-tensors in dims. For example: /// parseList([1, 2, 3]) -> Success, [3] @@ -1327,12 +1325,8 @@ private: /// The shape inferred from the parsed elements. SmallVector shape; - /// Storage used when parsing integer elements, this is a pair of . - std::vector> intStorage; - - /// Storage used when parsing float elements. - std::vector floatStorage; + /// Storage used when parsing elements, this is a pair of . + std::vector> storage; /// A flag that indicates the type of elements that have been parsed. llvm::Optional knownEltKind; @@ -1370,21 +1364,43 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, ShapedType type, IntegerType eltTy) { - // Check to see if floating point values were parsed. - if (!floatStorage.empty()) { - p.emitError() << "expected integer elements, but parsed floating-point"; - return nullptr; + std::vector intElements; + intElements.reserve(storage.size()); + for (const auto &signAndToken : storage) { + bool isNegative = signAndToken.first; + const Token &token = signAndToken.second; + + // Check to see if floating point values were parsed. + if (token.is(Token::floatliteral)) { + p.emitError() << "expected integer elements, but parsed floating-point"; + return nullptr; + } + + assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && + "unexpected token type"); + if (token.isAny(Token::kw_true, Token::kw_false)) { + if (!eltTy.isInteger(1)) + p.emitError() << "expected i1 type for 'true' or 'false' values"; + APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), + /*isSigned=*/false); + intElements.push_back(apInt); + continue; + } + + // Create APInt values for each element with the correct bitwidth. + auto val = token.getUInt64IntegerValue(); + if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0 + : (int64_t)val.getValue() < 0)) { + p.emitError(token.getLoc(), + "integer constant out of range for attribute"); + return nullptr; + } + APInt apInt(eltTy.getWidth(), val.getValue(), isNegative); + if (apInt != val.getValue()) + return (p.emitError("integer constant out of range for type"), nullptr); + intElements.push_back(isNegative ? -apInt : apInt); } - // Create APInt values for each element with the correct bitwidth. - std::vector intElements; - intElements.reserve(intStorage.size()); - for (auto &signAndValue : intStorage) { - APInt apInt(eltTy.getWidth(), signAndValue.second, signAndValue.first); - if (apInt != signAndValue.second) - return (p.emitError("integer constant out of range for type"), nullptr); - intElements.push_back(signAndValue.first ? -apInt : apInt); - } return DenseElementsAttr::get(type, intElements); } @@ -1392,109 +1408,73 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, ShapedType type, FloatType eltTy) { - // Check to see if integer values were parsed. - if (!intStorage.empty()) { - p.emitError() << "expected floating-point elements, but parsed integer"; - return nullptr; + std::vector floatValues; + floatValues.reserve(storage.size()); + for (const auto &signAndToken : storage) { + bool isNegative = signAndToken.first; + const Token &token = signAndToken.second; + + // Handle hexadecimal float literals. + if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { + if (isNegative) { + p.emitError(token.getLoc()) + << "hexadecimal float literal should not have a leading minus"; + return nullptr; + } + auto val = token.getUInt64IntegerValue(); + if (!val.hasValue()) { + p.emitError("hexadecimal float constant out of range for attribute"); + return nullptr; + } + FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val); + if (!attr) + return nullptr; + floatValues.push_back(attr); + continue; + } + + // Check to see if any decimal integers or booleans were parsed. + if (!token.is(Token::floatliteral)) { + p.emitError() << "expected floating-point elements, but parsed integer"; + return nullptr; + } + + // Build the float values from tokens. + auto val = token.getFloatingPointValue(); + if (!val.hasValue()) { + p.emitError("floating point value too large for attribute"); + return nullptr; + } + floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val)); } - // Build the float values from the raw integer storage. - std::vector floatValues; - floatValues.reserve(floatStorage.size()); - for (auto &elt : floatStorage) - floatValues.push_back(FloatAttr::get(eltTy, elt)); return DenseElementsAttr::get(type, floatValues); } ParseResult TensorLiteralParser::parseElement() { - auto loc = p.getToken().getLoc(); - - ElementKind newEltKind; switch (p.getToken().getKind()) { // Parse a boolean element. case Token::kw_true: case Token::kw_false: - intStorage.emplace_back(false, p.getToken().is(Token::kw_true)); + case Token::floatliteral: + case Token::integer: + storage.emplace_back(/*isNegative=*/false, p.getToken()); p.consumeToken(); - newEltKind = ElementKind::Boolean; break; // Parse a signed integer or a negative floating-point element. case Token::minus: p.consumeToken(Token::minus); - - // Otherwise, check for an integer value. - if (p.getToken().is(Token::integer)) { - if (parseIntegerElement(/*isSigned=*/true)) - return failure(); - newEltKind = ElementKind::Integer; - - // Otherwise, check for a floating point value. - } else if (p.getToken().is(Token::floatliteral)) { - if (parseFloatElement(/*isNegative=*/true)) - return failure(); - newEltKind = ElementKind::Float; - } else { + if (!p.getToken().isAny(Token::floatliteral, Token::integer)) return p.emitError("expected integer or floating point literal"); - } + storage.emplace_back(/*isNegative=*/true, p.getToken()); + p.consumeToken(); break; - // Parse a floating-point element. - case Token::floatliteral: - if (parseFloatElement(/*isNegative=*/false)) - return failure(); - newEltKind = ElementKind::Float; - break; - - // Parse an integer element. - case Token::integer: - if (parseIntegerElement(/*isSigned=*/false)) - return failure(); - newEltKind = ElementKind::Integer; - break; default: return p.emitError("expected element literal of primitive type"); } - // Check to see if the element kind has changed from the previously inferred - // type. - if (!knownEltKind) - knownEltKind = newEltKind; - else if (knownEltKind != newEltKind) - return p.emitError(loc) - << "tensor element type differs from previously inferred type, with " - "old type of " - << getElementKindStr(*knownEltKind) << ", and new type of " - << getElementKindStr(newEltKind); - return success(); -} - -/// Parse an integer element value, returning failure if the value isn't -/// valid. -ParseResult TensorLiteralParser::parseIntegerElement(bool isSigned) { - // Check that the integer value is valid. - auto val = p.getToken().getUInt64IntegerValue(); - if (!val.hasValue() || - (isSigned ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0)) - return p.emitError("integer constant out of range for attribute"); - - // Add it to the storage. - p.consumeToken(Token::integer); - intStorage.emplace_back(isSigned, *val); - return success(); -} - -/// Parse a floating-point element value, returning failure if the value isn't -/// valid. -ParseResult TensorLiteralParser::parseFloatElement(bool isNegative) { - // Check that the float value is valid. - auto val = p.getToken().getFloatingPointValue(); - if (!val.hasValue()) - return p.emitError("floating point value too large for attribute"); - - // Add it to the storage. - p.consumeToken(Token::floatliteral); - floatStorage.push_back(isNegative ? -val.getValue() : val.getValue()); return success(); }