From 37d4d0484cbb516875e97edfd482d3934aee9d45 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Wed, 19 Feb 2020 17:15:59 -0800 Subject: [PATCH] Fuse hardwish for mobilenet v3 The mobilenet v3 frozen graph has extra FakeQuant ops which blocks the fusion, thus we create a special pattern to remove the redundant FakeQuant ops. PiperOrigin-RevId: 296093529 Change-Id: Ic5bc6808afb12b2004ed7b6f3a81f914df917d5e --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 2 +- .../mlir/lite/transforms/optimize_patterns.td | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 3bb2b67be35..a04e1d44ea6 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1349,7 +1349,7 @@ def TFL_GreaterOp : TFL_Op<"greater", [ } def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect, - SameOperandsAndResultType]> { + SameOperandsAndResultShape]> { let summary = "Hardswish activation function."; let description = [{ Computes hard-swish activation function diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index bdf73ff3787..71017fe2801 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -199,6 +199,22 @@ def : Pat< (TFL_HardSwishOp $x), [(EqualOperands $x, $y)]>; +// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to +// incorrect placement in the quantization aware training. +// TODO(b/149735743): We should make the placement automatically. +def : Pat< + (TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp + (TFL_MulOp + $x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp + $y, + (ConstantOp ConstantAttr, "3.0f">), + TFL_AF_Relu6), $qattr2)), + TFL_AF_None), $qattr1)), + (ConstantOp ConstantAttr, "0.166666666f">), + TFL_AF_None), + (TFL_HardSwishOp $x), + [(EqualOperands $x, $y)]>; + // Constraint that the attribute value is less than 'n' class ConstDoubleValueLessThan : Constraint< CPred<"$0.isa() && "