Support tfl.maximum->tfl.relu with bf16 data types and/or splat constants
PiperOrigin-RevId: 337259680 Change-Id: I03380f020a347cfe88093e7f64400542181d0c70
This commit is contained in:
parent
19200d4ec5
commit
6b3a88b6e9
@ -991,6 +991,16 @@ func @Relu(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
// CHECK: return %[[RESULT]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Relu_bf16
|
||||
func @Relu_bf16(%arg0: tensor<2x3xbf16>) -> tensor<2x3xbf16> {
|
||||
%cst = constant dense<0.0> : tensor<2x3xbf16>
|
||||
%0 = "tfl.maximum"(%arg0, %cst) : (tensor<2x3xbf16>, tensor<2x3xbf16>) -> tensor<2x3xbf16>
|
||||
return %0 : tensor<2x3xbf16>
|
||||
|
||||
// CHECK: %[[RESULT:.*]] = "tfl.relu"(%arg0)
|
||||
// CHECK: return %[[RESULT]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Relu1
|
||||
func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%cst = constant dense<-1.0> : tensor<f32>
|
||||
|
@ -497,9 +497,9 @@ def ConvertExpandDimsToReshape : Pat<
|
||||
[(AnyStaticShapeTensor $expand_dims_op)]>;
|
||||
|
||||
class FloatValueEquals<string val> : Constraint<CPred<
|
||||
"$0.cast<DenseElementsAttr>().getNumElements() == 1 &&"
|
||||
"$0.isa<DenseFPElementsAttr>() &&"
|
||||
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
|
||||
"$0.isa<DenseFPElementsAttr>() && "
|
||||
"llvm::all_of($0.cast<DenseElementsAttr>().getFloatValues(), "
|
||||
"[](const APFloat& f) { return f.isExactlyValue(" # val # "); })">>;
|
||||
|
||||
// ReLU patterns
|
||||
def MatchReluPattern : Pat<
|
||||
|
Loading…
x
Reference in New Issue
Block a user