Fuse relu6 & relu1 to fc.

PiperOrigin-RevId: 306566180
Change-Id: I23c2e7c4301c0478ad16c26d46a7ad0f0cecef70
This commit is contained in:
Renjie Liu 2020-04-14 20:03:51 -07:00 committed by TensorFlower Gardener
parent a9f8a9b1c1
commit 6964061dc0
2 changed files with 34 additions and 5 deletions

View File

@ -475,6 +475,28 @@ func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32
// CHECK: return %[[RES]] // CHECK: return %[[RES]]
} }
// CHECK-LABEL: @FuseFullyConnectedRelu6
func @FuseFullyConnectedRelu6(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
%1 = "tfl.relu6"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32>
return %1 : tensor<1x128xf32>
// CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected"
// CHECK-SAME: fused_activation_function = "RELU6"
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @FuseFullyConnectedRelu1
func @FuseFullyConnectedRelu1(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
%1 = "tfl.relu_n1_to_1"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32>
return %1 : tensor<1x128xf32>
// CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected"
// CHECK-SAME: fused_activation_function = "RELU_N1_TO_1"
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @HardSwishPattern // CHECK-LABEL: @HardSwishPattern
func @HardSwishPattern(%arg0: tensor<1xf32>) -> tensor<1xf32> { func @HardSwishPattern(%arg0: tensor<1xf32>) -> tensor<1xf32> {
%three = constant dense<3.> : tensor<f32> %three = constant dense<3.> : tensor<f32>

View File

@ -51,6 +51,9 @@ namespace TFL {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// The actual Optimize Pass. // The actual Optimize Pass.
namespace { namespace {
const char kRelu[] = "RELU";
const char kRelu6[] = "RELU6";
const char kRelu1[] = "RELU_N1_TO_1";
bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) { bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
if (sq_op.getType().cast<ShapedType>().getRank() - 1 == if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
@ -300,10 +303,11 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
}; };
// TODO(b/136285429): Move to tablegen when variadic is supported. // TODO(b/136285429): Move to tablegen when variadic is supported.
struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> { template <typename ReluXOp, char const *Act>
using OpRewritePattern<TFL::ReluOp>::OpRewritePattern; struct FuseFullyConnectedAndReluX : public OpRewritePattern<ReluXOp> {
using OpRewritePattern<ReluXOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TFL::ReluOp relu_op, LogicalResult matchAndRewrite(ReluXOp relu_op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Operation *input = relu_op.getOperand().getDefiningOp(); Operation *input = relu_op.getOperand().getDefiningOp();
if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure(); if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure();
@ -311,7 +315,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
if (fully_connected_op.fused_activation_function() != "NONE") if (fully_connected_op.fused_activation_function() != "NONE")
return failure(); return failure();
auto new_activation_func = rewriter.getStringAttr("RELU"); auto new_activation_func = rewriter.getStringAttr(Act);
auto new_weights_format = auto new_weights_format =
rewriter.getStringAttr(fully_connected_op.weights_format()); rewriter.getStringAttr(fully_connected_op.weights_format());
auto new_keep_num_dims = auto new_keep_num_dims =
@ -708,7 +712,10 @@ void Optimize::runOnFunction() {
// we explore these potentially first and then fuse the binary ops with the // we explore these potentially first and then fuse the binary ops with the
// following ops in a second pattern match. // following ops in a second pattern match.
TFL::populateWithGenerated(ctx, &patterns); TFL::populateWithGenerated(ctx, &patterns);
patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu, patterns.insert<FuseFullyConnectedAndAdd,
FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
FuseFullyConnectedAndMul>(ctx); FuseFullyConnectedAndMul>(ctx);
applyPatternsAndFoldGreedily(func, patterns); applyPatternsAndFoldGreedily(func, patterns);