Lower Expm1 kernel to math.ExpM1.

PiperOrigin-RevId: 358152908
Change-Id: I5ea4a0214b6640f9ce1d47d9f787388446ce5589
This commit is contained in:
Adrian Kuegel 2021-02-18 04:53:34 -08:00 committed by TensorFlower Gardener
parent 045b62dc3e
commit 0b25a0284f
11 changed files with 40 additions and 20 deletions

View File

@ -55,6 +55,7 @@ MAP_HLO_TO_LHLO(CustomCallOp);
MAP_HLO_TO_LHLO(DivOp);
MAP_HLO_TO_LHLO(DotOp);
MAP_HLO_TO_LHLO(ExpOp);
MAP_HLO_TO_LHLO(Expm1Op);
MAP_HLO_TO_LHLO(FloorOp);
MAP_HLO_TO_LHLO(GatherOp);
MAP_HLO_TO_LHLO(ImagOp);

View File

@ -251,6 +251,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::Expm1Op>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpM1Op>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
ArrayRef<Type> result_types,

View File

@ -658,6 +658,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::DivOp>,
HloToLhloOpConverter<mhlo::DotOp>,
HloToLhloOpConverter<mhlo::ExpOp>,
HloToLhloOpConverter<mhlo::Expm1Op>,
HloToLhloOpConverter<mhlo::FloorOp>,
HloToLhloOpConverter<mhlo::GatherOp>,
HloToLhloOpConverter<mhlo::ImagOp>,

View File

@ -1398,6 +1398,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::CosOp>,
PointwiseToLinalgConverter<lmhlo::DivOp>,
PointwiseToLinalgConverter<lmhlo::ExpOp>,
PointwiseToLinalgConverter<lmhlo::Expm1Op>,
PointwiseToLinalgConverter<lmhlo::FloorOp>,
PointwiseToLinalgConverter<lmhlo::ImagOp>,
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
@ -1524,6 +1525,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::CosOp, false>,
PointwiseToLinalgConverter<mhlo::DivOp, false>,
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::Expm1Op, false>,
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,

View File

@ -87,6 +87,15 @@ func @exp(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @expm1
func @expm1(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
%result = "mhlo.exponential_minus_one"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "lmhlo.exponential_minus_one"(%{{.*}}, %{{.*}})
return %result : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @log
func @log(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
%result = "mhlo.log"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>

View File

@ -156,6 +156,16 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @float_expm1
func @float_expm1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: expm1
%0 = "mhlo.exponential_minus_one"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @float_log
func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic

View File

@ -996,14 +996,3 @@ func @size_to_prod_shape_i64(%arg0 : tensor<1x?x2x3xf32>) -> tensor<i64> {
// CHECK: %[[PROD:.*]] = "tf.Prod"(%[[SHAPE]], %[[CONSTANT]]) {keep_dims = false} : (tensor<4xi64>, tensor<i64>) -> tensor<i64>
// CHECK: return %[[PROD]]
}
// CHECK-LABEL: func @expm1
// CHECK-SAME: (%[[X:.*]]: tensor<*xf32>)
func @expm1(%x: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[EXP:.*]] = "tf.Exp"(%[[X]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: %[[RESULT:.*]] = "tf.Sub"(%[[EXP]], %[[ONE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
%0 = "tf.Expm1" (%x) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[RESULT]]
return %0: tensor<*xf32>
}

View File

@ -1542,7 +1542,6 @@ void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
LowerBiasAddGradOp,
LowerDivNoNanOp,
LowerEmptyOp,
LowerExpm1Op,
LowerFakeQuantWithMinMaxArgs,
LowerFillOp,
LowerIsNanOp,

View File

@ -162,14 +162,6 @@ def LowerDivNoNanOp : BinaryNoNanPat<TF_DivNoNanOp, TF_DivOp>;
def LowerMulNoNanOp : BinaryNoNanPat<TF_MulNoNanOp, TF_MulOp>;
//===----------------------------------------------------------------------===//
// Expm1 op patterns.
//===----------------------------------------------------------------------===//
def LowerExpm1Op : Pat<(TF_Expm1Op $x),
(TF_SubOp (TF_ExpOp $x),
(TF_ConstOp (GetScalarOfType<1> $x)))>;
//===----------------------------------------------------------------------===//
// Fill op patterns.
//===----------------------------------------------------------------------===//

View File

@ -2183,6 +2183,13 @@ func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
return %0 : tensor<2xf32>
}
// CHECK-LABEL: @expm1
func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: "mhlo.exponential_minus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "tf.Expm1"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// CHECK-LABEL: func @exp_dynamic
func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: "mhlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>

View File

@ -598,6 +598,7 @@ foreach Mapping = [
[TF_CosOp, HLO_CosOp],
[TF_DigammaOp, HLOClient_DigammaOp],
[TF_ExpOp, HLO_ExpOp],
[TF_Expm1Op, HLO_Expm1Op],
[TF_ErfOp, HLOClient_ErfOp],
[TF_ErfcOp, HLOClient_ErfcOp],
[TF_FloorOp, HLO_FloorOp],