[MLIR][KernelGen] Add tf.Atanh kernels
				
					
				
			PiperOrigin-RevId: 352393602 Change-Id: I2431e39759a12735241e9efb9ff778bdb287e6d3
This commit is contained in:
		
							parent
							
								
									6a9c366ae0
								
							
						
					
					
						commit
						0e2545d934
					
				| @ -397,6 +397,20 @@ def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [], | |||||||
|   }]; |   }]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh", [], | ||||||
|  |     HLO_FpOrComplexTensor> { | ||||||
|  |   let summary = "Atanh operator"; | ||||||
|  | 
 | ||||||
|  |   let description = [{ | ||||||
|  |     Returns `Atanh(operand)` element-wise. | ||||||
|  | 
 | ||||||
|  |     $$ | ||||||
|  |     \atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1 | ||||||
|  |               = nan                          otherwise | ||||||
|  |     $$ | ||||||
|  |   }]; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [], | def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [], | ||||||
|     HLO_FpOrComplexTensor> { |     HLO_FpOrComplexTensor> { | ||||||
|   let summary = "Conj operator"; |   let summary = "Conj operator"; | ||||||
|  | |||||||
| @ -175,6 +175,29 @@ def : Pat<(HLOClient_AtanOp $input), | |||||||
|     (HLO_ConstantLike<"1"> $input) |     (HLO_ConstantLike<"1"> $input) | ||||||
|   )>; |   )>; | ||||||
| 
 | 
 | ||||||
|  | // Express `atanh` as follows: | ||||||
|  | //   atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1 | ||||||
|  | //   atanh(x) = nan                          otherwise | ||||||
|  | def : Pat<(HLOClient_AtanhOp NonComplexElementType:$input), | ||||||
|  |   (HLO_SelectOp | ||||||
|  |     (HLO_CompareOp | ||||||
|  |       (HLO_AbsOp $input), | ||||||
|  |       (HLO_ConstantLike<"1"> $input), | ||||||
|  |       HLO_COMPARISON_DIRECTION_GT, | ||||||
|  |       (HLO_DEFAULT_COMPARISON_TYPE) | ||||||
|  |     ), | ||||||
|  |     (HLO_ConstantLike<"NAN"> $input), | ||||||
|  |     (HLO_MulOp | ||||||
|  |       (HLO_SubOp | ||||||
|  |         (HLO_Log1pOp $input), | ||||||
|  |         (HLO_Log1pOp | ||||||
|  |           (HLO_NegOp $input) | ||||||
|  |         ) | ||||||
|  |       ), | ||||||
|  |       (HLO_ConstantLike<"0.5"> $input) | ||||||
|  |     ) | ||||||
|  |   )>; | ||||||
|  | 
 | ||||||
| // Express `conj` as | // Express `conj` as | ||||||
| //   conj(x) = (re(x), -im(x)). | //   conj(x) = (re(x), -im(x)). | ||||||
| def : Pat<(HLOClient_ConjOp $v), | def : Pat<(HLOClient_ConjOp $v), | ||||||
|  | |||||||
| @ -51,8 +51,9 @@ namespace { | |||||||
| 
 | 
 | ||||||
| // TODO(herhut): Generate these out of op definitions.
 | // TODO(herhut): Generate these out of op definitions.
 | ||||||
| #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep)                            \ | #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep)                            \ | ||||||
|   fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(ConjOp) \ |   fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(AtanhOp) \ | ||||||
|       sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp) |       sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp)           \ | ||||||
|  |           sep fn(SinhOp) sep fn(TanOp) | ||||||
| 
 | 
 | ||||||
| template <typename OpTy> | template <typename OpTy> | ||||||
| inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { | inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { | ||||||
|  | |||||||
| @ -588,6 +588,7 @@ foreach Mapping = [ | |||||||
|                    [TF_AcosOp, HLOClient_AcosOp], |                    [TF_AcosOp, HLOClient_AcosOp], | ||||||
|                    [TF_AsinOp, HLOClient_AsinOp], |                    [TF_AsinOp, HLOClient_AsinOp], | ||||||
|                    [TF_AtanOp, HLOClient_AtanOp], |                    [TF_AtanOp, HLOClient_AtanOp], | ||||||
|  |                    [TF_AtanhOp, HLOClient_AtanhOp], | ||||||
|                    [TF_CeilOp, HLO_CeilOp], |                    [TF_CeilOp, HLO_CeilOp], | ||||||
|                    [TF_CoshOp, HLOClient_CoshOp], |                    [TF_CoshOp, HLOClient_CoshOp], | ||||||
|                    [TF_ComplexAbsOp, HLO_AbsOp], |                    [TF_ComplexAbsOp, HLO_AbsOp], | ||||||
|  | |||||||
| @ -20,8 +20,11 @@ namespace tensorflow { | |||||||
| REGISTER4(UnaryOp, CPU, "Atanh", functor::atanh, float, double, complex64, | REGISTER4(UnaryOp, CPU, "Atanh", functor::atanh, float, double, complex64, | ||||||
|           complex128); |           complex128); | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM | ||||||
|  | #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \ | ||||||
|  |     !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED) | ||||||
| REGISTER2(UnaryOp, GPU, "Atanh", functor::atanh, float, double); | REGISTER2(UnaryOp, GPU, "Atanh", functor::atanh, float, double); | ||||||
| #endif | #endif | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
| }  // namespace tensorflow
 | }  // namespace tensorflow
 | ||||||
|  | |||||||
| @ -50,6 +50,7 @@ filegroup( | |||||||
|         "gpu_op_asin.cc", |         "gpu_op_asin.cc", | ||||||
|         "gpu_op_asinh.cc", |         "gpu_op_asinh.cc", | ||||||
|         "gpu_op_atan.cc", |         "gpu_op_atan.cc", | ||||||
|  |         "gpu_op_atanh.cc", | ||||||
|         "gpu_op_ceil.cc", |         "gpu_op_ceil.cc", | ||||||
|         "gpu_op_complex.cc", |         "gpu_op_complex.cc", | ||||||
|         "gpu_op_complex_abs.cc", |         "gpu_op_complex_abs.cc", | ||||||
| @ -118,6 +119,7 @@ tf_kernel_library( | |||||||
|         ":asin_kernels", |         ":asin_kernels", | ||||||
|         ":asinh_kernels", |         ":asinh_kernels", | ||||||
|         ":atan_kernels", |         ":atan_kernels", | ||||||
|  |         ":atanh_kernels", | ||||||
|         ":ceil_kernels", |         ":ceil_kernels", | ||||||
|         ":complex_abs_kernels", |         ":complex_abs_kernels", | ||||||
|         ":complex_kernels", |         ":complex_kernels", | ||||||
| @ -349,6 +351,16 @@ gen_kernel_library( | |||||||
|     unroll_factors = "4", |     unroll_factors = "4", | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | gen_kernel_library( | ||||||
|  |     name = "atanh", | ||||||
|  |     tile_size = "256", | ||||||
|  |     types = [ | ||||||
|  |         "f32", | ||||||
|  |         "f64", | ||||||
|  |     ], | ||||||
|  |     unroll_factors = "4", | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| gen_kernel_library( | gen_kernel_library( | ||||||
|     name = "conj", |     name = "conj", | ||||||
|     tile_size = "256", |     tile_size = "256", | ||||||
|  | |||||||
							
								
								
									
										24
									
								
								tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,24 @@ | |||||||
|  | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" | ||||||
|  | #include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h" | ||||||
|  | 
 | ||||||
|  | namespace tensorflow { | ||||||
|  | 
 | ||||||
|  | GENERATE_AND_REGISTER_UNARY_KERNEL(Atanh, f32, DT_FLOAT, float); | ||||||
|  | GENERATE_AND_REGISTER_UNARY_KERNEL(Atanh, f64, DT_DOUBLE, double); | ||||||
|  | 
 | ||||||
|  | }  // namespace tensorflow
 | ||||||
| @ -207,6 +207,16 @@ GENERATE_DEFAULT_TEST(Atan, DT_FLOAT, DT_FLOAT, std::atan, | |||||||
| GENERATE_DEFAULT_TEST(Atan, DT_DOUBLE, DT_DOUBLE, std::atan, | GENERATE_DEFAULT_TEST(Atan, DT_DOUBLE, DT_DOUBLE, std::atan, | ||||||
|                       test::GpuOpsTestConfig()) |                       test::GpuOpsTestConfig()) | ||||||
| 
 | 
 | ||||||
|  | /// Test `tf.Atanh`.
 | ||||||
|  | 
 | ||||||
|  | GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES( | ||||||
|  |     Atanh, DT_FLOAT, DT_FLOAT, test::DefaultInputBetweenZeroAndOne<float>(), | ||||||
|  |     std::atanh, test::GpuOpsTestConfig()) | ||||||
|  | 
 | ||||||
|  | GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES( | ||||||
|  |     Atanh, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne<double>(), | ||||||
|  |     std::atanh, test::GpuOpsTestConfig()) | ||||||
|  | 
 | ||||||
| /// Test `tf.Ceil`.
 | /// Test `tf.Ceil`.
 | ||||||
| 
 | 
 | ||||||
| GENERATE_DEFAULT_TEST(Ceil, DT_FLOAT, DT_FLOAT, std::ceil, | GENERATE_DEFAULT_TEST(Ceil, DT_FLOAT, DT_FLOAT, std::ceil, | ||||||
|  | |||||||
| @ -0,0 +1,5 @@ | |||||||
|  | func @Atanh_elem_type(%arg0: tensor<*xelem_type>) | ||||||
|  |     -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { | ||||||
|  |   %0 = "tf.Atanh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> | ||||||
|  |   return %0 : tensor<*xelem_type> | ||||||
|  | } | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user