diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index c1d7ffcc9db..3f28937db58 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -66,6 +66,8 @@ static Value getConstantLike(OpBuilder& b, Location loc, T constant, return b.create(loc, getAttr(), val); } +Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val); + } // namespace chlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 37e6727dfb3..558da58fa08 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -372,6 +372,18 @@ def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [], }]; } +def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [], + HLO_FpOrComplexTensor> { + let summary = "Asinh operation"; + + let description = [{ + Returns `Asinh(operand)` element-wise. + + $$ + \asinh(x) = log(x + sqrt(x^2 + 1)) + $$ + }]; +} def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [], HLO_FpOrComplexTensor> { let summary = "Atan operator"; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td index 461527f3740..84df35362c5 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td @@ -30,6 +30,9 @@ class ConstantSplat : NativeCodeCall< class HLO_ConstantLike : NativeCodeCall< "chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; +def HLO_ConstantLikeMaxFiniteValue : NativeCodeCall< + "chlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; + def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; def BinBroadcastDimensions : NativeCodeCall< diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc index 9761e6abb0a..fa6cc019501 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -32,6 +32,20 @@ static LogicalResult Verify(T op) { return success(); } +static constexpr float kF16MaxFiniteValue = 0x1.ffcP15; + +Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) { + Type ty = getElementTypeOrSelf(val.getType()); + if (ty.isF16()) { + return getConstantLike(b, loc, kF16MaxFiniteValue, val); + } else if (ty.isF32()) { + return getConstantLike(b, loc, std::numeric_limits::max(), val); + } else if (ty.isF64()) { + return getConstantLike(b, loc, std::numeric_limits::max(), val); + } + llvm_unreachable("unhandled type"); +} + //===----------------------------------------------------------------------===// // BinaryOps //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td index 3d89f926c89..b8b6abb72d3 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -79,6 +79,94 @@ def : Pat<(HLOClient_AsinOp NonComplexElementType:$input), ) )>; +// Expand asinh to MHLO dialect as +// asinh(x) = log(x + sqrt(x^2 + 1)) +// +// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1) +// as 2*x and return log(2) + log(x). +// +// For small x, sqrt(x^2 + 1) will evaluate to 1 due to floating point +// arithmetic. However, we would like to retain the low order term of this, +// which is around 0.5 * x^2 using a binomial expansion. +// Let z = sqrt(a^2 + 1) +// The following rewrite retains the lower order term. +// log(a + sqrt(a^2 + 1)) +// = log((a + sqrt(a^2 + 1)) * (1 + sqrt(a^2 + 1)) / (1 + sqrt(a^2 + 1))) +// = log((a + a^2 + 1 + a * z + z) / (1 + z)) +// = log(1 + a + a^2 / (1 + z)) +// = log(1 + a + a^2 / (1 + sqrt(a^2 + 1))) +// +// If x is negative, the above would give us some trouble; we can't approximate +// the result as x + abs(x) = 0 but we are saved by the fact that asinh(-x) = +// -asinh(x). +def : Pat<(HLOClient_AsinhOp NonComplexElementType:$input), + (HLO_MulOp + (HLO_SignOp $input), + (HLO_SelectOp + (HLO_CompareOp + (HLO_AbsOp $input), + (HLO_SqrtOp + (HLO_ConstantLikeMaxFiniteValue $input) + ), + HLO_COMPARISON_DIRECTION_GE, + (HLO_DEFAULT_COMPARISON_TYPE) + ), + (HLO_AddOp + (HLO_LogOp + (HLO_AbsOp $input) + ), + (HLO_LogOp + (HLO_ConstantLike<"2"> $input) + ) + ), + (HLO_SelectOp + (HLO_CompareOp + (HLO_AbsOp $input), + (HLO_ConstantLike<"1"> $input), + HLO_COMPARISON_DIRECTION_LE, + (HLO_DEFAULT_COMPARISON_TYPE) + ), + (HLO_Log1pOp + (HLO_AddOp + (HLO_AbsOp $input), + (HLO_MulOp + (HLO_AbsOp $input), + (HLO_DivOp + (HLO_AbsOp $input), + (HLO_AddOp + (HLO_ConstantLike<"1"> $input), + (HLO_SqrtOp + (HLO_AddOp + (HLO_MulOp + (HLO_AbsOp $input), + (HLO_AbsOp $input) + ), + (HLO_ConstantLike<"1"> $input) + ) + ) + ) + ) + ) + ) + ), + (HLO_LogOp + (HLO_AddOp + (HLO_AbsOp $input), + (HLO_SqrtOp + (HLO_AddOp + (HLO_MulOp + (HLO_AbsOp $input), + (HLO_AbsOp $input) + ), + (HLO_ConstantLike<"1"> $input) + ) + ) + ) + ) + ) + ) + )>; + // Express `atan` as // atan(x) = atan2(x, 1) def : Pat<(HLOClient_AtanOp $input), diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 152be2176ff..bd6d8918d4f 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -50,9 +50,9 @@ namespace { sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp) // TODO(herhut): Generate these out of op definitions. -#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ - fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(CoshOp) \ - sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp) +#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ + fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(ConjOp) \ + sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp) template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 068e202410f..7707b4f8081 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -584,6 +584,7 @@ def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher Elements foreach Mapping = [ [TF_AbsOp, HLO_AbsOp], + [TF_AsinhOp, HLOClient_AsinhOp], [TF_AcosOp, HLOClient_AcosOp], [TF_AsinOp, HLOClient_AsinOp], [TF_AtanOp, HLOClient_AtanOp], diff --git a/tensorflow/core/kernels/cwise_op_asinh.cc b/tensorflow/core/kernels/cwise_op_asinh.cc index d096debca2e..6d162c93b42 100644 --- a/tensorflow/core/kernels/cwise_op_asinh.cc +++ b/tensorflow/core/kernels/cwise_op_asinh.cc @@ -20,8 +20,11 @@ namespace tensorflow { REGISTER4(UnaryOp, CPU, "Asinh", functor::asinh, float, double, complex64, complex128); - #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \ + !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED) REGISTER2(UnaryOp, GPU, "Asinh", functor::asinh, float, double); #endif +#endif + } // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index fbafb45398e..a93f5a49099 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -48,6 +48,7 @@ filegroup( name = "experimental_unary_kernel_srcs", srcs = [ "gpu_op_asin.cc", + "gpu_op_asinh.cc", "gpu_op_atan.cc", "gpu_op_ceil.cc", "gpu_op_complex.cc", @@ -113,6 +114,7 @@ tf_kernel_library( # link them in even if they are currently not needed yet. ":abs_kernels", ":asin_kernels", + ":asinh_kernels", ":atan_kernels", ":ceil_kernels", ":complex_kernels", @@ -323,6 +325,16 @@ gen_kernel_library( unroll_factors = "4", ) +gen_kernel_library( + name = "asinh", + tile_size = "256", + types = [ + "f32", + "f64", + ], + unroll_factors = "4", +) + gen_kernel_library( name = "atan", tile_size = "256", diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_asinh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_asinh.cc new file mode 100644 index 00000000000..e84f739d8d2 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_asinh.cc @@ -0,0 +1,24 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h" + +namespace tensorflow { + +GENERATE_AND_REGISTER_UNARY_KERNEL(Asinh, f32, DT_FLOAT, float); +GENERATE_AND_REGISTER_UNARY_KERNEL(Asinh, f64, DT_DOUBLE, double); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc index 528cbbbac4a..60a8298f1d6 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc @@ -190,6 +190,14 @@ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES( Asin, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne(), std::asin, test::GpuOpsTestConfig().ExpectStrictlyEqual()) +/// Test `tf.Asinh`. + +GENERATE_DEFAULT_TEST(Asinh, DT_FLOAT, DT_FLOAT, std::asinh, + test::GpuOpsTestConfig()) + +GENERATE_DEFAULT_TEST(Asinh, DT_DOUBLE, DT_DOUBLE, std::asinh, + test::GpuOpsTestConfig()) + /// Test `tf.Atan`. GENERATE_DEFAULT_TEST(Atan, DT_FLOAT, DT_FLOAT, std::atan, diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/asinh.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/asinh.mlir.tmpl new file mode 100644 index 00000000000..0d23ce756bf --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/asinh.mlir.tmpl @@ -0,0 +1,5 @@ +func @Asinh_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Asinh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +}