diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td index 92a02ea52fb..a4f425e8aae 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -23,10 +23,18 @@ include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td" // Unary op patterns. //===----------------------------------------------------------------------===// +def NonComplexElementType : Type< + CPred<"!$_self.cast().getElementType().isa()">, + "Non complex element type">; + // Expand acos to MHLO dialect as follows: // acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1 // = pi if x == -1 -def : Pat<(HLOClient_AcosOp $input), +// +// TODO(hinsu): Support operands with complex element types separately using +// the following formula. +// acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x)))) +def : Pat<(HLOClient_AcosOp NonComplexElementType:$input), (HLO_SelectOp (HLO_CompareOp $input, @@ -68,7 +76,9 @@ def : Pat<(HLOClient_ConjOp $v), // Express `sinh` as // sinh(x) = (e^x - e^-x) / 2 if |x| < 1 // = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. -def : Pat<(HLOClient_SinhOp $input), +// TODO(hinsu): Support operands with complex element types by always using the +// second formula. The compare op below is not legal for complex numbers. +def : Pat<(HLOClient_SinhOp NonComplexElementType:$input), (HLO_SelectOp (HLO_CompareOp (HLO_AbsOp $input), diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index e5aa94fd0b2..8aa39c676c8 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -2052,6 +2052,14 @@ func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %0 : tensor<2xf32> } +// CHECK-LABEL: @acos_complex +// CHLO-LABEL: @acos_complex +func @acos_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> { + // CHLO: tf.Acos + %0 = "tf.Acos"(%arg0) : (tensor<2xcomplex>) -> tensor<2xcomplex> + return %0 : tensor<2xcomplex> +} + // CHECK-LABEL: @acos_dynamic // CHLO-LABEL: @acos_dynamic func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { @@ -2090,6 +2098,14 @@ func @tan_unranked(%arg : tensor<*xf32>) -> tensor<*xf32> { return %result : tensor<*xf32> } +// CHECK-LABEL: @sinh_complex +// CHLO-LABEL: @sinh_complex +func @sinh_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> { + // CHLO: tf.Sinh + %0 = "tf.Sinh"(%arg0) : (tensor<2xcomplex>) -> tensor<2xcomplex> + return %0 : tensor<2xcomplex> +} + // CHECK-LABEL: func @cast_dynamic_i2f func @cast_dynamic_i2f(%arg0: tensor) -> tensor { // CHECK: "mhlo.convert"(%arg0) : (tensor) -> tensor