[tf:mlir] Hoist unary cwise operations out of Unpack

Rewrite:
   %unpacked:N = "tf.Unpack"(%0)
   %neg0 = "tf.Neg"(%unpacked#0)
   %neg1 = "tf.Neg"(%unpacked#1)
   ...
   %negN-1 = "tf.Neg"(%unpacked:N-1)

 To:

   %neg = "tf.Neg"(%0)
   %unpacked:N = "tf.Unpack"(%neg)

PiperOrigin-RevId: 346866427
Change-Id: I88ecb2011016d0bcfcff913dd8b4a570cf1e0ce6
This commit is contained in:
Eugene Zhulenev 2020-12-10 14:31:59 -08:00 committed by TensorFlower Gardener
parent 10c7c897c8
commit bd9ffecdb6
3 changed files with 84 additions and 1 deletions

View File

@ -8292,7 +8292,7 @@ def TF_NdtriOp : TF_Op<"Ndtri", [NoSideEffect]> {
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_NegOp : TF_Op<"Neg", [Involution, NoSideEffect, SameOperandsAndResultType]> {
def TF_NegOp : TF_Op<"Neg", [Involution, NoSideEffect, SameOperandsAndResultType, TF_CwiseUnary]> {
let summary = "Computes numerical negative value element-wise.";
let description = [{
@ -15680,6 +15680,8 @@ This is the opposite of `pack`.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let verifier = [{ return Verify(*this); }];
let hasCanonicalizer = 1;
}
def TF_UnsortedSegmentMaxOp : TF_Op<"UnsortedSegmentMax", [NoSideEffect]> {

View File

@ -2518,6 +2518,74 @@ static LogicalResult Verify(UnpackOp op) {
return success();
}
namespace {
// Hoist coefficient-wise unary operation out of the Unpack op:
//
// %unpacked:N = "tf.Unpack"(%0)
// %neg0 = "tf.Neg"(%unpacked#0)
// %neg1 = "tf.Neg"(%unpacked#1)
// ...
// %negN-1 = "tf.Neg"(%unpacked:N-1)
//
// Rewrite it to:
//
// %neg = "tf.Neg"(%0)
// %unpacked:N = "tf.Unpack"(%neg)
class HoistCwiseUnaryOutOfUnpack : public OpRewritePattern<UnpackOp> {
public:
explicit HoistCwiseUnaryOutOfUnpack(MLIRContext *context)
: OpRewritePattern<UnpackOp>(context) {}
LogicalResult matchAndRewrite(UnpackOp op,
PatternRewriter &rewriter) const override;
};
LogicalResult HoistCwiseUnaryOutOfUnpack::matchAndRewrite(
UnpackOp op, PatternRewriter &rewriter) const {
auto loc = op.getLoc();
// First unpack user must be coeff-wise unary operation.
Operation *first_user = *op->getUsers().begin();
if (!first_user->hasTrait<OpTrait::TF::CwiseUnary>()) return failure();
// All unpack users must be defined by the op of same kind.
bool users_same_op = llvm::all_of(op->getUsers(), [&](Operation *user) {
return user->getName() == first_user->getName();
});
if (!users_same_op) return failure();
// Pass unpack operand to unary operation.
OperationState new_unary_op_state(loc, first_user->getName().getStringRef(),
op.getOperand(), op.getOperand().getType(),
ArrayRef<NamedAttribute>());
Operation *new_unary_op = rewriter.createOperation(new_unary_op_state);
// Unpack results after applying unary operation.
auto unpack_unary_op = rewriter.create<UnpackOp>(
loc, op.getResultTypes(), new_unary_op->getResult(0), op.axis());
// Bypass all users of the original unpack operation and use `unpack_unary_op`
// results instead.
for (auto pair : llvm::zip(op.getResults(), unpack_unary_op.getResults())) {
OpResult old_result = std::get<0>(pair); // result of original Unpack
OpResult new_result = std::get<1>(pair); // result of transformed Unpack
for (Operation *user : llvm::make_early_inc_range(old_result.getUsers()))
rewriter.replaceOp(user, ValueRange(new_result));
}
// Erase original unpack operation.
rewriter.eraseOp(op.getOperation());
return success();
}
} // namespace
void UnpackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<HoistCwiseUnaryOutOfUnpack>(context);
}
//===----------------------------------------------------------------------===//
// Unsorted segment reduction ops
//===----------------------------------------------------------------------===//

View File

@ -1341,3 +1341,16 @@ func @testVariableToVariableV2() {
return
}
// CHECK-LABEL: testUnpackAndCwiseUnary
func @testUnpackAndCwiseUnary(%arg0: tensor<?x2xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
// CHECK: %[[NEG:.*]] = "tf.Neg"(%arg0)
// CHECK: %[[UNPACK:.*]]:2 = "tf.Unpack"(%[[NEG]])
%unpacked:2 = "tf.Unpack"(%arg0) {axis = 1 : i64, device = ""}
: (tensor<?x2xf32>) -> (tensor<?xf32>, tensor<?xf32>)
%0 = "tf.Neg"(%unpacked#0): (tensor<?xf32>) -> tensor<?xf32>
%1 = "tf.Neg"(%unpacked#1): (tensor<?xf32>) -> tensor<?xf32>
// CHECK: return %[[UNPACK]]#0, %[[UNPACK]]#1
return %0, %1 : tensor<?xf32>, tensor<?xf32>
}