MLIR does not have support for parsing special floating point values such as infinities and NaNs. If programmatically constructed, these values are printed as NaN and (+-)Inf and cannot be parsed back. Add parser support for hexadecimal literals in float attributes, following LLVM IR. The literal corresponds to the in-memory representation of the floating point value. IEEE 754 defines a range of possible values for NaNs, storing the bitwise representation allows MLIR to properly roundtrip NaNs with different bit values of significands. The initial version of this commit was missing support for float literals that used to be printed in decimal notation as a fallback, but ended up being printed in hexadecimal format which became the fallback for special values. The decimal fallback behavior was not exercised by tests. It is currently reinstated and tested by the newly added test @f32_potential_precision_loss in parser.mlir. PiperOrigin-RevId: 260790900
1778 lines
54 KiB
C++
1778 lines
54 KiB
C++
//===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
|
|
//
|
|
// Copyright 2019 The MLIR Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
// =============================================================================
|
|
//
|
|
// This file implements the MLIR AsmPrinter class, which is used to implement
|
|
// the various print() methods on the core IR objects.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/IR/Function.h"
|
|
#include "mlir/IR/IntegerSet.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/Module.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/StandardTypes.h"
|
|
#include "mlir/Support/STLExtras.h"
|
|
#include "llvm/ADT/APFloat.h"
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/MapVector.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/ScopedHashTable.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/SmallString.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/ADT/StringSet.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/Regex.h"
|
|
using namespace mlir;
|
|
|
|
void Identifier::print(raw_ostream &os) const { os << str(); }
|
|
|
|
void Identifier::dump() const { print(llvm::errs()); }
|
|
|
|
void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
|
|
|
|
void OperationName::dump() const { print(llvm::errs()); }
|
|
|
|
OpAsmPrinter::~OpAsmPrinter() {}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ModuleState
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TODO(riverriddle) Rethink this flag when we have a pass that can remove debug
|
|
// info or when we have a system for printer flags.
|
|
static llvm::cl::opt<bool>
|
|
shouldPrintDebugInfoOpt("mlir-print-debuginfo",
|
|
llvm::cl::desc("Print debug info in MLIR output"),
|
|
llvm::cl::init(false));
|
|
|
|
static llvm::cl::opt<bool> printPrettyDebugInfo(
|
|
"mlir-pretty-debuginfo",
|
|
llvm::cl::desc("Print pretty debug info in MLIR output"),
|
|
llvm::cl::init(false));
|
|
|
|
// Use the generic op output form in the operation printer even if the custom
|
|
// form is defined.
|
|
static llvm::cl::opt<bool>
|
|
printGenericOpForm("mlir-print-op-generic",
|
|
llvm::cl::desc("Print the generic op form"),
|
|
llvm::cl::init(false), llvm::cl::Hidden);
|
|
|
|
namespace {
|
|
/// A special index constant used for non-kind attribute aliases.
|
|
static constexpr int kNonAttrKindAlias = -1;
|
|
|
|
class ModuleState {
|
|
public:
|
|
/// This is the current context if it is knowable, otherwise this is null.
|
|
MLIRContext *const context;
|
|
|
|
explicit ModuleState(MLIRContext *context) : context(context) {}
|
|
|
|
// Initializes module state, populating affine map state.
|
|
void initialize(Operation *op);
|
|
|
|
Twine getAttributeAlias(Attribute attr) const {
|
|
auto alias = attrToAlias.find(attr);
|
|
if (alias == attrToAlias.end())
|
|
return Twine();
|
|
|
|
// Return the alias for this attribute, along with the index if this was
|
|
// generated by a kind alias.
|
|
int kindIndex = alias->second.second;
|
|
return alias->second.first +
|
|
(kindIndex == kNonAttrKindAlias ? Twine() : Twine(kindIndex));
|
|
}
|
|
|
|
void printAttributeAliases(raw_ostream &os) const {
|
|
auto printAlias = [&](StringRef alias, Attribute attr, int index) {
|
|
os << '#' << alias;
|
|
if (index != kNonAttrKindAlias)
|
|
os << index;
|
|
os << " = " << attr << '\n';
|
|
};
|
|
|
|
// Print all of the attribute kind aliases.
|
|
for (auto &kindAlias : attrKindToAlias) {
|
|
for (unsigned i = 0, e = kindAlias.second.second.size(); i != e; ++i)
|
|
printAlias(kindAlias.second.first, kindAlias.second.second[i], i);
|
|
os << "\n";
|
|
}
|
|
|
|
// In a second pass print all of the remaining attribute aliases that aren't
|
|
// kind aliases.
|
|
for (Attribute attr : usedAttributes) {
|
|
auto alias = attrToAlias.find(attr);
|
|
if (alias != attrToAlias.end() &&
|
|
alias->second.second == kNonAttrKindAlias)
|
|
printAlias(alias->second.first, attr, alias->second.second);
|
|
}
|
|
}
|
|
|
|
StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); }
|
|
|
|
void printTypeAliases(raw_ostream &os) const {
|
|
for (Type type : usedTypes) {
|
|
auto alias = typeToAlias.find(type);
|
|
if (alias != typeToAlias.end())
|
|
os << '!' << alias->second << " = type " << type << '\n';
|
|
}
|
|
}
|
|
|
|
private:
|
|
void recordAttributeReference(Attribute attr) {
|
|
// Don't recheck attributes that have already been seen or those that
|
|
// already have an alias.
|
|
if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
|
|
return;
|
|
|
|
// If this attribute kind has an alias, then record one for this attribute.
|
|
auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
|
|
if (alias == attrKindToAlias.end())
|
|
return;
|
|
std::pair<StringRef, int> attrAlias(alias->second.first,
|
|
alias->second.second.size());
|
|
attrToAlias.insert({attr, attrAlias});
|
|
alias->second.second.push_back(attr);
|
|
}
|
|
|
|
void recordTypeReference(Type ty) { usedTypes.insert(ty); }
|
|
|
|
// Visit functions.
|
|
void visitOperation(Operation *op);
|
|
void visitType(Type type);
|
|
void visitAttribute(Attribute attr);
|
|
|
|
// Initialize symbol aliases.
|
|
void initializeSymbolAliases();
|
|
|
|
/// Set of attributes known to be used within the module.
|
|
llvm::SetVector<Attribute> usedAttributes;
|
|
|
|
/// Mapping between attribute and a pair comprised of a base alias name and a
|
|
/// count suffix. If the suffix is set to -1, it is not displayed.
|
|
llvm::MapVector<Attribute, std::pair<StringRef, int>> attrToAlias;
|
|
|
|
/// Mapping between attribute kind and a pair comprised of a base alias name
|
|
/// and a unique list of attributes belonging to this kind sorted by location
|
|
/// seen in the module.
|
|
llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>>
|
|
attrKindToAlias;
|
|
|
|
/// Set of types known to be used within the module.
|
|
llvm::SetVector<Type> usedTypes;
|
|
|
|
/// A mapping between a type and a given alias.
|
|
DenseMap<Type, StringRef> typeToAlias;
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
// TODO Support visiting other types/operations when implemented.
|
|
void ModuleState::visitType(Type type) {
|
|
recordTypeReference(type);
|
|
if (auto funcType = type.dyn_cast<FunctionType>()) {
|
|
// Visit input and result types for functions.
|
|
for (auto input : funcType.getInputs())
|
|
visitType(input);
|
|
for (auto result : funcType.getResults())
|
|
visitType(result);
|
|
return;
|
|
}
|
|
if (auto memref = type.dyn_cast<MemRefType>()) {
|
|
// Visit affine maps in memref type.
|
|
for (auto map : memref.getAffineMaps())
|
|
recordAttributeReference(AffineMapAttr::get(map));
|
|
}
|
|
if (auto shapedType = type.dyn_cast<ShapedType>()) {
|
|
visitType(shapedType.getElementType());
|
|
}
|
|
}
|
|
|
|
void ModuleState::visitAttribute(Attribute attr) {
|
|
recordAttributeReference(attr);
|
|
if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
|
|
for (auto elt : arrayAttr.getValue())
|
|
visitAttribute(elt);
|
|
} else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
|
|
visitType(typeAttr.getValue());
|
|
}
|
|
}
|
|
|
|
void ModuleState::visitOperation(Operation *op) {
|
|
// Visit all the types used in the operation.
|
|
for (auto type : op->getOperandTypes())
|
|
visitType(type);
|
|
for (auto type : op->getResultTypes())
|
|
visitType(type);
|
|
for (auto ®ion : op->getRegions())
|
|
for (auto &block : region)
|
|
for (auto *arg : block.getArguments())
|
|
visitType(arg->getType());
|
|
|
|
// Visit each of the attributes.
|
|
for (auto elt : op->getAttrs())
|
|
visitAttribute(elt.second);
|
|
}
|
|
|
|
// Utility to generate a function to register a symbol alias.
|
|
static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
|
|
assert(!name.empty() && "expected alias name to be non-empty");
|
|
// TODO(riverriddle) Assert that the provided alias name can be lexed as
|
|
// an identifier.
|
|
|
|
// Check that the alias doesn't contain a '.' character and the name is not
|
|
// already in use.
|
|
return !name.contains('.') && usedAliases.insert(name).second;
|
|
}
|
|
|
|
void ModuleState::initializeSymbolAliases() {
|
|
// Track the identifiers in use for each symbol so that the same identifier
|
|
// isn't used twice.
|
|
llvm::StringSet<> usedAliases;
|
|
|
|
// Get the currently registered dialects.
|
|
auto dialects = context->getRegisteredDialects();
|
|
|
|
// Collect the set of aliases from each dialect.
|
|
SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
|
|
SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
|
|
SmallVector<std::pair<Type, StringRef>, 16> typeAliases;
|
|
|
|
// AffineMap/Integer set have specific kind aliases.
|
|
attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map");
|
|
attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set");
|
|
|
|
for (auto *dialect : dialects) {
|
|
dialect->getAttributeKindAliases(attributeKindAliases);
|
|
dialect->getAttributeAliases(attributeAliases);
|
|
dialect->getTypeAliases(typeAliases);
|
|
}
|
|
|
|
// Setup the attribute kind aliases.
|
|
StringRef alias;
|
|
unsigned attrKind;
|
|
for (auto &attrAliasPair : attributeKindAliases) {
|
|
std::tie(attrKind, alias) = attrAliasPair;
|
|
assert(!alias.empty() && "expected non-empty alias string");
|
|
if (!usedAliases.count(alias) && !alias.contains('.'))
|
|
attrKindToAlias.insert({attrKind, {alias, {}}});
|
|
}
|
|
|
|
// Clear the set of used identifiers so that the attribute kind aliases are
|
|
// just a prefix and not the full alias, i.e. there may be some overlap.
|
|
usedAliases.clear();
|
|
|
|
// Register the attribute aliases.
|
|
// Create a regex for the attribute kind alias names, these have a prefix with
|
|
// a counter appended to the end. We prevent normal aliases from having these
|
|
// names to avoid collisions.
|
|
llvm::Regex reservedAttrNames("[0-9]+$");
|
|
|
|
// Attribute value aliases.
|
|
Attribute attr;
|
|
for (auto &attrAliasPair : attributeAliases) {
|
|
std::tie(attr, alias) = attrAliasPair;
|
|
if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases))
|
|
attrToAlias.insert({attr, {alias, kNonAttrKindAlias}});
|
|
}
|
|
|
|
// Clear the set of used identifiers as types can have the same identifiers as
|
|
// affine structures.
|
|
usedAliases.clear();
|
|
|
|
// Type aliases.
|
|
for (auto &typeAliasPair : typeAliases)
|
|
if (canRegisterAlias(typeAliasPair.second, usedAliases))
|
|
typeToAlias.insert(typeAliasPair);
|
|
}
|
|
|
|
// Initializes module state, populating affine map and integer set state.
|
|
void ModuleState::initialize(Operation *op) {
|
|
// Initialize the symbol aliases.
|
|
initializeSymbolAliases();
|
|
|
|
// Visit each of the nested operations.
|
|
op->walk([&](Operation *op) { visitOperation(op); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ModulePrinter
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class ModulePrinter {
|
|
public:
|
|
ModulePrinter(raw_ostream &os, ModuleState &state) : os(os), state(state) {}
|
|
explicit ModulePrinter(ModulePrinter &printer)
|
|
: os(printer.os), state(printer.state) {}
|
|
|
|
template <typename Container, typename UnaryFunctor>
|
|
inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
|
|
interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
|
|
}
|
|
|
|
void print(ModuleOp module);
|
|
|
|
/// Print the given attribute. If 'mayElideType' is true, some attributes are
|
|
/// printed without the type when the type matches the default used in the
|
|
/// parser (for example i64 is the default for integer attributes).
|
|
void printAttribute(Attribute attr, bool mayElideType = false);
|
|
|
|
void printType(Type type);
|
|
void printLocation(LocationAttr loc);
|
|
|
|
void printAffineMap(AffineMap map);
|
|
void printAffineExpr(
|
|
AffineExpr expr,
|
|
llvm::function_ref<void(unsigned, bool)> printValueName = nullptr);
|
|
void printAffineConstraint(AffineExpr expr, bool isEq);
|
|
void printIntegerSet(IntegerSet set);
|
|
|
|
protected:
|
|
raw_ostream &os;
|
|
ModuleState &state;
|
|
|
|
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
|
ArrayRef<StringRef> elidedAttrs = {});
|
|
void printTrailingLocation(Location loc);
|
|
void printLocationInternal(LocationAttr loc, bool pretty = false);
|
|
void printDenseElementsAttr(DenseElementsAttr attr);
|
|
|
|
/// This enum is used to represent the binding stength of the enclosing
|
|
/// context that an AffineExprStorage is being printed in, so we can
|
|
/// intelligently produce parens.
|
|
enum class BindingStrength {
|
|
Weak, // + and -
|
|
Strong, // All other binary operators.
|
|
};
|
|
void printAffineExprInternal(
|
|
AffineExpr expr, BindingStrength enclosingTightness,
|
|
llvm::function_ref<void(unsigned, bool)> printValueName = nullptr);
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
void ModulePrinter::printTrailingLocation(Location loc) {
|
|
// Check to see if we are printing debug information.
|
|
if (!shouldPrintDebugInfoOpt)
|
|
return;
|
|
|
|
os << " ";
|
|
printLocation(loc);
|
|
}
|
|
|
|
void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
|
|
switch (loc.getKind()) {
|
|
case StandardAttributes::UnknownLocation:
|
|
if (pretty)
|
|
os << "[unknown]";
|
|
else
|
|
os << "unknown";
|
|
break;
|
|
case StandardAttributes::FileLineColLocation: {
|
|
auto fileLoc = loc.cast<FileLineColLoc>();
|
|
auto mayQuote = pretty ? "" : "\"";
|
|
os << mayQuote << fileLoc.getFilename() << mayQuote << ':'
|
|
<< fileLoc.getLine() << ':' << fileLoc.getColumn();
|
|
break;
|
|
}
|
|
case StandardAttributes::NameLocation: {
|
|
auto nameLoc = loc.cast<NameLoc>();
|
|
os << '\"' << nameLoc.getName() << '\"';
|
|
|
|
// Print the child if it isn't unknown.
|
|
auto childLoc = nameLoc.getChildLoc();
|
|
if (!childLoc.isa<UnknownLoc>()) {
|
|
os << '(';
|
|
printLocationInternal(childLoc, pretty);
|
|
os << ')';
|
|
}
|
|
break;
|
|
}
|
|
case StandardAttributes::CallSiteLocation: {
|
|
auto callLocation = loc.cast<CallSiteLoc>();
|
|
auto caller = callLocation.getCaller();
|
|
auto callee = callLocation.getCallee();
|
|
if (!pretty)
|
|
os << "callsite(";
|
|
printLocationInternal(callee, pretty);
|
|
if (pretty) {
|
|
if (callee.isa<NameLoc>()) {
|
|
if (caller.isa<FileLineColLoc>()) {
|
|
os << " at ";
|
|
} else {
|
|
os << "\n at ";
|
|
}
|
|
} else {
|
|
os << "\n at ";
|
|
}
|
|
} else {
|
|
os << " at ";
|
|
}
|
|
printLocationInternal(caller, pretty);
|
|
if (!pretty)
|
|
os << ")";
|
|
break;
|
|
}
|
|
case StandardAttributes::FusedLocation: {
|
|
auto fusedLoc = loc.cast<FusedLoc>();
|
|
if (!pretty)
|
|
os << "fused";
|
|
if (auto metadata = fusedLoc.getMetadata())
|
|
os << '<' << metadata << '>';
|
|
os << '[';
|
|
interleave(
|
|
fusedLoc.getLocations(),
|
|
[&](Location loc) { printLocationInternal(loc, pretty); },
|
|
[&]() { os << ", "; });
|
|
os << ']';
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Print a floating point value in a way that the parser will be able to
|
|
/// round-trip losslessly.
|
|
static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
|
|
// We would like to output the FP constant value in exponential notation,
|
|
// but we cannot do this if doing so will lose precision. Check here to
|
|
// make sure that we only output it in exponential format if we can parse
|
|
// the value back and get the same value.
|
|
bool isInf = apValue.isInfinity();
|
|
bool isNaN = apValue.isNaN();
|
|
if (!isInf && !isNaN) {
|
|
SmallString<128> strValue;
|
|
apValue.toString(strValue, 6, 0, false);
|
|
|
|
// Check to make sure that the stringized number is not some string like
|
|
// "Inf" or NaN, that atof will accept, but the lexer will not. Check
|
|
// that the string matches the "[-+]?[0-9]" regex.
|
|
assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
|
|
((strValue[0] == '-' || strValue[0] == '+') &&
|
|
(strValue[1] >= '0' && strValue[1] <= '9'))) &&
|
|
"[-+]?[0-9] regex does not match!");
|
|
|
|
// Parse back the stringized version and check that the value is equal
|
|
// (i.e., there is no precision loss). If it is not, use the default format
|
|
// of APFloat instead of the exponential notation.
|
|
if (!APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
|
|
strValue.clear();
|
|
apValue.toString(strValue);
|
|
}
|
|
os << strValue;
|
|
return;
|
|
}
|
|
|
|
// Print special values in hexadecimal format. The sign bit should be
|
|
// included in the literal.
|
|
SmallVector<char, 16> str;
|
|
APInt apInt = apValue.bitcastToAPInt();
|
|
apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
|
|
/*formatAsCLiteral=*/true);
|
|
os << str;
|
|
}
|
|
|
|
void ModulePrinter::printLocation(LocationAttr loc) {
|
|
if (printPrettyDebugInfo) {
|
|
printLocationInternal(loc, /*pretty=*/true);
|
|
} else {
|
|
os << "loc(";
|
|
printLocationInternal(loc);
|
|
os << ')';
|
|
}
|
|
}
|
|
|
|
/// Returns if the given dialect symbol data is simple enough to print in the
|
|
/// pretty form, i.e. without the enclosing "".
|
|
static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
|
|
// The name must start with an identifier.
|
|
if (symName.empty() || !isalpha(symName.front()))
|
|
return false;
|
|
|
|
// Ignore all the characters that are valid in an identifier in the symbol
|
|
// name.
|
|
symName =
|
|
symName.drop_while([](char c) { return llvm::isAlnum(c) || c == '.'; });
|
|
if (symName.empty())
|
|
return true;
|
|
|
|
// If we got to an unexpected character, then it must be a <>. Check those
|
|
// recursively.
|
|
if (symName.front() != '<' || symName.back() != '>')
|
|
return false;
|
|
|
|
SmallVector<char, 8> nestedPunctuation;
|
|
do {
|
|
// If we ran out of characters, then we had a punctuation mismatch.
|
|
if (symName.empty())
|
|
return false;
|
|
|
|
auto c = symName.front();
|
|
symName = symName.drop_front();
|
|
|
|
switch (c) {
|
|
// We never allow null characters. This is an EOF indicator for the lexer
|
|
// which we could handle, but isn't important for any known dialect.
|
|
case '\0':
|
|
return false;
|
|
case '<':
|
|
case '[':
|
|
case '(':
|
|
case '{':
|
|
nestedPunctuation.push_back(c);
|
|
continue;
|
|
// Reject types with mismatched brackets.
|
|
case '>':
|
|
if (nestedPunctuation.pop_back_val() != '<')
|
|
return false;
|
|
break;
|
|
case ']':
|
|
if (nestedPunctuation.pop_back_val() != '[')
|
|
return false;
|
|
break;
|
|
case ')':
|
|
if (nestedPunctuation.pop_back_val() != '(')
|
|
return false;
|
|
break;
|
|
case '}':
|
|
if (nestedPunctuation.pop_back_val() != '{')
|
|
return false;
|
|
break;
|
|
default:
|
|
continue;
|
|
}
|
|
|
|
// We're done when the punctuation is fully matched.
|
|
} while (!nestedPunctuation.empty());
|
|
|
|
// If there were extra characters, then we failed.
|
|
return symName.empty();
|
|
}
|
|
|
|
/// Print the given dialect symbol to the stream.
|
|
static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
|
|
StringRef dialectName, StringRef symString) {
|
|
os << symPrefix << dialectName;
|
|
|
|
// If this symbol name is simple enough, print it directly in pretty form,
|
|
// otherwise, we print it as an escaped string.
|
|
if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
|
|
os << '.' << symString;
|
|
return;
|
|
}
|
|
|
|
// TODO: escape the symbol name, it could contain " characters.
|
|
os << "<\"" << symString << "\">";
|
|
}
|
|
|
|
void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
|
|
if (!attr) {
|
|
os << "<<NULL ATTRIBUTE>>";
|
|
return;
|
|
}
|
|
|
|
// Check for an alias for this attribute.
|
|
Twine alias = state.getAttributeAlias(attr);
|
|
if (!alias.isTriviallyEmpty()) {
|
|
os << '#' << alias;
|
|
return;
|
|
}
|
|
|
|
switch (attr.getKind()) {
|
|
default: {
|
|
auto &dialect = attr.getDialect();
|
|
|
|
// Ask the dialect to serialize the attribute to a string.
|
|
std::string attrName;
|
|
{
|
|
llvm::raw_string_ostream attrNameStr(attrName);
|
|
dialect.printAttribute(attr, attrNameStr);
|
|
}
|
|
|
|
printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
|
|
break;
|
|
}
|
|
case StandardAttributes::Opaque: {
|
|
auto opaqueAttr = attr.cast<OpaqueAttr>();
|
|
printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
|
|
opaqueAttr.getAttrData());
|
|
break;
|
|
}
|
|
case StandardAttributes::Unit:
|
|
os << "unit";
|
|
break;
|
|
case StandardAttributes::Bool:
|
|
os << (attr.cast<BoolAttr>().getValue() ? "true" : "false");
|
|
|
|
// BoolAttr always elides the type.
|
|
return;
|
|
case StandardAttributes::Dictionary:
|
|
os << '{';
|
|
interleaveComma(attr.cast<DictionaryAttr>().getValue(),
|
|
[&](NamedAttribute attr) {
|
|
os << attr.first << " = ";
|
|
printAttribute(attr.second);
|
|
});
|
|
os << '}';
|
|
break;
|
|
case StandardAttributes::Integer: {
|
|
auto intAttr = attr.cast<IntegerAttr>();
|
|
// Print all integer attributes as signed unless i1.
|
|
bool isSigned = intAttr.getType().isIndex() ||
|
|
intAttr.getType().getIntOrFloatBitWidth() != 1;
|
|
intAttr.getValue().print(os, isSigned);
|
|
|
|
// IntegerAttr elides the type if I64.
|
|
if (mayElideType && intAttr.getType().isInteger(64))
|
|
return;
|
|
break;
|
|
}
|
|
case StandardAttributes::Float: {
|
|
auto floatAttr = attr.cast<FloatAttr>();
|
|
printFloatValue(floatAttr.getValue(), os);
|
|
|
|
// FloatAttr elides the type if F64.
|
|
if (mayElideType && floatAttr.getType().isF64())
|
|
return;
|
|
break;
|
|
}
|
|
case StandardAttributes::String:
|
|
os << '"';
|
|
printEscapedString(attr.cast<StringAttr>().getValue(), os);
|
|
os << '"';
|
|
break;
|
|
case StandardAttributes::Array:
|
|
os << '[';
|
|
interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) {
|
|
printAttribute(attr, /*mayElideType=*/true);
|
|
});
|
|
os << ']';
|
|
break;
|
|
case StandardAttributes::AffineMap:
|
|
attr.cast<AffineMapAttr>().getValue().print(os);
|
|
|
|
// AffineMap always elides the type.
|
|
return;
|
|
case StandardAttributes::IntegerSet:
|
|
attr.cast<IntegerSetAttr>().getValue().print(os);
|
|
break;
|
|
case StandardAttributes::Type:
|
|
printType(attr.cast<TypeAttr>().getValue());
|
|
break;
|
|
case StandardAttributes::SymbolRef:
|
|
os << '@' << attr.cast<SymbolRefAttr>().getValue();
|
|
break;
|
|
case StandardAttributes::OpaqueElements: {
|
|
auto eltsAttr = attr.cast<OpaqueElementsAttr>();
|
|
os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
|
|
os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">";
|
|
break;
|
|
}
|
|
case StandardAttributes::DenseElements: {
|
|
auto eltsAttr = attr.cast<DenseElementsAttr>();
|
|
os << "dense<";
|
|
printDenseElementsAttr(eltsAttr);
|
|
os << '>';
|
|
break;
|
|
}
|
|
case StandardAttributes::SparseElements: {
|
|
auto elementsAttr = attr.cast<SparseElementsAttr>();
|
|
os << "sparse<";
|
|
printDenseElementsAttr(elementsAttr.getIndices());
|
|
os << ", ";
|
|
printDenseElementsAttr(elementsAttr.getValues());
|
|
os << '>';
|
|
break;
|
|
}
|
|
|
|
// Location attributes.
|
|
case StandardAttributes::CallSiteLocation:
|
|
case StandardAttributes::FileLineColLocation:
|
|
case StandardAttributes::FusedLocation:
|
|
case StandardAttributes::NameLocation:
|
|
case StandardAttributes::UnknownLocation:
|
|
printLocation(attr.cast<LocationAttr>());
|
|
break;
|
|
}
|
|
|
|
// Print the type if it isn't a 'none' type.
|
|
auto attrType = attr.getType();
|
|
if (!attrType.isa<NoneType>()) {
|
|
os << " : ";
|
|
printType(attrType);
|
|
}
|
|
}
|
|
|
|
/// Print the integer element of the given DenseElementsAttr at 'index'.
|
|
static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
|
|
unsigned index) {
|
|
APInt value = *std::next(attr.getIntValues().begin(), index);
|
|
if (value.getBitWidth() == 1)
|
|
os << (value.getBoolValue() ? "true" : "false");
|
|
else
|
|
value.print(os, /*isSigned=*/true);
|
|
}
|
|
|
|
/// Print the float element of the given DenseElementsAttr at 'index'.
|
|
static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
|
|
unsigned index) {
|
|
APFloat value = *std::next(attr.getFloatValues().begin(), index);
|
|
printFloatValue(value, os);
|
|
}
|
|
|
|
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
|
|
auto type = attr.getType();
|
|
auto shape = type.getShape();
|
|
auto rank = type.getRank();
|
|
|
|
// The function used to print elements of this attribute.
|
|
auto printEltFn = type.getElementType().isa<IntegerType>()
|
|
? printDenseIntElement
|
|
: printDenseFloatElement;
|
|
|
|
// Special case for 0-d and splat tensors.
|
|
if (attr.isSplat()) {
|
|
printEltFn(attr, os, 0);
|
|
return;
|
|
}
|
|
|
|
// Special case for degenerate tensors.
|
|
auto numElements = type.getNumElements();
|
|
if (numElements == 0) {
|
|
for (int i = 0; i < rank; ++i)
|
|
os << '[';
|
|
for (int i = 0; i < rank; ++i)
|
|
os << ']';
|
|
return;
|
|
}
|
|
|
|
// We use a mixed-radix counter to iterate through the shape. When we bump a
|
|
// non-least-significant digit, we emit a close bracket. When we next emit an
|
|
// element we re-open all closed brackets.
|
|
|
|
// The mixed-radix counter, with radices in 'shape'.
|
|
SmallVector<unsigned, 4> counter(rank, 0);
|
|
// The number of brackets that have been opened and not closed.
|
|
unsigned openBrackets = 0;
|
|
|
|
auto bumpCounter = [&]() {
|
|
// Bump the least significant digit.
|
|
++counter[rank - 1];
|
|
// Iterate backwards bubbling back the increment.
|
|
for (unsigned i = rank - 1; i > 0; --i)
|
|
if (counter[i] >= shape[i]) {
|
|
// Index 'i' is rolled over. Bump (i-1) and close a bracket.
|
|
counter[i] = 0;
|
|
++counter[i - 1];
|
|
--openBrackets;
|
|
os << ']';
|
|
}
|
|
};
|
|
|
|
for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
|
|
if (idx != 0)
|
|
os << ", ";
|
|
while (openBrackets++ < rank)
|
|
os << '[';
|
|
openBrackets = rank;
|
|
printEltFn(attr, os, idx);
|
|
bumpCounter();
|
|
}
|
|
while (openBrackets-- > 0)
|
|
os << ']';
|
|
}
|
|
|
|
void ModulePrinter::printType(Type type) {
|
|
// Check for an alias for this type.
|
|
StringRef alias = state.getTypeAlias(type);
|
|
if (!alias.empty()) {
|
|
os << '!' << alias;
|
|
return;
|
|
}
|
|
|
|
switch (type.getKind()) {
|
|
default: {
|
|
auto &dialect = type.getDialect();
|
|
|
|
// Ask the dialect to serialize the type to a string.
|
|
std::string typeName;
|
|
{
|
|
llvm::raw_string_ostream typeNameStr(typeName);
|
|
dialect.printType(type, typeNameStr);
|
|
}
|
|
|
|
printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
|
|
return;
|
|
}
|
|
case Type::Kind::Opaque: {
|
|
auto opaqueTy = type.cast<OpaqueType>();
|
|
printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
|
|
opaqueTy.getTypeData());
|
|
return;
|
|
}
|
|
case StandardTypes::Index:
|
|
os << "index";
|
|
return;
|
|
case StandardTypes::BF16:
|
|
os << "bf16";
|
|
return;
|
|
case StandardTypes::F16:
|
|
os << "f16";
|
|
return;
|
|
case StandardTypes::F32:
|
|
os << "f32";
|
|
return;
|
|
case StandardTypes::F64:
|
|
os << "f64";
|
|
return;
|
|
|
|
case StandardTypes::Integer: {
|
|
auto integer = type.cast<IntegerType>();
|
|
os << 'i' << integer.getWidth();
|
|
return;
|
|
}
|
|
case Type::Kind::Function: {
|
|
auto func = type.cast<FunctionType>();
|
|
os << '(';
|
|
interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
|
|
os << ") -> ";
|
|
auto results = func.getResults();
|
|
if (results.size() == 1 && !results[0].isa<FunctionType>())
|
|
os << results[0];
|
|
else {
|
|
os << '(';
|
|
interleaveComma(results, [&](Type type) { printType(type); });
|
|
os << ')';
|
|
}
|
|
return;
|
|
}
|
|
case StandardTypes::Vector: {
|
|
auto v = type.cast<VectorType>();
|
|
os << "vector<";
|
|
for (auto dim : v.getShape())
|
|
os << dim << 'x';
|
|
os << v.getElementType() << '>';
|
|
return;
|
|
}
|
|
case StandardTypes::RankedTensor: {
|
|
auto v = type.cast<RankedTensorType>();
|
|
os << "tensor<";
|
|
for (auto dim : v.getShape()) {
|
|
if (dim < 0)
|
|
os << '?';
|
|
else
|
|
os << dim;
|
|
os << 'x';
|
|
}
|
|
os << v.getElementType() << '>';
|
|
return;
|
|
}
|
|
case StandardTypes::UnrankedTensor: {
|
|
auto v = type.cast<UnrankedTensorType>();
|
|
os << "tensor<*x";
|
|
printType(v.getElementType());
|
|
os << '>';
|
|
return;
|
|
}
|
|
case StandardTypes::MemRef: {
|
|
auto v = type.cast<MemRefType>();
|
|
os << "memref<";
|
|
for (auto dim : v.getShape()) {
|
|
if (dim < 0)
|
|
os << '?';
|
|
else
|
|
os << dim;
|
|
os << 'x';
|
|
}
|
|
printType(v.getElementType());
|
|
for (auto map : v.getAffineMaps()) {
|
|
os << ", ";
|
|
printAttribute(AffineMapAttr::get(map));
|
|
}
|
|
// Only print the memory space if it is the non-default one.
|
|
if (v.getMemorySpace())
|
|
os << ", " << v.getMemorySpace();
|
|
os << '>';
|
|
return;
|
|
}
|
|
case StandardTypes::Complex:
|
|
os << "complex<";
|
|
printType(type.cast<ComplexType>().getElementType());
|
|
os << '>';
|
|
return;
|
|
case StandardTypes::Tuple: {
|
|
auto tuple = type.cast<TupleType>();
|
|
os << "tuple<";
|
|
interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); });
|
|
os << '>';
|
|
return;
|
|
}
|
|
case StandardTypes::None:
|
|
os << "none";
|
|
return;
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Affine expressions and maps
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ModulePrinter::printAffineExpr(
|
|
AffineExpr expr, llvm::function_ref<void(unsigned, bool)> printValueName) {
|
|
printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
|
|
}
|
|
|
|
void ModulePrinter::printAffineExprInternal(
|
|
AffineExpr expr, BindingStrength enclosingTightness,
|
|
llvm::function_ref<void(unsigned, bool)> printValueName) {
|
|
const char *binopSpelling = nullptr;
|
|
switch (expr.getKind()) {
|
|
case AffineExprKind::SymbolId: {
|
|
unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
|
|
if (printValueName)
|
|
printValueName(pos, /*isSymbol=*/true);
|
|
else
|
|
os << 's' << pos;
|
|
return;
|
|
}
|
|
case AffineExprKind::DimId: {
|
|
unsigned pos = expr.cast<AffineDimExpr>().getPosition();
|
|
if (printValueName)
|
|
printValueName(pos, /*isSymbol=*/false);
|
|
else
|
|
os << 'd' << pos;
|
|
return;
|
|
}
|
|
case AffineExprKind::Constant:
|
|
os << expr.cast<AffineConstantExpr>().getValue();
|
|
return;
|
|
case AffineExprKind::Add:
|
|
binopSpelling = " + ";
|
|
break;
|
|
case AffineExprKind::Mul:
|
|
binopSpelling = " * ";
|
|
break;
|
|
case AffineExprKind::FloorDiv:
|
|
binopSpelling = " floordiv ";
|
|
break;
|
|
case AffineExprKind::CeilDiv:
|
|
binopSpelling = " ceildiv ";
|
|
break;
|
|
case AffineExprKind::Mod:
|
|
binopSpelling = " mod ";
|
|
break;
|
|
}
|
|
|
|
auto binOp = expr.cast<AffineBinaryOpExpr>();
|
|
AffineExpr lhsExpr = binOp.getLHS();
|
|
AffineExpr rhsExpr = binOp.getRHS();
|
|
|
|
// Handle tightly binding binary operators.
|
|
if (binOp.getKind() != AffineExprKind::Add) {
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << '(';
|
|
|
|
// Pretty print multiplication with -1.
|
|
auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
|
|
if (rhsConst && rhsConst.getValue() == -1) {
|
|
os << "-";
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
|
|
return;
|
|
}
|
|
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
|
|
|
|
os << binopSpelling;
|
|
printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
|
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
return;
|
|
}
|
|
|
|
// Print out special "pretty" forms for add.
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << '(';
|
|
|
|
// Pretty print addition to a product that has a negative operand as a
|
|
// subtraction.
|
|
if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
|
|
if (rhs.getKind() == AffineExprKind::Mul) {
|
|
AffineExpr rrhsExpr = rhs.getRHS();
|
|
if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
|
|
if (rrhs.getValue() == -1) {
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Weak,
|
|
printValueName);
|
|
os << " - ";
|
|
if (rhs.getLHS().getKind() == AffineExprKind::Add) {
|
|
printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
|
|
printValueName);
|
|
} else {
|
|
printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
|
|
printValueName);
|
|
}
|
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
return;
|
|
}
|
|
|
|
if (rrhs.getValue() < -1) {
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Weak,
|
|
printValueName);
|
|
os << " - ";
|
|
printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
|
|
printValueName);
|
|
os << " * " << -rrhs.getValue();
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Pretty print addition to a negative number as a subtraction.
|
|
if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
|
|
if (rhsConst.getValue() < 0) {
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
|
|
os << " - " << -rhsConst.getValue();
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
return;
|
|
}
|
|
}
|
|
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
|
|
|
|
os << " + ";
|
|
printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
|
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
}
|
|
|
|
void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
|
|
printAffineExprInternal(expr, BindingStrength::Weak);
|
|
isEq ? os << " == 0" : os << " >= 0";
|
|
}
|
|
|
|
void ModulePrinter::printAffineMap(AffineMap map) {
|
|
// Dimension identifiers.
|
|
os << '(';
|
|
for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
|
|
os << 'd' << i << ", ";
|
|
if (map.getNumDims() >= 1)
|
|
os << 'd' << map.getNumDims() - 1;
|
|
os << ')';
|
|
|
|
// Symbolic identifiers.
|
|
if (map.getNumSymbols() != 0) {
|
|
os << '[';
|
|
for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
|
|
os << 's' << i << ", ";
|
|
if (map.getNumSymbols() >= 1)
|
|
os << 's' << map.getNumSymbols() - 1;
|
|
os << ']';
|
|
}
|
|
|
|
// AffineMap should have at least one result.
|
|
assert(!map.getResults().empty());
|
|
// Result affine expressions.
|
|
os << " -> (";
|
|
interleaveComma(map.getResults(),
|
|
[&](AffineExpr expr) { printAffineExpr(expr); });
|
|
os << ')';
|
|
}
|
|
|
|
void ModulePrinter::printIntegerSet(IntegerSet set) {
|
|
// Dimension identifiers.
|
|
os << '(';
|
|
for (unsigned i = 1; i < set.getNumDims(); ++i)
|
|
os << 'd' << i - 1 << ", ";
|
|
if (set.getNumDims() >= 1)
|
|
os << 'd' << set.getNumDims() - 1;
|
|
os << ')';
|
|
|
|
// Symbolic identifiers.
|
|
if (set.getNumSymbols() != 0) {
|
|
os << '[';
|
|
for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
|
|
os << 's' << i << ", ";
|
|
if (set.getNumSymbols() >= 1)
|
|
os << 's' << set.getNumSymbols() - 1;
|
|
os << ']';
|
|
}
|
|
|
|
// Print constraints.
|
|
os << " : (";
|
|
int numConstraints = set.getNumConstraints();
|
|
for (int i = 1; i < numConstraints; ++i) {
|
|
printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
|
|
os << ", ";
|
|
}
|
|
if (numConstraints >= 1)
|
|
printAffineConstraint(set.getConstraint(numConstraints - 1),
|
|
set.isEq(numConstraints - 1));
|
|
os << ')';
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Operation printing
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
|
ArrayRef<StringRef> elidedAttrs) {
|
|
// If there are no attributes, then there is nothing to be done.
|
|
if (attrs.empty())
|
|
return;
|
|
|
|
// Filter out any attributes that shouldn't be included.
|
|
SmallVector<NamedAttribute, 8> filteredAttrs;
|
|
for (auto attr : attrs) {
|
|
// If the caller has requested that this attribute be ignored, then drop it.
|
|
if (llvm::any_of(elidedAttrs,
|
|
[&](StringRef elided) { return attr.first.is(elided); }))
|
|
continue;
|
|
|
|
// Otherwise add it to our filteredAttrs list.
|
|
filteredAttrs.push_back(attr);
|
|
}
|
|
|
|
// If there are no attributes left to print after filtering, then we're done.
|
|
if (filteredAttrs.empty())
|
|
return;
|
|
|
|
// Otherwise, print them all out in braces.
|
|
os << " {";
|
|
interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
|
|
os << attr.first;
|
|
|
|
// Pretty printing elides the attribute value for unit attributes.
|
|
if (attr.second.isa<UnitAttr>())
|
|
return;
|
|
|
|
os << " = ";
|
|
printAttribute(attr.second);
|
|
});
|
|
os << '}';
|
|
}
|
|
|
|
namespace {
|
|
|
|
// OperationPrinter contains common functionality for printing operations.
|
|
class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
|
|
public:
|
|
OperationPrinter(Operation *op, ModulePrinter &other);
|
|
OperationPrinter(Region *region, ModulePrinter &other);
|
|
|
|
// Methods to print operations.
|
|
void print(Operation *op);
|
|
void print(Block *block, bool printBlockArgs = true,
|
|
bool printBlockTerminator = true);
|
|
|
|
void printOperation(Operation *op);
|
|
void printGenericOp(Operation *op) override;
|
|
|
|
// Implement OpAsmPrinter.
|
|
raw_ostream &getStream() const override { return os; }
|
|
void printType(Type type) override { ModulePrinter::printType(type); }
|
|
void printAttribute(Attribute attr) override {
|
|
ModulePrinter::printAttribute(attr);
|
|
}
|
|
void printOperand(Value *value) override { printValueID(value); }
|
|
|
|
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
|
ArrayRef<StringRef> elidedAttrs = {}) override {
|
|
return ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
|
|
};
|
|
|
|
enum { nameSentinel = ~0U };
|
|
|
|
void printBlockName(Block *block) {
|
|
auto id = getBlockID(block);
|
|
if (id != ~0U)
|
|
os << "^bb" << id;
|
|
else
|
|
os << "^INVALIDBLOCK";
|
|
}
|
|
|
|
unsigned getBlockID(Block *block) {
|
|
auto it = blockIDs.find(block);
|
|
return it != blockIDs.end() ? it->second : ~0U;
|
|
}
|
|
|
|
void printSuccessorAndUseList(Operation *term, unsigned index) override;
|
|
|
|
/// Print a region.
|
|
void printRegion(Region &blocks, bool printEntryBlockArgs,
|
|
bool printBlockTerminators) override {
|
|
os << " {\n";
|
|
if (!blocks.empty()) {
|
|
auto *entryBlock = &blocks.front();
|
|
print(entryBlock,
|
|
printEntryBlockArgs && entryBlock->getNumArguments() != 0,
|
|
printBlockTerminators);
|
|
for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1))
|
|
print(&b);
|
|
}
|
|
os.indent(currentIndent) << "}";
|
|
}
|
|
|
|
void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
|
|
ArrayRef<Value *> operands) override {
|
|
AffineMap map = mapAttr.getValue();
|
|
unsigned numDims = map.getNumDims();
|
|
auto printValueName = [&](unsigned pos, bool isSymbol) {
|
|
unsigned index = isSymbol ? numDims + pos : pos;
|
|
assert(index < operands.size());
|
|
if (isSymbol)
|
|
os << "symbol(";
|
|
printValueID(operands[index]);
|
|
if (isSymbol)
|
|
os << ')';
|
|
};
|
|
|
|
interleaveComma(map.getResults(), [&](AffineExpr expr) {
|
|
printAffineExpr(expr, printValueName);
|
|
});
|
|
}
|
|
|
|
// Number of spaces used for indenting nested operations.
|
|
const static unsigned indentWidth = 2;
|
|
|
|
protected:
|
|
void numberValueID(Value *value);
|
|
void numberValuesInRegion(Region ®ion);
|
|
void numberValuesInBlock(Block &block);
|
|
void printValueID(Value *value, bool printResultNo = true) const;
|
|
|
|
private:
|
|
/// Uniques the given value name within the printer. If the given name
|
|
/// conflicts, it is automatically renamed.
|
|
StringRef uniqueValueName(StringRef name);
|
|
|
|
/// This is the value ID for each SSA value. If this returns ~0, then the
|
|
/// valueID has an entry in valueNames.
|
|
DenseMap<Value *, unsigned> valueIDs;
|
|
DenseMap<Value *, StringRef> valueNames;
|
|
|
|
/// This is the block ID for each block in the current.
|
|
DenseMap<Block *, unsigned> blockIDs;
|
|
|
|
/// This keeps track of all of the non-numeric names that are in flight,
|
|
/// allowing us to check for duplicates.
|
|
/// Note: the value of the map is unused.
|
|
llvm::ScopedHashTable<StringRef, char> usedNames;
|
|
llvm::BumpPtrAllocator usedNameAllocator;
|
|
|
|
// This is the current indentation level for nested structures.
|
|
unsigned currentIndent = 0;
|
|
|
|
/// This is the next value ID to assign in numbering.
|
|
unsigned nextValueID = 0;
|
|
/// This is the next ID to assign to a region entry block argument.
|
|
unsigned nextArgumentID = 0;
|
|
/// This is the next ID to assign when a name conflict is detected.
|
|
unsigned nextConflictID = 0;
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other)
|
|
: ModulePrinter(other) {
|
|
if (op->getNumResults() != 0)
|
|
numberValueID(op->getResult(0));
|
|
for (auto ®ion : op->getRegions())
|
|
numberValuesInRegion(region);
|
|
}
|
|
|
|
OperationPrinter::OperationPrinter(Region *region, ModulePrinter &other)
|
|
: ModulePrinter(other) {
|
|
numberValuesInRegion(*region);
|
|
}
|
|
|
|
/// Number all of the SSA values in the specified region.
|
|
void OperationPrinter::numberValuesInRegion(Region ®ion) {
|
|
// Save the current value ids to allow for numbering values in sibling regions
|
|
// the same.
|
|
unsigned curValueID = nextValueID;
|
|
unsigned curArgumentID = nextArgumentID;
|
|
unsigned curConflictID = nextConflictID;
|
|
|
|
// Push a new used names scope.
|
|
llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
|
|
|
|
// Number the values within this region in a breadth-first order.
|
|
unsigned nextBlockID = 0;
|
|
for (auto &block : region) {
|
|
// Each block gets a unique ID, and all of the operations within it get
|
|
// numbered as well.
|
|
blockIDs[&block] = nextBlockID++;
|
|
numberValuesInBlock(block);
|
|
}
|
|
|
|
// After that we traverse the nested regions.
|
|
// TODO: Rework this loop to not use recursion.
|
|
for (auto &block : region) {
|
|
for (auto &op : block)
|
|
for (auto &nestedRegion : op.getRegions())
|
|
numberValuesInRegion(nestedRegion);
|
|
}
|
|
|
|
// Restore the original value ids.
|
|
nextValueID = curValueID;
|
|
nextArgumentID = curArgumentID;
|
|
nextConflictID = curConflictID;
|
|
}
|
|
|
|
/// Number all of the SSA values in the specified block, without traversing
|
|
/// nested regions.
|
|
void OperationPrinter::numberValuesInBlock(Block &block) {
|
|
// Number the block arguments.
|
|
for (auto *arg : block.getArguments())
|
|
numberValueID(arg);
|
|
|
|
// We number operation that have results, and we only number the first result.
|
|
for (auto &op : block)
|
|
if (op.getNumResults() != 0)
|
|
numberValueID(op.getResult(0));
|
|
}
|
|
|
|
void OperationPrinter::numberValueID(Value *value) {
|
|
assert(!valueIDs.count(value) && "Value numbered multiple times");
|
|
|
|
SmallString<32> specialNameBuffer;
|
|
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
|
|
|
// Give constant integers special names.
|
|
if (auto *op = value->getDefiningOp()) {
|
|
Attribute cst;
|
|
if (m_Constant(&cst).match(op)) {
|
|
Type type = op->getResult(0)->getType();
|
|
if (auto intCst = cst.dyn_cast<IntegerAttr>()) {
|
|
if (type.isIndex()) {
|
|
specialName << 'c' << intCst.getInt();
|
|
} else if (type.cast<IntegerType>().isInteger(1)) {
|
|
// i1 constants get special names.
|
|
specialName << (intCst.getInt() ? "true" : "false");
|
|
} else {
|
|
specialName << 'c' << intCst.getInt() << '_' << type;
|
|
}
|
|
} else if (type.isa<FunctionType>()) {
|
|
specialName << 'f';
|
|
} else {
|
|
specialName << "cst";
|
|
}
|
|
}
|
|
}
|
|
|
|
if (specialNameBuffer.empty()) {
|
|
switch (value->getKind()) {
|
|
case Value::Kind::BlockArgument:
|
|
// If this is an argument to the entry block of a region, give it an 'arg'
|
|
// name.
|
|
if (auto *block = cast<BlockArgument>(value)->getOwner()) {
|
|
auto *parentRegion = block->getParent();
|
|
if (parentRegion && block == &parentRegion->front()) {
|
|
specialName << "arg" << nextArgumentID++;
|
|
break;
|
|
}
|
|
}
|
|
// Otherwise number it normally.
|
|
valueIDs[value] = nextValueID++;
|
|
return;
|
|
case Value::Kind::OpResult:
|
|
// This is an uninteresting result, give it a boring number and be
|
|
// done with it.
|
|
valueIDs[value] = nextValueID++;
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Ok, this value had an interesting name. Remember it with a sentinel.
|
|
valueIDs[value] = nameSentinel;
|
|
valueNames[value] = uniqueValueName(specialName.str());
|
|
}
|
|
|
|
/// Uniques the given value name within the printer. If the given name
|
|
/// conflicts, it is automatically renamed.
|
|
StringRef OperationPrinter::uniqueValueName(StringRef name) {
|
|
// Check to see if this name is already unique.
|
|
if (!usedNames.count(name)) {
|
|
name = name.copy(usedNameAllocator);
|
|
} else {
|
|
// Otherwise, we had a conflict - probe until we find a unique name. This
|
|
// is guaranteed to terminate (and usually in a single iteration) because it
|
|
// generates new names by incrementing nextConflictID.
|
|
SmallString<64> probeName(name);
|
|
probeName.push_back('_');
|
|
while (1) {
|
|
probeName.resize(name.size() + 1);
|
|
probeName += llvm::utostr(nextConflictID++);
|
|
if (!usedNames.count(probeName)) {
|
|
name = StringRef(probeName).copy(usedNameAllocator);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
usedNames.insert(name, char());
|
|
return name;
|
|
}
|
|
|
|
void OperationPrinter::print(Block *block, bool printBlockArgs,
|
|
bool printBlockTerminator) {
|
|
// Print the block label and argument list if requested.
|
|
if (printBlockArgs) {
|
|
os.indent(currentIndent);
|
|
printBlockName(block);
|
|
|
|
// Print the argument list if non-empty.
|
|
if (!block->args_empty()) {
|
|
os << '(';
|
|
interleaveComma(block->getArguments(), [&](BlockArgument *arg) {
|
|
printValueID(arg);
|
|
os << ": ";
|
|
printType(arg->getType());
|
|
});
|
|
os << ')';
|
|
}
|
|
os << ':';
|
|
|
|
// Print out some context information about the predecessors of this block.
|
|
if (!block->getParent()) {
|
|
os << "\t// block is not in a region!";
|
|
} else if (block->hasNoPredecessors()) {
|
|
os << "\t// no predecessors";
|
|
} else if (auto *pred = block->getSinglePredecessor()) {
|
|
os << "\t// pred: ";
|
|
printBlockName(pred);
|
|
} else {
|
|
// We want to print the predecessors in increasing numeric order, not in
|
|
// whatever order the use-list is in, so gather and sort them.
|
|
SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
|
|
for (auto *pred : block->getPredecessors())
|
|
predIDs.push_back({getBlockID(pred), pred});
|
|
llvm::array_pod_sort(predIDs.begin(), predIDs.end());
|
|
|
|
os << "\t// " << predIDs.size() << " preds: ";
|
|
|
|
interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
|
|
printBlockName(pred.second);
|
|
});
|
|
}
|
|
os << '\n';
|
|
}
|
|
|
|
currentIndent += indentWidth;
|
|
auto range = llvm::make_range(
|
|
block->getOperations().begin(),
|
|
std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1));
|
|
for (auto &op : range) {
|
|
print(&op);
|
|
os << '\n';
|
|
}
|
|
currentIndent -= indentWidth;
|
|
}
|
|
|
|
void OperationPrinter::print(Operation *op) {
|
|
os.indent(currentIndent);
|
|
printOperation(op);
|
|
printTrailingLocation(op->getLoc());
|
|
}
|
|
|
|
void OperationPrinter::printValueID(Value *value, bool printResultNo) const {
|
|
int resultNo = -1;
|
|
auto lookupValue = value;
|
|
|
|
// If this is a reference to the result of a multi-result operation or
|
|
// operation, print out the # identifier and make sure to map our lookup
|
|
// to the first result of the operation.
|
|
if (auto *result = dyn_cast<OpResult>(value)) {
|
|
if (result->getOwner()->getNumResults() != 1) {
|
|
resultNo = result->getResultNumber();
|
|
lookupValue = result->getOwner()->getResult(0);
|
|
}
|
|
}
|
|
|
|
auto it = valueIDs.find(lookupValue);
|
|
if (it == valueIDs.end()) {
|
|
os << "<<INVALID SSA VALUE>>";
|
|
return;
|
|
}
|
|
|
|
os << '%';
|
|
if (it->second != nameSentinel) {
|
|
os << it->second;
|
|
} else {
|
|
auto nameIt = valueNames.find(lookupValue);
|
|
assert(nameIt != valueNames.end() && "Didn't have a name entry?");
|
|
os << nameIt->second;
|
|
}
|
|
|
|
if (resultNo != -1 && printResultNo)
|
|
os << '#' << resultNo;
|
|
}
|
|
|
|
void OperationPrinter::printOperation(Operation *op) {
|
|
if (size_t numResults = op->getNumResults()) {
|
|
printValueID(op->getResult(0), /*printResultNo=*/false);
|
|
if (numResults > 1)
|
|
os << ':' << numResults;
|
|
os << " = ";
|
|
}
|
|
|
|
// TODO(riverriddle): FuncOp cannot be round-tripped currently, as
|
|
// FunctionType cannot be used in a TypeAttr.
|
|
if (printGenericOpForm && !isa<FuncOp>(op))
|
|
return printGenericOp(op);
|
|
|
|
// Check to see if this is a known operation. If so, use the registered
|
|
// custom printer hook.
|
|
if (auto *opInfo = op->getAbstractOperation()) {
|
|
opInfo->printAssembly(op, this);
|
|
return;
|
|
}
|
|
|
|
// Otherwise print with the generic assembly form.
|
|
printGenericOp(op);
|
|
}
|
|
|
|
void OperationPrinter::printGenericOp(Operation *op) {
|
|
os << '"';
|
|
printEscapedString(op->getName().getStringRef(), os);
|
|
os << "\"(";
|
|
|
|
// Get the list of operands that are not successor operands.
|
|
unsigned totalNumSuccessorOperands = 0;
|
|
unsigned numSuccessors = op->getNumSuccessors();
|
|
for (unsigned i = 0; i < numSuccessors; ++i)
|
|
totalNumSuccessorOperands += op->getNumSuccessorOperands(i);
|
|
unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands;
|
|
SmallVector<Value *, 8> properOperands(
|
|
op->operand_begin(), std::next(op->operand_begin(), numProperOperands));
|
|
|
|
interleaveComma(properOperands, [&](Value *value) { printValueID(value); });
|
|
|
|
os << ')';
|
|
|
|
// For terminators, print the list of successors and their operands.
|
|
if (numSuccessors != 0) {
|
|
os << '[';
|
|
for (unsigned i = 0; i < numSuccessors; ++i) {
|
|
if (i != 0)
|
|
os << ", ";
|
|
printSuccessorAndUseList(op, i);
|
|
}
|
|
os << ']';
|
|
}
|
|
|
|
// Print regions.
|
|
if (op->getNumRegions() != 0) {
|
|
os << " (";
|
|
interleaveComma(op->getRegions(), [&](Region ®ion) {
|
|
printRegion(region, /*printEntryBlockArgs=*/true,
|
|
/*printBlockTerminators=*/true);
|
|
});
|
|
os << ')';
|
|
}
|
|
|
|
auto attrs = op->getAttrs();
|
|
printOptionalAttrDict(attrs);
|
|
|
|
// Print the type signature of the operation.
|
|
os << " : ";
|
|
printFunctionalType(op);
|
|
}
|
|
|
|
void OperationPrinter::printSuccessorAndUseList(Operation *term,
|
|
unsigned index) {
|
|
printBlockName(term->getSuccessor(index));
|
|
|
|
auto succOperands = term->getSuccessorOperands(index);
|
|
if (succOperands.begin() == succOperands.end())
|
|
return;
|
|
|
|
os << '(';
|
|
interleaveComma(succOperands,
|
|
[this](Value *operand) { printValueID(operand); });
|
|
os << " : ";
|
|
interleaveComma(succOperands,
|
|
[this](Value *operand) { printType(operand->getType()); });
|
|
os << ')';
|
|
}
|
|
|
|
void ModulePrinter::print(ModuleOp module) {
|
|
// Output the aliases at the top level.
|
|
state.printAttributeAliases(os);
|
|
state.printTypeAliases(os);
|
|
|
|
// Print the module.
|
|
OperationPrinter(module, *this).print(module);
|
|
os << '\n';
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// print and dump methods
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void Attribute::print(raw_ostream &os) const {
|
|
ModuleState state(/*no context is known*/ nullptr);
|
|
ModulePrinter(os, state).printAttribute(*this);
|
|
}
|
|
|
|
void Attribute::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void Type::print(raw_ostream &os) {
|
|
ModuleState state(getContext());
|
|
ModulePrinter(os, state).printType(*this);
|
|
}
|
|
|
|
void Type::dump() { print(llvm::errs()); }
|
|
|
|
void AffineMap::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void IntegerSet::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void AffineExpr::print(raw_ostream &os) const {
|
|
if (expr == nullptr) {
|
|
os << "null affine expr";
|
|
return;
|
|
}
|
|
ModuleState state(getContext());
|
|
ModulePrinter(os, state).printAffineExpr(*this);
|
|
}
|
|
|
|
void AffineExpr::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void AffineMap::print(raw_ostream &os) const {
|
|
if (map == nullptr) {
|
|
os << "null affine map";
|
|
return;
|
|
}
|
|
ModuleState state(getContext());
|
|
ModulePrinter(os, state).printAffineMap(*this);
|
|
}
|
|
|
|
void IntegerSet::print(raw_ostream &os) const {
|
|
ModuleState state(/*no context is known*/ nullptr);
|
|
ModulePrinter(os, state).printIntegerSet(*this);
|
|
}
|
|
|
|
void Value::print(raw_ostream &os) {
|
|
switch (getKind()) {
|
|
case Value::Kind::BlockArgument:
|
|
// TODO: Improve this.
|
|
os << "<block argument>\n";
|
|
return;
|
|
case Value::Kind::OpResult:
|
|
return getDefiningOp()->print(os);
|
|
}
|
|
}
|
|
|
|
void Value::dump() { print(llvm::errs()); }
|
|
|
|
void Operation::print(raw_ostream &os) {
|
|
// Handle top-level operations.
|
|
if (!getParent()) {
|
|
ModuleState state(getContext());
|
|
ModulePrinter modulePrinter(os, state);
|
|
OperationPrinter(this, modulePrinter).print(this);
|
|
return;
|
|
}
|
|
|
|
auto region = getContainingRegion();
|
|
if (!region) {
|
|
os << "<<UNLINKED INSTRUCTION>>\n";
|
|
return;
|
|
}
|
|
|
|
// Get the top-level region.
|
|
while (auto *nextRegion = region->getContainingRegion())
|
|
region = nextRegion;
|
|
|
|
ModuleState state(getContext());
|
|
ModulePrinter modulePrinter(os, state);
|
|
OperationPrinter(region, modulePrinter).print(this);
|
|
}
|
|
|
|
void Operation::dump() {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void Block::print(raw_ostream &os) {
|
|
auto region = getParent();
|
|
if (!region) {
|
|
os << "<<UNLINKED BLOCK>>\n";
|
|
return;
|
|
}
|
|
|
|
// Get the top-level region.
|
|
while (auto *nextRegion = region->getContainingRegion())
|
|
region = nextRegion;
|
|
|
|
ModuleState state(region->getContext());
|
|
ModulePrinter modulePrinter(os, state);
|
|
OperationPrinter(region, modulePrinter).print(this);
|
|
}
|
|
|
|
void Block::dump() { print(llvm::errs()); }
|
|
|
|
/// Print out the name of the block without printing its body.
|
|
void Block::printAsOperand(raw_ostream &os, bool printType) {
|
|
auto region = getParent();
|
|
if (!region) {
|
|
os << "<<UNLINKED BLOCK>>\n";
|
|
return;
|
|
}
|
|
|
|
// Get the top-level region.
|
|
while (auto *nextRegion = region->getContainingRegion())
|
|
region = nextRegion;
|
|
|
|
ModuleState state(region->getContext());
|
|
ModulePrinter modulePrinter(os, state);
|
|
OperationPrinter(region, modulePrinter).printBlockName(this);
|
|
}
|
|
|
|
void ModuleOp::print(raw_ostream &os) {
|
|
ModuleState state(getContext());
|
|
state.initialize(*this);
|
|
ModulePrinter(os, state).print(*this);
|
|
}
|
|
|
|
void ModuleOp::dump() { print(llvm::errs()); }
|