[MLIR][KernelGen] Add tf.Asinh
kernels and complete their lowerings
PiperOrigin-RevId: 351989552 Change-Id: Ib86198983dcb4e8ec8b7edcef09af3b3787c09af
This commit is contained in:
parent
bc464afaba
commit
7449c255cb
@ -66,6 +66,8 @@ static Value getConstantLike(OpBuilder& b, Location loc, T constant,
|
||||
return b.create<ConstantLikeOp>(loc, getAttr(), val);
|
||||
}
|
||||
|
||||
Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val);
|
||||
|
||||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -372,6 +372,18 @@ def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [],
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
let summary = "Asinh operation";
|
||||
|
||||
let description = [{
|
||||
Returns `Asinh(operand)` element-wise.
|
||||
|
||||
$$
|
||||
\asinh(x) = log(x + sqrt(x^2 + 1))
|
||||
$$
|
||||
}];
|
||||
}
|
||||
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
let summary = "Atan operator";
|
||||
|
@ -30,6 +30,9 @@ class ConstantSplat<string value> : NativeCodeCall<
|
||||
class HLO_ConstantLike<string value> : NativeCodeCall<
|
||||
"chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">;
|
||||
|
||||
def HLO_ConstantLikeMaxFiniteValue : NativeCodeCall<
|
||||
"chlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">;
|
||||
|
||||
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
||||
|
||||
def BinBroadcastDimensions : NativeCodeCall<
|
||||
|
@ -32,6 +32,20 @@ static LogicalResult Verify(T op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static constexpr float kF16MaxFiniteValue = 0x1.ffcP15;
|
||||
|
||||
Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) {
|
||||
Type ty = getElementTypeOrSelf(val.getType());
|
||||
if (ty.isF16()) {
|
||||
return getConstantLike(b, loc, kF16MaxFiniteValue, val);
|
||||
} else if (ty.isF32()) {
|
||||
return getConstantLike(b, loc, std::numeric_limits<float>::max(), val);
|
||||
} else if (ty.isF64()) {
|
||||
return getConstantLike(b, loc, std::numeric_limits<double>::max(), val);
|
||||
}
|
||||
llvm_unreachable("unhandled type");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BinaryOps
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -79,6 +79,94 @@ def : Pat<(HLOClient_AsinOp NonComplexElementType:$input),
|
||||
)
|
||||
)>;
|
||||
|
||||
// Expand asinh to MHLO dialect as
|
||||
// asinh(x) = log(x + sqrt(x^2 + 1))
|
||||
//
|
||||
// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1)
|
||||
// as 2*x and return log(2) + log(x).
|
||||
//
|
||||
// For small x, sqrt(x^2 + 1) will evaluate to 1 due to floating point
|
||||
// arithmetic. However, we would like to retain the low order term of this,
|
||||
// which is around 0.5 * x^2 using a binomial expansion.
|
||||
// Let z = sqrt(a^2 + 1)
|
||||
// The following rewrite retains the lower order term.
|
||||
// log(a + sqrt(a^2 + 1))
|
||||
// = log((a + sqrt(a^2 + 1)) * (1 + sqrt(a^2 + 1)) / (1 + sqrt(a^2 + 1)))
|
||||
// = log((a + a^2 + 1 + a * z + z) / (1 + z))
|
||||
// = log(1 + a + a^2 / (1 + z))
|
||||
// = log(1 + a + a^2 / (1 + sqrt(a^2 + 1)))
|
||||
//
|
||||
// If x is negative, the above would give us some trouble; we can't approximate
|
||||
// the result as x + abs(x) = 0 but we are saved by the fact that asinh(-x) =
|
||||
// -asinh(x).
|
||||
def : Pat<(HLOClient_AsinhOp NonComplexElementType:$input),
|
||||
(HLO_MulOp
|
||||
(HLO_SignOp $input),
|
||||
(HLO_SelectOp
|
||||
(HLO_CompareOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_SqrtOp
|
||||
(HLO_ConstantLikeMaxFiniteValue $input)
|
||||
),
|
||||
HLO_COMPARISON_DIRECTION_GE,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||
),
|
||||
(HLO_AddOp
|
||||
(HLO_LogOp
|
||||
(HLO_AbsOp $input)
|
||||
),
|
||||
(HLO_LogOp
|
||||
(HLO_ConstantLike<"2"> $input)
|
||||
)
|
||||
),
|
||||
(HLO_SelectOp
|
||||
(HLO_CompareOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_ConstantLike<"1"> $input),
|
||||
HLO_COMPARISON_DIRECTION_LE,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||
),
|
||||
(HLO_Log1pOp
|
||||
(HLO_AddOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_MulOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_DivOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_AddOp
|
||||
(HLO_ConstantLike<"1"> $input),
|
||||
(HLO_SqrtOp
|
||||
(HLO_AddOp
|
||||
(HLO_MulOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_AbsOp $input)
|
||||
),
|
||||
(HLO_ConstantLike<"1"> $input)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
(HLO_LogOp
|
||||
(HLO_AddOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_SqrtOp
|
||||
(HLO_AddOp
|
||||
(HLO_MulOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_AbsOp $input)
|
||||
),
|
||||
(HLO_ConstantLike<"1"> $input)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)>;
|
||||
|
||||
// Express `atan` as
|
||||
// atan(x) = atan2(x, 1)
|
||||
def : Pat<(HLOClient_AtanOp $input),
|
||||
|
@ -50,9 +50,9 @@ namespace {
|
||||
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
|
||||
|
||||
// TODO(herhut): Generate these out of op definitions.
|
||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||
fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(CoshOp) \
|
||||
sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
|
||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||
fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(ConjOp) \
|
||||
sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
|
||||
|
||||
template <typename OpTy>
|
||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||
|
@ -584,6 +584,7 @@ def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher Elements
|
||||
|
||||
foreach Mapping = [
|
||||
[TF_AbsOp, HLO_AbsOp],
|
||||
[TF_AsinhOp, HLOClient_AsinhOp],
|
||||
[TF_AcosOp, HLOClient_AcosOp],
|
||||
[TF_AsinOp, HLOClient_AsinOp],
|
||||
[TF_AtanOp, HLOClient_AtanOp],
|
||||
|
@ -20,8 +20,11 @@ namespace tensorflow {
|
||||
REGISTER4(UnaryOp, CPU, "Asinh", functor::asinh, float, double, complex64,
|
||||
complex128);
|
||||
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
REGISTER2(UnaryOp, GPU, "Asinh", functor::asinh, float, double);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -48,6 +48,7 @@ filegroup(
|
||||
name = "experimental_unary_kernel_srcs",
|
||||
srcs = [
|
||||
"gpu_op_asin.cc",
|
||||
"gpu_op_asinh.cc",
|
||||
"gpu_op_atan.cc",
|
||||
"gpu_op_ceil.cc",
|
||||
"gpu_op_complex.cc",
|
||||
@ -113,6 +114,7 @@ tf_kernel_library(
|
||||
# link them in even if they are currently not needed yet.
|
||||
":abs_kernels",
|
||||
":asin_kernels",
|
||||
":asinh_kernels",
|
||||
":atan_kernels",
|
||||
":ceil_kernels",
|
||||
":complex_kernels",
|
||||
@ -323,6 +325,16 @@ gen_kernel_library(
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "asinh",
|
||||
tile_size = "256",
|
||||
types = [
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "atan",
|
||||
tile_size = "256",
|
||||
|
24
tensorflow/core/kernels/mlir_generated/gpu_op_asinh.cc
Normal file
24
tensorflow/core/kernels/mlir_generated/gpu_op_asinh.cc
Normal file
@ -0,0 +1,24 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Asinh, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Asinh, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
@ -190,6 +190,14 @@ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Asin, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne<double>(),
|
||||
std::asin, test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
/// Test `tf.Asinh`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Asinh, DT_FLOAT, DT_FLOAT, std::asinh,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Asinh, DT_DOUBLE, DT_DOUBLE, std::asinh,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
/// Test `tf.Atan`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Atan, DT_FLOAT, DT_FLOAT, std::atan,
|
||||
|
@ -0,0 +1,5 @@
|
||||
func @Asinh_elem_type(%arg0: tensor<*xelem_type>)
|
||||
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
|
||||
%0 = "tf.Asinh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
|
||||
return %0 : tensor<*xelem_type>
|
||||
}
|
Loading…
Reference in New Issue
Block a user