diff --git a/third_party/mlir/include/mlir/Transforms/DialectConversion.h b/third_party/mlir/include/mlir/Transforms/DialectConversion.h index 2deb0c9c048..88669505e23 100644 --- a/third_party/mlir/include/mlir/Transforms/DialectConversion.h +++ b/third_party/mlir/include/mlir/Transforms/DialectConversion.h @@ -338,6 +338,10 @@ public: return cast(cloneWithoutRegions(op.getOperation())); } + /// Return the converted value that replaces 'key'. Return 'key' if there is + /// no such a converted value. + Value *getRemappedValue(Value *key); + //===--------------------------------------------------------------------===// // PatternRewriter Hooks //===--------------------------------------------------------------------===// diff --git a/third_party/mlir/lib/Transforms/DialectConversion.cpp b/third_party/mlir/lib/Transforms/DialectConversion.cpp index a2065f16a21..7931932a789 100644 --- a/third_party/mlir/lib/Transforms/DialectConversion.cpp +++ b/third_party/mlir/lib/Transforms/DialectConversion.cpp @@ -803,6 +803,12 @@ Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) { return newOp; } +/// Return the converted value that replaces 'key'. Return 'key' if there is +/// no such a converted value. +Value *ConversionPatternRewriter::getRemappedValue(Value *key) { + return impl->mapping.lookupOrDefault(key); +} + /// PatternRewriter hook for splitting a block into two parts. Block *ConversionPatternRewriter::splitBlock(Block *block, Block::iterator before) { diff --git a/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp b/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp index 936d7632967..5ef03606dbe 100644 --- a/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -435,3 +435,66 @@ static mlir::PassRegistration return std::make_unique( legalizerConversionMode); }); + +//===----------------------------------------------------------------------===// +// ConversionPatternRewriter::getRemappedValue testing. This method is used +// to get the remapped value of a original value that was replaced using +// ConversionPatternRewriter. +namespace { +/// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with +/// a one-operand two-result OneVResOneVOperandOp1 by replicating its original +/// operand twice. +/// +/// Example: +/// %1 = test.one_variadic_out_one_variadic_in1"(%0) +/// is replaced with: +/// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) +struct OneVResOneVOperandOp1Converter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + PatternMatchResult + matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto origOps = op.getOperands(); + assert(std::distance(origOps.begin(), origOps.end()) == 1 && + "One operand expected"); + Value *origOp = *origOps.begin(); + SmallVector remappedOperands; + // Replicate the remapped original operand twice. Note that we don't used + // the remapped 'operand' since the goal is testing 'getRemappedValue'. + remappedOperands.push_back(rewriter.getRemappedValue(origOp)); + remappedOperands.push_back(rewriter.getRemappedValue(origOp)); + + SmallVector resultTypes(op.getResultTypes()); + rewriter.replaceOpWithNewOp(op, resultTypes, + remappedOperands); + return matchSuccess(); + } +}; + +struct TestRemappedValue : public mlir::FunctionPass { + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + mlir::ConversionTarget target(getContext()); + target.addLegalOp(); + // We make OneVResOneVOperandOp1 legal only when it has more that one + // operand. This will trigger the conversion that will replace one-operand + // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. + target.addDynamicallyLegalOp( + [](Operation *op) -> bool { + return std::distance(op->operand_begin(), op->operand_end()) > 1; + }); + + if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) { + signalPassFailure(); + } + } +}; +} // end anonymous namespace + +static PassRegistration remapped_value_pass( + "test-remapped-value", + "Test public remapped value mechanism in ConversionPatternRewriter");