[MLIR][KernelGen] Add tf.Asinh kernels and complete their lowerings

PiperOrigin-RevId: 351989552
Change-Id: Ib86198983dcb4e8ec8b7edcef09af3b3787c09af
This commit is contained in:
A. Unique TensorFlower 2021-01-15 05:25:49 -08:00 committed by TensorFlower Gardener
parent bc464afaba
commit 7449c255cb
12 changed files with 176 additions and 4 deletions

View File

@ -66,6 +66,8 @@ static Value getConstantLike(OpBuilder& b, Location loc, T constant,
return b.create<ConstantLikeOp>(loc, getAttr(), val);
}
Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val);
} // namespace chlo
} // namespace mlir

View File

@ -372,6 +372,18 @@ def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [],
}];
}
def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [],
HLO_FpOrComplexTensor> {
let summary = "Asinh operation";
let description = [{
Returns `Asinh(operand)` element-wise.
$$
\asinh(x) = log(x + sqrt(x^2 + 1))
$$
}];
}
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
HLO_FpOrComplexTensor> {
let summary = "Atan operator";

View File

@ -30,6 +30,9 @@ class ConstantSplat<string value> : NativeCodeCall<
class HLO_ConstantLike<string value> : NativeCodeCall<
"chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">;
def HLO_ConstantLikeMaxFiniteValue : NativeCodeCall<
"chlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">;
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
def BinBroadcastDimensions : NativeCodeCall<

View File

@ -32,6 +32,20 @@ static LogicalResult Verify(T op) {
return success();
}
static constexpr float kF16MaxFiniteValue = 0x1.ffcP15;
Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) {
Type ty = getElementTypeOrSelf(val.getType());
if (ty.isF16()) {
return getConstantLike(b, loc, kF16MaxFiniteValue, val);
} else if (ty.isF32()) {
return getConstantLike(b, loc, std::numeric_limits<float>::max(), val);
} else if (ty.isF64()) {
return getConstantLike(b, loc, std::numeric_limits<double>::max(), val);
}
llvm_unreachable("unhandled type");
}
//===----------------------------------------------------------------------===//
// BinaryOps
//===----------------------------------------------------------------------===//

View File

@ -79,6 +79,94 @@ def : Pat<(HLOClient_AsinOp NonComplexElementType:$input),
)
)>;
// Expand asinh to MHLO dialect as
// asinh(x) = log(x + sqrt(x^2 + 1))
//
// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1)
// as 2*x and return log(2) + log(x).
//
// For small x, sqrt(x^2 + 1) will evaluate to 1 due to floating point
// arithmetic. However, we would like to retain the low order term of this,
// which is around 0.5 * x^2 using a binomial expansion.
// Let z = sqrt(a^2 + 1)
// The following rewrite retains the lower order term.
// log(a + sqrt(a^2 + 1))
// = log((a + sqrt(a^2 + 1)) * (1 + sqrt(a^2 + 1)) / (1 + sqrt(a^2 + 1)))
// = log((a + a^2 + 1 + a * z + z) / (1 + z))
// = log(1 + a + a^2 / (1 + z))
// = log(1 + a + a^2 / (1 + sqrt(a^2 + 1)))
//
// If x is negative, the above would give us some trouble; we can't approximate
// the result as x + abs(x) = 0 but we are saved by the fact that asinh(-x) =
// -asinh(x).
def : Pat<(HLOClient_AsinhOp NonComplexElementType:$input),
(HLO_MulOp
(HLO_SignOp $input),
(HLO_SelectOp
(HLO_CompareOp
(HLO_AbsOp $input),
(HLO_SqrtOp
(HLO_ConstantLikeMaxFiniteValue $input)
),
HLO_COMPARISON_DIRECTION_GE,
(HLO_DEFAULT_COMPARISON_TYPE)
),
(HLO_AddOp
(HLO_LogOp
(HLO_AbsOp $input)
),
(HLO_LogOp
(HLO_ConstantLike<"2"> $input)
)
),
(HLO_SelectOp
(HLO_CompareOp
(HLO_AbsOp $input),
(HLO_ConstantLike<"1"> $input),
HLO_COMPARISON_DIRECTION_LE,
(HLO_DEFAULT_COMPARISON_TYPE)
),
(HLO_Log1pOp
(HLO_AddOp
(HLO_AbsOp $input),
(HLO_MulOp
(HLO_AbsOp $input),
(HLO_DivOp
(HLO_AbsOp $input),
(HLO_AddOp
(HLO_ConstantLike<"1"> $input),
(HLO_SqrtOp
(HLO_AddOp
(HLO_MulOp
(HLO_AbsOp $input),
(HLO_AbsOp $input)
),
(HLO_ConstantLike<"1"> $input)
)
)
)
)
)
)
),
(HLO_LogOp
(HLO_AddOp
(HLO_AbsOp $input),
(HLO_SqrtOp
(HLO_AddOp
(HLO_MulOp
(HLO_AbsOp $input),
(HLO_AbsOp $input)
),
(HLO_ConstantLike<"1"> $input)
)
)
)
)
)
)
)>;
// Express `atan` as
// atan(x) = atan2(x, 1)
def : Pat<(HLOClient_AtanOp $input),

View File

@ -50,9 +50,9 @@ namespace {
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(CoshOp) \
sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(ConjOp) \
sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {

View File

@ -584,6 +584,7 @@ def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher Elements
foreach Mapping = [
[TF_AbsOp, HLO_AbsOp],
[TF_AsinhOp, HLOClient_AsinhOp],
[TF_AcosOp, HLOClient_AcosOp],
[TF_AsinOp, HLOClient_AsinOp],
[TF_AtanOp, HLOClient_AtanOp],

View File

@ -20,8 +20,11 @@ namespace tensorflow {
REGISTER4(UnaryOp, CPU, "Asinh", functor::asinh, float, double, complex64,
complex128);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER2(UnaryOp, GPU, "Asinh", functor::asinh, float, double);
#endif
#endif
} // namespace tensorflow

View File

@ -48,6 +48,7 @@ filegroup(
name = "experimental_unary_kernel_srcs",
srcs = [
"gpu_op_asin.cc",
"gpu_op_asinh.cc",
"gpu_op_atan.cc",
"gpu_op_ceil.cc",
"gpu_op_complex.cc",
@ -113,6 +114,7 @@ tf_kernel_library(
# link them in even if they are currently not needed yet.
":abs_kernels",
":asin_kernels",
":asinh_kernels",
":atan_kernels",
":ceil_kernels",
":complex_kernels",
@ -323,6 +325,16 @@ gen_kernel_library(
unroll_factors = "4",
)
gen_kernel_library(
name = "asinh",
tile_size = "256",
types = [
"f32",
"f64",
],
unroll_factors = "4",
)
gen_kernel_library(
name = "atan",
tile_size = "256",

View File

@ -0,0 +1,24 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Asinh, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Asinh, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -190,6 +190,14 @@ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Asin, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne<double>(),
std::asin, test::GpuOpsTestConfig().ExpectStrictlyEqual())
/// Test `tf.Asinh`.
GENERATE_DEFAULT_TEST(Asinh, DT_FLOAT, DT_FLOAT, std::asinh,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST(Asinh, DT_DOUBLE, DT_DOUBLE, std::asinh,
test::GpuOpsTestConfig())
/// Test `tf.Atan`.
GENERATE_DEFAULT_TEST(Atan, DT_FLOAT, DT_FLOAT, std::atan,

View File

@ -0,0 +1,5 @@
func @Asinh_elem_type(%arg0: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Asinh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>
}