diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index fb9f1c53abc..688b1c20346 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -928,7 +928,7 @@ func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { // CHECK-LABEL: func @complex func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex> { - // CHECK: "mhlo.complex" + // CHECK: chlo.broadcast_complex %1 = "tf.Complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> return %1 : tensor<3xcomplex> } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 0ab98208c21..cae70d55f87 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -92,6 +92,7 @@ class DirectBinaryPat foreach fromToBinPair = [[TF_AddV2Op, HLOClient_BroadcastAddOp], [TF_Atan2Op, HLOClient_BroadcastAtan2Op], + [TF_ComplexOp, HLOClient_BroadcastComplexOp], [TF_DivOp, HLOClient_BroadcastDivOp], [TF_LeftShiftOp, HLOClient_BroadcastShiftLeftOp], [TF_MaximumOp, HLOClient_BroadcastMaxOp], @@ -111,8 +112,6 @@ def LowerRightShiftSigned : // TODO(hinsu): Lower unsigned types to HLO_ShiftRightLogical once the HLO op // supports unsigned integers. -def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>; - // Performs a substitution of FloorDiv, pseudo code below: // // return floor(div(x, y)) diff --git a/tensorflow/core/kernels/cwise_op_complex.cc b/tensorflow/core/kernels/cwise_op_complex.cc index 309e65a7621..c23ea64b4b4 100644 --- a/tensorflow/core/kernels/cwise_op_complex.cc +++ b/tensorflow/core/kernels/cwise_op_complex.cc @@ -27,9 +27,12 @@ REGISTER_COMPLEX(CPU, float, complex64); REGISTER_COMPLEX(CPU, double, complex128); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \ + !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED) REGISTER_COMPLEX(GPU, float, complex64); REGISTER_COMPLEX(GPU, double, complex128); #endif +#endif #undef REGISTER_COMPLEX } // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index 7a2e1189396..5660ba9a9bf 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -50,6 +50,7 @@ filegroup( "gpu_op_asin.cc", "gpu_op_atan.cc", "gpu_op_ceil.cc", + "gpu_op_complex.cc", "gpu_op_conj.cc", "gpu_op_cos.cc", "gpu_op_exp.cc", @@ -112,6 +113,7 @@ tf_kernel_library( ":asin_kernels", ":atan_kernels", ":ceil_kernels", + ":complex_kernels", ":conj_kernels", ":cos_kernels", ":exp_kernels", diff --git a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc index 3504dc975f1..9eba7f5b898 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc @@ -469,6 +469,25 @@ GENERATE_DEFAULT_TESTS(BitwiseXor, GENERATE_DEFAULT_TESTS(BitwiseXor, /*test_name=*/Int64, int64, int64, baseline_bitwise_xor) +/// Test `tf.Complex`. + +template +std::complex baseline_complex(T lhs, T rhs) { + return std::complex(lhs, rhs); +} + +GENERATE_DEFAULT_TESTS_2( + Complex, + /*test_name=*/C64, float, float, std::complex, std::complex, + test::DefaultInput(), test::DefaultInput(), baseline_complex, + test::GpuOpsTestConfig().ExpectStrictlyEqual().AddTout()) +GENERATE_DEFAULT_TESTS_2( + Complex, + /*test_name=*/C128, double, double, std::complex, + std::complex, test::DefaultInput(), + test::DefaultInput(), baseline_complex, + test::GpuOpsTestConfig().ExpectStrictlyEqual().AddTout()) + /// Test `tf.Div`. template diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc index d076e5c5c7d..d61d9f31b2f 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc @@ -20,10 +20,10 @@ limitations under the License. namespace tensorflow { -GENERATE_UNARY_KERNEL2(Complex, f32, DT_COMPLEX64, std::complex, float); +GENERATE_BINARY_KERNEL2(Complex, f32, DT_COMPLEX64, std::complex, float); REGISTER_COMPLEX_KERNEL(Complex, f32, std::complex, float); -GENERATE_UNARY_KERNEL2(Complex, f64, DT_COMPLEX128, std::complex, - double); +GENERATE_BINARY_KERNEL2(Complex, f64, DT_COMPLEX128, std::complex, + double); REGISTER_COMPLEX_KERNEL(Complex, f64, std::complex, double); } // namespace tensorflow