diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index 19f1c72fed4..252b15d72ed 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -55,6 +55,7 @@ MAP_HLO_TO_LHLO(CustomCallOp); MAP_HLO_TO_LHLO(DivOp); MAP_HLO_TO_LHLO(DotOp); MAP_HLO_TO_LHLO(ExpOp); +MAP_HLO_TO_LHLO(Expm1Op); MAP_HLO_TO_LHLO(FloorOp); MAP_HLO_TO_LHLO(GatherOp); MAP_HLO_TO_LHLO(ImagOp); diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index d726c7fd81e..5a957e7c5bc 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -251,6 +251,15 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 7ec6629f7f2..b562e3b5a2a 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -658,6 +658,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 6cbb68cc549..8dff2f513c4 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1398,6 +1398,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -1524,6 +1525,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir index a25ee89ecdb..5c497244912 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir @@ -87,6 +87,15 @@ func @exp(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @expm1 +func @expm1(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { + %result = "mhlo.exponential_minus_one"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: "lmhlo.exponential_minus_one"(%{{.*}}, %{{.*}}) + return %result : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @log func @log(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { %result = "mhlo.log"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir index 545a781d58b..99415e09c8f 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir @@ -156,6 +156,16 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @float_expm1 +func @float_expm1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: expm1 + %0 = "mhlo.exponential_minus_one"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @float_log func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index 936680934f0..5e387e7b383 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -996,14 +996,3 @@ func @size_to_prod_shape_i64(%arg0 : tensor<1x?x2x3xf32>) -> tensor { // CHECK: %[[PROD:.*]] = "tf.Prod"(%[[SHAPE]], %[[CONSTANT]]) {keep_dims = false} : (tensor<4xi64>, tensor) -> tensor // CHECK: return %[[PROD]] } - -// CHECK-LABEL: func @expm1 -// CHECK-SAME: (%[[X:.*]]: tensor<*xf32>) -func @expm1(%x: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor - // CHECK: %[[EXP:.*]] = "tf.Exp"(%[[X]]) : (tensor<*xf32>) -> tensor<*xf32> - // CHECK: %[[RESULT:.*]] = "tf.Sub"(%[[EXP]], %[[ONE]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> - %0 = "tf.Expm1" (%x) : (tensor<*xf32>) -> tensor<*xf32> - // CHECK: return %[[RESULT]] - return %0: tensor<*xf32> -} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index 81dbae148d3..6bc22e5d951 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -1542,7 +1542,6 @@ void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context, LowerBiasAddGradOp, LowerDivNoNanOp, LowerEmptyOp, - LowerExpm1Op, LowerFakeQuantWithMinMaxArgs, LowerFillOp, LowerIsNanOp, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index bc75c122b30..7a04795eeda 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -162,14 +162,6 @@ def LowerDivNoNanOp : BinaryNoNanPat; def LowerMulNoNanOp : BinaryNoNanPat; -//===----------------------------------------------------------------------===// -// Expm1 op patterns. -//===----------------------------------------------------------------------===// - -def LowerExpm1Op : Pat<(TF_Expm1Op $x), - (TF_SubOp (TF_ExpOp $x), - (TF_ConstOp (GetScalarOfType<1> $x)))>; - //===----------------------------------------------------------------------===// // Fill op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 1df14448ffd..aae265ef0dc 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -2183,6 +2183,13 @@ func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %0 : tensor<2xf32> } +// CHECK-LABEL: @expm1 +func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: "mhlo.exponential_minus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "tf.Expm1"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + // CHECK-LABEL: func @exp_dynamic func @exp_dynamic(%arg0: tensor) -> tensor { // CHECK: "mhlo.exponential"(%arg0) : (tensor) -> tensor diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index dec0aace4d1..0baeba90f17 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -598,6 +598,7 @@ foreach Mapping = [ [TF_CosOp, HLO_CosOp], [TF_DigammaOp, HLOClient_DigammaOp], [TF_ExpOp, HLO_ExpOp], + [TF_Expm1Op, HLO_Expm1Op], [TF_ErfOp, HLOClient_ErfOp], [TF_ErfcOp, HLOClient_ErfcOp], [TF_FloorOp, HLO_FloorOp],