[MLIR][KernelGen] Add asin kernels and tests
PiperOrigin-RevId: 351381423 Change-Id: Idb6bb42a153c472da2b495bc66fd0b0202531e29
This commit is contained in:
parent
0423a4a075
commit
b8ac43493a
tensorflow
compiler/mlir
hlo
include/mlir-hlo/Dialect/mhlo/IR
lib/Dialect/mhlo/transforms
xla/transforms
core/kernels
@ -359,6 +359,19 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [],
|
||||||
|
HLO_FpOrComplexTensor> {
|
||||||
|
let summary = "Asin operator";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Returns `Asin(operand)` element-wise.
|
||||||
|
|
||||||
|
$$
|
||||||
|
\asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
|
||||||
|
$$
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
||||||
HLO_FpOrComplexTensor> {
|
HLO_FpOrComplexTensor> {
|
||||||
let summary = "Atan operator";
|
let summary = "Atan operator";
|
||||||
|
@ -60,6 +60,25 @@ def : Pat<(HLOClient_AcosOp NonComplexElementType:$input),
|
|||||||
(HLO_ConstantLike<"M_PI"> $input)
|
(HLO_ConstantLike<"M_PI"> $input)
|
||||||
)>;
|
)>;
|
||||||
|
|
||||||
|
// Expand asin to MHLO dialect as follows:
|
||||||
|
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
|
||||||
|
def : Pat<(HLOClient_AsinOp $input),
|
||||||
|
(HLO_MulOp
|
||||||
|
(HLO_ConstantLike<"2"> $input),
|
||||||
|
(HLO_Atan2Op
|
||||||
|
$input,
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_ConstantLike<"1"> $input),
|
||||||
|
(HLO_SqrtOp
|
||||||
|
(HLO_SubOp
|
||||||
|
(HLO_ConstantLike<"1"> $input),
|
||||||
|
(HLO_MulOp $input, $input)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)>;
|
||||||
|
|
||||||
// Express `atan` as
|
// Express `atan` as
|
||||||
// atan(x) = atan2(x, 1)
|
// atan(x) = atan2(x, 1)
|
||||||
def : Pat<(HLOClient_AtanOp $input),
|
def : Pat<(HLOClient_AtanOp $input),
|
||||||
|
@ -51,8 +51,8 @@ namespace {
|
|||||||
|
|
||||||
// TODO(herhut): Generate these out of op definitions.
|
// TODO(herhut): Generate these out of op definitions.
|
||||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||||
fn(AcosOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(ErfOp) sep fn(ErfcOp) \
|
fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(ErfOp) \
|
||||||
sep fn(SinhOp) sep fn(TanOp)
|
sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||||
|
@ -586,6 +586,7 @@ def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher Elements
|
|||||||
foreach Mapping = [
|
foreach Mapping = [
|
||||||
[TF_AbsOp, HLO_AbsOp],
|
[TF_AbsOp, HLO_AbsOp],
|
||||||
[TF_AcosOp, HLOClient_AcosOp],
|
[TF_AcosOp, HLOClient_AcosOp],
|
||||||
|
[TF_AsinOp, HLOClient_AsinOp],
|
||||||
[TF_AtanOp, HLOClient_AtanOp],
|
[TF_AtanOp, HLOClient_AtanOp],
|
||||||
[TF_CeilOp, HLO_CeilOp],
|
[TF_CeilOp, HLO_CeilOp],
|
||||||
[TF_ComplexAbsOp, HLO_AbsOp],
|
[TF_ComplexAbsOp, HLO_AbsOp],
|
||||||
|
@ -19,7 +19,10 @@ namespace tensorflow {
|
|||||||
REGISTER2(UnaryOp, CPU, "Asin", functor::asin, float, double);
|
REGISTER2(UnaryOp, CPU, "Asin", functor::asin, float, double);
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||||
|
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||||
REGISTER2(UnaryOp, GPU, "Asin", functor::asin, float, double);
|
REGISTER2(UnaryOp, GPU, "Asin", functor::asin, float, double);
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -47,6 +47,7 @@ filegroup(
|
|||||||
filegroup(
|
filegroup(
|
||||||
name = "experimental_unary_kernel_srcs",
|
name = "experimental_unary_kernel_srcs",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"gpu_op_asin.cc",
|
||||||
"gpu_op_atan.cc",
|
"gpu_op_atan.cc",
|
||||||
"gpu_op_ceil.cc",
|
"gpu_op_ceil.cc",
|
||||||
"gpu_op_conj.cc",
|
"gpu_op_conj.cc",
|
||||||
@ -107,6 +108,7 @@ tf_kernel_library(
|
|||||||
# sure that those targets can be built, so it should not hurt to
|
# sure that those targets can be built, so it should not hurt to
|
||||||
# link them in even if they are currently not needed yet.
|
# link them in even if they are currently not needed yet.
|
||||||
":abs_kernels",
|
":abs_kernels",
|
||||||
|
":asin_kernels",
|
||||||
":atan_kernels",
|
":atan_kernels",
|
||||||
":ceil_kernels",
|
":ceil_kernels",
|
||||||
":conj_kernels",
|
":conj_kernels",
|
||||||
@ -303,6 +305,16 @@ gen_kernel_library(
|
|||||||
unroll_factors = "4",
|
unroll_factors = "4",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gen_kernel_library(
|
||||||
|
name = "asin",
|
||||||
|
tile_size = "256",
|
||||||
|
types = [
|
||||||
|
"f32",
|
||||||
|
"f64",
|
||||||
|
],
|
||||||
|
unroll_factors = "4",
|
||||||
|
)
|
||||||
|
|
||||||
gen_kernel_library(
|
gen_kernel_library(
|
||||||
name = "atan",
|
name = "atan",
|
||||||
tile_size = "256",
|
tile_size = "256",
|
||||||
|
24
tensorflow/core/kernels/mlir_generated/gpu_op_asin.cc
Normal file
24
tensorflow/core/kernels/mlir_generated/gpu_op_asin.cc
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
GENERATE_AND_REGISTER_UNARY_KERNEL(Asin, f32, DT_FLOAT, float);
|
||||||
|
GENERATE_AND_REGISTER_UNARY_KERNEL(Asin, f64, DT_DOUBLE, double);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -134,7 +134,14 @@ absl::InlinedVector<T, 10> DefaultInputNonZero() {
|
|||||||
{-18, -9, -1, 1, 3, 4, 5, 7, 9, 10, 18});
|
{-18, -9, -1, 1, 3, 4, 5, 7, 9, 10, 18});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper functions to get default input data.
|
template <typename T, std::enable_if_t<
|
||||||
|
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
||||||
|
bool> = true>
|
||||||
|
absl::InlinedVector<T, 10> DefaultInputBetweenZeroAndOne() {
|
||||||
|
return test::InputAsVector<T, double>({-0.999, -0.9, -0.8, -0.5, -0.1, -0.001,
|
||||||
|
-0, 0, 0.001, 0.1, 0.5, 0.8, 0.9,
|
||||||
|
0.999});
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T,
|
template <typename T,
|
||||||
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
||||||
@ -146,6 +153,8 @@ absl::InlinedVector<T, 10> DefaultInputLessThanBitwidth() {
|
|||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Helper functions to get default input data.
|
||||||
|
|
||||||
template <typename T,
|
template <typename T,
|
||||||
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
||||||
bool> = true>
|
bool> = true>
|
||||||
|
@ -175,6 +175,17 @@ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
|||||||
Abs, DT_INT64, DT_INT64, test::NearZeroAndExtremeInput<int64>(), std::abs,
|
Abs, DT_INT64, DT_INT64, test::NearZeroAndExtremeInput<int64>(), std::abs,
|
||||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
|
|
||||||
|
/// Test `tf.Asin`.
|
||||||
|
|
||||||
|
// Test only values in the function domain. The othweise returned nan value
|
||||||
|
// fails comparison for equality.
|
||||||
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||||
|
Asin, DT_FLOAT, DT_FLOAT, test::DefaultInputBetweenZeroAndOne<float>(),
|
||||||
|
std::asin, test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||||
|
Asin, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne<double>(),
|
||||||
|
std::asin, test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
|
|
||||||
/// Test `tf.Atan`.
|
/// Test `tf.Atan`.
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST(Atan, DT_FLOAT, DT_FLOAT, std::atan,
|
GENERATE_DEFAULT_TEST(Atan, DT_FLOAT, DT_FLOAT, std::atan,
|
||||||
|
Loading…
Reference in New Issue
Block a user