[MLIR][KernelGen] Add tf.Atanh
kernels
PiperOrigin-RevId: 352393602 Change-Id: I2431e39759a12735241e9efb9ff778bdb287e6d3
This commit is contained in:
parent
6a9c366ae0
commit
0e2545d934
tensorflow
compiler/mlir
hlo
include/mlir-hlo/Dialect/mhlo/IR
lib/Dialect/mhlo/transforms
xla/transforms
core/kernels
cwise_op_atanh.cc
mlir_generated
@ -397,6 +397,20 @@ def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
let summary = "Atanh operator";
|
||||
|
||||
let description = [{
|
||||
Returns `Atanh(operand)` element-wise.
|
||||
|
||||
$$
|
||||
\atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
|
||||
= nan otherwise
|
||||
$$
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
let summary = "Conj operator";
|
||||
|
@ -175,6 +175,29 @@ def : Pat<(HLOClient_AtanOp $input),
|
||||
(HLO_ConstantLike<"1"> $input)
|
||||
)>;
|
||||
|
||||
// Express `atanh` as follows:
|
||||
// atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
|
||||
// atanh(x) = nan otherwise
|
||||
def : Pat<(HLOClient_AtanhOp NonComplexElementType:$input),
|
||||
(HLO_SelectOp
|
||||
(HLO_CompareOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_ConstantLike<"1"> $input),
|
||||
HLO_COMPARISON_DIRECTION_GT,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||
),
|
||||
(HLO_ConstantLike<"NAN"> $input),
|
||||
(HLO_MulOp
|
||||
(HLO_SubOp
|
||||
(HLO_Log1pOp $input),
|
||||
(HLO_Log1pOp
|
||||
(HLO_NegOp $input)
|
||||
)
|
||||
),
|
||||
(HLO_ConstantLike<"0.5"> $input)
|
||||
)
|
||||
)>;
|
||||
|
||||
// Express `conj` as
|
||||
// conj(x) = (re(x), -im(x)).
|
||||
def : Pat<(HLOClient_ConjOp $v),
|
||||
|
@ -50,9 +50,10 @@ namespace {
|
||||
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
|
||||
|
||||
// TODO(herhut): Generate these out of op definitions.
|
||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||
fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(ConjOp) \
|
||||
sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
|
||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||
fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(AtanhOp) \
|
||||
sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) \
|
||||
sep fn(SinhOp) sep fn(TanOp)
|
||||
|
||||
template <typename OpTy>
|
||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||
|
@ -588,6 +588,7 @@ foreach Mapping = [
|
||||
[TF_AcosOp, HLOClient_AcosOp],
|
||||
[TF_AsinOp, HLOClient_AsinOp],
|
||||
[TF_AtanOp, HLOClient_AtanOp],
|
||||
[TF_AtanhOp, HLOClient_AtanhOp],
|
||||
[TF_CeilOp, HLO_CeilOp],
|
||||
[TF_CoshOp, HLOClient_CoshOp],
|
||||
[TF_ComplexAbsOp, HLO_AbsOp],
|
||||
|
@ -20,8 +20,11 @@ namespace tensorflow {
|
||||
REGISTER4(UnaryOp, CPU, "Atanh", functor::atanh, float, double, complex64,
|
||||
complex128);
|
||||
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
REGISTER2(UnaryOp, GPU, "Atanh", functor::atanh, float, double);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -50,6 +50,7 @@ filegroup(
|
||||
"gpu_op_asin.cc",
|
||||
"gpu_op_asinh.cc",
|
||||
"gpu_op_atan.cc",
|
||||
"gpu_op_atanh.cc",
|
||||
"gpu_op_ceil.cc",
|
||||
"gpu_op_complex.cc",
|
||||
"gpu_op_complex_abs.cc",
|
||||
@ -118,6 +119,7 @@ tf_kernel_library(
|
||||
":asin_kernels",
|
||||
":asinh_kernels",
|
||||
":atan_kernels",
|
||||
":atanh_kernels",
|
||||
":ceil_kernels",
|
||||
":complex_abs_kernels",
|
||||
":complex_kernels",
|
||||
@ -349,6 +351,16 @@ gen_kernel_library(
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "atanh",
|
||||
tile_size = "256",
|
||||
types = [
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "conj",
|
||||
tile_size = "256",
|
||||
|
24
tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc
Normal file
24
tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc
Normal file
@ -0,0 +1,24 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Atanh, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Atanh, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
@ -207,6 +207,16 @@ GENERATE_DEFAULT_TEST(Atan, DT_FLOAT, DT_FLOAT, std::atan,
|
||||
GENERATE_DEFAULT_TEST(Atan, DT_DOUBLE, DT_DOUBLE, std::atan,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
/// Test `tf.Atanh`.
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Atanh, DT_FLOAT, DT_FLOAT, test::DefaultInputBetweenZeroAndOne<float>(),
|
||||
std::atanh, test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Atanh, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne<double>(),
|
||||
std::atanh, test::GpuOpsTestConfig())
|
||||
|
||||
/// Test `tf.Ceil`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Ceil, DT_FLOAT, DT_FLOAT, std::ceil,
|
||||
|
@ -0,0 +1,5 @@
|
||||
func @Atanh_elem_type(%arg0: tensor<*xelem_type>)
|
||||
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
|
||||
%0 = "tf.Atanh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
|
||||
return %0 : tensor<*xelem_type>
|
||||
}
|
Loading…
Reference in New Issue
Block a user