Restrict CHLO Acos and Sinh op lowering to non complex types
These are failing for complex types. Complex types require special handling. We have a fallback lowering for these ops so we can disable complex element types for now. PiperOrigin-RevId: 348205002 Change-Id: I1a9abc1c7db8adecb447e7505eced5798976ad51
This commit is contained in:
parent
1cd185160a
commit
9a68efbad9
@ -23,10 +23,18 @@ include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
|
||||
// Unary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def NonComplexElementType : Type<
|
||||
CPred<"!$_self.cast<ShapedType>().getElementType().isa<ComplexType>()">,
|
||||
"Non complex element type">;
|
||||
|
||||
// Expand acos to MHLO dialect as follows:
|
||||
// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1
|
||||
// = pi if x == -1
|
||||
def : Pat<(HLOClient_AcosOp $input),
|
||||
//
|
||||
// TODO(hinsu): Support operands with complex element types separately using
|
||||
// the following formula.
|
||||
// acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x))))
|
||||
def : Pat<(HLOClient_AcosOp NonComplexElementType:$input),
|
||||
(HLO_SelectOp
|
||||
(HLO_CompareOp
|
||||
$input,
|
||||
@ -68,7 +76,9 @@ def : Pat<(HLOClient_ConjOp $v),
|
||||
// Express `sinh` as
|
||||
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
||||
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
|
||||
def : Pat<(HLOClient_SinhOp $input),
|
||||
// TODO(hinsu): Support operands with complex element types by always using the
|
||||
// second formula. The compare op below is not legal for complex numbers.
|
||||
def : Pat<(HLOClient_SinhOp NonComplexElementType:$input),
|
||||
(HLO_SelectOp
|
||||
(HLO_CompareOp
|
||||
(HLO_AbsOp $input),
|
||||
|
@ -2052,6 +2052,14 @@ func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @acos_complex
|
||||
// CHLO-LABEL: @acos_complex
|
||||
func @acos_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
|
||||
// CHLO: tf.Acos
|
||||
%0 = "tf.Acos"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
|
||||
return %0 : tensor<2xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @acos_dynamic
|
||||
// CHLO-LABEL: @acos_dynamic
|
||||
func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
@ -2090,6 +2098,14 @@ func @tan_unranked(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
||||
return %result : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sinh_complex
|
||||
// CHLO-LABEL: @sinh_complex
|
||||
func @sinh_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
|
||||
// CHLO: tf.Sinh
|
||||
%0 = "tf.Sinh"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
|
||||
return %0 : tensor<2xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cast_dynamic_i2f
|
||||
func @cast_dynamic_i2f(%arg0: tensor<?xi32>) -> tensor<?xf32> {
|
||||
// CHECK: "mhlo.convert"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
|
||||
|
Loading…
x
Reference in New Issue
Block a user