NFC: Pass OpAsmPrinter by reference instead of by pointer.
MLIR follows the LLVM style of pass-by-reference. PiperOrigin-RevId: 270401378
This commit is contained in:
parent
854d3f45e9
commit
db85dbe6b6
tensorflow/compiler/mlir/tensorflow/ir
third_party/mlir
include/mlir
Dialect
AffineOps
GPU
LLVMIR
Linalg/IR
LoopOps
SPIRV
StandardOps
VectorOps
IR
lib
Dialect
AffineOps
GPU/IR
LLVMIR/IR
Linalg/IR
LoopOps
SPIRV
StandardOps
VectorOps
IR
test/lib/TestDialect
tools/mlir-tblgen
@ -210,10 +210,10 @@ LogicalResult Verify(GraphOp graph) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void Print(GraphOp graph, OpAsmPrinter *p) {
|
||||
*p << graph.getOperationName();
|
||||
p->printRegion(graph.getOperation()->getRegion(0));
|
||||
p->printOptionalAttrDict(graph.getAttrs());
|
||||
void Print(GraphOp graph, OpAsmPrinter &p) {
|
||||
p << graph.getOperationName();
|
||||
p.printRegion(graph.getOperation()->getRegion(0));
|
||||
p.printOptionalAttrDict(graph.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -257,15 +257,15 @@ ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
||||
namespace {
|
||||
|
||||
void Print(FetchOp fetch, OpAsmPrinter *p) {
|
||||
*p << fetch.getOperationName();
|
||||
void Print(FetchOp fetch, OpAsmPrinter &p) {
|
||||
p << fetch.getOperationName();
|
||||
if (fetch.getNumOperands() > 0) {
|
||||
*p << ' ';
|
||||
p->printOperands(fetch.operand_begin(), fetch.operand_end());
|
||||
*p << " : ";
|
||||
interleaveComma(fetch.getOperandTypes(), *p);
|
||||
p << ' ';
|
||||
p.printOperands(fetch.operand_begin(), fetch.operand_end());
|
||||
p << " : ";
|
||||
interleaveComma(fetch.getOperandTypes(), p);
|
||||
}
|
||||
p->printOptionalAttrDict(fetch.getAttrs());
|
||||
p.printOptionalAttrDict(fetch.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseFetchOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -322,13 +322,13 @@ LogicalResult Verify(IslandOp island) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void Print(IslandOp op, OpAsmPrinter *p) {
|
||||
*p << op.getOperationName();
|
||||
void Print(IslandOp op, OpAsmPrinter &p) {
|
||||
p << op.getOperationName();
|
||||
if (op.getNumOperands()) {
|
||||
// These are always control operand, no explicit type needed.
|
||||
*p << '(';
|
||||
p->printOperands(op.getOperands());
|
||||
*p << ')';
|
||||
p << '(';
|
||||
p.printOperands(op.getOperands());
|
||||
p << ')';
|
||||
}
|
||||
|
||||
// Check if we can print the short "wraps" form: that is if the island
|
||||
@ -342,13 +342,13 @@ void Print(IslandOp op, OpAsmPrinter *p) {
|
||||
std::equal(wrapped_op.getResults().begin(),
|
||||
wrapped_op.getResults().end(),
|
||||
yield_op.getOperands().begin())) {
|
||||
*p << " wraps ";
|
||||
p->printGenericOp(&op.GetBody().front());
|
||||
p << " wraps ";
|
||||
p.printGenericOp(&op.GetBody().front());
|
||||
return;
|
||||
}
|
||||
}
|
||||
p->printRegion(op.getOperation()->getRegion(0));
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
p.printRegion(op.getOperation()->getRegion(0));
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -404,15 +404,15 @@ ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
||||
namespace {
|
||||
|
||||
void Print(YieldOp yield, OpAsmPrinter *p) {
|
||||
*p << yield.getOperationName();
|
||||
void Print(YieldOp yield, OpAsmPrinter &p) {
|
||||
p << yield.getOperationName();
|
||||
if (yield.getNumOperands() > 0) {
|
||||
*p << ' ';
|
||||
p->printOperands(yield.operand_begin(), yield.operand_end());
|
||||
*p << " : ";
|
||||
interleaveComma(yield.getOperandTypes(), *p);
|
||||
p << ' ';
|
||||
p.printOperands(yield.operand_begin(), yield.operand_end());
|
||||
p << " : ";
|
||||
interleaveComma(yield.getOperandTypes(), p);
|
||||
}
|
||||
p->printOptionalAttrDict(yield.getAttrs());
|
||||
p.printOptionalAttrDict(yield.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseYieldOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -469,20 +469,20 @@ ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
|
||||
return parser.parseOptionalAttributeDict(result.attributes);
|
||||
}
|
||||
|
||||
void Print(SwitchOp switch_op, OpAsmPrinter *p) {
|
||||
*p << switch_op.getOperationName() << ' ';
|
||||
p->printOperands(switch_op.getOperands());
|
||||
void Print(SwitchOp switch_op, OpAsmPrinter &p) {
|
||||
p << switch_op.getOperationName() << ' ';
|
||||
p.printOperands(switch_op.getOperands());
|
||||
Type data_operand_ty = switch_op.data()->getType();
|
||||
// If the types aren't perfectly matching, print the functional type syntax
|
||||
// else print the shorter single type.
|
||||
*p << " : ";
|
||||
p << " : ";
|
||||
if (switch_op.trueOutput()->getType() != data_operand_ty ||
|
||||
switch_op.falseOutput()->getType() != data_operand_ty) {
|
||||
p->printFunctionalType(switch_op.getOperation());
|
||||
p.printFunctionalType(switch_op.getOperation());
|
||||
} else {
|
||||
*p << switch_op.getType(0);
|
||||
p << switch_op.getType(0);
|
||||
}
|
||||
p->printOptionalAttrDict(switch_op.getAttrs());
|
||||
p.printOptionalAttrDict(switch_op.getAttrs());
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
@ -514,20 +514,20 @@ LogicalResult Verify(SwitchNOp switchn) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void Print(SwitchNOp switchn, OpAsmPrinter *p) {
|
||||
*p << switchn.getOperationName() << ' ';
|
||||
void Print(SwitchNOp switchn, OpAsmPrinter &p) {
|
||||
p << switchn.getOperationName() << ' ';
|
||||
auto operands = switchn.getOperands();
|
||||
// Print the 2 data operands.
|
||||
p->printOperands(operands.begin(), std::next(operands.begin(), 2));
|
||||
*p << " of " << (switchn.getNumResults() - 1);
|
||||
p.printOperands(operands.begin(), std::next(operands.begin(), 2));
|
||||
p << " of " << (switchn.getNumResults() - 1);
|
||||
// print control dependencies if any
|
||||
if (!llvm::empty(switchn.controlInputs())) {
|
||||
*p << " (";
|
||||
p->printOperands(switchn.controlInputs());
|
||||
*p << ")";
|
||||
p << " (";
|
||||
p.printOperands(switchn.controlInputs());
|
||||
p << ")";
|
||||
}
|
||||
*p << " : " << switchn.getType(0);
|
||||
p->printOptionalAttrDict(switchn.getAttrs(), {"num_outs"});
|
||||
p << " : " << switchn.getType(0);
|
||||
p.printOptionalAttrDict(switchn.getAttrs(), {"num_outs"});
|
||||
}
|
||||
|
||||
ParseResult ParseSwitchNOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -629,7 +629,7 @@ LogicalResult Verify(MergeOp merge) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void Print(MergeOp merge, OpAsmPrinter *p) {
|
||||
void Print(MergeOp merge, OpAsmPrinter &p) {
|
||||
// Use short form only when there are exactly two data operands and their
|
||||
// type matches the output type. Otherwise, use the generic printer.
|
||||
bool use_short_form = true;
|
||||
@ -646,18 +646,18 @@ void Print(MergeOp merge, OpAsmPrinter *p) {
|
||||
}
|
||||
}
|
||||
|
||||
*p << merge.getOperationName() << ' ';
|
||||
p->printOperands(merge.getOperands());
|
||||
p << merge.getOperationName() << ' ';
|
||||
p.printOperands(merge.getOperands());
|
||||
|
||||
// Print the type signature of the operation.
|
||||
*p << " : ";
|
||||
p << " : ";
|
||||
if (!use_short_form || num_data_operands != 2) {
|
||||
p->printFunctionalType(merge.getOperation());
|
||||
p.printFunctionalType(merge.getOperation());
|
||||
} else {
|
||||
*p << output_type;
|
||||
p << output_type;
|
||||
}
|
||||
|
||||
p->printOptionalAttrDict(merge.getAttrs());
|
||||
p.printOptionalAttrDict(merge.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseMergeOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -705,28 +705,28 @@ namespace {
|
||||
// Default number for the parallel_iterations attributes on Enter nodes.
|
||||
constexpr int kDefaultParallelIterations = 10;
|
||||
|
||||
void Print(EnterOp enter, OpAsmPrinter *p) {
|
||||
*p << enter.getOperationName() << ' ';
|
||||
p->printOperands(enter.getOperands());
|
||||
void Print(EnterOp enter, OpAsmPrinter &p) {
|
||||
p << enter.getOperationName() << ' ';
|
||||
p.printOperands(enter.getOperands());
|
||||
|
||||
*p << " frame \"";
|
||||
printEscapedString(enter.frame_name(), p->getStream());
|
||||
*p << "\"";
|
||||
p << " frame \"";
|
||||
printEscapedString(enter.frame_name(), p.getStream());
|
||||
p << "\"";
|
||||
if (enter.parallel_iterations() != kDefaultParallelIterations)
|
||||
*p << " parallel_iterations " << enter.parallel_iterations();
|
||||
if (enter.is_constant()) *p << " constant ";
|
||||
p << " parallel_iterations " << enter.parallel_iterations();
|
||||
if (enter.is_constant()) p << " constant ";
|
||||
|
||||
// If the types aren't perfectly matching, print the functional type syntax
|
||||
// else print the shorter single type.
|
||||
*p << " : ";
|
||||
p << " : ";
|
||||
if (enter.data()->getType() != enter.output()->getType()) {
|
||||
p->printFunctionalType(enter.getOperation());
|
||||
p.printFunctionalType(enter.getOperation());
|
||||
} else {
|
||||
*p << enter.getType(0);
|
||||
p << enter.getType(0);
|
||||
}
|
||||
|
||||
p->printOptionalAttrDict(
|
||||
enter.getAttrs(), {"frame_name", "parallel_iterations", "is_constant"});
|
||||
p.printOptionalAttrDict(enter.getAttrs(),
|
||||
{"frame_name", "parallel_iterations", "is_constant"});
|
||||
}
|
||||
|
||||
ParseResult ParseEnterOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -801,9 +801,9 @@ LogicalResult Verify(NextIterationSourceOp source) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void Print(NextIterationSourceOp next_iteration, OpAsmPrinter *p) {
|
||||
*p << next_iteration.getOperationName() << " : " << next_iteration.getType(0);
|
||||
p->printOptionalAttrDict(next_iteration.getAttrs());
|
||||
void Print(NextIterationSourceOp next_iteration, OpAsmPrinter &p) {
|
||||
p << next_iteration.getOperationName() << " : " << next_iteration.getType(0);
|
||||
p.printOptionalAttrDict(next_iteration.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseNextIterationSourceOp(OpAsmParser &parser,
|
||||
@ -844,13 +844,13 @@ LogicalResult Verify(NextIterationSinkOp sink) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void Print(NextIterationSinkOp next_iteration, OpAsmPrinter *p) {
|
||||
*p << next_iteration.getOperationName() << " [";
|
||||
p->printOperand(next_iteration.getOperand(0));
|
||||
*p << "] ";
|
||||
p->printOperands(llvm::drop_begin(next_iteration.getOperands(), 1));
|
||||
*p << " : " << next_iteration.getOperand(1)->getType();
|
||||
p->printOptionalAttrDict(next_iteration.getAttrs());
|
||||
void Print(NextIterationSinkOp next_iteration, OpAsmPrinter &p) {
|
||||
p << next_iteration.getOperationName() << " [";
|
||||
p.printOperand(next_iteration.getOperand(0));
|
||||
p << "] ";
|
||||
p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1));
|
||||
p << " : " << next_iteration.getOperand(1)->getType();
|
||||
p.printOptionalAttrDict(next_iteration.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseNextIterationSinkOp(OpAsmParser &parser,
|
||||
@ -882,11 +882,11 @@ ParseResult ParseNextIterationSinkOp(OpAsmParser &parser,
|
||||
|
||||
namespace {
|
||||
|
||||
void Print(ExitOp exit, OpAsmPrinter *p) {
|
||||
*p << exit.getOperationName() << ' ';
|
||||
p->printOperands(exit.getOperands());
|
||||
*p << " : " << exit.getType(0);
|
||||
p->printOptionalAttrDict(exit.getAttrs());
|
||||
void Print(ExitOp exit, OpAsmPrinter &p) {
|
||||
p << exit.getOperationName() << ' ';
|
||||
p.printOperands(exit.getOperands());
|
||||
p << " : " << exit.getType(0);
|
||||
p.printOptionalAttrDict(exit.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -914,10 +914,10 @@ ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
||||
namespace {
|
||||
|
||||
void Print(ControlTriggerOp trigger, OpAsmPrinter *p) {
|
||||
*p << trigger.getOperationName() << ' ';
|
||||
p->printOperands(trigger.getOperands());
|
||||
p->printOptionalAttrDict(trigger.getAttrs());
|
||||
void Print(ControlTriggerOp trigger, OpAsmPrinter &p) {
|
||||
p << trigger.getOperationName() << ' ';
|
||||
p.printOperands(trigger.getOperands());
|
||||
p.printOptionalAttrDict(trigger.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseControlTriggerOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -944,19 +944,19 @@ ParseResult ParseControlTriggerOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
||||
namespace {
|
||||
|
||||
void Print(LoopCondOp loop_cond, OpAsmPrinter *p) {
|
||||
*p << loop_cond.getOperationName() << ' ';
|
||||
p->printOperands(loop_cond.getOperands());
|
||||
void Print(LoopCondOp loop_cond, OpAsmPrinter &p) {
|
||||
p << loop_cond.getOperationName() << ' ';
|
||||
p.printOperands(loop_cond.getOperands());
|
||||
|
||||
// If the types aren't matching (broadcast), print the functional type syntax.
|
||||
if (loop_cond.input()->getType() != loop_cond.output()->getType()) {
|
||||
*p << " : ";
|
||||
p->printFunctionalType(loop_cond.getOperation());
|
||||
p << " : ";
|
||||
p.printFunctionalType(loop_cond.getOperation());
|
||||
} else {
|
||||
*p << " : " << loop_cond.input()->getType();
|
||||
p << " : " << loop_cond.input()->getType();
|
||||
}
|
||||
|
||||
p->printOptionalAttrDict(loop_cond.getAttrs());
|
||||
p.printOptionalAttrDict(loop_cond.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseLoopCondOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
@ -87,7 +87,7 @@ public:
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter *p);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
OpFoldResult fold(ArrayRef<Attribute> operands);
|
||||
|
||||
@ -287,7 +287,7 @@ public:
|
||||
|
||||
static StringRef getOperationName() { return "affine.dma_start"; }
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter *p);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context);
|
||||
@ -372,7 +372,7 @@ public:
|
||||
|
||||
static StringRef getTagMapAttrName() { return "tag_map"; }
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter *p);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context);
|
||||
@ -440,7 +440,7 @@ public:
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter *p);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context);
|
||||
@ -511,7 +511,7 @@ public:
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter *p);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context);
|
||||
|
@ -39,7 +39,7 @@ def Affine_Dialect : Dialect {
|
||||
class Affine_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Affine_Dialect, mnemonic, traits> {
|
||||
// For every affine op, there needs to be a:
|
||||
// * void print(OpAsmPrinter *p, ${C++ class of Op} op)
|
||||
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
|
||||
// * LogicalResult verify(${C++ class of Op} op)
|
||||
// * ParseResult parse${C++ class of Op}(OpAsmParser &parser,
|
||||
// OperationState &result)
|
||||
|
@ -98,7 +98,7 @@ public:
|
||||
LogicalResult verify();
|
||||
|
||||
/// Custom syntax support.
|
||||
void print(OpAsmPrinter *p);
|
||||
void print(OpAsmPrinter &p);
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
|
||||
static StringRef getOperationName() { return "gpu.launch"; }
|
||||
|
@ -56,7 +56,7 @@ def gpu_Return : GPU_Op<"return", [Terminator]>, Arguments<(ins)>,
|
||||
}];
|
||||
|
||||
let parser = [{ return success(); }];
|
||||
let printer = [{ *p << getOperationName(); }];
|
||||
let printer = [{ p << getOperationName(); }];
|
||||
}
|
||||
|
||||
#endif // GPU_OPS
|
||||
|
@ -452,7 +452,7 @@ def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []> {
|
||||
def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
|
||||
string llvmBuilder = [{ builder.CreateUnreachable(); }];
|
||||
let parser = [{ return success(); }];
|
||||
let printer = [{ *p << getOperationName(); }];
|
||||
let printer = [{ p << getOperationName(); }];
|
||||
}
|
||||
|
||||
// Auxiliary operations (do not appear in LLVM IR but necessary for the dialect
|
||||
|
@ -30,7 +30,7 @@ include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
||||
class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Linalg_Dialect, mnemonic, traits> {
|
||||
// For every linalg op, there needs to be a:
|
||||
// * void print(OpAsmPrinter *p, ${C++ class of Op} op)
|
||||
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
|
||||
// * LogicalResult verify(${C++ class of Op} op)
|
||||
// * ParseResult parse${C++ class of Op}(OpAsmParser &parser,
|
||||
// OperationState &result)
|
||||
|
@ -37,7 +37,7 @@ def Loop_Dialect : Dialect {
|
||||
class Loop_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Loop_Dialect, mnemonic, traits> {
|
||||
// For every standard op, there needs to be a:
|
||||
// * void print(OpAsmPrinter *p, ${C++ class of Op} op)
|
||||
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
|
||||
// * LogicalResult verify(${C++ class of Op} op)
|
||||
// * ParseResult parse${C++ class of Op}(OpAsmParser &parser,
|
||||
// OperationState &result)
|
||||
|
@ -1138,7 +1138,7 @@ class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
//
|
||||
// * static ParseResult parse<op-c++-class-name>(OpAsmParser &parser,
|
||||
// OperationState &result)
|
||||
// * static void print(OpAsmPrinter *p, <op-c++-class-name> op)
|
||||
// * static void print(OpAsmPrinter &p, <op-c++-class-name> op)
|
||||
// * static LogicalResult verify(<op-c++-class-name> op)
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
let printer = [{ return ::print(*this, p); }];
|
||||
|
@ -278,7 +278,7 @@ public:
|
||||
|
||||
static StringRef getOperationName() { return "std.dma_start"; }
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter *p);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
|
||||
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
@ -343,7 +343,7 @@ public:
|
||||
Value *getNumElements() { return getOperand(1 + getTagMemRefRank()); }
|
||||
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter *p);
|
||||
void print(OpAsmPrinter &p);
|
||||
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context);
|
||||
};
|
||||
@ -351,7 +351,7 @@ public:
|
||||
/// Prints dimension and symbol list.
|
||||
void printDimAndSymbolList(Operation::operand_iterator begin,
|
||||
Operation::operand_iterator end, unsigned numDims,
|
||||
OpAsmPrinter *p);
|
||||
OpAsmPrinter &p);
|
||||
|
||||
/// Parses dimension and symbol list and returns true if parsing failed.
|
||||
ParseResult parseDimAndSymbolList(OpAsmParser &parser,
|
||||
|
@ -37,7 +37,7 @@ def Std_Dialect : Dialect {
|
||||
class Std_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Std_Dialect, mnemonic, traits> {
|
||||
// For every standard op, there needs to be a:
|
||||
// * void print(OpAsmPrinter *p, ${C++ class of Op} op)
|
||||
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
|
||||
// * LogicalResult verify(${C++ class of Op} op)
|
||||
// * ParseResult parse${C++ class of Op}(OpAsmParser &parser,
|
||||
// OperationState &result)
|
||||
|
@ -116,7 +116,7 @@ public:
|
||||
AffineMap getPermutationMap();
|
||||
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter *p);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
@ -177,7 +177,7 @@ public:
|
||||
AffineMap getPermutationMap();
|
||||
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter *p);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
@ -199,7 +199,7 @@ public:
|
||||
static void build(Builder *builder, OperationState &result, Value *srcVector,
|
||||
Type dstType);
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter *p);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
|
@ -37,7 +37,7 @@ def Vector_Dialect : Dialect {
|
||||
class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Vector_Dialect, mnemonic, traits> {
|
||||
// For every vector op, there needs to be a:
|
||||
// * void print(OpAsmPrinter *p, ${C++ class of Op} op)
|
||||
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
|
||||
// * LogicalResult verify(${C++ class of Op} op)
|
||||
// * ParseResult parse${C++ class of Op}(OpAsmParser &parser,
|
||||
// OperationState &result)
|
||||
|
2
third_party/mlir/include/mlir/IR/Function.h
vendored
2
third_party/mlir/include/mlir/IR/Function.h
vendored
@ -60,7 +60,7 @@ public:
|
||||
|
||||
/// Operation hooks.
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter *p);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
|
||||
/// Returns the type of this function.
|
||||
|
@ -85,7 +85,7 @@ ParseResult parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
|
||||
|
||||
/// Printer implementation for function-like operations. Accepts lists of
|
||||
/// argument and result types to use while printing.
|
||||
void printFunctionLikeOp(OpAsmPrinter *p, Operation *op,
|
||||
void printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
|
||||
ArrayRef<Type> argTypes, bool isVariadic,
|
||||
ArrayRef<Type> results);
|
||||
|
||||
|
2
third_party/mlir/include/mlir/IR/Module.h
vendored
2
third_party/mlir/include/mlir/IR/Module.h
vendored
@ -53,7 +53,7 @@ public:
|
||||
|
||||
/// Operation hooks.
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter *p);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
|
||||
/// Return body of this module.
|
||||
|
@ -213,7 +213,7 @@ protected:
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
|
||||
// The fallback for the printer is to print it the generic assembly form.
|
||||
void print(OpAsmPrinter *p);
|
||||
void print(OpAsmPrinter &p);
|
||||
|
||||
/// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
|
||||
/// so we can cast it away here.
|
||||
@ -941,7 +941,7 @@ public:
|
||||
|
||||
/// This is the hook used by the AsmPrinter to emit this to the .mlir file.
|
||||
/// Op implementations should provide a print method.
|
||||
static void printAssembly(Operation *op, OpAsmPrinter *p) {
|
||||
static void printAssembly(Operation *op, OpAsmPrinter &p) {
|
||||
auto opPointer = dyn_cast<ConcreteType>(op);
|
||||
assert(opPointer &&
|
||||
"op's name does not match name of concrete type instantiated with");
|
||||
@ -1149,7 +1149,7 @@ ParseResult parseBinaryOp(OpAsmParser &parser, OperationState &result);
|
||||
// Prints the given binary `op` in custom assembly form if both the two operands
|
||||
// and the result have the same time. Otherwise, prints the generic assembly
|
||||
// form.
|
||||
void printBinaryOp(Operation *op, OpAsmPrinter *p);
|
||||
void printBinaryOp(Operation *op, OpAsmPrinter &p);
|
||||
} // namespace impl
|
||||
|
||||
// These functions are out-of-line implementations of the methods in CastOp,
|
||||
@ -1158,7 +1158,7 @@ namespace impl {
|
||||
void buildCastOp(Builder *builder, OperationState &result, Value *source,
|
||||
Type destType);
|
||||
ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
|
||||
void printCastOp(Operation *op, OpAsmPrinter *p);
|
||||
void printCastOp(Operation *op, OpAsmPrinter &p);
|
||||
Value *foldCastOp(Operation *op);
|
||||
} // namespace impl
|
||||
} // end namespace mlir
|
||||
|
@ -101,7 +101,7 @@ public:
|
||||
ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result);
|
||||
|
||||
/// This hook implements the AsmPrinter for this operation.
|
||||
void (&printAssembly)(Operation *op, OpAsmPrinter *p);
|
||||
void (&printAssembly)(Operation *op, OpAsmPrinter &p);
|
||||
|
||||
/// This hook implements the verifier for this operation. It should emits an
|
||||
/// error message and returns failure if a problem is detected, or returns
|
||||
@ -172,7 +172,7 @@ private:
|
||||
StringRef name, Dialect &dialect, OperationProperties opProperties,
|
||||
bool (&classof)(Operation *op),
|
||||
ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result),
|
||||
void (&printAssembly)(Operation *op, OpAsmPrinter *p),
|
||||
void (&printAssembly)(Operation *op, OpAsmPrinter &p),
|
||||
LogicalResult (&verifyInvariants)(Operation *op),
|
||||
LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results),
|
||||
|
126
third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
vendored
126
third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
vendored
@ -208,11 +208,11 @@ ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void AffineApplyOp::print(OpAsmPrinter *p) {
|
||||
*p << "affine.apply " << getAttr("map");
|
||||
void AffineApplyOp::print(OpAsmPrinter &p) {
|
||||
p << "affine.apply " << getAttr("map");
|
||||
printDimAndSymbolList(operand_begin(), operand_end(),
|
||||
getAffineMap().getNumDims(), p);
|
||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"map"});
|
||||
p.printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"map"});
|
||||
}
|
||||
|
||||
LogicalResult AffineApplyOp::verify() {
|
||||
@ -816,23 +816,23 @@ void AffineDmaStartOp::build(Builder *builder, OperationState &result,
|
||||
}
|
||||
}
|
||||
|
||||
void AffineDmaStartOp::print(OpAsmPrinter *p) {
|
||||
*p << "affine.dma_start " << *getSrcMemRef() << '[';
|
||||
void AffineDmaStartOp::print(OpAsmPrinter &p) {
|
||||
p << "affine.dma_start " << *getSrcMemRef() << '[';
|
||||
SmallVector<Value *, 8> operands(getSrcIndices());
|
||||
p->printAffineMapOfSSAIds(getSrcMapAttr(), operands);
|
||||
*p << "], " << *getDstMemRef() << '[';
|
||||
p.printAffineMapOfSSAIds(getSrcMapAttr(), operands);
|
||||
p << "], " << *getDstMemRef() << '[';
|
||||
operands.assign(getDstIndices().begin(), getDstIndices().end());
|
||||
p->printAffineMapOfSSAIds(getDstMapAttr(), operands);
|
||||
*p << "], " << *getTagMemRef() << '[';
|
||||
p.printAffineMapOfSSAIds(getDstMapAttr(), operands);
|
||||
p << "], " << *getTagMemRef() << '[';
|
||||
operands.assign(getTagIndices().begin(), getTagIndices().end());
|
||||
p->printAffineMapOfSSAIds(getTagMapAttr(), operands);
|
||||
*p << "], " << *getNumElements();
|
||||
p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
|
||||
p << "], " << *getNumElements();
|
||||
if (isStrided()) {
|
||||
*p << ", " << *getStride();
|
||||
*p << ", " << *getNumElementsPerStride();
|
||||
p << ", " << *getStride();
|
||||
p << ", " << *getNumElementsPerStride();
|
||||
}
|
||||
*p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
|
||||
<< getTagMemRefType();
|
||||
p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
|
||||
<< getTagMemRefType();
|
||||
}
|
||||
|
||||
// Parse AffineDmaStartOp.
|
||||
@ -975,13 +975,13 @@ void AffineDmaWaitOp::build(Builder *builder, OperationState &result,
|
||||
result.addOperands(numElements);
|
||||
}
|
||||
|
||||
void AffineDmaWaitOp::print(OpAsmPrinter *p) {
|
||||
*p << "affine.dma_wait " << *getTagMemRef() << '[';
|
||||
void AffineDmaWaitOp::print(OpAsmPrinter &p) {
|
||||
p << "affine.dma_wait " << *getTagMemRef() << '[';
|
||||
SmallVector<Value *, 2> operands(getTagIndices());
|
||||
p->printAffineMapOfSSAIds(getTagMapAttr(), operands);
|
||||
*p << "], ";
|
||||
p->printOperand(getNumElements());
|
||||
*p << " : " << getTagMemRef()->getType();
|
||||
p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
|
||||
p << "], ";
|
||||
p.printOperand(getNumElements());
|
||||
p << " : " << getTagMemRef()->getType();
|
||||
}
|
||||
|
||||
// Parse AffineDmaWaitOp.
|
||||
@ -1258,7 +1258,7 @@ ParseResult parseAffineForOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
||||
static void printBound(AffineMapAttr boundMap,
|
||||
Operation::operand_range boundOperands,
|
||||
const char *prefix, OpAsmPrinter *p) {
|
||||
const char *prefix, OpAsmPrinter &p) {
|
||||
AffineMap map = boundMap.getValue();
|
||||
|
||||
// Check if this bound should be printed using custom assembly form.
|
||||
@ -1273,7 +1273,7 @@ static void printBound(AffineMapAttr boundMap,
|
||||
// Print constant bound.
|
||||
if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
|
||||
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
|
||||
*p << constExpr.getValue();
|
||||
p << constExpr.getValue();
|
||||
return;
|
||||
}
|
||||
}
|
||||
@ -1282,38 +1282,38 @@ static void printBound(AffineMapAttr boundMap,
|
||||
// single symbol.
|
||||
if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
|
||||
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
|
||||
p->printOperand(*boundOperands.begin());
|
||||
p.printOperand(*boundOperands.begin());
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Map has multiple results. Print 'min' or 'max' prefix.
|
||||
*p << prefix << ' ';
|
||||
p << prefix << ' ';
|
||||
}
|
||||
|
||||
// Print the map and its operands.
|
||||
*p << boundMap;
|
||||
p << boundMap;
|
||||
printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
|
||||
map.getNumDims(), p);
|
||||
}
|
||||
|
||||
void print(OpAsmPrinter *p, AffineForOp op) {
|
||||
*p << "affine.for ";
|
||||
p->printOperand(op.getBody()->getArgument(0));
|
||||
*p << " = ";
|
||||
void print(OpAsmPrinter &p, AffineForOp op) {
|
||||
p << "affine.for ";
|
||||
p.printOperand(op.getBody()->getArgument(0));
|
||||
p << " = ";
|
||||
printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
|
||||
*p << " to ";
|
||||
p << " to ";
|
||||
printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
|
||||
|
||||
if (op.getStep() != 1)
|
||||
*p << " step " << op.getStep();
|
||||
p->printRegion(op.region(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
p->printOptionalAttrDict(op.getAttrs(),
|
||||
/*elidedAttrs=*/{op.getLowerBoundAttrName(),
|
||||
op.getUpperBoundAttrName(),
|
||||
op.getStepAttrName()});
|
||||
p << " step " << op.getStep();
|
||||
p.printRegion(op.region(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
/*elidedAttrs=*/{op.getLowerBoundAttrName(),
|
||||
op.getUpperBoundAttrName(),
|
||||
op.getStepAttrName()});
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -1636,28 +1636,28 @@ ParseResult parseAffineIfOp(OpAsmParser &parser, OperationState &result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void print(OpAsmPrinter *p, AffineIfOp op) {
|
||||
void print(OpAsmPrinter &p, AffineIfOp op) {
|
||||
auto conditionAttr =
|
||||
op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
|
||||
*p << "affine.if " << conditionAttr;
|
||||
p << "affine.if " << conditionAttr;
|
||||
printDimAndSymbolList(op.operand_begin(), op.operand_end(),
|
||||
conditionAttr.getValue().getNumDims(), p);
|
||||
p->printRegion(op.thenRegion(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
p.printRegion(op.thenRegion(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
|
||||
// Print the 'else' regions if it has any blocks.
|
||||
auto &elseRegion = op.elseRegion();
|
||||
if (!elseRegion.empty()) {
|
||||
*p << " else";
|
||||
p->printRegion(elseRegion,
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
p << " else";
|
||||
p.printRegion(elseRegion,
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
}
|
||||
|
||||
// Print the attribute list.
|
||||
p->printOptionalAttrDict(op.getAttrs(),
|
||||
/*elidedAttrs=*/op.getConditionAttrName());
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
/*elidedAttrs=*/op.getConditionAttrName());
|
||||
}
|
||||
|
||||
IntegerSet AffineIfOp::getIntegerSet() {
|
||||
@ -1771,16 +1771,16 @@ ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
parser.addTypeToList(type.getElementType(), result.types));
|
||||
}
|
||||
|
||||
void AffineLoadOp::print(OpAsmPrinter *p) {
|
||||
*p << "affine.load " << *getMemRef() << '[';
|
||||
void AffineLoadOp::print(OpAsmPrinter &p) {
|
||||
p << "affine.load " << *getMemRef() << '[';
|
||||
AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
|
||||
if (mapAttr) {
|
||||
SmallVector<Value *, 2> operands(getMapOperands());
|
||||
p->printAffineMapOfSSAIds(mapAttr, operands);
|
||||
p.printAffineMapOfSSAIds(mapAttr, operands);
|
||||
}
|
||||
*p << ']';
|
||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
|
||||
*p << " : " << getMemRefType();
|
||||
p << ']';
|
||||
p.printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
|
||||
p << " : " << getMemRefType();
|
||||
}
|
||||
|
||||
LogicalResult AffineLoadOp::verify() {
|
||||
@ -1865,17 +1865,17 @@ ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
parser.resolveOperands(mapOperands, affineIntTy, result.operands));
|
||||
}
|
||||
|
||||
void AffineStoreOp::print(OpAsmPrinter *p) {
|
||||
*p << "affine.store " << *getValueToStore();
|
||||
*p << ", " << *getMemRef() << '[';
|
||||
void AffineStoreOp::print(OpAsmPrinter &p) {
|
||||
p << "affine.store " << *getValueToStore();
|
||||
p << ", " << *getMemRef() << '[';
|
||||
AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
|
||||
if (mapAttr) {
|
||||
SmallVector<Value *, 2> operands(getMapOperands());
|
||||
p->printAffineMapOfSSAIds(mapAttr, operands);
|
||||
p.printAffineMapOfSSAIds(mapAttr, operands);
|
||||
}
|
||||
*p << ']';
|
||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
|
||||
*p << " : " << getMemRefType();
|
||||
p << ']';
|
||||
p.printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
|
||||
p << " : " << getMemRefType();
|
||||
}
|
||||
|
||||
LogicalResult AffineStoreOp::verify() {
|
||||
|
@ -169,22 +169,22 @@ LogicalResult LaunchOp::verify() {
|
||||
// (%iter-x, %iter-y, %iter-z) in
|
||||
// (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
|
||||
// where %size-* and %iter-* will correspond to the body region arguments.
|
||||
static void printSizeAssignment(OpAsmPrinter *p, KernelDim3 size,
|
||||
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size,
|
||||
ArrayRef<Value *> operands, KernelDim3 ids) {
|
||||
*p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in (";
|
||||
*p << *size.x << " = " << *operands[0] << ", ";
|
||||
*p << *size.y << " = " << *operands[1] << ", ";
|
||||
*p << *size.z << " = " << *operands[2] << ')';
|
||||
p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in (";
|
||||
p << *size.x << " = " << *operands[0] << ", ";
|
||||
p << *size.y << " = " << *operands[1] << ", ";
|
||||
p << *size.z << " = " << *operands[2] << ')';
|
||||
}
|
||||
|
||||
void LaunchOp::print(OpAsmPrinter *p) {
|
||||
void LaunchOp::print(OpAsmPrinter &p) {
|
||||
SmallVector<Value *, 12> operandContainer(operand_begin(), operand_end());
|
||||
ArrayRef<Value *> operands(operandContainer);
|
||||
|
||||
// Print the launch configuration.
|
||||
*p << getOperationName() << ' ' << getBlocksKeyword();
|
||||
p << getOperationName() << ' ' << getBlocksKeyword();
|
||||
printSizeAssignment(p, getGridSize(), operands.take_front(3), getBlockIds());
|
||||
*p << ' ' << getThreadsKeyword();
|
||||
p << ' ' << getThreadsKeyword();
|
||||
printSizeAssignment(p, getBlockSize(), operands.slice(3, 3), getThreadIds());
|
||||
|
||||
// From now on, the first kNumConfigOperands operands corresponding to grid
|
||||
@ -193,28 +193,28 @@ void LaunchOp::print(OpAsmPrinter *p) {
|
||||
|
||||
// Print the data argument remapping.
|
||||
if (!getBody().empty() && !operands.empty()) {
|
||||
*p << ' ' << getArgsKeyword() << '(';
|
||||
p << ' ' << getArgsKeyword() << '(';
|
||||
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
|
||||
if (i != 0)
|
||||
*p << ", ";
|
||||
*p << *getBody().front().getArgument(kNumConfigRegionAttributes + i)
|
||||
<< " = " << *operands[i];
|
||||
p << ", ";
|
||||
p << *getBody().front().getArgument(kNumConfigRegionAttributes + i)
|
||||
<< " = " << *operands[i];
|
||||
}
|
||||
*p << ") ";
|
||||
p << ") ";
|
||||
}
|
||||
|
||||
// Print the types of data arguments.
|
||||
if (!operands.empty()) {
|
||||
*p << ": ";
|
||||
p << ": ";
|
||||
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
|
||||
if (i != 0)
|
||||
*p << ", ";
|
||||
*p << operands[i]->getType();
|
||||
p << ", ";
|
||||
p << operands[i]->getType();
|
||||
}
|
||||
}
|
||||
|
||||
p->printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
p->printOptionalAttrDict(getAttrs());
|
||||
p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
p.printOptionalAttrDict(getAttrs());
|
||||
}
|
||||
|
||||
// Parse the size assignment blocks for blocks and threads. These have the form
|
||||
|
@ -40,18 +40,18 @@ using namespace mlir::LLVM;
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing/parsing for LLVM::CmpOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
static void printICmpOp(OpAsmPrinter *p, ICmpOp &op) {
|
||||
*p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate())
|
||||
<< "\" " << *op.getOperand(0) << ", " << *op.getOperand(1);
|
||||
p->printOptionalAttrDict(op.getAttrs(), {"predicate"});
|
||||
*p << " : " << op.lhs()->getType();
|
||||
static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) {
|
||||
p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate())
|
||||
<< "\" " << *op.getOperand(0) << ", " << *op.getOperand(1);
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
|
||||
p << " : " << op.lhs()->getType();
|
||||
}
|
||||
|
||||
static void printFCmpOp(OpAsmPrinter *p, FCmpOp &op) {
|
||||
*p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
|
||||
<< "\" " << *op.getOperand(0) << ", " << *op.getOperand(1);
|
||||
p->printOptionalAttrDict(op.getAttrs(), {"predicate"});
|
||||
*p << " : " << op.lhs()->getType();
|
||||
static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) {
|
||||
p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
|
||||
<< "\" " << *op.getOperand(0) << ", " << *op.getOperand(1);
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
|
||||
p << " : " << op.lhs()->getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
|
||||
@ -124,18 +124,18 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
|
||||
// Printing/parsing for LLVM::AllocaOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) {
|
||||
static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
|
||||
auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
|
||||
|
||||
auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()},
|
||||
op.getContext());
|
||||
|
||||
*p << op.getOperationName() << ' ' << *op.arraySize() << " x " << elemTy;
|
||||
p << op.getOperationName() << ' ' << *op.arraySize() << " x " << elemTy;
|
||||
if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0)
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
else
|
||||
p->printOptionalAttrDict(op.getAttrs(), {"alignment"});
|
||||
*p << " : " << funcTy;
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"alignment"});
|
||||
p << " : " << funcTy;
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
|
||||
@ -171,15 +171,15 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
|
||||
// Printing/parsing for LLVM::GEPOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printGEPOp(OpAsmPrinter *p, GEPOp &op) {
|
||||
static void printGEPOp(OpAsmPrinter &p, GEPOp &op) {
|
||||
SmallVector<Type, 8> types(op.getOperandTypes());
|
||||
auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
|
||||
|
||||
*p << op.getOperationName() << ' ' << *op.base() << '[';
|
||||
p->printOperands(std::next(op.operand_begin()), op.operand_end());
|
||||
*p << ']';
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << funcTy;
|
||||
p << op.getOperationName() << ' ' << *op.base() << '[';
|
||||
p.printOperands(std::next(op.operand_begin()), op.operand_end());
|
||||
p << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << funcTy;
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]`
|
||||
@ -219,10 +219,10 @@ static ParseResult parseGEPOp(OpAsmParser &parser, OperationState &result) {
|
||||
// Printing/parsing for LLVM::LoadOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printLoadOp(OpAsmPrinter *p, LoadOp &op) {
|
||||
*p << op.getOperationName() << ' ' << *op.addr();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.addr()->getType();
|
||||
static void printLoadOp(OpAsmPrinter &p, LoadOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.addr();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.addr()->getType();
|
||||
}
|
||||
|
||||
// Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return
|
||||
@ -263,10 +263,10 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
||||
// Printing/parsing for LLVM::StoreOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printStoreOp(OpAsmPrinter *p, StoreOp &op) {
|
||||
*p << op.getOperationName() << ' ' << *op.value() << ", " << *op.addr();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.addr()->getType();
|
||||
static void printStoreOp(OpAsmPrinter &p, StoreOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.value() << ", " << *op.addr();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.addr()->getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type
|
||||
@ -298,30 +298,30 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
||||
// Printing/parsing for LLVM::CallOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printCallOp(OpAsmPrinter *p, CallOp &op) {
|
||||
static void printCallOp(OpAsmPrinter &p, CallOp &op) {
|
||||
auto callee = op.callee();
|
||||
bool isDirect = callee.hasValue();
|
||||
|
||||
// Print the direct callee if present as a function attribute, or an indirect
|
||||
// callee (first operand) otherwise.
|
||||
*p << op.getOperationName() << ' ';
|
||||
p << op.getOperationName() << ' ';
|
||||
if (isDirect)
|
||||
*p << '@' << callee.getValue();
|
||||
p << '@' << callee.getValue();
|
||||
else
|
||||
*p << *op.getOperand(0);
|
||||
p << *op.getOperand(0);
|
||||
|
||||
*p << '(';
|
||||
p->printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1));
|
||||
*p << ')';
|
||||
p << '(';
|
||||
p.printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1));
|
||||
p << ')';
|
||||
|
||||
p->printOptionalAttrDict(op.getAttrs(), {"callee"});
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"callee"});
|
||||
|
||||
// Reconstruct the function MLIR function type from operand and result types.
|
||||
SmallVector<Type, 1> resultTypes(op.getResultTypes());
|
||||
SmallVector<Type, 8> argTypes(
|
||||
llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
|
||||
|
||||
*p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
|
||||
p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
|
||||
@ -425,10 +425,10 @@ void LLVM::ExtractElementOp::build(Builder *b, OperationState &result,
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
static void printExtractElementOp(OpAsmPrinter *p, ExtractElementOp &op) {
|
||||
*p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.position();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.vector()->getType();
|
||||
static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.position();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.vector()->getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
|
||||
@ -461,10 +461,10 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
||||
// Printing/parsing for LLVM::ExtractValueOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printExtractValueOp(OpAsmPrinter *p, ExtractValueOp &op) {
|
||||
*p << op.getOperationName() << ' ' << *op.container() << op.position();
|
||||
p->printOptionalAttrDict(op.getAttrs(), {"position"});
|
||||
*p << " : " << op.container()->getType();
|
||||
static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.container() << op.position();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"position"});
|
||||
p << " : " << op.container()->getType();
|
||||
}
|
||||
|
||||
// Extract the type at `position` in the wrapped LLVM IR aggregate type
|
||||
@ -553,11 +553,11 @@ static ParseResult parseExtractValueOp(OpAsmParser &parser,
|
||||
// Printing/parsing for LLVM::InsertElementOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printInsertElementOp(OpAsmPrinter *p, InsertElementOp &op) {
|
||||
*p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.value()
|
||||
<< ", " << *op.position();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.vector()->getType();
|
||||
static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.value()
|
||||
<< ", " << *op.position();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.vector()->getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
|
||||
@ -599,11 +599,11 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser,
|
||||
// Printing/parsing for LLVM::InsertValueOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printInsertValueOp(OpAsmPrinter *p, InsertValueOp &op) {
|
||||
*p << op.getOperationName() << ' ' << *op.value() << ", " << *op.container()
|
||||
<< op.position();
|
||||
p->printOptionalAttrDict(op.getAttrs(), {"position"});
|
||||
*p << " : " << op.container()->getType();
|
||||
static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.value() << ", " << *op.container()
|
||||
<< op.position();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"position"});
|
||||
p << " : " << op.container()->getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
|
||||
@ -642,11 +642,11 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
|
||||
// Printing/parsing for LLVM::SelectOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printSelectOp(OpAsmPrinter *p, SelectOp &op) {
|
||||
*p << op.getOperationName() << ' ' << *op.condition() << ", "
|
||||
<< *op.trueValue() << ", " << *op.falseValue();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.condition()->getType() << ", " << op.trueValue()->getType();
|
||||
static void printSelectOp(OpAsmPrinter &p, SelectOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.condition() << ", "
|
||||
<< *op.trueValue() << ", " << *op.falseValue();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.condition()->getType() << ", " << op.trueValue()->getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use
|
||||
@ -676,10 +676,10 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
|
||||
// Printing/parsing for LLVM::BrOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printBrOp(OpAsmPrinter *p, BrOp &op) {
|
||||
*p << op.getOperationName() << ' ';
|
||||
p->printSuccessorAndUseList(op.getOperation(), 0);
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
static void printBrOp(OpAsmPrinter &p, BrOp &op) {
|
||||
p << op.getOperationName() << ' ';
|
||||
p.printSuccessorAndUseList(op.getOperation(), 0);
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)?
|
||||
@ -699,12 +699,12 @@ static ParseResult parseBrOp(OpAsmParser &parser, OperationState &result) {
|
||||
// Printing/parsing for LLVM::CondBrOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printCondBrOp(OpAsmPrinter *p, CondBrOp &op) {
|
||||
*p << op.getOperationName() << ' ' << *op.getOperand(0) << ", ";
|
||||
p->printSuccessorAndUseList(op.getOperation(), 0);
|
||||
*p << ", ";
|
||||
p->printSuccessorAndUseList(op.getOperation(), 1);
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
static void printCondBrOp(OpAsmPrinter &p, CondBrOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.getOperand(0) << ", ";
|
||||
p.printSuccessorAndUseList(op.getOperation(), 0);
|
||||
p << ", ";
|
||||
p.printSuccessorAndUseList(op.getOperation(), 1);
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.cond_br` ssa-use `,`
|
||||
@ -739,15 +739,15 @@ static ParseResult parseCondBrOp(OpAsmParser &parser, OperationState &result) {
|
||||
// Printing/parsing for LLVM::ReturnOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printReturnOp(OpAsmPrinter *p, ReturnOp &op) {
|
||||
*p << op.getOperationName();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
static void printReturnOp(OpAsmPrinter &p, ReturnOp &op) {
|
||||
p << op.getOperationName();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
assert(op.getNumOperands() <= 1);
|
||||
|
||||
if (op.getNumOperands() == 0)
|
||||
return;
|
||||
|
||||
*p << ' ' << *op.getOperand(0) << " : " << op.getOperand(0)->getType();
|
||||
p << ' ' << *op.getOperand(0) << " : " << op.getOperand(0)->getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:`
|
||||
@ -772,10 +772,10 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
|
||||
// Printing/parsing for LLVM::UndefOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printUndefOp(OpAsmPrinter *p, UndefOp &op) {
|
||||
*p << op.getOperationName();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.res()->getType();
|
||||
static void printUndefOp(OpAsmPrinter &p, UndefOp &op) {
|
||||
p << op.getOperationName();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.res()->getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.mlir.undef` attribute-dict? : type
|
||||
@ -800,10 +800,10 @@ GlobalOp AddressOfOp::getGlobal() {
|
||||
return module.lookupSymbol<LLVM::GlobalOp>(global_name());
|
||||
}
|
||||
|
||||
static void printAddressOfOp(OpAsmPrinter *p, AddressOfOp op) {
|
||||
*p << op.getOperationName() << " @" << op.global_name();
|
||||
p->printOptionalAttrDict(op.getAttrs(), {"global_name"});
|
||||
*p << " : " << op.getResult()->getType();
|
||||
static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) {
|
||||
p << op.getOperationName() << " @" << op.global_name();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"global_name"});
|
||||
p << " : " << op.getResult()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseAddressOfOp(OpAsmParser &parser,
|
||||
@ -837,10 +837,10 @@ static LogicalResult verify(AddressOfOp op) {
|
||||
// Printing/parsing for LLVM::ConstantOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printConstantOp(OpAsmPrinter *p, ConstantOp &op) {
|
||||
*p << op.getOperationName() << '(' << op.value() << ')';
|
||||
p->printOptionalAttrDict(op.getAttrs(), {"value"});
|
||||
*p << " : " << op.res()->getType();
|
||||
static void printConstantOp(OpAsmPrinter &p, ConstantOp &op) {
|
||||
p << op.getOperationName() << '(' << op.value() << ')';
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"value"});
|
||||
p << " : " << op.res()->getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.mlir.constant` `(` attribute `)` attribute-list? : type
|
||||
@ -876,21 +876,21 @@ void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type,
|
||||
result.attributes.append(attrs.begin(), attrs.end());
|
||||
}
|
||||
|
||||
static void printGlobalOp(OpAsmPrinter *p, GlobalOp op) {
|
||||
*p << op.getOperationName() << ' ';
|
||||
static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
|
||||
p << op.getOperationName() << ' ';
|
||||
if (op.constant())
|
||||
*p << "constant ";
|
||||
*p << '@' << op.sym_name() << '(';
|
||||
p->printAttribute(op.value());
|
||||
*p << ')';
|
||||
p->printOptionalAttrDict(op.getAttrs(), {SymbolTable::getSymbolAttrName(),
|
||||
"type", "constant", "value"});
|
||||
p << "constant ";
|
||||
p << '@' << op.sym_name() << '(';
|
||||
p.printAttribute(op.value());
|
||||
p << ')';
|
||||
p.printOptionalAttrDict(op.getAttrs(), {SymbolTable::getSymbolAttrName(),
|
||||
"type", "constant", "value"});
|
||||
|
||||
// Print the trailing type unless it's a string global.
|
||||
if (op.value().isa<StringAttr>())
|
||||
return;
|
||||
*p << " : ";
|
||||
p->printType(op.type());
|
||||
p << " : ";
|
||||
p.printType(op.type());
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.mlir.global` `constant`? `@` identifier
|
||||
@ -967,11 +967,11 @@ void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, Value *v1,
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
static void printShuffleVectorOp(OpAsmPrinter *p, ShuffleVectorOp &op) {
|
||||
*p << op.getOperationName() << ' ' << *op.v1() << ", " << *op.v2() << " "
|
||||
<< op.mask();
|
||||
p->printOptionalAttrDict(op.getAttrs(), {"mask"});
|
||||
*p << " : " << op.v1()->getType() << ", " << op.v2()->getType();
|
||||
static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.v1() << ", " << *op.v2() << " "
|
||||
<< op.mask();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"mask"});
|
||||
p << " : " << op.v1()->getType() << ", " << op.v2()->getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
|
||||
@ -1072,7 +1072,7 @@ static Type buildLLVMFunctionType(Builder &b, ArrayRef<Type> inputs,
|
||||
|
||||
// Print the LLVMFuncOp. Collects argument and result types and passes them
|
||||
// to the trait printer. Drops "void" result since it cannot be parsed back.
|
||||
static void printLLVMFuncOp(OpAsmPrinter *p, LLVMFuncOp op) {
|
||||
static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
|
||||
LLVMType fnType = op.getType();
|
||||
SmallVector<Type, 8> argTypes;
|
||||
SmallVector<Type, 1> resTypes;
|
||||
|
@ -43,11 +43,11 @@ namespace NVVM {
|
||||
// Printing/parsing for NVVM ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printNVVMIntrinsicOp(OpAsmPrinter *p, Operation *op) {
|
||||
*p << op->getName() << " ";
|
||||
p->printOperands(op->getOperands());
|
||||
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
|
||||
p << op->getName() << " ";
|
||||
p.printOperands(op->getOperands());
|
||||
if (op->getNumResults() > 0)
|
||||
interleaveComma(op->getResultTypes(), *p << " : ");
|
||||
interleaveComma(op->getResultTypes(), p << " : ");
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.nvvm.XYZ` : type
|
||||
|
172
third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
vendored
172
third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
vendored
@ -128,16 +128,16 @@ SimplifyDimOp::matchAndRewrite(linalg::DimOp dimOp,
|
||||
// BufferAllocOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, BufferAllocOp op) {
|
||||
*p << op.getOperationName() << " ";
|
||||
static void print(OpAsmPrinter &p, BufferAllocOp op) {
|
||||
p << op.getOperationName() << " ";
|
||||
if (!llvm::empty(op.size()))
|
||||
*p << *op.getOperand(0);
|
||||
p << *op.getOperand(0);
|
||||
if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0)
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
else
|
||||
p->printOptionalAttrDict(op.getAttrs(),
|
||||
BufferAllocOp::getAlignmentAttrName());
|
||||
*p << " : " << op.getBufferType();
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
BufferAllocOp::getAlignmentAttrName());
|
||||
p << " : " << op.getBufferType();
|
||||
}
|
||||
|
||||
static ParseResult parseBufferAllocOp(OpAsmParser &parser,
|
||||
@ -181,10 +181,10 @@ static LogicalResult verify(BufferAllocOp op) {
|
||||
// BufferDeallocOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, BufferDeallocOp op) {
|
||||
*p << op.getOperationName() << " " << *op.buffer();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.getBufferType();
|
||||
static void print(OpAsmPrinter &p, BufferDeallocOp op) {
|
||||
p << op.getOperationName() << " " << *op.buffer();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getBufferType();
|
||||
}
|
||||
|
||||
static ParseResult parseBufferDeallocOp(OpAsmParser &parser,
|
||||
@ -202,10 +202,10 @@ static ParseResult parseBufferDeallocOp(OpAsmParser &parser,
|
||||
// BufferSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, BufferSizeOp op) {
|
||||
*p << op.getOperationName() << " " << *op.buffer();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.buffer()->getType();
|
||||
static void print(OpAsmPrinter &p, BufferSizeOp op) {
|
||||
p << op.getOperationName() << " " << *op.buffer();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.buffer()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseBufferSizeOp(OpAsmParser &parser,
|
||||
@ -228,11 +228,11 @@ void mlir::linalg::DimOp::getCanonicalizationPatterns(
|
||||
results.insert<SimplifyDimOp>(context);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, linalg::DimOp op) {
|
||||
*p << op.getOperationName() << " " << *op.getOperand() << ", "
|
||||
<< op.getIndex();
|
||||
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
|
||||
*p << " : " << op.getOperand()->getType();
|
||||
static void print(OpAsmPrinter &p, linalg::DimOp op) {
|
||||
p << op.getOperationName() << " " << *op.getOperand() << ", "
|
||||
<< op.getIndex();
|
||||
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
|
||||
p << " : " << op.getOperand()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseDimOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -253,7 +253,7 @@ static ParseResult parseDimOp(OpAsmParser &parser, OperationState &result) {
|
||||
// GenericOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, GenericOp op) {
|
||||
static void print(OpAsmPrinter &p, GenericOp op) {
|
||||
auto attrNames = op.linalgTraitAttrNames();
|
||||
llvm::StringSet<> linalgTraitAttrsSet;
|
||||
linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end());
|
||||
@ -263,13 +263,13 @@ static void print(OpAsmPrinter *p, GenericOp op) {
|
||||
attrs.push_back(attr);
|
||||
}
|
||||
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
|
||||
*p << op.getOperationName() << " " << dictAttr << " ";
|
||||
p->printOperands(op.getOperands());
|
||||
p << op.getOperationName() << " " << dictAttr << " ";
|
||||
p.printOperands(op.getOperands());
|
||||
if (!op.region().empty())
|
||||
p->printRegion(op.region());
|
||||
p->printOptionalAttrDict(op.getAttrs(), attrNames);
|
||||
*p << ": ";
|
||||
interleaveComma(op.getOperandTypes(), *p);
|
||||
p.printRegion(op.region());
|
||||
p.printOptionalAttrDict(op.getAttrs(), attrNames);
|
||||
p << ": ";
|
||||
interleaveComma(op.getOperandTypes(), p);
|
||||
}
|
||||
|
||||
static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -393,12 +393,12 @@ static LogicalResult verify(GenericOp op) {
|
||||
// LoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, linalg::LoadOp op) {
|
||||
*p << op.getOperationName() << " " << *op.view() << '[';
|
||||
p->printOperands(op.indices());
|
||||
*p << ']';
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.getViewType();
|
||||
static void print(OpAsmPrinter &p, linalg::LoadOp op) {
|
||||
p << op.getOperationName() << " " << *op.view() << '[';
|
||||
p.printOperands(op.indices());
|
||||
p << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getViewType();
|
||||
}
|
||||
|
||||
static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -428,11 +428,11 @@ static LogicalResult verify(linalg::LoadOp op) {
|
||||
// RangeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, RangeOp op) {
|
||||
*p << op.getOperationName() << " " << *op.min() << ":" << *op.max() << ":"
|
||||
<< *op.step();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.getResult()->getType();
|
||||
static void print(OpAsmPrinter &p, RangeOp op) {
|
||||
p << op.getOperationName() << " " << *op.min() << ":" << *op.max() << ":"
|
||||
<< *op.step();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getResult()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseRangeOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -467,16 +467,16 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
|
||||
result.addTypes({ViewType::get(b->getContext(), elementType, rank)});
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, SliceOp op) {
|
||||
*p << SliceOp::getOperationName() << " " << *op.view() << "[";
|
||||
p->printOperands(op.indexings());
|
||||
*p << "] ";
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.getBaseViewType();
|
||||
static void print(OpAsmPrinter &p, SliceOp op) {
|
||||
p << SliceOp::getOperationName() << " " << *op.view() << "[";
|
||||
p.printOperands(op.indexings());
|
||||
p << "] ";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getBaseViewType();
|
||||
for (auto indexing : op.indexings()) {
|
||||
*p << ", " << indexing->getType();
|
||||
p << ", " << indexing->getType();
|
||||
}
|
||||
*p << ", " << op.getType();
|
||||
p << ", " << op.getType();
|
||||
}
|
||||
|
||||
static ParseResult parseSliceOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -523,13 +523,13 @@ static LogicalResult verify(SliceOp op) {
|
||||
// StoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, linalg::StoreOp op) {
|
||||
*p << op.getOperationName() << " " << *op.value();
|
||||
*p << ", " << *op.view() << '[';
|
||||
p->printOperands(op.indices());
|
||||
*p << ']';
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.getViewType();
|
||||
static void print(OpAsmPrinter &p, linalg::StoreOp op) {
|
||||
p << op.getOperationName() << " " << *op.value();
|
||||
p << ", " << *op.view() << '[';
|
||||
p.printOperands(op.indices());
|
||||
p << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getViewType();
|
||||
}
|
||||
|
||||
static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -564,15 +564,15 @@ static LogicalResult verify(linalg::StoreOp op) {
|
||||
// SubViewOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, SubViewOp op) {
|
||||
*p << op.getOperationName() << " " << *op.getOperand(0) << "[";
|
||||
static void print(OpAsmPrinter &p, SubViewOp op) {
|
||||
p << op.getOperationName() << " " << *op.getOperand(0) << "[";
|
||||
auto ranges = op.getRanges();
|
||||
interleaveComma(ranges, *p, [&p](const SubViewOp::Range &i) {
|
||||
*p << *i.min << ", " << *i.max << ", " << *i.step;
|
||||
interleaveComma(ranges, p, [&p](const SubViewOp::Range &i) {
|
||||
p << *i.min << ", " << *i.max << ", " << *i.step;
|
||||
});
|
||||
*p << "]";
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.getViewType();
|
||||
p << "]";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getViewType();
|
||||
}
|
||||
|
||||
static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -608,11 +608,11 @@ void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result,
|
||||
result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, TransposeOp op) {
|
||||
*p << op.getOperationName() << " " << *op.view() << " " << op.permutation();
|
||||
p->printOptionalAttrDict(op.getAttrs(),
|
||||
{TransposeOp::getPermutationAttrName()});
|
||||
*p << " : " << op.view()->getType();
|
||||
static void print(OpAsmPrinter &p, TransposeOp op) {
|
||||
p << op.getOperationName() << " " << *op.view() << " " << op.permutation();
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
{TransposeOp::getPermutationAttrName()});
|
||||
p << " : " << op.view()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseTransposeOp(OpAsmParser &parser,
|
||||
@ -645,12 +645,12 @@ void mlir::linalg::ViewOp::build(Builder *b, OperationState &result,
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, ViewOp op) {
|
||||
*p << op.getOperationName() << " " << *op.buffer() << "[";
|
||||
interleaveComma(op.ranges(), *p, [&](Value *v) { *p << *v; });
|
||||
*p << "] ";
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.buffer()->getType() << " -> " << op.getType();
|
||||
static void print(OpAsmPrinter &p, ViewOp op) {
|
||||
p << op.getOperationName() << " " << *op.buffer() << "[";
|
||||
interleaveComma(op.ranges(), p, [&](Value *v) { p << *v; });
|
||||
p << "] ";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.buffer()->getType() << " -> " << op.getType();
|
||||
}
|
||||
|
||||
static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -683,17 +683,17 @@ static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
|
||||
// YieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, YieldOp op) {
|
||||
*p << op.getOperationName();
|
||||
static void print(OpAsmPrinter &p, YieldOp op) {
|
||||
p << op.getOperationName();
|
||||
if (op.getNumOperands() > 0) {
|
||||
*p << ' ';
|
||||
p->printOperands(op.operand_begin(), op.operand_end());
|
||||
p << ' ';
|
||||
p.printOperands(op.operand_begin(), op.operand_end());
|
||||
}
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
if (op.getNumOperands() > 0) {
|
||||
*p << " : ";
|
||||
interleaveComma(op.getOperands(), *p,
|
||||
[&](Value *e) { p->printType(e->getType()); });
|
||||
p << " : ";
|
||||
interleaveComma(op.getOperands(), p,
|
||||
[&](Value *e) { p.printType(e->getType()); });
|
||||
}
|
||||
}
|
||||
|
||||
@ -752,18 +752,18 @@ static LogicalResult verify(YieldOp op) {
|
||||
// ```
|
||||
//
|
||||
// Where %0, %1 and %2 are ssa-values of type ViewType.
|
||||
static void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
|
||||
static void printLinalgLibraryOp(OpAsmPrinter &p, Operation *op) {
|
||||
assert(op->getAbstractOperation() && "unregistered operation");
|
||||
*p << op->getName().getStringRef() << "(";
|
||||
p << op->getName().getStringRef() << "(";
|
||||
interleave(
|
||||
op->getOperands().begin(), op->getOperands().end(),
|
||||
[&](Value *v) { *p << *v; }, [&]() { *p << ", "; });
|
||||
*p << ")";
|
||||
p->printOptionalAttrDict(op->getAttrs());
|
||||
*p << " : ";
|
||||
[&](Value *v) { p << *v; }, [&]() { p << ", "; });
|
||||
p << ")";
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
p << " : ";
|
||||
interleave(
|
||||
op->getOperands().begin(), op->getOperands().end(),
|
||||
[&](Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
|
||||
[&](Value *v) { p << v->getType(); }, [&]() { p << ", "; });
|
||||
}
|
||||
|
||||
static ParseResult parseLinalgLibraryOp(OpAsmParser &parser,
|
||||
|
35
third_party/mlir/lib/Dialect/LoopOps/LoopOps.cpp
vendored
35
third_party/mlir/lib/Dialect/LoopOps/LoopOps.cpp
vendored
@ -72,14 +72,13 @@ LogicalResult verify(ForOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, ForOp op) {
|
||||
*p << op.getOperationName() << " " << *op.getInductionVar() << " = "
|
||||
<< *op.lowerBound() << " to " << *op.upperBound() << " step "
|
||||
<< *op.step();
|
||||
p->printRegion(op.region(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
static void print(OpAsmPrinter &p, ForOp op) {
|
||||
p << op.getOperationName() << " " << *op.getInductionVar() << " = "
|
||||
<< *op.lowerBound() << " to " << *op.upperBound() << " step " << *op.step();
|
||||
p.printRegion(op.region(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
}
|
||||
|
||||
static ParseResult parseForOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -182,22 +181,22 @@ static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, IfOp op) {
|
||||
*p << IfOp::getOperationName() << " " << *op.condition();
|
||||
p->printRegion(op.thenRegion(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
static void print(OpAsmPrinter &p, IfOp op) {
|
||||
p << IfOp::getOperationName() << " " << *op.condition();
|
||||
p.printRegion(op.thenRegion(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
|
||||
// Print the 'else' regions if it exists and has a block.
|
||||
auto &elseRegion = op.elseRegion();
|
||||
if (!elseRegion.empty()) {
|
||||
*p << " else";
|
||||
p->printRegion(elseRegion,
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
p << " else";
|
||||
p.printRegion(elseRegion,
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
}
|
||||
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
246
third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
vendored
246
third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
vendored
@ -162,27 +162,27 @@ static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printBinaryLogicalOp(Operation *logicalOp, OpAsmPrinter *printer) {
|
||||
*printer << logicalOp->getName() << ' ' << *logicalOp->getOperand(0) << ", "
|
||||
<< *logicalOp->getOperand(1);
|
||||
*printer << " : " << logicalOp->getOperand(0)->getType();
|
||||
static void printBinaryLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
|
||||
printer << logicalOp->getName() << ' ' << *logicalOp->getOperand(0) << ", "
|
||||
<< *logicalOp->getOperand(1);
|
||||
printer << " : " << logicalOp->getOperand(0)->getType();
|
||||
}
|
||||
|
||||
template <typename LoadStoreOpTy>
|
||||
static void
|
||||
printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter *printer,
|
||||
printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter &printer,
|
||||
SmallVectorImpl<StringRef> &elidedAttrs) {
|
||||
// Print optional memory access attribute.
|
||||
if (auto memAccess = loadStoreOp.memory_access()) {
|
||||
elidedAttrs.push_back(spirv::attributeName<spirv::MemoryAccess>());
|
||||
*printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
|
||||
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
|
||||
|
||||
// Print integer alignment attribute.
|
||||
if (auto alignment = loadStoreOp.alignment()) {
|
||||
elidedAttrs.push_back(kAlignmentAttrName);
|
||||
*printer << ", " << alignment;
|
||||
printer << ", " << alignment;
|
||||
}
|
||||
*printer << "]";
|
||||
printer << "]";
|
||||
}
|
||||
elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
|
||||
}
|
||||
@ -243,9 +243,9 @@ static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value *ptr,
|
||||
}
|
||||
|
||||
// Prints an op that has no inputs and no outputs.
|
||||
static void printNoIOOp(Operation *op, OpAsmPrinter *printer) {
|
||||
*printer << op->getName();
|
||||
printer->printOptionalAttrDict(op->getAttrs());
|
||||
static void printNoIOOp(Operation *op, OpAsmPrinter &printer) {
|
||||
printer << op->getName();
|
||||
printer.printOptionalAttrDict(op->getAttrs());
|
||||
}
|
||||
|
||||
static ParseResult parseVariableDecorations(OpAsmParser &parser,
|
||||
@ -285,7 +285,7 @@ static ParseResult parseVariableDecorations(OpAsmParser &parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printVariableDecorations(Operation *op, OpAsmPrinter *printer,
|
||||
static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
|
||||
SmallVectorImpl<StringRef> &elidedAttrs) {
|
||||
// Print optional descriptor binding
|
||||
auto descriptorSetName =
|
||||
@ -297,19 +297,19 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter *printer,
|
||||
if (descriptorSet && binding) {
|
||||
elidedAttrs.push_back(descriptorSetName);
|
||||
elidedAttrs.push_back(bindingName);
|
||||
*printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
|
||||
<< ")";
|
||||
printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
|
||||
<< ")";
|
||||
}
|
||||
|
||||
// Print BuiltIn attribute if present
|
||||
auto builtInName =
|
||||
convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
|
||||
if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
|
||||
*printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
|
||||
printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
|
||||
elidedAttrs.push_back(builtInName);
|
||||
}
|
||||
|
||||
printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
||||
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
||||
}
|
||||
|
||||
// Extracts an element from the given `composite` by following the given
|
||||
@ -431,11 +431,11 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::AccessChainOp op, OpAsmPrinter *printer) {
|
||||
*printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr()
|
||||
<< '[';
|
||||
printer->printOperands(op.indices());
|
||||
*printer << "] : " << op.base_ptr()->getType();
|
||||
static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
|
||||
printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr()
|
||||
<< '[';
|
||||
printer.printOperands(op.indices());
|
||||
printer << "] : " << op.base_ptr()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
|
||||
@ -485,15 +485,15 @@ static ParseResult parseAddressOfOp(OpAsmParser &parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter *printer) {
|
||||
static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter &printer) {
|
||||
SmallVector<StringRef, 4> elidedAttrs;
|
||||
*printer << spirv::AddressOfOp::getOperationName();
|
||||
printer << spirv::AddressOfOp::getOperationName();
|
||||
|
||||
// Print symbol name.
|
||||
*printer << " @" << addressOfOp.variable();
|
||||
printer << " @" << addressOfOp.variable();
|
||||
|
||||
// Print the type.
|
||||
*printer << " : " << addressOfOp.pointer()->getType();
|
||||
printer << " : " << addressOfOp.pointer()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
|
||||
@ -523,9 +523,9 @@ static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &state) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::BranchOp branchOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::BranchOp::getOperationName() << ' ';
|
||||
printer->printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0);
|
||||
static void print(spirv::BranchOp branchOp, OpAsmPrinter &printer) {
|
||||
printer << spirv::BranchOp::getOperationName() << ' ';
|
||||
printer.printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0);
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::BranchOp branchOp) {
|
||||
@ -585,24 +585,24 @@ static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::BranchConditionalOp::getOperationName() << ' ';
|
||||
printer->printOperand(branchOp.condition());
|
||||
static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
|
||||
printer << spirv::BranchConditionalOp::getOperationName() << ' ';
|
||||
printer.printOperand(branchOp.condition());
|
||||
|
||||
if (auto weights = branchOp.branch_weights()) {
|
||||
*printer << " [";
|
||||
interleaveComma(weights->getValue(), *printer, [&](Attribute a) {
|
||||
*printer << a.cast<IntegerAttr>().getInt();
|
||||
printer << " [";
|
||||
interleaveComma(weights->getValue(), printer, [&](Attribute a) {
|
||||
printer << a.cast<IntegerAttr>().getInt();
|
||||
});
|
||||
*printer << "]";
|
||||
printer << "]";
|
||||
}
|
||||
|
||||
*printer << ", ";
|
||||
printer->printSuccessorAndUseList(branchOp.getOperation(),
|
||||
spirv::BranchConditionalOp::kTrueIndex);
|
||||
*printer << ", ";
|
||||
printer->printSuccessorAndUseList(branchOp.getOperation(),
|
||||
spirv::BranchConditionalOp::kFalseIndex);
|
||||
printer << ", ";
|
||||
printer.printSuccessorAndUseList(branchOp.getOperation(),
|
||||
spirv::BranchConditionalOp::kTrueIndex);
|
||||
printer << ", ";
|
||||
printer.printSuccessorAndUseList(branchOp.getOperation(),
|
||||
spirv::BranchConditionalOp::kFalseIndex);
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
|
||||
@ -684,10 +684,10 @@ static ParseResult parseCompositeExtractOp(OpAsmParser &parser,
|
||||
}
|
||||
|
||||
static void print(spirv::CompositeExtractOp compositeExtractOp,
|
||||
OpAsmPrinter *printer) {
|
||||
*printer << spirv::CompositeExtractOp::getOperationName() << ' '
|
||||
<< *compositeExtractOp.composite() << compositeExtractOp.indices()
|
||||
<< " : " << compositeExtractOp.composite()->getType();
|
||||
OpAsmPrinter &printer) {
|
||||
printer << spirv::CompositeExtractOp::getOperationName() << ' '
|
||||
<< *compositeExtractOp.composite() << compositeExtractOp.indices()
|
||||
<< " : " << compositeExtractOp.composite()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
|
||||
@ -752,10 +752,10 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
|
||||
return parser.addTypeToList(type, state.types);
|
||||
}
|
||||
|
||||
static void print(spirv::ConstantOp constOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
|
||||
static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
|
||||
printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
|
||||
if (constOp.getType().isa<spirv::ArrayType>()) {
|
||||
*printer << " : " << constOp.getType();
|
||||
printer << " : " << constOp.getType();
|
||||
}
|
||||
}
|
||||
|
||||
@ -853,13 +853,13 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::EntryPointOp::getOperationName() << " \""
|
||||
<< stringifyExecutionModel(entryPointOp.execution_model()) << "\" @"
|
||||
<< entryPointOp.fn();
|
||||
static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter &printer) {
|
||||
printer << spirv::EntryPointOp::getOperationName() << " \""
|
||||
<< stringifyExecutionModel(entryPointOp.execution_model()) << "\" @"
|
||||
<< entryPointOp.fn();
|
||||
if (auto interface = entryPointOp.interface()) {
|
||||
*printer << ", ";
|
||||
interleaveComma(interface.getValue().getValue(), *printer);
|
||||
printer << ", ";
|
||||
interleaveComma(interface.getValue().getValue(), printer);
|
||||
}
|
||||
}
|
||||
|
||||
@ -897,18 +897,18 @@ static ParseResult parseExecutionModeOp(OpAsmParser &parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::ExecutionModeOp::getOperationName() << " @"
|
||||
<< execModeOp.fn() << " \""
|
||||
<< stringifyExecutionMode(execModeOp.execution_mode()) << "\"";
|
||||
static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
|
||||
printer << spirv::ExecutionModeOp::getOperationName() << " @"
|
||||
<< execModeOp.fn() << " \""
|
||||
<< stringifyExecutionMode(execModeOp.execution_mode()) << "\"";
|
||||
auto values = execModeOp.values();
|
||||
if (!values) {
|
||||
return;
|
||||
}
|
||||
*printer << ", ";
|
||||
printer << ", ";
|
||||
interleaveComma(
|
||||
values.getValue().cast<ArrayAttr>(), *printer,
|
||||
[&](Attribute a) { *printer << a.cast<IntegerAttr>().getInt(); });
|
||||
values.getValue().cast<ArrayAttr>(), printer,
|
||||
[&](Attribute a) { printer << a.cast<IntegerAttr>().getInt(); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -944,16 +944,16 @@ static ParseResult parseFunctionCallOp(OpAsmParser &parser,
|
||||
state.operands));
|
||||
}
|
||||
|
||||
static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter *printer) {
|
||||
static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) {
|
||||
SmallVector<Type, 4> argTypes(functionCallOp.getOperandTypes());
|
||||
SmallVector<Type, 1> resultTypes(functionCallOp.getResultTypes());
|
||||
Type functionType =
|
||||
FunctionType::get(argTypes, resultTypes, functionCallOp.getContext());
|
||||
|
||||
*printer << spirv::FunctionCallOp::getOperationName() << ' '
|
||||
<< functionCallOp.getAttr(kCallee) << '(';
|
||||
printer->printOperands(functionCallOp.arguments());
|
||||
*printer << ") : " << functionType;
|
||||
printer << spirv::FunctionCallOp::getOperationName() << ' '
|
||||
<< functionCallOp.getAttr(kCallee) << '(';
|
||||
printer.printOperands(functionCallOp.arguments());
|
||||
printer << ") : " << functionType;
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
|
||||
@ -1029,9 +1029,9 @@ static ParseResult parseGLSLUnaryOp(OpAsmParser &parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printGLSLUnaryOp(Operation *unaryOp, OpAsmPrinter *printer) {
|
||||
*printer << unaryOp->getName() << ' ' << *unaryOp->getOperand(0) << " : "
|
||||
<< unaryOp->getOperand(0)->getType();
|
||||
static void printGLSLUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) {
|
||||
printer << unaryOp->getName() << ' ' << *unaryOp->getOperand(0) << " : "
|
||||
<< unaryOp->getOperand(0)->getType();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1074,26 +1074,26 @@ static ParseResult parseGlobalVariableOp(OpAsmParser &parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter *printer) {
|
||||
static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter &printer) {
|
||||
auto *op = varOp.getOperation();
|
||||
SmallVector<StringRef, 4> elidedAttrs{
|
||||
spirv::attributeName<spirv::StorageClass>()};
|
||||
*printer << spirv::GlobalVariableOp::getOperationName();
|
||||
printer << spirv::GlobalVariableOp::getOperationName();
|
||||
|
||||
// Print variable name.
|
||||
*printer << " @" << varOp.sym_name();
|
||||
printer << " @" << varOp.sym_name();
|
||||
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
|
||||
|
||||
// Print optional initializer
|
||||
if (auto initializer = varOp.initializer()) {
|
||||
*printer << " " << kInitializerAttrName << "(@" << initializer.getValue()
|
||||
<< ")";
|
||||
printer << " " << kInitializerAttrName << "(@" << initializer.getValue()
|
||||
<< ")";
|
||||
elidedAttrs.push_back(kInitializerAttrName);
|
||||
}
|
||||
|
||||
elidedAttrs.push_back(kTypeAttrName);
|
||||
printVariableDecorations(op, printer, elidedAttrs);
|
||||
*printer << " : " << varOp.type();
|
||||
printer << " : " << varOp.type();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::GlobalVariableOp varOp) {
|
||||
@ -1145,19 +1145,19 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &state) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::LoadOp loadOp, OpAsmPrinter *printer) {
|
||||
static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
|
||||
auto *op = loadOp.getOperation();
|
||||
SmallVector<StringRef, 4> elidedAttrs;
|
||||
StringRef sc = stringifyStorageClass(
|
||||
loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
|
||||
*printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" ";
|
||||
printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" ";
|
||||
// Print the pointer operand.
|
||||
printer->printOperand(loadOp.ptr());
|
||||
printer.printOperand(loadOp.ptr());
|
||||
|
||||
printMemoryAccessAttribute(loadOp, printer, elidedAttrs);
|
||||
|
||||
printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
||||
*printer << " : " << loadOp.getType();
|
||||
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
||||
printer << " : " << loadOp.getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::LoadOp loadOp) {
|
||||
@ -1186,12 +1186,12 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) {
|
||||
/*argTypes=*/{});
|
||||
}
|
||||
|
||||
static void print(spirv::LoopOp loopOp, OpAsmPrinter *printer) {
|
||||
static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) {
|
||||
auto *op = loopOp.getOperation();
|
||||
|
||||
*printer << spirv::LoopOp::getOperationName();
|
||||
printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/true);
|
||||
printer << spirv::LoopOp::getOperationName();
|
||||
printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/true);
|
||||
}
|
||||
|
||||
/// Returns true if the given `block` only contains one `spv._merge` op.
|
||||
@ -1381,7 +1381,7 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) {
|
||||
static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
|
||||
auto *op = moduleOp.getOperation();
|
||||
|
||||
// Only print out addressing model and memory model in a nicer way if both
|
||||
@ -1392,15 +1392,15 @@ static void print(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) {
|
||||
auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
|
||||
if (op->getAttr(addressingModelAttrName) &&
|
||||
op->getAttr(memoryModelAttrName)) {
|
||||
*printer << spirv::ModuleOp::getOperationName() << " \""
|
||||
<< spirv::stringifyAddressingModel(moduleOp.addressing_model())
|
||||
<< "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model())
|
||||
<< '"';
|
||||
printer << spirv::ModuleOp::getOperationName() << " \""
|
||||
<< spirv::stringifyAddressingModel(moduleOp.addressing_model())
|
||||
<< "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model())
|
||||
<< '"';
|
||||
elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName});
|
||||
}
|
||||
|
||||
printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
|
||||
bool printAttrDict =
|
||||
elidedAttrs.size() != 2 ||
|
||||
@ -1411,8 +1411,8 @@ static void print(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) {
|
||||
});
|
||||
|
||||
if (printAttrDict) {
|
||||
*printer << " attributes";
|
||||
printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
||||
printer << " attributes";
|
||||
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1524,10 +1524,10 @@ static ParseResult parseReferenceOfOp(OpAsmParser &parser,
|
||||
return parser.addTypeToList(type, state.types);
|
||||
}
|
||||
|
||||
static void print(spirv::ReferenceOfOp referenceOfOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::ReferenceOfOp::getOperationName() << " @"
|
||||
<< referenceOfOp.spec_const() << " : "
|
||||
<< referenceOfOp.reference()->getType();
|
||||
static void print(spirv::ReferenceOfOp referenceOfOp, OpAsmPrinter &printer) {
|
||||
printer << spirv::ReferenceOfOp::getOperationName() << " @"
|
||||
<< referenceOfOp.spec_const() << " : "
|
||||
<< referenceOfOp.reference()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
|
||||
@ -1572,10 +1572,10 @@ static ParseResult parseReturnValueOp(OpAsmParser &parser,
|
||||
parser.resolveOperand(retValInfo, retValType, state.operands));
|
||||
}
|
||||
|
||||
static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::ReturnValueOp::getOperationName() << ' ';
|
||||
printer->printOperand(retValOp.value());
|
||||
*printer << " : " << retValOp.value()->getType();
|
||||
static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) {
|
||||
printer << spirv::ReturnValueOp::getOperationName() << ' ';
|
||||
printer.printOperand(retValOp.value());
|
||||
printer << " : " << retValOp.value()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::ReturnValueOp retValOp) {
|
||||
@ -1621,15 +1621,15 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) {
|
||||
return parser.addTypesToList(types[1], state.types);
|
||||
}
|
||||
|
||||
static void print(spirv::SelectOp op, OpAsmPrinter *printer) {
|
||||
*printer << spirv::SelectOp::getOperationName() << " ";
|
||||
static void print(spirv::SelectOp op, OpAsmPrinter &printer) {
|
||||
printer << spirv::SelectOp::getOperationName() << " ";
|
||||
|
||||
// Print the operands.
|
||||
printer->printOperands(op.getOperands());
|
||||
printer.printOperands(op.getOperands());
|
||||
|
||||
// Print colon and types.
|
||||
*printer << " : " << op.condition()->getType() << ", "
|
||||
<< op.result()->getType();
|
||||
printer << " : " << op.condition()->getType() << ", "
|
||||
<< op.result()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::SelectOp op) {
|
||||
@ -1672,10 +1672,10 @@ static ParseResult parseSpecConstantOp(OpAsmParser &parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::SpecConstantOp constOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::SpecConstantOp::getOperationName() << " @"
|
||||
<< constOp.sym_name() << " = ";
|
||||
printer->printAttribute(constOp.default_value());
|
||||
static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) {
|
||||
printer << spirv::SpecConstantOp::getOperationName() << " @"
|
||||
<< constOp.sym_name() << " = ";
|
||||
printer.printAttribute(constOp.default_value());
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::SpecConstantOp constOp) {
|
||||
@ -1721,23 +1721,23 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &state) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::StoreOp storeOp, OpAsmPrinter *printer) {
|
||||
static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
|
||||
auto *op = storeOp.getOperation();
|
||||
SmallVector<StringRef, 4> elidedAttrs;
|
||||
StringRef sc = stringifyStorageClass(
|
||||
storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
|
||||
*printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" ";
|
||||
printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" ";
|
||||
// Print the pointer operand
|
||||
printer->printOperand(storeOp.ptr());
|
||||
*printer << ", ";
|
||||
printer.printOperand(storeOp.ptr());
|
||||
printer << ", ";
|
||||
// Print the value operand
|
||||
printer->printOperand(storeOp.value());
|
||||
printer.printOperand(storeOp.value());
|
||||
|
||||
printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
|
||||
|
||||
*printer << " : " << storeOp.value()->getType();
|
||||
printer << " : " << storeOp.value()->getType();
|
||||
|
||||
printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
||||
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::StoreOp storeOp) {
|
||||
@ -1796,22 +1796,22 @@ static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) {
|
||||
static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) {
|
||||
auto *op = varOp.getOperation();
|
||||
SmallVector<StringRef, 4> elidedAttrs{
|
||||
spirv::attributeName<spirv::StorageClass>()};
|
||||
*printer << spirv::VariableOp::getOperationName();
|
||||
printer << spirv::VariableOp::getOperationName();
|
||||
|
||||
// Print optional initializer
|
||||
if (op->getNumOperands() > 0) {
|
||||
*printer << " init(";
|
||||
printer->printOperands(varOp.initializer());
|
||||
*printer << ")";
|
||||
printer << " init(";
|
||||
printer.printOperands(varOp.initializer());
|
||||
printer << ")";
|
||||
}
|
||||
|
||||
printVariableDecorations(op, printer, elidedAttrs);
|
||||
|
||||
*printer << " : " << varOp.getType();
|
||||
printer << " : " << varOp.getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::VariableOp varOp) {
|
||||
|
284
third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
vendored
284
third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
vendored
@ -125,7 +125,7 @@ struct StdInlinerInterface : public DialectInlinerInterface {
|
||||
|
||||
/// A custom binary operation printer that omits the "std." prefix from the
|
||||
/// operation names.
|
||||
static void printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
|
||||
static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
|
||||
assert(op->getNumOperands() == 2 && "binary op should have two operands");
|
||||
assert(op->getNumResults() == 1 && "binary op should have one result");
|
||||
|
||||
@ -134,24 +134,24 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
|
||||
auto resultType = op->getResult(0)->getType();
|
||||
if (op->getOperand(0)->getType() != resultType ||
|
||||
op->getOperand(1)->getType() != resultType) {
|
||||
p->printGenericOp(op);
|
||||
p.printGenericOp(op);
|
||||
return;
|
||||
}
|
||||
|
||||
*p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
|
||||
<< *op->getOperand(0) << ", " << *op->getOperand(1);
|
||||
p->printOptionalAttrDict(op->getAttrs());
|
||||
p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
|
||||
<< *op->getOperand(0) << ", " << *op->getOperand(1);
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
|
||||
// Now we can output only one type for all operands and the result.
|
||||
*p << " : " << op->getResult(0)->getType();
|
||||
p << " : " << op->getResult(0)->getType();
|
||||
}
|
||||
|
||||
/// A custom cast operation printer that omits the "std." prefix from the
|
||||
/// operation names.
|
||||
static void printStandardCastOp(Operation *op, OpAsmPrinter *p) {
|
||||
*p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
|
||||
<< *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
|
||||
<< op->getResult(0)->getType();
|
||||
static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
|
||||
p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
|
||||
<< *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
|
||||
<< op->getResult(0)->getType();
|
||||
}
|
||||
|
||||
/// A custom cast operation verifier.
|
||||
@ -176,15 +176,15 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
|
||||
|
||||
void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
|
||||
Operation::operand_iterator end,
|
||||
unsigned numDims, OpAsmPrinter *p) {
|
||||
*p << '(';
|
||||
p->printOperands(begin, begin + numDims);
|
||||
*p << ')';
|
||||
unsigned numDims, OpAsmPrinter &p) {
|
||||
p << '(';
|
||||
p.printOperands(begin, begin + numDims);
|
||||
p << ')';
|
||||
|
||||
if (begin + numDims != end) {
|
||||
*p << '[';
|
||||
p->printOperands(begin + numDims, end);
|
||||
*p << ']';
|
||||
p << '[';
|
||||
p.printOperands(begin + numDims, end);
|
||||
p << ']';
|
||||
}
|
||||
}
|
||||
|
||||
@ -305,15 +305,15 @@ OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
|
||||
// AllocOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, AllocOp op) {
|
||||
*p << "alloc";
|
||||
static void print(OpAsmPrinter &p, AllocOp op) {
|
||||
p << "alloc";
|
||||
|
||||
// Print dynamic dimension operands.
|
||||
MemRefType type = op.getType();
|
||||
printDimAndSymbolList(op.operand_begin(), op.operand_end(),
|
||||
type.getNumDynamicDims(), p);
|
||||
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
|
||||
*p << " : " << type;
|
||||
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
|
||||
p << " : " << type;
|
||||
}
|
||||
|
||||
static ParseResult parseAllocOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -468,9 +468,9 @@ static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, BranchOp op) {
|
||||
*p << "br ";
|
||||
p->printSuccessorAndUseList(op.getOperation(), 0);
|
||||
static void print(OpAsmPrinter &p, BranchOp op) {
|
||||
p << "br ";
|
||||
p.printSuccessorAndUseList(op.getOperation(), 0);
|
||||
}
|
||||
|
||||
Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); }
|
||||
@ -504,13 +504,13 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, CallOp op) {
|
||||
*p << "call " << op.getAttr("callee") << '(';
|
||||
p->printOperands(op.getOperands());
|
||||
*p << ')';
|
||||
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
|
||||
*p << " : ";
|
||||
p->printType(op.getCalleeType());
|
||||
static void print(OpAsmPrinter &p, CallOp op) {
|
||||
p << "call " << op.getAttr("callee") << '(';
|
||||
p.printOperands(op.getOperands());
|
||||
p << ')';
|
||||
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
|
||||
p << " : ";
|
||||
p.printType(op.getCalleeType());
|
||||
}
|
||||
|
||||
static LogicalResult verify(CallOp op) {
|
||||
@ -592,14 +592,14 @@ static ParseResult parseCallIndirectOp(OpAsmParser &parser,
|
||||
parser.addTypesToList(calleeType.getResults(), result.types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, CallIndirectOp op) {
|
||||
*p << "call_indirect ";
|
||||
p->printOperand(op.getCallee());
|
||||
*p << '(';
|
||||
p->printOperands(op.getArgOperands());
|
||||
*p << ')';
|
||||
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
|
||||
*p << " : " << op.getCallee()->getType();
|
||||
static void print(OpAsmPrinter &p, CallIndirectOp op) {
|
||||
p << "call_indirect ";
|
||||
p.printOperand(op.getCallee());
|
||||
p << '(';
|
||||
p.printOperands(op.getArgOperands());
|
||||
p << ')';
|
||||
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
|
||||
p << " : " << op.getCallee()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(CallIndirectOp op) {
|
||||
@ -741,8 +741,8 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, CmpIOp op) {
|
||||
*p << "cmpi ";
|
||||
static void print(OpAsmPrinter &p, CmpIOp op) {
|
||||
p << "cmpi ";
|
||||
|
||||
auto predicateValue =
|
||||
op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
|
||||
@ -752,15 +752,15 @@ static void print(OpAsmPrinter *p, CmpIOp op) {
|
||||
Builder b(op.getContext());
|
||||
auto predicateStringAttr =
|
||||
b.getStringAttr(getCmpIPredicateNames()[predicateValue]);
|
||||
p->printAttribute(predicateStringAttr);
|
||||
p.printAttribute(predicateStringAttr);
|
||||
|
||||
*p << ", ";
|
||||
p->printOperand(op.lhs());
|
||||
*p << ", ";
|
||||
p->printOperand(op.rhs());
|
||||
p->printOptionalAttrDict(op.getAttrs(),
|
||||
/*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
|
||||
*p << " : " << op.lhs()->getType();
|
||||
p << ", ";
|
||||
p.printOperand(op.lhs());
|
||||
p << ", ";
|
||||
p.printOperand(op.rhs());
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
/*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
|
||||
p << " : " << op.lhs()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(CmpIOp op) {
|
||||
@ -918,8 +918,8 @@ static ParseResult parseCmpFOp(OpAsmParser &parser, OperationState &result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, CmpFOp op) {
|
||||
*p << "cmpf ";
|
||||
static void print(OpAsmPrinter &p, CmpFOp op) {
|
||||
p << "cmpf ";
|
||||
|
||||
auto predicateValue =
|
||||
op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()).getInt();
|
||||
@ -929,15 +929,15 @@ static void print(OpAsmPrinter *p, CmpFOp op) {
|
||||
Builder b(op.getContext());
|
||||
auto predicateStringAttr =
|
||||
b.getStringAttr(getCmpFPredicateNames()[predicateValue]);
|
||||
p->printAttribute(predicateStringAttr);
|
||||
p.printAttribute(predicateStringAttr);
|
||||
|
||||
*p << ", ";
|
||||
p->printOperand(op.lhs());
|
||||
*p << ", ";
|
||||
p->printOperand(op.rhs());
|
||||
p->printOptionalAttrDict(op.getAttrs(),
|
||||
/*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
|
||||
*p << " : " << op.lhs()->getType();
|
||||
p << ", ";
|
||||
p.printOperand(op.lhs());
|
||||
p << ", ";
|
||||
p.printOperand(op.rhs());
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
/*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
|
||||
p << " : " << op.lhs()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(CmpFOp op) {
|
||||
@ -1085,13 +1085,13 @@ static ParseResult parseCondBranchOp(OpAsmParser &parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, CondBranchOp op) {
|
||||
*p << "cond_br ";
|
||||
p->printOperand(op.getCondition());
|
||||
*p << ", ";
|
||||
p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
|
||||
*p << ", ";
|
||||
p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
|
||||
static void print(OpAsmPrinter &p, CondBranchOp op) {
|
||||
p << "cond_br ";
|
||||
p.printOperand(op.getCondition());
|
||||
p << ", ";
|
||||
p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
|
||||
p << ", ";
|
||||
p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
|
||||
}
|
||||
|
||||
void CondBranchOp::getCanonicalizationPatterns(
|
||||
@ -1103,17 +1103,17 @@ void CondBranchOp::getCanonicalizationPatterns(
|
||||
// Constant*Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, ConstantOp &op) {
|
||||
*p << "constant ";
|
||||
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
|
||||
static void print(OpAsmPrinter &p, ConstantOp &op) {
|
||||
p << "constant ";
|
||||
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
|
||||
|
||||
if (op.getAttrs().size() > 1)
|
||||
*p << ' ';
|
||||
p->printAttribute(op.getValue());
|
||||
p << ' ';
|
||||
p.printAttribute(op.getValue());
|
||||
|
||||
// If the value is a symbol reference, print a trailing type.
|
||||
if (op.getValue().isa<SymbolRefAttr>())
|
||||
*p << " : " << op.getType();
|
||||
p << " : " << op.getType();
|
||||
}
|
||||
|
||||
static ParseResult parseConstantOp(OpAsmParser &parser,
|
||||
@ -1288,8 +1288,8 @@ struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
static void print(OpAsmPrinter *p, DeallocOp op) {
|
||||
*p << "dealloc " << *op.memref() << " : " << op.memref()->getType();
|
||||
static void print(OpAsmPrinter &p, DeallocOp op) {
|
||||
p << "dealloc " << *op.memref() << " : " << op.memref()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -1318,10 +1318,10 @@ void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
// DimOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, DimOp op) {
|
||||
*p << "dim " << *op.getOperand() << ", " << op.getIndex();
|
||||
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
|
||||
*p << " : " << op.getOperand()->getType();
|
||||
static void print(OpAsmPrinter &p, DimOp op) {
|
||||
p << "dim " << *op.getOperand() << ", " << op.getIndex();
|
||||
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
|
||||
p << " : " << op.getOperand()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseDimOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -1446,23 +1446,23 @@ void DmaStartOp::build(Builder *builder, OperationState &result,
|
||||
result.addOperands({stride, elementsPerStride});
|
||||
}
|
||||
|
||||
void DmaStartOp::print(OpAsmPrinter *p) {
|
||||
*p << "dma_start " << *getSrcMemRef() << '[';
|
||||
p->printOperands(getSrcIndices());
|
||||
*p << "], " << *getDstMemRef() << '[';
|
||||
p->printOperands(getDstIndices());
|
||||
*p << "], " << *getNumElements();
|
||||
*p << ", " << *getTagMemRef() << '[';
|
||||
p->printOperands(getTagIndices());
|
||||
*p << ']';
|
||||
void DmaStartOp::print(OpAsmPrinter &p) {
|
||||
p << "dma_start " << *getSrcMemRef() << '[';
|
||||
p.printOperands(getSrcIndices());
|
||||
p << "], " << *getDstMemRef() << '[';
|
||||
p.printOperands(getDstIndices());
|
||||
p << "], " << *getNumElements();
|
||||
p << ", " << *getTagMemRef() << '[';
|
||||
p.printOperands(getTagIndices());
|
||||
p << ']';
|
||||
if (isStrided()) {
|
||||
*p << ", " << *getStride();
|
||||
*p << ", " << *getNumElementsPerStride();
|
||||
p << ", " << *getStride();
|
||||
p << ", " << *getNumElementsPerStride();
|
||||
}
|
||||
p->printOptionalAttrDict(getAttrs());
|
||||
*p << " : " << getSrcMemRef()->getType();
|
||||
*p << ", " << getDstMemRef()->getType();
|
||||
*p << ", " << getTagMemRef()->getType();
|
||||
p.printOptionalAttrDict(getAttrs());
|
||||
p << " : " << getSrcMemRef()->getType();
|
||||
p << ", " << getDstMemRef()->getType();
|
||||
p << ", " << getTagMemRef()->getType();
|
||||
}
|
||||
|
||||
// Parse DmaStartOp.
|
||||
@ -1589,15 +1589,15 @@ void DmaWaitOp::build(Builder *builder, OperationState &result,
|
||||
result.addOperands(numElements);
|
||||
}
|
||||
|
||||
void DmaWaitOp::print(OpAsmPrinter *p) {
|
||||
*p << "dma_wait ";
|
||||
p->printOperand(getTagMemRef());
|
||||
*p << '[';
|
||||
p->printOperands(getTagIndices());
|
||||
*p << "], ";
|
||||
p->printOperand(getNumElements());
|
||||
p->printOptionalAttrDict(getAttrs());
|
||||
*p << " : " << getTagMemRef()->getType();
|
||||
void DmaWaitOp::print(OpAsmPrinter &p) {
|
||||
p << "dma_wait ";
|
||||
p.printOperand(getTagMemRef());
|
||||
p << '[';
|
||||
p.printOperands(getTagIndices());
|
||||
p << "], ";
|
||||
p.printOperand(getNumElements());
|
||||
p.printOptionalAttrDict(getAttrs());
|
||||
p << " : " << getTagMemRef()->getType();
|
||||
}
|
||||
|
||||
// Parse DmaWaitOp.
|
||||
@ -1643,12 +1643,12 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
// ExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, ExtractElementOp op) {
|
||||
*p << "extract_element " << *op.getAggregate() << '[';
|
||||
p->printOperands(op.getIndices());
|
||||
*p << ']';
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.getAggregate()->getType();
|
||||
static void print(OpAsmPrinter &p, ExtractElementOp op) {
|
||||
p << "extract_element " << *op.getAggregate() << '[';
|
||||
p.printOperands(op.getIndices());
|
||||
p << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getAggregate()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
||||
@ -1725,12 +1725,12 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) {
|
||||
// LoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, LoadOp op) {
|
||||
*p << "load " << *op.getMemRef() << '[';
|
||||
p->printOperands(op.getIndices());
|
||||
*p << ']';
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.getMemRefType();
|
||||
static void print(OpAsmPrinter &p, LoadOp op) {
|
||||
p << "load " << *op.getMemRef() << '[';
|
||||
p.printOperands(op.getIndices());
|
||||
p << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getMemRefType();
|
||||
}
|
||||
|
||||
static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -1837,8 +1837,8 @@ OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
|
||||
// RankOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, RankOp op) {
|
||||
*p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType();
|
||||
static void print(OpAsmPrinter &p, RankOp op) {
|
||||
p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -1924,13 +1924,13 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
|
||||
parser.resolveOperands(opInfo, types, loc, result.operands));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, ReturnOp op) {
|
||||
*p << "return";
|
||||
static void print(OpAsmPrinter &p, ReturnOp op) {
|
||||
p << "return";
|
||||
if (op.getNumOperands() != 0) {
|
||||
*p << ' ';
|
||||
p->printOperands(op.getOperands());
|
||||
*p << " : ";
|
||||
interleaveComma(op.getOperandTypes(), *p);
|
||||
p << ' ';
|
||||
p.printOperands(op.getOperands());
|
||||
p << " : ";
|
||||
interleaveComma(op.getOperandTypes(), p);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1987,11 +1987,11 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
|
||||
parser.addTypeToList(type, result.types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, SelectOp op) {
|
||||
*p << "select ";
|
||||
p->printOperands(op.getOperands());
|
||||
*p << " : " << op.getTrueValue()->getType();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
static void print(OpAsmPrinter &p, SelectOp op) {
|
||||
p << "select ";
|
||||
p.printOperands(op.getOperands());
|
||||
p << " : " << op.getTrueValue()->getType();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
}
|
||||
|
||||
static LogicalResult verify(SelectOp op) {
|
||||
@ -2022,13 +2022,13 @@ OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
|
||||
// StoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, StoreOp op) {
|
||||
*p << "store " << *op.getValueToStore();
|
||||
*p << ", " << *op.getMemRef() << '[';
|
||||
p->printOperands(op.getIndices());
|
||||
*p << ']';
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.getMemRefType();
|
||||
static void print(OpAsmPrinter &p, StoreOp op) {
|
||||
p << "store " << *op.getValueToStore();
|
||||
p << ", " << *op.getMemRef() << '[';
|
||||
p.printOperands(op.getIndices());
|
||||
p << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getMemRefType();
|
||||
}
|
||||
|
||||
static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -2197,10 +2197,10 @@ static Type getTensorTypeFromMemRefType(Builder &b, Type type) {
|
||||
// TensorLoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, TensorLoadOp op) {
|
||||
*p << "tensor_load " << *op.getOperand();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.getOperand()->getType();
|
||||
static void print(OpAsmPrinter &p, TensorLoadOp op) {
|
||||
p << "tensor_load " << *op.getOperand();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getOperand()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseTensorLoadOp(OpAsmParser &parser,
|
||||
@ -2220,10 +2220,10 @@ static ParseResult parseTensorLoadOp(OpAsmParser &parser,
|
||||
// TensorStoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, TensorStoreOp op) {
|
||||
*p << "tensor_store " << *op.tensor() << ", " << *op.memref();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.memref()->getType();
|
||||
static void print(OpAsmPrinter &p, TensorStoreOp op) {
|
||||
p << "tensor_store " << *op.tensor() << ", " << *op.memref();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.memref()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseTensorStoreOp(OpAsmParser &parser,
|
||||
|
@ -49,10 +49,10 @@ mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
|
||||
// ExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, ExtractElementOp op) {
|
||||
*p << op.getOperationName() << " " << *op.vector() << op.position();
|
||||
p->printOptionalAttrDict(op.getAttrs(), {"position"});
|
||||
*p << " : " << op.vector()->getType();
|
||||
static void print(OpAsmPrinter &p, ExtractElementOp op) {
|
||||
p << op.getOperationName() << " " << *op.vector() << op.position();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"position"});
|
||||
p << " : " << op.vector()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
||||
@ -113,11 +113,11 @@ static LogicalResult verify(ExtractElementOp op) {
|
||||
// OuterProductOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, OuterProductOp op) {
|
||||
*p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
|
||||
static void print(OpAsmPrinter &p, OuterProductOp op) {
|
||||
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
|
||||
if (llvm::size(op.acc()) > 0)
|
||||
*p << ", " << **op.acc().begin();
|
||||
*p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
|
||||
p << ", " << **op.acc().begin();
|
||||
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseOuterProductOp(OpAsmParser &parser,
|
||||
@ -228,21 +228,21 @@ AffineMap VectorTransferReadOp::getPermutationMap() {
|
||||
return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
|
||||
}
|
||||
|
||||
void VectorTransferReadOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " ";
|
||||
p->printOperand(getMemRef());
|
||||
*p << "[";
|
||||
p->printOperands(getIndices());
|
||||
*p << "]";
|
||||
void VectorTransferReadOp::print(OpAsmPrinter &p) {
|
||||
p << getOperationName() << " ";
|
||||
p.printOperand(getMemRef());
|
||||
p << "[";
|
||||
p.printOperands(getIndices());
|
||||
p << "]";
|
||||
auto optionalPaddingValue = getPaddingValue();
|
||||
if (optionalPaddingValue) {
|
||||
*p << ", (";
|
||||
p->printOperand(*optionalPaddingValue);
|
||||
*p << ")";
|
||||
p << ", (";
|
||||
p.printOperand(*optionalPaddingValue);
|
||||
p << ")";
|
||||
}
|
||||
p->printOptionalAttrDict(getAttrs());
|
||||
*p << " : " << getMemRefType();
|
||||
*p << ", " << getResultType();
|
||||
p.printOptionalAttrDict(getAttrs());
|
||||
p << " : " << getMemRefType();
|
||||
p << ", " << getResultType();
|
||||
}
|
||||
|
||||
ParseResult VectorTransferReadOp::parse(OpAsmParser &parser,
|
||||
@ -396,18 +396,18 @@ AffineMap VectorTransferWriteOp::getPermutationMap() {
|
||||
return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
|
||||
}
|
||||
|
||||
void VectorTransferWriteOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName();
|
||||
*p << " " << *getVector();
|
||||
*p << ", " << *getMemRef();
|
||||
*p << "[";
|
||||
p->printOperands(getIndices());
|
||||
*p << "]";
|
||||
p->printOptionalAttrDict(getAttrs());
|
||||
*p << " : ";
|
||||
p->printType(getVectorType());
|
||||
*p << ", ";
|
||||
p->printType(getMemRefType());
|
||||
void VectorTransferWriteOp::print(OpAsmPrinter &p) {
|
||||
p << getOperationName();
|
||||
p << " " << *getVector();
|
||||
p << ", " << *getMemRef();
|
||||
p << "[";
|
||||
p.printOperands(getIndices());
|
||||
p << "]";
|
||||
p.printOptionalAttrDict(getAttrs());
|
||||
p << " : ";
|
||||
p.printType(getVectorType());
|
||||
p << ", ";
|
||||
p.printType(getMemRefType());
|
||||
}
|
||||
|
||||
ParseResult VectorTransferWriteOp::parse(OpAsmParser &parser,
|
||||
@ -524,9 +524,9 @@ ParseResult VectorTypeCastOp::parse(OpAsmParser &parser,
|
||||
parser.resolveOperand(operand, srcType, result.operands));
|
||||
}
|
||||
|
||||
void VectorTypeCastOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << ' ' << *getOperand() << " : "
|
||||
<< getOperand()->getType() << ", " << getType();
|
||||
void VectorTypeCastOp::print(OpAsmPrinter &p) {
|
||||
p << getOperationName() << ' ' << *getOperand() << " : "
|
||||
<< getOperand()->getType() << ", " << getType();
|
||||
}
|
||||
|
||||
LogicalResult VectorTypeCastOp::verify() {
|
||||
|
2
third_party/mlir/lib/IR/AsmPrinter.cpp
vendored
2
third_party/mlir/lib/IR/AsmPrinter.cpp
vendored
@ -1599,7 +1599,7 @@ void OperationPrinter::printOperation(Operation *op) {
|
||||
// Check to see if this is a known operation. If so, use the registered
|
||||
// custom printer hook.
|
||||
if (auto *opInfo = op->getAbstractOperation()) {
|
||||
opInfo->printAssembly(op, this);
|
||||
opInfo->printAssembly(op, *this);
|
||||
return;
|
||||
}
|
||||
|
||||
|
2
third_party/mlir/lib/IR/Function.cpp
vendored
2
third_party/mlir/lib/IR/Function.cpp
vendored
@ -86,7 +86,7 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
buildFuncType);
|
||||
}
|
||||
|
||||
void FuncOp::print(OpAsmPrinter *p) {
|
||||
void FuncOp::print(OpAsmPrinter &p) {
|
||||
FunctionType fnType = getType();
|
||||
impl::printFunctionLikeOp(p, *this, fnType.getInputs(), /*isVariadic=*/false,
|
||||
fnType.getResults());
|
||||
|
34
third_party/mlir/lib/IR/FunctionSupport.cpp
vendored
34
third_party/mlir/lib/IR/FunctionSupport.cpp
vendored
@ -163,46 +163,46 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
|
||||
|
||||
/// Print the signature of the function-like operation `op`. Assumes `op` has
|
||||
/// the FunctionLike trait and passed the verification.
|
||||
static void printSignature(OpAsmPrinter *p, Operation *op,
|
||||
static void printSignature(OpAsmPrinter &p, Operation *op,
|
||||
ArrayRef<Type> argTypes, bool isVariadic,
|
||||
ArrayRef<Type> results) {
|
||||
Region &body = op->getRegion(0);
|
||||
bool isExternal = body.empty();
|
||||
|
||||
*p << '(';
|
||||
p << '(';
|
||||
for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
|
||||
if (i > 0)
|
||||
*p << ", ";
|
||||
p << ", ";
|
||||
|
||||
if (!isExternal) {
|
||||
p->printOperand(body.front().getArgument(i));
|
||||
*p << ": ";
|
||||
p.printOperand(body.front().getArgument(i));
|
||||
p << ": ";
|
||||
}
|
||||
|
||||
p->printType(argTypes[i]);
|
||||
p->printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i));
|
||||
p.printType(argTypes[i]);
|
||||
p.printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i));
|
||||
}
|
||||
|
||||
if (isVariadic) {
|
||||
if (!argTypes.empty())
|
||||
*p << ", ";
|
||||
*p << "...";
|
||||
p << ", ";
|
||||
p << "...";
|
||||
}
|
||||
|
||||
*p << ')';
|
||||
p->printOptionalArrowTypeList(results);
|
||||
p << ')';
|
||||
p.printOptionalArrowTypeList(results);
|
||||
}
|
||||
|
||||
/// Printer implementation for function-like operations. Accepts lists of
|
||||
/// argument and result types to use while printing.
|
||||
void mlir::impl::printFunctionLikeOp(OpAsmPrinter *p, Operation *op,
|
||||
void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
|
||||
ArrayRef<Type> argTypes, bool isVariadic,
|
||||
ArrayRef<Type> results) {
|
||||
// Print the operation and the function name.
|
||||
auto funcName =
|
||||
op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName())
|
||||
.getValue();
|
||||
*p << op->getName() << " @" << funcName;
|
||||
p << op->getName() << " @" << funcName;
|
||||
|
||||
// Print the signature.
|
||||
printSignature(p, op, argTypes, isVariadic, results);
|
||||
@ -221,13 +221,13 @@ void mlir::impl::printFunctionLikeOp(OpAsmPrinter *p, Operation *op,
|
||||
|
||||
auto attrs = op->getAttrs();
|
||||
if (attrs.size() > ignoredAttrs.size()) {
|
||||
*p << "\n attributes ";
|
||||
p->printOptionalAttrDict(attrs, ignoredAttrs);
|
||||
p << "\n attributes ";
|
||||
p.printOptionalAttrDict(attrs, ignoredAttrs);
|
||||
}
|
||||
|
||||
// Print the body if this is not an external function.
|
||||
Region &body = op->getRegion(0);
|
||||
if (!body.empty())
|
||||
p->printRegion(body, /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/true);
|
||||
p.printRegion(body, /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/true);
|
||||
}
|
||||
|
12
third_party/mlir/lib/IR/Module.cpp
vendored
12
third_party/mlir/lib/IR/Module.cpp
vendored
@ -53,19 +53,19 @@ ParseResult ModuleOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void ModuleOp::print(OpAsmPrinter *p) {
|
||||
*p << "module";
|
||||
void ModuleOp::print(OpAsmPrinter &p) {
|
||||
p << "module";
|
||||
|
||||
// Print the module attributes.
|
||||
auto attrs = getAttrs();
|
||||
if (!attrs.empty()) {
|
||||
*p << " attributes";
|
||||
p->printOptionalAttrDict(attrs, {});
|
||||
p << " attributes";
|
||||
p.printOptionalAttrDict(attrs, {});
|
||||
}
|
||||
|
||||
// Print the region.
|
||||
p->printRegion(getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
p.printRegion(getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
}
|
||||
|
||||
LogicalResult ModuleOp::verify() {
|
||||
|
23
third_party/mlir/lib/IR/Operation.cpp
vendored
23
third_party/mlir/lib/IR/Operation.cpp
vendored
@ -617,7 +617,7 @@ ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) {
|
||||
}
|
||||
|
||||
// The fallback for the printer is to print in the generic assembly form.
|
||||
void OpState::print(OpAsmPrinter *p) { p->printGenericOp(getOperation()); }
|
||||
void OpState::print(OpAsmPrinter &p) { p.printGenericOp(getOperation()); }
|
||||
|
||||
/// Emit an error about fatal conditions with this operation, reporting up to
|
||||
/// any diagnostic handlers that may be listening.
|
||||
@ -960,7 +960,7 @@ ParseResult impl::parseBinaryOp(OpAsmParser &parser, OperationState &result) {
|
||||
parser.addTypeToList(type, result.types));
|
||||
}
|
||||
|
||||
void impl::printBinaryOp(Operation *op, OpAsmPrinter *p) {
|
||||
void impl::printBinaryOp(Operation *op, OpAsmPrinter &p) {
|
||||
assert(op->getNumOperands() == 2 && "binary op should have two operands");
|
||||
assert(op->getNumResults() == 1 && "binary op should have one result");
|
||||
|
||||
@ -969,15 +969,14 @@ void impl::printBinaryOp(Operation *op, OpAsmPrinter *p) {
|
||||
auto resultType = op->getResult(0)->getType();
|
||||
if (op->getOperand(0)->getType() != resultType ||
|
||||
op->getOperand(1)->getType() != resultType) {
|
||||
p->printGenericOp(op);
|
||||
p.printGenericOp(op);
|
||||
return;
|
||||
}
|
||||
|
||||
*p << op->getName() << ' ' << *op->getOperand(0) << ", "
|
||||
<< *op->getOperand(1);
|
||||
p->printOptionalAttrDict(op->getAttrs());
|
||||
p << op->getName() << ' ' << *op->getOperand(0) << ", " << *op->getOperand(1);
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
// Now we can output only one type for all operands and the result.
|
||||
*p << " : " << op->getResult(0)->getType();
|
||||
p << " : " << op->getResult(0)->getType();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1001,11 +1000,11 @@ ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) {
|
||||
parser.addTypeToList(dstType, result.types));
|
||||
}
|
||||
|
||||
void impl::printCastOp(Operation *op, OpAsmPrinter *p) {
|
||||
*p << op->getName() << ' ' << *op->getOperand(0);
|
||||
p->printOptionalAttrDict(op->getAttrs());
|
||||
*p << " : " << op->getOperand(0)->getType() << " to "
|
||||
<< op->getResult(0)->getType();
|
||||
void impl::printCastOp(Operation *op, OpAsmPrinter &p) {
|
||||
p << op->getName() << ' ' << *op->getOperand(0);
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
p << " : " << op->getOperand(0)->getType() << " to "
|
||||
<< op->getResult(0)->getType();
|
||||
}
|
||||
|
||||
Value *impl::foldCastOp(Operation *op) {
|
||||
|
@ -115,11 +115,11 @@ static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
|
||||
/*enableNameShadowing=*/true);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, IsolatedRegionOp op) {
|
||||
*p << "test.isolated_region ";
|
||||
p->printOperand(op.getOperand());
|
||||
p->shadowRegionArgs(op.region(), op.getOperand());
|
||||
p->printRegion(op.region(), /*printEntryBlockArgs=*/false);
|
||||
static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
|
||||
p << "test.isolated_region ";
|
||||
p.printOperand(op.getOperand());
|
||||
p.shadowRegionArgs(op.region(), op.getOperand());
|
||||
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -135,8 +135,8 @@ static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, WrappedKeywordOp op) {
|
||||
*p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
|
||||
static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
|
||||
p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -169,9 +169,9 @@ static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, WrappingRegionOp op) {
|
||||
*p << op.getOperationName() << " wraps ";
|
||||
p->printGenericOp(&op.region().front().front());
|
||||
static void print(OpAsmPrinter &p, WrappingRegionOp op) {
|
||||
p << op.getOperationName() << " wraps ";
|
||||
p.printGenericOp(&op.region().front().front());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -119,7 +119,7 @@ def TypeArrayAttrOp : TEST_Op<"type_array_attr"> {
|
||||
}
|
||||
def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> {
|
||||
let arguments = (ins StrAttr:$attr);
|
||||
let printer = [{ *p << getAttr("attr"); }];
|
||||
let printer = [{ p << getAttr("attr"); }];
|
||||
let parser = [{
|
||||
Attribute attr;
|
||||
Type stringType = OpaqueType::get(Identifier::get("foo",
|
||||
|
@ -1090,7 +1090,7 @@ void OpEmitter::genPrinter() {
|
||||
if (!codeInit)
|
||||
return;
|
||||
|
||||
auto &method = opClass.newMethod("void", "print", "OpAsmPrinter *p");
|
||||
auto &method = opClass.newMethod("void", "print", "OpAsmPrinter &p");
|
||||
FmtContext fctx;
|
||||
fctx.addSubst("cppClass", opClass.getClassName());
|
||||
auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r");
|
||||
|
Loading…
Reference in New Issue
Block a user