Use declarative rewrite patterns for TopKV2 legalization
PiperOrigin-RevId: 257685698
This commit is contained in:
parent
dbd23da825
commit
14bfae8a04
@ -229,6 +229,9 @@ def : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2), (TFL_MeanOp $arg0, $arg1, $a
|
||||
|
||||
def : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), (TFL_SumOp $arg, $axes, $arg2)>;
|
||||
|
||||
// TopK in TFL is always sorted so we ignore that attribute here.
|
||||
def : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted), (TFL_TopKV2Op $input, $k)>;
|
||||
|
||||
def : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMinOp $arg0, $arg1, $arg2)>;
|
||||
|
||||
def : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMaxOp $arg0, $arg1, $arg2)>;
|
||||
|
@ -72,7 +72,6 @@ DECL_CONVERT_OP(MatMul);
|
||||
DECL_CONVERT_OP(Pack);
|
||||
DECL_CONVERT_OP(Split);
|
||||
DECL_CONVERT_OP(SplitV);
|
||||
DECL_CONVERT_OP(TopKV2);
|
||||
DECL_CONVERT_OP(Unpack);
|
||||
|
||||
#undef DECL_CONVERT_OP
|
||||
@ -207,14 +206,6 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFTopKV2Op::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
// TopK in TFL is always sorted so we ignore that attribute here.
|
||||
rewriter.replaceOpWithNewOp<TFL::TopKV2Op>(op, op->getOperand(0),
|
||||
op->getOperand(1));
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_unpack_op = cast<TF::UnpackOp>(op);
|
||||
@ -239,7 +230,7 @@ void LegalizeTF::runOnFunction() {
|
||||
populateWithGenerated(ctx, &patterns);
|
||||
RewriteListBuilder<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFGatherOp,
|
||||
ConvertTFGatherV2Op, ConvertTFMatMulOp, ConvertTFPackOp,
|
||||
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFTopKV2Op,
|
||||
ConvertTFSplitOp, ConvertTFSplitVOp,
|
||||
ConvertTFUnpackOp>::build(patterns, ctx);
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user