Support tfl.maximum->tfl.relu with bf16 data types and/or splat constants

PiperOrigin-RevId: 337259680
Change-Id: I03380f020a347cfe88093e7f64400542181d0c70
This commit is contained in:
A. Unique TensorFlower 2020-10-15 01:15:11 -07:00 committed by TensorFlower Gardener
parent 19200d4ec5
commit 6b3a88b6e9
2 changed files with 13 additions and 3 deletions

View File

@ -991,6 +991,16 @@ func @Relu(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %[[RESULT]]
}
// CHECK-LABEL: Relu_bf16
func @Relu_bf16(%arg0: tensor<2x3xbf16>) -> tensor<2x3xbf16> {
%cst = constant dense<0.0> : tensor<2x3xbf16>
%0 = "tfl.maximum"(%arg0, %cst) : (tensor<2x3xbf16>, tensor<2x3xbf16>) -> tensor<2x3xbf16>
return %0 : tensor<2x3xbf16>
// CHECK: %[[RESULT:.*]] = "tfl.relu"(%arg0)
// CHECK: return %[[RESULT]]
}
// CHECK-LABEL: Relu1
func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%cst = constant dense<-1.0> : tensor<f32>

View File

@ -497,9 +497,9 @@ def ConvertExpandDimsToReshape : Pat<
[(AnyStaticShapeTensor $expand_dims_op)]>;
class FloatValueEquals<string val> : Constraint<CPred<
"$0.cast<DenseElementsAttr>().getNumElements() == 1 &&"
"$0.isa<DenseFPElementsAttr>() &&"
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
"$0.isa<DenseFPElementsAttr>() && "
"llvm::all_of($0.cast<DenseElementsAttr>().getFloatValues(), "
"[](const APFloat& f) { return f.isExactlyValue(" # val # "); })">>;
// ReLU patterns
def MatchReluPattern : Pat<