Add getRemappedValue to ConversionPatternRewriter
This method is needed for N->1 conversion patterns to retrieve remapped Values used in the original N operations. Closes #237 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/237 from dcaballe:dcaballe/getRemappedValue 1f64fadcf2b203f7b336ff0c5838b116ae3625db PiperOrigin-RevId: 281321881 Change-Id: I980a67573634c8a8b65ae74a6e7b84e6080ad2be
This commit is contained in:
parent
cd67f4f372
commit
a6ceaa8440
@ -338,6 +338,10 @@ public:
|
|||||||
return cast<OpT>(cloneWithoutRegions(op.getOperation()));
|
return cast<OpT>(cloneWithoutRegions(op.getOperation()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return the converted value that replaces 'key'. Return 'key' if there is
|
||||||
|
/// no such a converted value.
|
||||||
|
Value *getRemappedValue(Value *key);
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// PatternRewriter Hooks
|
// PatternRewriter Hooks
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
@ -803,6 +803,12 @@ Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) {
|
|||||||
return newOp;
|
return newOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return the converted value that replaces 'key'. Return 'key' if there is
|
||||||
|
/// no such a converted value.
|
||||||
|
Value *ConversionPatternRewriter::getRemappedValue(Value *key) {
|
||||||
|
return impl->mapping.lookupOrDefault(key);
|
||||||
|
}
|
||||||
|
|
||||||
/// PatternRewriter hook for splitting a block into two parts.
|
/// PatternRewriter hook for splitting a block into two parts.
|
||||||
Block *ConversionPatternRewriter::splitBlock(Block *block,
|
Block *ConversionPatternRewriter::splitBlock(Block *block,
|
||||||
Block::iterator before) {
|
Block::iterator before) {
|
||||||
|
@ -435,3 +435,66 @@ static mlir::PassRegistration<TestLegalizePatternDriver>
|
|||||||
return std::make_unique<TestLegalizePatternDriver>(
|
return std::make_unique<TestLegalizePatternDriver>(
|
||||||
legalizerConversionMode);
|
legalizerConversionMode);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConversionPatternRewriter::getRemappedValue testing. This method is used
|
||||||
|
// to get the remapped value of a original value that was replaced using
|
||||||
|
// ConversionPatternRewriter.
|
||||||
|
namespace {
|
||||||
|
/// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
|
||||||
|
/// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
|
||||||
|
/// operand twice.
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
/// %1 = test.one_variadic_out_one_variadic_in1"(%0)
|
||||||
|
/// is replaced with:
|
||||||
|
/// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
|
||||||
|
struct OneVResOneVOperandOp1Converter
|
||||||
|
: public OpConversionPattern<OneVResOneVOperandOp1> {
|
||||||
|
using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto origOps = op.getOperands();
|
||||||
|
assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
|
||||||
|
"One operand expected");
|
||||||
|
Value *origOp = *origOps.begin();
|
||||||
|
SmallVector<Value *, 2> remappedOperands;
|
||||||
|
// Replicate the remapped original operand twice. Note that we don't used
|
||||||
|
// the remapped 'operand' since the goal is testing 'getRemappedValue'.
|
||||||
|
remappedOperands.push_back(rewriter.getRemappedValue(origOp));
|
||||||
|
remappedOperands.push_back(rewriter.getRemappedValue(origOp));
|
||||||
|
|
||||||
|
SmallVector<Type, 1> resultTypes(op.getResultTypes());
|
||||||
|
rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, resultTypes,
|
||||||
|
remappedOperands);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TestRemappedValue : public mlir::FunctionPass<TestRemappedValue> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
mlir::OwningRewritePatternList patterns;
|
||||||
|
patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
|
||||||
|
|
||||||
|
mlir::ConversionTarget target(getContext());
|
||||||
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
|
||||||
|
// We make OneVResOneVOperandOp1 legal only when it has more that one
|
||||||
|
// operand. This will trigger the conversion that will replace one-operand
|
||||||
|
// OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
|
||||||
|
target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
|
||||||
|
[](Operation *op) -> bool {
|
||||||
|
return std::distance(op->operand_begin(), op->operand_end()) > 1;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) {
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
static PassRegistration<TestRemappedValue> remapped_value_pass(
|
||||||
|
"test-remapped-value",
|
||||||
|
"Test public remapped value mechanism in ConversionPatternRewriter");
|
||||||
|
Loading…
Reference in New Issue
Block a user