Fold the MulOp with the proceeding FullyConnectedOp

This CL folded the MulOp into the filter and bias of the FullyConnectedOp for
this case. If the bias is a NoneType, it only folded the MulOp to the filter.

PiperOrigin-RevId: 262056643
This commit is contained in:
Feng Liu 2019-08-06 20:50:51 -07:00 committed by TensorFlower Gardener
parent 2fb51e98b9
commit 7d6b60f1ed
4 changed files with 118 additions and 5 deletions

View File

@ -252,6 +252,7 @@ cc_library(
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
],
alwayslink = 1,

View File

@ -96,6 +96,38 @@ func @intermOpUsedTwice(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf
}
// CHECK-LABEL: @fuseMulIntoFullyConnected
func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
%cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
%cst1 = constant dense<2.0> : tensor<2xf32>
%cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32>
%0 = "tfl.fully_connected"(%arg0, %cst0, %cst1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
return %1 : tensor<4x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %cst_0 = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
// CHECK: return %0 : tensor<4x2xf32>
}
// CHECK-LABEL: @fuseMulIntoFullyConnectedNoBias
func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<4x2xf32> {
%cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
%cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32>
%0 = "tfl.fully_connected"(%arg0, %cst0, %arg1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
return %1 : tensor<4x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
// CHECK: return %0 : tensor<4x2xf32>
}
// CHECK-LABEL: @fuseMulIntoDepthwiseConv2d
func @fuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
%cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32>

View File

@ -29,7 +29,6 @@ class ExtractI32At<int i> : NativeCodeCall<
"$_builder.getI32IntegerAttr($_self.cast<ArrayAttr>().getValue()[" # i #
"].cast<IntegerAttr>().getInt())">;
// Merge the two Attributes to a ArrayAttr;
def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;

View File

@ -21,12 +21,17 @@ limitations under the License.
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
@ -45,14 +50,20 @@ struct Optimize : public FunctionPass<Optimize> {
void runOnFunction() override;
};
// Returns whether the given type `a` is broadcast-compatible with `b`.
bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
return OpTrait::util::getBroadcastedType(a, b) != Type();
}
// Returns whether the given `a` and `b` ElementsAttr have broadcast-compatible
// types.
bool IsBroadcastableElementsAttrs(Attribute a, Attribute b) {
return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type();
return IsBroadcastableElementsAttrAndType(a.getType(), b.getType());
}
#include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
// Fuse Add with FullyConnected.
// Fuse Add with proceeding FullyConnected.
// Note that this assumes that the bias in the fullyConnected
// is always None.
// TODO(b/136285429): Move to tablegen when variadic is supported
@ -153,6 +164,76 @@ struct FuseFullyConnectedAndRelu : public RewritePattern {
}
};
// Fuse Mul with proceeding FullyConnected.
// TODO(b/136285429): Move to tablegen when variadic is supported
struct FuseFullyConnectedAndMul : public RewritePattern {
explicit FuseFullyConnectedAndMul(MLIRContext *context)
: RewritePattern(TFL::MulOp::getOperationName(),
{"tfl.fully_connected", "tfl.mul", "std.constant"}, 4,
context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// Mul.
auto mul_op = cast<MulOp>(op);
DenseElementsAttr cst;
Value *constant_val = mul_op.rhs();
if (!matchPattern(constant_val, m_Constant(&cst))) {
return matchFailure();
}
// Fully Connected.
auto fc_op =
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs()->getDefiningOp());
if (!fc_op) return matchFailure();
Value *filter = fc_op.filter();
Value *bias = fc_op.bias();
if (fc_op.fused_activation_function().equals("None")) return matchFailure();
// Broadcast the constant operand of Mul if it isn't compatible to the
// filter input. We only support broadcasting the operand along the depth
// dimension, when the operand's depth is 1.
Value *new_const_val = constant_val;
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter->getType())) {
auto original_shape = cst.getType().getShape();
llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
original_shape.end());
normalized_shape.push_back(1);
auto new_cst = cst.reshape(rewriter.getTensorType(
normalized_shape, cst.getType().getElementType()));
Type new_type = new_cst.getType();
if (!IsBroadcastableElementsAttrAndType(new_type, filter->getType())) {
return matchFailure();
}
auto new_op =
rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
new_const_val = new_op.getResult();
}
// Rewrite.
Location loc = fc_op.getLoc();
auto af_none = rewriter.getStringAttr(fc_op.fused_activation_function());
auto new_filter =
rewriter.create<MulOp>(loc, filter, new_const_val, af_none);
// If bias isn't None, it needs to be multiplied as well.
if (!bias->getType().isa<NoneType>()) {
bias = rewriter.create<MulOp>(loc, bias, constant_val, af_none).output();
}
rewriter.replaceOpWithNewOp<TFL::FullyConnectedOp>(
mul_op, mul_op.getType(),
/*input=*/fc_op.input(),
/*filter=*/new_filter.output(),
/*bias=*/bias,
/*fused_activation_function=*/
rewriter.getStringAttr(mul_op.fused_activation_function()),
/*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
/*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
return matchSuccess();
}
};
// StridedSlice can have complicated atributes like begin_axis_mask,
// end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
// masks will complicate the strided_slice computation logic, we can simplify
@ -238,11 +319,11 @@ void Optimize::runOnFunction() {
OwningRewritePatternList patterns;
auto *ctx = &getContext();
auto func = getFunction();
// Add the generated patterns to the list.
TFL::populateWithGenerated(ctx, &patterns);
patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
PadStridedSliceDims>(ctx);
FuseFullyConnectedAndMul, PadStridedSliceDims>(ctx);
applyPatternsGreedily(func, std::move(patterns));
}