[MLIR:TF] Fuse LeakyRelu into contraction

Pass fused operation attributes to the fused operation, and read them in contraction output kernel builder.

PiperOrigin-RevId: 331662512
Change-Id: I65c9459983ab28121a8066d2a7922743edec0cf4
This commit is contained in:
Eugene Zhulenev 2020-09-14 17:13:53 -07:00 committed by TensorFlower Gardener
parent 4ea067f6f5
commit f08c6b6bc1
5 changed files with 47 additions and 5 deletions

View File

@ -4892,7 +4892,7 @@ def TF_LRNGradOp : TF_Op<"LRNGrad", [NoSideEffect]> {
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType]> {
def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface]> {
let summary = "Computes rectified linear: `max(features, features * alpha)`.";
let arguments = (ins
@ -4908,6 +4908,11 @@ def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasFolder = 1;
let extraClassDeclaration = [{
// TF_ContractionFusableInterface:
Optional<ContractionFusion> GetContractionFusion();
}];
}
def TF_LeakyReluGradOp : TF_Op<"LeakyReluGrad", [NoSideEffect, SameOperandsAndResultType]> {

View File

@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
@ -27,10 +29,14 @@ namespace TF {
//===----------------------------------------------------------------------===//
struct ContractionFusion {
ContractionFusion(StringRef output_kernel, ArrayRef<int> additional_arguments)
explicit ContractionFusion(
StringRef output_kernel, ArrayRef<int> additional_arguments = {},
ArrayRef<NamedAttribute> additional_attributes = {})
: output_kernel(output_kernel.str()),
additional_arguments(additional_arguments.begin(),
additional_arguments.end()) {}
additional_arguments.end()),
additional_attributes(additional_attributes.begin(),
additional_attributes.end()) {}
// Name of the output kernel implementing the contraction fusion.
std::string output_kernel;
@ -38,6 +44,9 @@ struct ContractionFusion {
// Indices of additional arguments that will be forwarded to the fused
// operation (e.g. forward bias vector if fusing BiasAdd operation).
SmallVector<int, 4> additional_arguments;
// Add additional attributes to the fused node.
SmallVector<NamedAttribute, 4> additional_attributes;
};
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc"

View File

@ -2164,6 +2164,15 @@ OpFoldResult LeakyReluOp::fold(ArrayRef<Attribute> operands) {
return {};
}
Optional<ContractionFusion> LeakyReluOp::GetContractionFusion() {
// Only f32 is supported for fusion.
if (!T().isF32()) return None;
NamedAttribute alpha(Identifier::get("alpha", getContext()), alphaAttr());
return ContractionFusion("LeakyRelu", /*additional_arguments=*/{},
/*additional_attributes=*/{alpha});
}
//===----------------------------------------------------------------------===//
// LogOp
//===----------------------------------------------------------------------===//

View File

@ -22,3 +22,16 @@ func @matmulBiasAddRelu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: t
// CHECK: return %[[FUSED]]
return %5 : tensor<8x64xf32>
}
// CHECK-LABEL: matmulBiasAddLeakyRelu
func @matmulBiasAddLeakyRelu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> tensor<8x64xf32> {
// CHECK: %[[FUSED:.*]] = "tf._JitFusedMatMul"(%arg1, %arg2, %arg0)
// CHECK-SAME: alpha = 2.000000e-01 : f32
// CHECK-SAME: fusion = ["BiasAdd", "LeakyRelu"]
// CHECK-SAME: transpose_a = false, transpose_b = false
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<8x64xf32>
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<8x64xf32>, tensor<64xf32>) -> tensor<8x64xf32>
%5 = "tf.LeakyRelu"(%4) { alpha = 0.2 : f32 } : (tensor<8x64xf32>) -> tensor<8x64xf32>
// CHECK: return %[[FUSED]]
return %5 : tensor<8x64xf32>
}

View File

@ -82,16 +82,22 @@ class FuseIntoContractionOp : public RewritePattern {
// Fusion can't change the type of a fused operation.
Type result_ty = fuse_into->getResult(0).getType();
// Copy all operands from a matmul and add additional fusion arguments.
// Copy all operands from a base op and add additional fusion arguments.
SmallVector<Value, 3> operands(fuse_into->getOperands());
for (int idx : fusion->additional_arguments) {
operands.push_back(op->getOperand(idx));
}
// Copy attributes from a MatMul operation.
// Copy attributes from a base op that we fuse into (e.g. copy all
// MatMul or Conv attributes to the fused operation).
SmallVector<NamedAttribute, 4> attrs(fuse_into->getAttrs().begin(),
fuse_into->getAttrs().end());
// Add fusion specific additional attributes.
for (auto attr : fusion->additional_attributes) {
attrs.push_back(attr);
}
// Add a fused output kernel name to the list of fusions.
Identifier fusion_id = Identifier::get("fusion", ctx);
StringAttr fusion_name = StringAttr::get(fusion->output_kernel, ctx);