Fold the MulOp with the proceeding FullyConnectedOp
This CL folded the MulOp into the filter and bias of the FullyConnectedOp for this case. If the bias is a NoneType, it only folded the MulOp to the filter. PiperOrigin-RevId: 262056643
This commit is contained in:
parent
2fb51e98b9
commit
7d6b60f1ed
@ -252,6 +252,7 @@ cc_library(
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -96,6 +96,38 @@ func @intermOpUsedTwice(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf
|
||||
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseMulIntoFullyConnected
|
||||
func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
|
||||
%cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
|
||||
%cst1 = constant dense<2.0> : tensor<2xf32>
|
||||
%cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32>
|
||||
|
||||
%0 = "tfl.fully_connected"(%arg0, %cst0, %cst1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
|
||||
|
||||
return %1 : tensor<4x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %cst_0 = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
|
||||
// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
|
||||
// CHECK: return %0 : tensor<4x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseMulIntoFullyConnectedNoBias
|
||||
func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<4x2xf32> {
|
||||
%cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
|
||||
%cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32>
|
||||
|
||||
%0 = "tfl.fully_connected"(%arg0, %cst0, %arg1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
|
||||
|
||||
return %1 : tensor<4x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
|
||||
// CHECK: return %0 : tensor<4x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseMulIntoDepthwiseConv2d
|
||||
func @fuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
|
||||
%cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32>
|
||||
|
@ -29,7 +29,6 @@ class ExtractI32At<int i> : NativeCodeCall<
|
||||
"$_builder.getI32IntegerAttr($_self.cast<ArrayAttr>().getValue()[" # i #
|
||||
"].cast<IntegerAttr>().getInt())">;
|
||||
|
||||
|
||||
// Merge the two Attributes to a ArrayAttr;
|
||||
def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
|
||||
|
||||
|
@ -21,12 +21,17 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
@ -45,14 +50,20 @@ struct Optimize : public FunctionPass<Optimize> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Returns whether the given type `a` is broadcast-compatible with `b`.
|
||||
bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
|
||||
return OpTrait::util::getBroadcastedType(a, b) != Type();
|
||||
}
|
||||
|
||||
// Returns whether the given `a` and `b` ElementsAttr have broadcast-compatible
|
||||
// types.
|
||||
bool IsBroadcastableElementsAttrs(Attribute a, Attribute b) {
|
||||
return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type();
|
||||
return IsBroadcastableElementsAttrAndType(a.getType(), b.getType());
|
||||
}
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
|
||||
// Fuse Add with FullyConnected.
|
||||
|
||||
// Fuse Add with proceeding FullyConnected.
|
||||
// Note that this assumes that the bias in the fullyConnected
|
||||
// is always None.
|
||||
// TODO(b/136285429): Move to tablegen when variadic is supported
|
||||
@ -153,6 +164,76 @@ struct FuseFullyConnectedAndRelu : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
// Fuse Mul with proceeding FullyConnected.
|
||||
// TODO(b/136285429): Move to tablegen when variadic is supported
|
||||
struct FuseFullyConnectedAndMul : public RewritePattern {
|
||||
explicit FuseFullyConnectedAndMul(MLIRContext *context)
|
||||
: RewritePattern(TFL::MulOp::getOperationName(),
|
||||
{"tfl.fully_connected", "tfl.mul", "std.constant"}, 4,
|
||||
context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Mul.
|
||||
auto mul_op = cast<MulOp>(op);
|
||||
DenseElementsAttr cst;
|
||||
Value *constant_val = mul_op.rhs();
|
||||
if (!matchPattern(constant_val, m_Constant(&cst))) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// Fully Connected.
|
||||
auto fc_op =
|
||||
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs()->getDefiningOp());
|
||||
if (!fc_op) return matchFailure();
|
||||
Value *filter = fc_op.filter();
|
||||
Value *bias = fc_op.bias();
|
||||
if (fc_op.fused_activation_function().equals("None")) return matchFailure();
|
||||
|
||||
// Broadcast the constant operand of Mul if it isn't compatible to the
|
||||
// filter input. We only support broadcasting the operand along the depth
|
||||
// dimension, when the operand's depth is 1.
|
||||
Value *new_const_val = constant_val;
|
||||
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter->getType())) {
|
||||
auto original_shape = cst.getType().getShape();
|
||||
llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
|
||||
original_shape.end());
|
||||
normalized_shape.push_back(1);
|
||||
auto new_cst = cst.reshape(rewriter.getTensorType(
|
||||
normalized_shape, cst.getType().getElementType()));
|
||||
Type new_type = new_cst.getType();
|
||||
if (!IsBroadcastableElementsAttrAndType(new_type, filter->getType())) {
|
||||
return matchFailure();
|
||||
}
|
||||
auto new_op =
|
||||
rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
|
||||
new_const_val = new_op.getResult();
|
||||
}
|
||||
|
||||
// Rewrite.
|
||||
Location loc = fc_op.getLoc();
|
||||
auto af_none = rewriter.getStringAttr(fc_op.fused_activation_function());
|
||||
auto new_filter =
|
||||
rewriter.create<MulOp>(loc, filter, new_const_val, af_none);
|
||||
// If bias isn't None, it needs to be multiplied as well.
|
||||
if (!bias->getType().isa<NoneType>()) {
|
||||
bias = rewriter.create<MulOp>(loc, bias, constant_val, af_none).output();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<TFL::FullyConnectedOp>(
|
||||
mul_op, mul_op.getType(),
|
||||
/*input=*/fc_op.input(),
|
||||
/*filter=*/new_filter.output(),
|
||||
/*bias=*/bias,
|
||||
/*fused_activation_function=*/
|
||||
rewriter.getStringAttr(mul_op.fused_activation_function()),
|
||||
/*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
|
||||
/*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// StridedSlice can have complicated atributes like begin_axis_mask,
|
||||
// end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
|
||||
// masks will complicate the strided_slice computation logic, we can simplify
|
||||
@ -238,11 +319,11 @@ void Optimize::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto *ctx = &getContext();
|
||||
auto func = getFunction();
|
||||
|
||||
// Add the generated patterns to the list.
|
||||
TFL::populateWithGenerated(ctx, &patterns);
|
||||
patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
|
||||
PadStridedSliceDims>(ctx);
|
||||
|
||||
FuseFullyConnectedAndMul, PadStridedSliceDims>(ctx);
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user