Lower Expm1 kernel to math.ExpM1.
PiperOrigin-RevId: 358152908 Change-Id: I5ea4a0214b6640f9ce1d47d9f787388446ce5589
This commit is contained in:
parent
045b62dc3e
commit
0b25a0284f
@ -55,6 +55,7 @@ MAP_HLO_TO_LHLO(CustomCallOp);
|
||||
MAP_HLO_TO_LHLO(DivOp);
|
||||
MAP_HLO_TO_LHLO(DotOp);
|
||||
MAP_HLO_TO_LHLO(ExpOp);
|
||||
MAP_HLO_TO_LHLO(Expm1Op);
|
||||
MAP_HLO_TO_LHLO(FloorOp);
|
||||
MAP_HLO_TO_LHLO(GatherOp);
|
||||
MAP_HLO_TO_LHLO(ImagOp);
|
||||
|
@ -251,6 +251,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::Expm1Op>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpM1Op>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
|
@ -658,6 +658,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||
HloToLhloOpConverter<mhlo::DivOp>,
|
||||
HloToLhloOpConverter<mhlo::DotOp>,
|
||||
HloToLhloOpConverter<mhlo::ExpOp>,
|
||||
HloToLhloOpConverter<mhlo::Expm1Op>,
|
||||
HloToLhloOpConverter<mhlo::FloorOp>,
|
||||
HloToLhloOpConverter<mhlo::GatherOp>,
|
||||
HloToLhloOpConverter<mhlo::ImagOp>,
|
||||
|
@ -1398,6 +1398,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<lmhlo::CosOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::DivOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ExpOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::Expm1Op>,
|
||||
PointwiseToLinalgConverter<lmhlo::FloorOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ImagOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
||||
@ -1524,6 +1525,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::Expm1Op, false>,
|
||||
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
||||
|
@ -87,6 +87,15 @@ func @exp(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @expm1
|
||||
func @expm1(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%result = "mhlo.exponential_minus_one"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "lmhlo.exponential_minus_one"(%{{.*}}, %{{.*}})
|
||||
return %result : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @log
|
||||
func @log(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%result = "mhlo.log"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
|
@ -156,6 +156,16 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_expm1
|
||||
func @float_expm1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: expm1
|
||||
%0 = "mhlo.exponential_minus_one"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_log
|
||||
func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
|
@ -996,14 +996,3 @@ func @size_to_prod_shape_i64(%arg0 : tensor<1x?x2x3xf32>) -> tensor<i64> {
|
||||
// CHECK: %[[PROD:.*]] = "tf.Prod"(%[[SHAPE]], %[[CONSTANT]]) {keep_dims = false} : (tensor<4xi64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK: return %[[PROD]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @expm1
|
||||
// CHECK-SAME: (%[[X:.*]]: tensor<*xf32>)
|
||||
func @expm1(%x: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[EXP:.*]] = "tf.Exp"(%[[X]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.Sub"(%[[EXP]], %[[ONE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||
%0 = "tf.Expm1" (%x) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
return %0: tensor<*xf32>
|
||||
}
|
||||
|
@ -1542,7 +1542,6 @@ void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
|
||||
LowerBiasAddGradOp,
|
||||
LowerDivNoNanOp,
|
||||
LowerEmptyOp,
|
||||
LowerExpm1Op,
|
||||
LowerFakeQuantWithMinMaxArgs,
|
||||
LowerFillOp,
|
||||
LowerIsNanOp,
|
||||
|
@ -162,14 +162,6 @@ def LowerDivNoNanOp : BinaryNoNanPat<TF_DivNoNanOp, TF_DivOp>;
|
||||
|
||||
def LowerMulNoNanOp : BinaryNoNanPat<TF_MulNoNanOp, TF_MulOp>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Expm1 op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def LowerExpm1Op : Pat<(TF_Expm1Op $x),
|
||||
(TF_SubOp (TF_ExpOp $x),
|
||||
(TF_ConstOp (GetScalarOfType<1> $x)))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Fill op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2183,6 +2183,13 @@ func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @expm1
|
||||
func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: "mhlo.exponential_minus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "tf.Expm1"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @exp_dynamic
|
||||
func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "mhlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
|
@ -598,6 +598,7 @@ foreach Mapping = [
|
||||
[TF_CosOp, HLO_CosOp],
|
||||
[TF_DigammaOp, HLOClient_DigammaOp],
|
||||
[TF_ExpOp, HLO_ExpOp],
|
||||
[TF_Expm1Op, HLO_Expm1Op],
|
||||
[TF_ErfOp, HLOClient_ErfOp],
|
||||
[TF_ErfcOp, HLOClient_ErfcOp],
|
||||
[TF_FloorOp, HLO_FloorOp],
|
||||
|
Loading…
x
Reference in New Issue
Block a user