Lowerings to remove TF::IdentityNOp during XLA Lowering
PiperOrigin-RevId: 338118295 Change-Id: I888e1dd096f633816df6a42ae96042e5407f662c
This commit is contained in:
parent
59e30f648f
commit
3cf23c75bc
@ -1021,6 +1021,13 @@ func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
return %0: tensor<1xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @identityN
|
||||
func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) {
|
||||
// CHECK-NEXT: return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32>
|
||||
%0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>)
|
||||
return %0#0, %0#1: tensor<1xi32>, tensor<1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @stopgradient
|
||||
func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<1xi32>
|
||||
|
@ -1705,6 +1705,17 @@ class ConvertEinsumOp : public OpRewritePattern<TF::EinsumOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Bypasses IdentityN op.
|
||||
class ConvertIdentityNOp : public OpRewritePattern<TF::IdentityNOp> {
|
||||
public:
|
||||
using OpRewritePattern<TF::IdentityNOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(TF::IdentityNOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOp(op, op.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
class ConvertFFTOp : public OpRewritePattern<OpTy> {
|
||||
public:
|
||||
@ -6131,13 +6142,14 @@ void PopulateLegalizeTfPatterns(MLIRContext *context,
|
||||
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
|
||||
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
|
||||
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
|
||||
ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp,
|
||||
ConvertAvgPool2DOp, ConvertAvgPool3DOp, ConvertAvgPool2DGradOp,
|
||||
ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp,
|
||||
ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp,
|
||||
ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
|
||||
ConvertDynamicRangeOp, ConvertMatrixDiagPartV3Op, ConvertRangeOp,
|
||||
ConvertSelectV2Op, ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
|
||||
ConvertIdentityNOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp,
|
||||
ConvertMaxOp, ConvertMinOp, ConvertAvgPool2DOp, ConvertAvgPool3DOp,
|
||||
ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp, ConvertMaxPool2DOp,
|
||||
ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
|
||||
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
|
||||
ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp,
|
||||
ConvertMatrixDiagPartV3Op, ConvertRangeOp, ConvertSelectV2Op,
|
||||
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
|
||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
||||
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
||||
|
Loading…
x
Reference in New Issue
Block a user