Lowerings to remove TF::IdentityNOp during XLA Lowering
PiperOrigin-RevId: 338118295 Change-Id: I888e1dd096f633816df6a42ae96042e5407f662c
This commit is contained in:
parent
59e30f648f
commit
3cf23c75bc
@ -1021,6 +1021,13 @@ func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
|||||||
return %0: tensor<1xi32>
|
return %0: tensor<1xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @identityN
|
||||||
|
func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) {
|
||||||
|
// CHECK-NEXT: return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32>
|
||||||
|
%0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>)
|
||||||
|
return %0#0, %0#1: tensor<1xi32>, tensor<1xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @stopgradient
|
// CHECK-LABEL: func @stopgradient
|
||||||
func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||||
// CHECK-NEXT: return %arg0 : tensor<1xi32>
|
// CHECK-NEXT: return %arg0 : tensor<1xi32>
|
||||||
|
@ -1705,6 +1705,17 @@ class ConvertEinsumOp : public OpRewritePattern<TF::EinsumOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Bypasses IdentityN op.
|
||||||
|
class ConvertIdentityNOp : public OpRewritePattern<TF::IdentityNOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<TF::IdentityNOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(TF::IdentityNOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOp(op, op.getOperands());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
class ConvertFFTOp : public OpRewritePattern<OpTy> {
|
class ConvertFFTOp : public OpRewritePattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
@ -6131,13 +6142,14 @@ void PopulateLegalizeTfPatterns(MLIRContext *context,
|
|||||||
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
|
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
|
||||||
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
|
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
|
||||||
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
|
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
|
||||||
ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp,
|
ConvertIdentityNOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp,
|
||||||
ConvertAvgPool2DOp, ConvertAvgPool3DOp, ConvertAvgPool2DGradOp,
|
ConvertMaxOp, ConvertMinOp, ConvertAvgPool2DOp, ConvertAvgPool3DOp,
|
||||||
ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp,
|
ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp, ConvertMaxPool2DOp,
|
||||||
ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp,
|
ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
|
||||||
ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
|
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
|
||||||
ConvertDynamicRangeOp, ConvertMatrixDiagPartV3Op, ConvertRangeOp,
|
ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp,
|
||||||
ConvertSelectV2Op, ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
|
ConvertMatrixDiagPartV3Op, ConvertRangeOp, ConvertSelectV2Op,
|
||||||
|
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
|
||||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
||||||
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user