[tf:mlir] Hoist unary cwise operations out of Unpack
Rewrite: %unpacked:N = "tf.Unpack"(%0) %neg0 = "tf.Neg"(%unpacked#0) %neg1 = "tf.Neg"(%unpacked#1) ... %negN-1 = "tf.Neg"(%unpacked:N-1) To: %neg = "tf.Neg"(%0) %unpacked:N = "tf.Unpack"(%neg) PiperOrigin-RevId: 346866427 Change-Id: I88ecb2011016d0bcfcff913dd8b4a570cf1e0ce6
This commit is contained in:
parent
10c7c897c8
commit
bd9ffecdb6
@ -8292,7 +8292,7 @@ def TF_NdtriOp : TF_Op<"Ndtri", [NoSideEffect]> {
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_NegOp : TF_Op<"Neg", [Involution, NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_NegOp : TF_Op<"Neg", [Involution, NoSideEffect, SameOperandsAndResultType, TF_CwiseUnary]> {
|
||||
let summary = "Computes numerical negative value element-wise.";
|
||||
|
||||
let description = [{
|
||||
@ -15680,6 +15680,8 @@ This is the opposite of `pack`.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_UnsortedSegmentMaxOp : TF_Op<"UnsortedSegmentMax", [NoSideEffect]> {
|
||||
|
||||
@ -2518,6 +2518,74 @@ static LogicalResult Verify(UnpackOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Hoist coefficient-wise unary operation out of the Unpack op:
|
||||
//
|
||||
// %unpacked:N = "tf.Unpack"(%0)
|
||||
// %neg0 = "tf.Neg"(%unpacked#0)
|
||||
// %neg1 = "tf.Neg"(%unpacked#1)
|
||||
// ...
|
||||
// %negN-1 = "tf.Neg"(%unpacked:N-1)
|
||||
//
|
||||
// Rewrite it to:
|
||||
//
|
||||
// %neg = "tf.Neg"(%0)
|
||||
// %unpacked:N = "tf.Unpack"(%neg)
|
||||
class HoistCwiseUnaryOutOfUnpack : public OpRewritePattern<UnpackOp> {
|
||||
public:
|
||||
explicit HoistCwiseUnaryOutOfUnpack(MLIRContext *context)
|
||||
: OpRewritePattern<UnpackOp>(context) {}
|
||||
LogicalResult matchAndRewrite(UnpackOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
LogicalResult HoistCwiseUnaryOutOfUnpack::matchAndRewrite(
|
||||
UnpackOp op, PatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// First unpack user must be coeff-wise unary operation.
|
||||
Operation *first_user = *op->getUsers().begin();
|
||||
if (!first_user->hasTrait<OpTrait::TF::CwiseUnary>()) return failure();
|
||||
|
||||
// All unpack users must be defined by the op of same kind.
|
||||
bool users_same_op = llvm::all_of(op->getUsers(), [&](Operation *user) {
|
||||
return user->getName() == first_user->getName();
|
||||
});
|
||||
if (!users_same_op) return failure();
|
||||
|
||||
// Pass unpack operand to unary operation.
|
||||
OperationState new_unary_op_state(loc, first_user->getName().getStringRef(),
|
||||
op.getOperand(), op.getOperand().getType(),
|
||||
ArrayRef<NamedAttribute>());
|
||||
Operation *new_unary_op = rewriter.createOperation(new_unary_op_state);
|
||||
|
||||
// Unpack results after applying unary operation.
|
||||
auto unpack_unary_op = rewriter.create<UnpackOp>(
|
||||
loc, op.getResultTypes(), new_unary_op->getResult(0), op.axis());
|
||||
|
||||
// Bypass all users of the original unpack operation and use `unpack_unary_op`
|
||||
// results instead.
|
||||
for (auto pair : llvm::zip(op.getResults(), unpack_unary_op.getResults())) {
|
||||
OpResult old_result = std::get<0>(pair); // result of original Unpack
|
||||
OpResult new_result = std::get<1>(pair); // result of transformed Unpack
|
||||
for (Operation *user : llvm::make_early_inc_range(old_result.getUsers()))
|
||||
rewriter.replaceOp(user, ValueRange(new_result));
|
||||
}
|
||||
|
||||
// Erase original unpack operation.
|
||||
rewriter.eraseOp(op.getOperation());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void UnpackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<HoistCwiseUnaryOutOfUnpack>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Unsorted segment reduction ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -1341,3 +1341,16 @@ func @testVariableToVariableV2() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testUnpackAndCwiseUnary
|
||||
func @testUnpackAndCwiseUnary(%arg0: tensor<?x2xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
|
||||
|
||||
// CHECK: %[[NEG:.*]] = "tf.Neg"(%arg0)
|
||||
// CHECK: %[[UNPACK:.*]]:2 = "tf.Unpack"(%[[NEG]])
|
||||
%unpacked:2 = "tf.Unpack"(%arg0) {axis = 1 : i64, device = ""}
|
||||
: (tensor<?x2xf32>) -> (tensor<?xf32>, tensor<?xf32>)
|
||||
%0 = "tf.Neg"(%unpacked#0): (tensor<?xf32>) -> tensor<?xf32>
|
||||
%1 = "tf.Neg"(%unpacked#1): (tensor<?xf32>) -> tensor<?xf32>
|
||||
|
||||
// CHECK: return %[[UNPACK]]#0, %[[UNPACK]]#1
|
||||
return %0, %1 : tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user