[MLIR:TF] Fuse LeakyRelu into contraction
Pass fused operation attributes to the fused operation, and read them in contraction output kernel builder. PiperOrigin-RevId: 331662512 Change-Id: I65c9459983ab28121a8066d2a7922743edec0cf4
This commit is contained in:
parent
4ea067f6f5
commit
f08c6b6bc1
@ -4892,7 +4892,7 @@ def TF_LRNGradOp : TF_Op<"LRNGrad", [NoSideEffect]> {
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface]> {
|
||||
let summary = "Computes rectified linear: `max(features, features * alpha)`.";
|
||||
|
||||
let arguments = (ins
|
||||
@ -4908,6 +4908,11 @@ def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_ContractionFusableInterface:
|
||||
Optional<ContractionFusion> GetContractionFusion();
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_LeakyReluGradOp : TF_Op<"LeakyReluGrad", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
|
||||
@ -27,10 +29,14 @@ namespace TF {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct ContractionFusion {
|
||||
ContractionFusion(StringRef output_kernel, ArrayRef<int> additional_arguments)
|
||||
explicit ContractionFusion(
|
||||
StringRef output_kernel, ArrayRef<int> additional_arguments = {},
|
||||
ArrayRef<NamedAttribute> additional_attributes = {})
|
||||
: output_kernel(output_kernel.str()),
|
||||
additional_arguments(additional_arguments.begin(),
|
||||
additional_arguments.end()) {}
|
||||
additional_arguments.end()),
|
||||
additional_attributes(additional_attributes.begin(),
|
||||
additional_attributes.end()) {}
|
||||
|
||||
// Name of the output kernel implementing the contraction fusion.
|
||||
std::string output_kernel;
|
||||
@ -38,6 +44,9 @@ struct ContractionFusion {
|
||||
// Indices of additional arguments that will be forwarded to the fused
|
||||
// operation (e.g. forward bias vector if fusing BiasAdd operation).
|
||||
SmallVector<int, 4> additional_arguments;
|
||||
|
||||
// Add additional attributes to the fused node.
|
||||
SmallVector<NamedAttribute, 4> additional_attributes;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc"
|
||||
|
@ -2164,6 +2164,15 @@ OpFoldResult LeakyReluOp::fold(ArrayRef<Attribute> operands) {
|
||||
return {};
|
||||
}
|
||||
|
||||
Optional<ContractionFusion> LeakyReluOp::GetContractionFusion() {
|
||||
// Only f32 is supported for fusion.
|
||||
if (!T().isF32()) return None;
|
||||
|
||||
NamedAttribute alpha(Identifier::get("alpha", getContext()), alphaAttr());
|
||||
return ContractionFusion("LeakyRelu", /*additional_arguments=*/{},
|
||||
/*additional_attributes=*/{alpha});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LogOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -22,3 +22,16 @@ func @matmulBiasAddRelu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: t
|
||||
// CHECK: return %[[FUSED]]
|
||||
return %5 : tensor<8x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: matmulBiasAddLeakyRelu
|
||||
func @matmulBiasAddLeakyRelu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> tensor<8x64xf32> {
|
||||
// CHECK: %[[FUSED:.*]] = "tf._JitFusedMatMul"(%arg1, %arg2, %arg0)
|
||||
// CHECK-SAME: alpha = 2.000000e-01 : f32
|
||||
// CHECK-SAME: fusion = ["BiasAdd", "LeakyRelu"]
|
||||
// CHECK-SAME: transpose_a = false, transpose_b = false
|
||||
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<8x64xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<8x64xf32>, tensor<64xf32>) -> tensor<8x64xf32>
|
||||
%5 = "tf.LeakyRelu"(%4) { alpha = 0.2 : f32 } : (tensor<8x64xf32>) -> tensor<8x64xf32>
|
||||
// CHECK: return %[[FUSED]]
|
||||
return %5 : tensor<8x64xf32>
|
||||
}
|
||||
|
@ -82,16 +82,22 @@ class FuseIntoContractionOp : public RewritePattern {
|
||||
// Fusion can't change the type of a fused operation.
|
||||
Type result_ty = fuse_into->getResult(0).getType();
|
||||
|
||||
// Copy all operands from a matmul and add additional fusion arguments.
|
||||
// Copy all operands from a base op and add additional fusion arguments.
|
||||
SmallVector<Value, 3> operands(fuse_into->getOperands());
|
||||
for (int idx : fusion->additional_arguments) {
|
||||
operands.push_back(op->getOperand(idx));
|
||||
}
|
||||
|
||||
// Copy attributes from a MatMul operation.
|
||||
// Copy attributes from a base op that we fuse into (e.g. copy all
|
||||
// MatMul or Conv attributes to the fused operation).
|
||||
SmallVector<NamedAttribute, 4> attrs(fuse_into->getAttrs().begin(),
|
||||
fuse_into->getAttrs().end());
|
||||
|
||||
// Add fusion specific additional attributes.
|
||||
for (auto attr : fusion->additional_attributes) {
|
||||
attrs.push_back(attr);
|
||||
}
|
||||
|
||||
// Add a fused output kernel name to the list of fusions.
|
||||
Identifier fusion_id = Identifier::get("fusion", ctx);
|
||||
StringAttr fusion_name = StringAttr::get(fusion->output_kernel, ctx);
|
||||
|
Loading…
Reference in New Issue
Block a user