[MLIR:TF] Fuse LeakyRelu into contraction
Pass fused operation attributes to the fused operation, and read them in contraction output kernel builder. PiperOrigin-RevId: 331662512 Change-Id: I65c9459983ab28121a8066d2a7922743edec0cf4
This commit is contained in:
parent
4ea067f6f5
commit
f08c6b6bc1
@ -4892,7 +4892,7 @@ def TF_LRNGradOp : TF_Op<"LRNGrad", [NoSideEffect]> {
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface]> {
|
||||||
let summary = "Computes rectified linear: `max(features, features * alpha)`.";
|
let summary = "Computes rectified linear: `max(features, features * alpha)`.";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@ -4908,6 +4908,11 @@ def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// TF_ContractionFusableInterface:
|
||||||
|
Optional<ContractionFusion> GetContractionFusion();
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_LeakyReluGradOp : TF_Op<"LeakyReluGrad", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_LeakyReluGradOp : TF_Op<"LeakyReluGrad", [NoSideEffect, SameOperandsAndResultType]> {
|
||||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
|
||||||
@ -27,10 +29,14 @@ namespace TF {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
struct ContractionFusion {
|
struct ContractionFusion {
|
||||||
ContractionFusion(StringRef output_kernel, ArrayRef<int> additional_arguments)
|
explicit ContractionFusion(
|
||||||
|
StringRef output_kernel, ArrayRef<int> additional_arguments = {},
|
||||||
|
ArrayRef<NamedAttribute> additional_attributes = {})
|
||||||
: output_kernel(output_kernel.str()),
|
: output_kernel(output_kernel.str()),
|
||||||
additional_arguments(additional_arguments.begin(),
|
additional_arguments(additional_arguments.begin(),
|
||||||
additional_arguments.end()) {}
|
additional_arguments.end()),
|
||||||
|
additional_attributes(additional_attributes.begin(),
|
||||||
|
additional_attributes.end()) {}
|
||||||
|
|
||||||
// Name of the output kernel implementing the contraction fusion.
|
// Name of the output kernel implementing the contraction fusion.
|
||||||
std::string output_kernel;
|
std::string output_kernel;
|
||||||
@ -38,6 +44,9 @@ struct ContractionFusion {
|
|||||||
// Indices of additional arguments that will be forwarded to the fused
|
// Indices of additional arguments that will be forwarded to the fused
|
||||||
// operation (e.g. forward bias vector if fusing BiasAdd operation).
|
// operation (e.g. forward bias vector if fusing BiasAdd operation).
|
||||||
SmallVector<int, 4> additional_arguments;
|
SmallVector<int, 4> additional_arguments;
|
||||||
|
|
||||||
|
// Add additional attributes to the fused node.
|
||||||
|
SmallVector<NamedAttribute, 4> additional_attributes;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc"
|
||||||
|
@ -2164,6 +2164,15 @@ OpFoldResult LeakyReluOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Optional<ContractionFusion> LeakyReluOp::GetContractionFusion() {
|
||||||
|
// Only f32 is supported for fusion.
|
||||||
|
if (!T().isF32()) return None;
|
||||||
|
|
||||||
|
NamedAttribute alpha(Identifier::get("alpha", getContext()), alphaAttr());
|
||||||
|
return ContractionFusion("LeakyRelu", /*additional_arguments=*/{},
|
||||||
|
/*additional_attributes=*/{alpha});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LogOp
|
// LogOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -22,3 +22,16 @@ func @matmulBiasAddRelu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: t
|
|||||||
// CHECK: return %[[FUSED]]
|
// CHECK: return %[[FUSED]]
|
||||||
return %5 : tensor<8x64xf32>
|
return %5 : tensor<8x64xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: matmulBiasAddLeakyRelu
|
||||||
|
func @matmulBiasAddLeakyRelu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> tensor<8x64xf32> {
|
||||||
|
// CHECK: %[[FUSED:.*]] = "tf._JitFusedMatMul"(%arg1, %arg2, %arg0)
|
||||||
|
// CHECK-SAME: alpha = 2.000000e-01 : f32
|
||||||
|
// CHECK-SAME: fusion = ["BiasAdd", "LeakyRelu"]
|
||||||
|
// CHECK-SAME: transpose_a = false, transpose_b = false
|
||||||
|
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<8x64xf32>
|
||||||
|
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<8x64xf32>, tensor<64xf32>) -> tensor<8x64xf32>
|
||||||
|
%5 = "tf.LeakyRelu"(%4) { alpha = 0.2 : f32 } : (tensor<8x64xf32>) -> tensor<8x64xf32>
|
||||||
|
// CHECK: return %[[FUSED]]
|
||||||
|
return %5 : tensor<8x64xf32>
|
||||||
|
}
|
||||||
|
@ -82,16 +82,22 @@ class FuseIntoContractionOp : public RewritePattern {
|
|||||||
// Fusion can't change the type of a fused operation.
|
// Fusion can't change the type of a fused operation.
|
||||||
Type result_ty = fuse_into->getResult(0).getType();
|
Type result_ty = fuse_into->getResult(0).getType();
|
||||||
|
|
||||||
// Copy all operands from a matmul and add additional fusion arguments.
|
// Copy all operands from a base op and add additional fusion arguments.
|
||||||
SmallVector<Value, 3> operands(fuse_into->getOperands());
|
SmallVector<Value, 3> operands(fuse_into->getOperands());
|
||||||
for (int idx : fusion->additional_arguments) {
|
for (int idx : fusion->additional_arguments) {
|
||||||
operands.push_back(op->getOperand(idx));
|
operands.push_back(op->getOperand(idx));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy attributes from a MatMul operation.
|
// Copy attributes from a base op that we fuse into (e.g. copy all
|
||||||
|
// MatMul or Conv attributes to the fused operation).
|
||||||
SmallVector<NamedAttribute, 4> attrs(fuse_into->getAttrs().begin(),
|
SmallVector<NamedAttribute, 4> attrs(fuse_into->getAttrs().begin(),
|
||||||
fuse_into->getAttrs().end());
|
fuse_into->getAttrs().end());
|
||||||
|
|
||||||
|
// Add fusion specific additional attributes.
|
||||||
|
for (auto attr : fusion->additional_attributes) {
|
||||||
|
attrs.push_back(attr);
|
||||||
|
}
|
||||||
|
|
||||||
// Add a fused output kernel name to the list of fusions.
|
// Add a fused output kernel name to the list of fusions.
|
||||||
Identifier fusion_id = Identifier::get("fusion", ctx);
|
Identifier fusion_id = Identifier::get("fusion", ctx);
|
||||||
StringAttr fusion_name = StringAttr::get(fusion->output_kernel, ctx);
|
StringAttr fusion_name = StringAttr::get(fusion->output_kernel, ctx);
|
||||||
|
Loading…
Reference in New Issue
Block a user