Add Expm1 Kernel and tests.
Lower Expm1 kernel via tf2tf. PiperOrigin-RevId: 351970465 Change-Id: Ic3fc9ac6b49a58b91997a70e1b67ebce75ead3e8
This commit is contained in:
parent
df87bbd3bb
commit
7cdea3889f
@ -969,3 +969,14 @@ func @size_to_prod_shape_i64(%arg0 : tensor<1x?x2x3xf32>) -> tensor<i64> {
|
||||
// CHECK: %[[PROD:.*]] = "tf.Prod"(%[[SHAPE]], %[[CONSTANT]]) {keep_dims = false} : (tensor<4xi64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK: return %[[PROD]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @expm1
|
||||
// CHECK-SAME: (%[[X:.*]]: tensor<*xf32>)
|
||||
func @expm1(%x: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[EXP:.*]] = "tf.Exp"(%[[X]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.Sub"(%[[EXP]], %[[ONE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||
%0 = "tf.Expm1" (%x) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
return %0: tensor<*xf32>
|
||||
}
|
||||
|
@ -160,6 +160,14 @@ foreach fromToBinPair = [[TF_DivNoNanOp, TF_DivOp],
|
||||
[TF_MulNoNanOp, TF_MulOp]] in
|
||||
def : BinaryNoNanPat<fromToBinPair[0], fromToBinPair[1]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Expm1 op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def LowerExpm1Op : Pat<(TF_Expm1Op $x),
|
||||
(TF_SubOp (TF_ExpOp $x),
|
||||
(TF_ConstOp (GetScalarOfType<1> $x)))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Fill op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -134,7 +134,6 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::ErfcOp>(),
|
||||
TypeID::get<TF::ErfinvOp>(),
|
||||
TypeID::get<TF::ErfOp>(),
|
||||
TypeID::get<TF::Expm1Op>(),
|
||||
TypeID::get<TF::ExtractImagePatchesOp>(),
|
||||
TypeID::get<TF::FFT2DOp>(),
|
||||
TypeID::get<TF::FFT3DOp>(),
|
||||
|
@ -19,6 +19,9 @@ namespace tensorflow {
|
||||
REGISTER6(UnaryOp, CPU, "Expm1", functor::expm1, float, Eigen::half, bfloat16,
|
||||
double, complex64, complex128);
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
REGISTER3(UnaryOp, GPU, "Expm1", functor::expm1, float, Eigen::half, double);
|
||||
#endif
|
||||
#endif
|
||||
} // namespace tensorflow
|
||||
|
@ -54,6 +54,7 @@ filegroup(
|
||||
"gpu_op_conj.cc",
|
||||
"gpu_op_cos.cc",
|
||||
"gpu_op_exp.cc",
|
||||
"gpu_op_expm1.cc",
|
||||
"gpu_op_floor.cc",
|
||||
"gpu_op_imag.cc",
|
||||
"gpu_op_is_inf.cc",
|
||||
@ -117,6 +118,7 @@ tf_kernel_library(
|
||||
":conj_kernels",
|
||||
":cos_kernels",
|
||||
":exp_kernels",
|
||||
":expm1_kernels",
|
||||
":floor_kernels",
|
||||
":imag_kernels",
|
||||
":is_inf_kernels",
|
||||
@ -622,6 +624,7 @@ gen_kernel_library(
|
||||
for name in [
|
||||
"ceil",
|
||||
"exp",
|
||||
"expm1",
|
||||
"floor",
|
||||
"is_finite",
|
||||
"log",
|
||||
|
25
tensorflow/core/kernels/mlir_generated/gpu_op_expm1.cc
Normal file
25
tensorflow/core/kernels/mlir_generated/gpu_op_expm1.cc
Normal file
@ -0,0 +1,25 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Expm1, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Expm1, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Expm1, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
@ -240,6 +240,17 @@ GENERATE_DEFAULT_TEST(Exp, DT_DOUBLE, DT_DOUBLE, std::exp,
|
||||
GENERATE_DEFAULT_TEST_2(Exp, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::exp,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
/// Test `tf.Expm1`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Expm1, DT_FLOAT, DT_FLOAT, std::expm1,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Expm1, DT_DOUBLE, DT_DOUBLE, std::expm1,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_2(Expm1, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::expm1,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
/// Test `tf.Floor`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Floor, DT_FLOAT, DT_FLOAT, std::floor,
|
||||
|
Loading…
Reference in New Issue
Block a user