[MLIR][KernelGen] Add cosh kernels and tests
Allow for relative tolerance in unary kernel tests. In case of the cosh kernels, this allows to accept an observed difference of 5.6e-8 between the kernel and the `std::cosh` reference (32829984.568665262 vs. 32829984.568665318) in one of the test cases. PiperOrigin-RevId: 351983698 Change-Id: I4b0f345cda2925a10619cba24d4427797db49599
This commit is contained in:
parent
489133d42a
commit
1af6b4e6a5
@ -398,6 +398,19 @@ def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [],
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLOClient_CoshOp : HLOClient_UnaryElementwiseOp<"cosh", [],
|
||||||
|
HLO_FpOrComplexTensor> {
|
||||||
|
let summary = "Cosh operator";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Returns `Cosh(operand)` element-wise.
|
||||||
|
|
||||||
|
$$
|
||||||
|
\cosh(x) = (e^x + e^-x) / 2
|
||||||
|
$$
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
||||||
HLO_FpOrComplexTensor> {
|
HLO_FpOrComplexTensor> {
|
||||||
let summary = "Sinh operation";
|
let summary = "Sinh operation";
|
||||||
|
@ -62,7 +62,7 @@ def : Pat<(HLOClient_AcosOp NonComplexElementType:$input),
|
|||||||
|
|
||||||
// Expand asin to MHLO dialect as follows:
|
// Expand asin to MHLO dialect as follows:
|
||||||
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
|
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
|
||||||
def : Pat<(HLOClient_AsinOp $input),
|
def : Pat<(HLOClient_AsinOp NonComplexElementType:$input),
|
||||||
(HLO_MulOp
|
(HLO_MulOp
|
||||||
(HLO_ConstantLike<"2"> $input),
|
(HLO_ConstantLike<"2"> $input),
|
||||||
(HLO_Atan2Op
|
(HLO_Atan2Op
|
||||||
@ -92,6 +92,36 @@ def : Pat<(HLOClient_AtanOp $input),
|
|||||||
def : Pat<(HLOClient_ConjOp $v),
|
def : Pat<(HLOClient_ConjOp $v),
|
||||||
(HLO_ComplexOp (HLO_RealOp $v), (HLO_NegOp (HLO_ImagOp $v)))>;
|
(HLO_ComplexOp (HLO_RealOp $v), (HLO_NegOp (HLO_ImagOp $v)))>;
|
||||||
|
|
||||||
|
// Express `cosh` as
|
||||||
|
// cosh(x) = (e^x + e^-x) / 2
|
||||||
|
// = e^(x + log(1/2)) + e^(-x + log(1/2))
|
||||||
|
//
|
||||||
|
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not.
|
||||||
|
//
|
||||||
|
// This incorrectly overflows to inf for two f32 input values, namely
|
||||||
|
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
|
||||||
|
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
|
||||||
|
// we deem this acceptable.
|
||||||
|
def : Pat<(HLOClient_CoshOp NonComplexElementType:$input),
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_ExpOp
|
||||||
|
(HLO_AddOp
|
||||||
|
$input,
|
||||||
|
(HLO_LogOp
|
||||||
|
(HLO_ConstantLike<"0.5"> $input)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(HLO_ExpOp
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_NegOp $input),
|
||||||
|
(HLO_LogOp
|
||||||
|
(HLO_ConstantLike<"0.5"> $input)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)>;
|
||||||
|
|
||||||
// Express `sinh` as
|
// Express `sinh` as
|
||||||
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
||||||
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
|
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
|
||||||
@ -136,7 +166,7 @@ def : Pat<(HLOClient_SinhOp NonComplexElementType:$input),
|
|||||||
|
|
||||||
// Express tan in MHLO dialect as
|
// Express tan in MHLO dialect as
|
||||||
// tan(x) = sin(x) / cos(x).
|
// tan(x) = sin(x) / cos(x).
|
||||||
def : Pat<(HLOClient_TanOp $input),
|
def : Pat<(HLOClient_TanOp NonComplexElementType:$input),
|
||||||
(HLO_DivOp
|
(HLO_DivOp
|
||||||
(HLO_SinOp $input),
|
(HLO_SinOp $input),
|
||||||
(HLO_CosOp $input)
|
(HLO_CosOp $input)
|
||||||
|
@ -50,9 +50,9 @@ namespace {
|
|||||||
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
|
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
|
||||||
|
|
||||||
// TODO(herhut): Generate these out of op definitions.
|
// TODO(herhut): Generate these out of op definitions.
|
||||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||||
fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(ErfOp) \
|
fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(CoshOp) \
|
||||||
sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
|
sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||||
|
@ -588,6 +588,7 @@ foreach Mapping = [
|
|||||||
[TF_AsinOp, HLOClient_AsinOp],
|
[TF_AsinOp, HLOClient_AsinOp],
|
||||||
[TF_AtanOp, HLOClient_AtanOp],
|
[TF_AtanOp, HLOClient_AtanOp],
|
||||||
[TF_CeilOp, HLO_CeilOp],
|
[TF_CeilOp, HLO_CeilOp],
|
||||||
|
[TF_CoshOp, HLOClient_CoshOp],
|
||||||
[TF_ComplexAbsOp, HLO_AbsOp],
|
[TF_ComplexAbsOp, HLO_AbsOp],
|
||||||
[TF_ConjOp, HLOClient_ConjOp],
|
[TF_ConjOp, HLOClient_ConjOp],
|
||||||
[TF_CosOp, HLO_CosOp],
|
[TF_CosOp, HLO_CosOp],
|
||||||
|
@ -19,8 +19,11 @@ namespace tensorflow {
|
|||||||
REGISTER5(UnaryOp, CPU, "Cosh", functor::cosh, float, double, bfloat16,
|
REGISTER5(UnaryOp, CPU, "Cosh", functor::cosh, float, double, bfloat16,
|
||||||
complex64, complex128);
|
complex64, complex128);
|
||||||
|
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||||
|
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||||
REGISTER2(UnaryOp, GPU, "Cosh", functor::cosh, float, double);
|
REGISTER2(UnaryOp, GPU, "Cosh", functor::cosh, float, double);
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -53,6 +53,7 @@ filegroup(
|
|||||||
"gpu_op_complex.cc",
|
"gpu_op_complex.cc",
|
||||||
"gpu_op_conj.cc",
|
"gpu_op_conj.cc",
|
||||||
"gpu_op_cos.cc",
|
"gpu_op_cos.cc",
|
||||||
|
"gpu_op_cosh.cc",
|
||||||
"gpu_op_exp.cc",
|
"gpu_op_exp.cc",
|
||||||
"gpu_op_expm1.cc",
|
"gpu_op_expm1.cc",
|
||||||
"gpu_op_floor.cc",
|
"gpu_op_floor.cc",
|
||||||
@ -117,6 +118,7 @@ tf_kernel_library(
|
|||||||
":complex_kernels",
|
":complex_kernels",
|
||||||
":conj_kernels",
|
":conj_kernels",
|
||||||
":cos_kernels",
|
":cos_kernels",
|
||||||
|
":cosh_kernels",
|
||||||
":exp_kernels",
|
":exp_kernels",
|
||||||
":expm1_kernels",
|
":expm1_kernels",
|
||||||
":floor_kernels",
|
":floor_kernels",
|
||||||
@ -341,6 +343,16 @@ gen_kernel_library(
|
|||||||
unroll_factors = "2",
|
unroll_factors = "2",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gen_kernel_library(
|
||||||
|
name = "cosh",
|
||||||
|
tile_size = "256",
|
||||||
|
types = [
|
||||||
|
"f32",
|
||||||
|
"f64",
|
||||||
|
],
|
||||||
|
unroll_factors = "4",
|
||||||
|
)
|
||||||
|
|
||||||
gen_kernel_library(
|
gen_kernel_library(
|
||||||
name = "imag",
|
name = "imag",
|
||||||
tile_size = "256",
|
tile_size = "256",
|
||||||
|
24
tensorflow/core/kernels/mlir_generated/gpu_op_cosh.cc
Normal file
24
tensorflow/core/kernels/mlir_generated/gpu_op_cosh.cc
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
GENERATE_AND_REGISTER_UNARY_KERNEL(Cosh, f32, DT_FLOAT, float);
|
||||||
|
GENERATE_AND_REGISTER_UNARY_KERNEL(Cosh, f64, DT_DOUBLE, double);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -84,7 +84,8 @@ class GpuUnaryOpTest : public OpsTestBase {
|
|||||||
if (config.expect_strictly_equal) {
|
if (config.expect_strictly_equal) {
|
||||||
test::ExpectEqual(expected_tensor, *GetOutput(0));
|
test::ExpectEqual(expected_tensor, *GetOutput(0));
|
||||||
} else {
|
} else {
|
||||||
test::ExpectClose(expected_tensor, *GetOutput(0));
|
test::ExpectClose(expected_tensor, *GetOutput(0), kAbsoluteTolerance,
|
||||||
|
kRelativeTolerance);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,6 +107,9 @@ class GpuUnaryOpTest : public OpsTestBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
constexpr static double kAbsoluteTolerance = 0.001;
|
||||||
|
constexpr static double kRelativeTolerance = 0.001;
|
||||||
|
|
||||||
template <typename T, typename BaselineT, typename OutT,
|
template <typename T, typename BaselineT, typename OutT,
|
||||||
typename BaselineOutT>
|
typename BaselineOutT>
|
||||||
absl::InlinedVector<OutT, 10> ComputeExpectedOutput(
|
absl::InlinedVector<OutT, 10> ComputeExpectedOutput(
|
||||||
@ -229,6 +233,14 @@ GENERATE_DEFAULT_TEST(Cos, DT_DOUBLE, DT_DOUBLE, std::cos,
|
|||||||
GENERATE_DEFAULT_TEST_2(Cos, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::cos,
|
GENERATE_DEFAULT_TEST_2(Cos, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::cos,
|
||||||
test::GpuOpsTestConfig())
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
|
/// Test `tf.Cosh`.
|
||||||
|
|
||||||
|
GENERATE_DEFAULT_TEST(Cosh, DT_FLOAT, DT_FLOAT, std::cosh,
|
||||||
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
|
GENERATE_DEFAULT_TEST(Cosh, DT_DOUBLE, DT_DOUBLE, std::cosh,
|
||||||
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
/// Test `tf.Exp`.
|
/// Test `tf.Exp`.
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST(Exp, DT_FLOAT, DT_FLOAT, std::exp,
|
GENERATE_DEFAULT_TEST(Exp, DT_FLOAT, DT_FLOAT, std::exp,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user