From bd9ffecdb6214e1c5a9f5e3cf80840d068ee468e Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 10 Dec 2020 14:31:59 -0800 Subject: [PATCH] [tf:mlir] Hoist unary cwise operations out of Unpack Rewrite: %unpacked:N = "tf.Unpack"(%0) %neg0 = "tf.Neg"(%unpacked#0) %neg1 = "tf.Neg"(%unpacked#1) ... %negN-1 = "tf.Neg"(%unpacked:N-1) To: %neg = "tf.Neg"(%0) %unpacked:N = "tf.Unpack"(%neg) PiperOrigin-RevId: 346866427 Change-Id: I88ecb2011016d0bcfcff913dd8b4a570cf1e0ce6 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 4 +- .../compiler/mlir/tensorflow/ir/tf_ops_n_z.cc | 68 +++++++++++++++++++ .../mlir/tensorflow/tests/canonicalize.mlir | 13 ++++ 3 files changed, 84 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 281e3f49c98..5cf85d939e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -8292,7 +8292,7 @@ def TF_NdtriOp : TF_Op<"Ndtri", [NoSideEffect]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_NegOp : TF_Op<"Neg", [Involution, NoSideEffect, SameOperandsAndResultType]> { +def TF_NegOp : TF_Op<"Neg", [Involution, NoSideEffect, SameOperandsAndResultType, TF_CwiseUnary]> { let summary = "Computes numerical negative value element-wise."; let description = [{ @@ -15680,6 +15680,8 @@ This is the opposite of `pack`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let verifier = [{ return Verify(*this); }]; + + let hasCanonicalizer = 1; } def TF_UnsortedSegmentMaxOp : TF_Op<"UnsortedSegmentMax", [NoSideEffect]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 67afbc18592..a2cbac6a4d2 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -2518,6 +2518,74 @@ static LogicalResult Verify(UnpackOp op) { return success(); } +namespace { + +// Hoist coefficient-wise unary operation out of the Unpack op: +// +// %unpacked:N = "tf.Unpack"(%0) +// %neg0 = "tf.Neg"(%unpacked#0) +// %neg1 = "tf.Neg"(%unpacked#1) +// ... +// %negN-1 = "tf.Neg"(%unpacked:N-1) +// +// Rewrite it to: +// +// %neg = "tf.Neg"(%0) +// %unpacked:N = "tf.Unpack"(%neg) +class HoistCwiseUnaryOutOfUnpack : public OpRewritePattern { + public: + explicit HoistCwiseUnaryOutOfUnpack(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(UnpackOp op, + PatternRewriter &rewriter) const override; +}; + +LogicalResult HoistCwiseUnaryOutOfUnpack::matchAndRewrite( + UnpackOp op, PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + + // First unpack user must be coeff-wise unary operation. + Operation *first_user = *op->getUsers().begin(); + if (!first_user->hasTrait()) return failure(); + + // All unpack users must be defined by the op of same kind. + bool users_same_op = llvm::all_of(op->getUsers(), [&](Operation *user) { + return user->getName() == first_user->getName(); + }); + if (!users_same_op) return failure(); + + // Pass unpack operand to unary operation. + OperationState new_unary_op_state(loc, first_user->getName().getStringRef(), + op.getOperand(), op.getOperand().getType(), + ArrayRef()); + Operation *new_unary_op = rewriter.createOperation(new_unary_op_state); + + // Unpack results after applying unary operation. + auto unpack_unary_op = rewriter.create( + loc, op.getResultTypes(), new_unary_op->getResult(0), op.axis()); + + // Bypass all users of the original unpack operation and use `unpack_unary_op` + // results instead. + for (auto pair : llvm::zip(op.getResults(), unpack_unary_op.getResults())) { + OpResult old_result = std::get<0>(pair); // result of original Unpack + OpResult new_result = std::get<1>(pair); // result of transformed Unpack + for (Operation *user : llvm::make_early_inc_range(old_result.getUsers())) + rewriter.replaceOp(user, ValueRange(new_result)); + } + + // Erase original unpack operation. + rewriter.eraseOp(op.getOperation()); + + return success(); +} + +} // namespace + +void UnpackOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // Unsorted segment reduction ops //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 9b7993d97d7..e344d3250aa 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -1341,3 +1341,16 @@ func @testVariableToVariableV2() { return } +// CHECK-LABEL: testUnpackAndCwiseUnary +func @testUnpackAndCwiseUnary(%arg0: tensor) -> (tensor, tensor) { + + // CHECK: %[[NEG:.*]] = "tf.Neg"(%arg0) + // CHECK: %[[UNPACK:.*]]:2 = "tf.Unpack"(%[[NEG]]) + %unpacked:2 = "tf.Unpack"(%arg0) {axis = 1 : i64, device = ""} + : (tensor) -> (tensor, tensor) + %0 = "tf.Neg"(%unpacked#0): (tensor) -> tensor + %1 = "tf.Neg"(%unpacked#1): (tensor) -> tensor + + // CHECK: return %[[UNPACK]]#0, %[[UNPACK]]#1 + return %0, %1 : tensor, tensor +}