diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 410f432b917..d1c0dd20c05 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -475,6 +475,28 @@ func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32 // CHECK: return %[[RES]] } +// CHECK-LABEL: @FuseFullyConnectedRelu6 +func @FuseFullyConnectedRelu6(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> { + %0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32> + %1 = "tfl.relu6"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32> + return %1 : tensor<1x128xf32> + + // CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected" + // CHECK-SAME: fused_activation_function = "RELU6" + // CHECK: return %[[RES]] +} + +// CHECK-LABEL: @FuseFullyConnectedRelu1 +func @FuseFullyConnectedRelu1(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> { + %0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32> + %1 = "tfl.relu_n1_to_1"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32> + return %1 : tensor<1x128xf32> + + // CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected" + // CHECK-SAME: fused_activation_function = "RELU_N1_TO_1" + // CHECK: return %[[RES]] +} + // CHECK-LABEL: @HardSwishPattern func @HardSwishPattern(%arg0: tensor<1xf32>) -> tensor<1xf32> { %three = constant dense<3.> : tensor diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 6cbe3c49fb0..f013e73b75b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -51,6 +51,9 @@ namespace TFL { //===----------------------------------------------------------------------===// // The actual Optimize Pass. namespace { +const char kRelu[] = "RELU"; +const char kRelu6[] = "RELU6"; +const char kRelu1[] = "RELU_N1_TO_1"; bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) { if (sq_op.getType().cast().getRank() - 1 == @@ -300,10 +303,11 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { }; // TODO(b/136285429): Move to tablegen when variadic is supported. -struct FuseFullyConnectedAndRelu : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +template +struct FuseFullyConnectedAndReluX : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TFL::ReluOp relu_op, + LogicalResult matchAndRewrite(ReluXOp relu_op, PatternRewriter &rewriter) const override { Operation *input = relu_op.getOperand().getDefiningOp(); if (!isa_and_nonnull(input)) return failure(); @@ -311,7 +315,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern { if (fully_connected_op.fused_activation_function() != "NONE") return failure(); - auto new_activation_func = rewriter.getStringAttr("RELU"); + auto new_activation_func = rewriter.getStringAttr(Act); auto new_weights_format = rewriter.getStringAttr(fully_connected_op.weights_format()); auto new_keep_num_dims = @@ -708,7 +712,10 @@ void Optimize::runOnFunction() { // we explore these potentially first and then fuse the binary ops with the // following ops in a second pattern match. TFL::populateWithGenerated(ctx, &patterns); - patterns.insert, + FuseFullyConnectedAndReluX, + FuseFullyConnectedAndReluX, FuseFullyConnectedAndMul>(ctx); applyPatternsAndFoldGreedily(func, patterns);