diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td index 2cc97c90d1c..d13229719c8 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td @@ -51,40 +51,22 @@ def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, (HLO_MulOp $lhs_real, $rhs_imag), (HLO_MulOp $lhs_imag, $rhs_real)))>; -// Multiplication between a complex and real tensor can be distributed by -// applying the real multiplicant to both the real and complex component. -// -// Note that the sourcep pattern is not legal according to the HLO dialect but -// instead handle intermediates generated by other patterns. -def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), - (HLO_ComplexOp - (HLO_MulOp (HLO_RealOp $lhs), $rhs), - (HLO_MulOp (HLO_ImagOp $lhs), $rhs))>; - -def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs), - (HLO_ComplexOp - (HLO_MulOp $lhs, (HLO_RealOp $rhs)), - (HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>; - // Division is performed by normalizing the denominator by multiplying by the // conjugate of the rhs. // numerator = lhs * conj(rhs) // denominator = rhs * conj(rhs) def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs), - (HLO_DivOp - (HLO_MulOp:$num $lhs, - (HLO_ComplexOp:$conj - (HLO_RealOp $rhs), - (HLO_NegOp (HLO_ImagOp $rhs)))), - (HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>; - - -def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), (HLO_ComplexOp - (HLO_DivOp (HLO_RealOp $lhs), $rhs), - (HLO_DivOp (HLO_ImagOp $lhs), $rhs))>; - + (HLO_DivOp + (HLO_RealOp (HLO_MulOp:$num $lhs, + (HLO_ComplexOp:$conj + (HLO_RealOp $rhs), + (HLO_NegOp (HLO_ImagOp $rhs))))), + (HLO_AddOp:$den + (HLO_MulOp (HLO_RealOp $rhs), (HLO_RealOp $rhs)), + (HLO_MulOp (HLO_ImagOp $rhs), (HLO_ImagOp $rhs)))), + (HLO_DivOp (HLO_ImagOp $num), $den))>; // Absolute value is evaluated as: // result = sqrt(val.real * val.real + val.imag * val.imag) @@ -98,10 +80,10 @@ def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val), // sum of sinusoids of the imaginary component, which equates to a normal // exponential operator multiplied by Euler's formula. // -// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * (Cos(b) + iSin(b)) +// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * Cos(b) + Exp(a) * iSin(b)) def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val), - (HLO_MulOp - (HLO_ExpOp (HLO_RealOp $val)), - (HLO_ComplexOp + (HLO_ComplexOp + (HLO_MulOp (HLO_CosOp (HLO_ImagOp:$imag $val)), - (HLO_SinOp $imag)))>; + (HLO_ExpOp:$exp (HLO_RealOp:$real $val))), + (HLO_MulOp (HLO_SinOp $imag), $exp))>; diff --git a/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir b/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir index b9c91d61377..141c238f930 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir @@ -114,8 +114,8 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % // Compute the real valued denominator as rhs * con(rhs): // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]] + // CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, %arg3 + // CHECK-DAG: [[VAL6:%.+]] = mhlo.add [[VAL4]], [[VAL5]] // Compute the numerator's imaginary component: // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag @@ -153,8 +153,8 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor< // Compute the real valued denominator as rhs * con(rhs): // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]] + // CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, %arg3 + // CHECK-DAG: [[VAL6:%.+]] = mhlo.add [[VAL4]], [[VAL5]] // Compute the numerator's imaginary component: // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag @@ -165,6 +165,7 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor< // Divide the numerator by the real valued denominator. // CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]] // CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]] + %4 = "mhlo.divide"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) %5 = "mhlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) @@ -192,32 +193,48 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) { func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0) - // CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1) - // CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1) - // CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]] + // CHECK-DAG: [[EXP:%.+]] = "mhlo.exponential"(%arg0) + // CHECK-DAG: [[COS:%.+]] = "mhlo.cosine"(%arg1) + // CHECK-DAG: [[SIN:%.+]] = "mhlo.sine"(%arg1) + // CHECK-DAG: [[OUTR:%.+]] = mhlo.multiply [[COS]], [[EXP]] + // CHECK-DAG: [[OUTI:%.+]] = mhlo.multiply [[SIN]], [[EXP]] %1 = "mhlo.exponential"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) + %2 = "mhlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) %3 = "mhlo.imag"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) - // CHECK: return [[VAL3]], [[VAL4]] + // CHECK: [[OUTR]], [[OUTI]] return %2, %3 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @exp_unranked -func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) +// CHECK-LABEL: @exp_complex +func @exp_complex(%arg0 : tensor<2xcomplex>) -> (tensor<2xcomplex>) { + // CHECK-DAG: [[REAL:%.+]] = "mhlo.real"(%arg0) + // CHECK-DAG: [[IMAG:%.+]] = "mhlo.imag"(%arg0) + // CHECK-DAG: [[EXP:%.+]] = "mhlo.exponential"([[REAL]]) + // CHECK-DAG: [[COS:%.+]] = "mhlo.cosine"([[IMAG]]) + // CHECK-DAG: [[SIN:%.+]] = "mhlo.sine"([[IMAG]]) + // CHECK-DAG: [[OUTR:%.+]] = mhlo.multiply [[COS]], [[EXP]] + // CHECK-DAG: [[OUTI:%.+]] = mhlo.multiply [[SIN]], [[EXP]] + // CHECK-DAG: [[OUT:%.+]] = "mhlo.complex"([[OUTR]], [[OUTI]]) + %0 = "mhlo.exponential"(%arg0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0) - // CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1) - // CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1) - // CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]] - %1 = "mhlo.exponential"(%0) : (tensor<*xcomplex>) -> (tensor<*xcomplex>) - %2 = "mhlo.real"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) - %3 = "mhlo.imag"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) - - // CHECK: return [[VAL3]], [[VAL4]] - return %2, %3 : tensor<*xf32>, tensor<*xf32> + // CHECK: [[OUT]] + return %0 : tensor<2xcomplex> +} + +// CHECK-LABEL: @exp_unranked +func @exp_unranked(%arg0 : tensor<*xcomplex>) -> (tensor<*xcomplex>) { + // CHECK-DAG: [[REAL:%.+]] = "mhlo.real"(%arg0) + // CHECK-DAG: [[IMAG:%.+]] = "mhlo.imag"(%arg0) + // CHECK-DAG: [[EXP:%.+]] = "mhlo.exponential"([[REAL]]) + // CHECK-DAG: [[COS:%.+]] = "mhlo.cosine"([[IMAG]]) + // CHECK-DAG: [[SIN:%.+]] = "mhlo.sine"([[IMAG]]) + // CHECK-DAG: [[OUTR:%.+]] = mhlo.multiply [[COS]], [[EXP]] + // CHECK-DAG: [[OUTI:%.+]] = mhlo.multiply [[SIN]], [[EXP]] + // CHECK-DAG: [[OUT:%.+]] = "mhlo.complex"([[OUTR]], [[OUTI]]) + %0 = "mhlo.exponential"(%arg0) : (tensor<*xcomplex>) -> (tensor<*xcomplex>) + + // CHECK: [[OUT]] + return %0 : tensor<*xcomplex> }