[MLIR][KernelGen] Add tf.Atanh kernels

PiperOrigin-RevId: 352393602
Change-Id: I2431e39759a12735241e9efb9ff778bdb287e6d3
This commit is contained in:
A. Unique TensorFlower 2021-01-18 05:13:02 -08:00 committed by TensorFlower Gardener
parent 6a9c366ae0
commit 0e2545d934
9 changed files with 97 additions and 4 deletions
tensorflow
compiler/mlir
hlo
include/mlir-hlo/Dialect/mhlo/IR
lib/Dialect/mhlo/transforms
xla/transforms
core/kernels

View File

@ -397,6 +397,20 @@ def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
}];
}
def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh", [],
HLO_FpOrComplexTensor> {
let summary = "Atanh operator";
let description = [{
Returns `Atanh(operand)` element-wise.
$$
\atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
= nan otherwise
$$
}];
}
def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [],
HLO_FpOrComplexTensor> {
let summary = "Conj operator";

View File

@ -175,6 +175,29 @@ def : Pat<(HLOClient_AtanOp $input),
(HLO_ConstantLike<"1"> $input)
)>;
// Express `atanh` as follows:
// atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
// atanh(x) = nan otherwise
def : Pat<(HLOClient_AtanhOp NonComplexElementType:$input),
(HLO_SelectOp
(HLO_CompareOp
(HLO_AbsOp $input),
(HLO_ConstantLike<"1"> $input),
HLO_COMPARISON_DIRECTION_GT,
(HLO_DEFAULT_COMPARISON_TYPE)
),
(HLO_ConstantLike<"NAN"> $input),
(HLO_MulOp
(HLO_SubOp
(HLO_Log1pOp $input),
(HLO_Log1pOp
(HLO_NegOp $input)
)
),
(HLO_ConstantLike<"0.5"> $input)
)
)>;
// Express `conj` as
// conj(x) = (re(x), -im(x)).
def : Pat<(HLOClient_ConjOp $v),

View File

@ -50,9 +50,10 @@ namespace {
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(ConjOp) \
sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(AtanhOp) \
sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) \
sep fn(SinhOp) sep fn(TanOp)
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {

View File

@ -588,6 +588,7 @@ foreach Mapping = [
[TF_AcosOp, HLOClient_AcosOp],
[TF_AsinOp, HLOClient_AsinOp],
[TF_AtanOp, HLOClient_AtanOp],
[TF_AtanhOp, HLOClient_AtanhOp],
[TF_CeilOp, HLO_CeilOp],
[TF_CoshOp, HLOClient_CoshOp],
[TF_ComplexAbsOp, HLO_AbsOp],

View File

@ -20,8 +20,11 @@ namespace tensorflow {
REGISTER4(UnaryOp, CPU, "Atanh", functor::atanh, float, double, complex64,
complex128);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER2(UnaryOp, GPU, "Atanh", functor::atanh, float, double);
#endif
#endif
} // namespace tensorflow

View File

@ -50,6 +50,7 @@ filegroup(
"gpu_op_asin.cc",
"gpu_op_asinh.cc",
"gpu_op_atan.cc",
"gpu_op_atanh.cc",
"gpu_op_ceil.cc",
"gpu_op_complex.cc",
"gpu_op_complex_abs.cc",
@ -118,6 +119,7 @@ tf_kernel_library(
":asin_kernels",
":asinh_kernels",
":atan_kernels",
":atanh_kernels",
":ceil_kernels",
":complex_abs_kernels",
":complex_kernels",
@ -349,6 +351,16 @@ gen_kernel_library(
unroll_factors = "4",
)
gen_kernel_library(
name = "atanh",
tile_size = "256",
types = [
"f32",
"f64",
],
unroll_factors = "4",
)
gen_kernel_library(
name = "conj",
tile_size = "256",

View File

@ -0,0 +1,24 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Atanh, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Atanh, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -207,6 +207,16 @@ GENERATE_DEFAULT_TEST(Atan, DT_FLOAT, DT_FLOAT, std::atan,
GENERATE_DEFAULT_TEST(Atan, DT_DOUBLE, DT_DOUBLE, std::atan,
test::GpuOpsTestConfig())
/// Test `tf.Atanh`.
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Atanh, DT_FLOAT, DT_FLOAT, test::DefaultInputBetweenZeroAndOne<float>(),
std::atanh, test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Atanh, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne<double>(),
std::atanh, test::GpuOpsTestConfig())
/// Test `tf.Ceil`.
GENERATE_DEFAULT_TEST(Ceil, DT_FLOAT, DT_FLOAT, std::ceil,

View File

@ -0,0 +1,5 @@
func @Atanh_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Atanh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>
}