Lowerings to remove TF::IdentityNOp during XLA Lowering

PiperOrigin-RevId: 338118295
Change-Id: I888e1dd096f633816df6a42ae96042e5407f662c
This commit is contained in:
Robert Suderman 2020-10-20 12:55:16 -07:00 committed by TensorFlower Gardener
parent 59e30f648f
commit 3cf23c75bc
2 changed files with 26 additions and 7 deletions

View File

@ -1021,6 +1021,13 @@ func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> {
return %0: tensor<1xi32>
}
// CHECK-LABEL: func @identityN
func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) {
// CHECK-NEXT: return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32>
%0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>)
return %0#0, %0#1: tensor<1xi32>, tensor<1xf32>
}
// CHECK-LABEL: func @stopgradient
func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> {
// CHECK-NEXT: return %arg0 : tensor<1xi32>

View File

@ -1705,6 +1705,17 @@ class ConvertEinsumOp : public OpRewritePattern<TF::EinsumOp> {
}
};
// Bypasses IdentityN op.
class ConvertIdentityNOp : public OpRewritePattern<TF::IdentityNOp> {
public:
using OpRewritePattern<TF::IdentityNOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::IdentityNOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOp(op, op.getOperands());
return success();
}
};
template <typename OpTy>
class ConvertFFTOp : public OpRewritePattern<OpTy> {
public:
@ -6131,13 +6142,14 @@ void PopulateLegalizeTfPatterns(MLIRContext *context,
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp,
ConvertAvgPool2DOp, ConvertAvgPool3DOp, ConvertAvgPool2DGradOp,
ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp,
ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp,
ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
ConvertDynamicRangeOp, ConvertMatrixDiagPartV3Op, ConvertRangeOp,
ConvertSelectV2Op, ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
ConvertIdentityNOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp,
ConvertMaxOp, ConvertMinOp, ConvertAvgPool2DOp, ConvertAvgPool3DOp,
ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp, ConvertMaxPool2DOp,
ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp,
ConvertMatrixDiagPartV3Op, ConvertRangeOp, ConvertSelectV2Op,
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,