KernelGen for Pow
PiperOrigin-RevId: 351844384 Change-Id: I8d8bf2216c6e0dcc91105af8e8c8da760dfd91de
This commit is contained in:
parent
1203ea6aad
commit
be1c9cdfd9
tensorflow/core/kernels
@ -19,11 +19,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
DEFINE_BINARY3(pow, Eigen::half, float, double);
|
||||
DEFINE_BINARY1(safe_pow_ignore_error, int64);
|
||||
#endif
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -21,10 +21,7 @@ REGISTER6(BinaryOp, CPU, "Pow", functor::pow, float, Eigen::half, bfloat16,
|
||||
REGISTER2(BinaryOp, CPU, "Pow", functor::safe_pow, int32, int64);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
REGISTER3(BinaryOp, GPU, "Pow", functor::pow, float, Eigen::half, double);
|
||||
REGISTER(BinaryOp, GPU, "Pow", functor::safe_pow_ignore_error, int64);
|
||||
#endif
|
||||
#endif
|
||||
} // namespace tensorflow
|
||||
|
@ -153,7 +153,6 @@ tf_kernel_library(
|
||||
"gpu_op_logical_or.cc",
|
||||
"gpu_op_mul.cc",
|
||||
"gpu_op_not_equal.cc",
|
||||
"gpu_op_pow.cc",
|
||||
"gpu_op_right_shift.cc",
|
||||
"gpu_op_sub.cc",
|
||||
],
|
||||
@ -181,7 +180,6 @@ tf_kernel_library(
|
||||
":minimum_kernels",
|
||||
":mul_kernels",
|
||||
":not_equal_kernels",
|
||||
":pow_kernels",
|
||||
":right_shift_kernels",
|
||||
":sub_kernels",
|
||||
"//third_party/eigen3",
|
||||
@ -230,7 +228,7 @@ tf_cuda_cc_test(
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "gpu_binary_ops_test",
|
||||
size = "medium",
|
||||
size = "small",
|
||||
srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_binary_ops_test.cc"]),
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"no_cuda_asan", # b/173033461
|
||||
@ -648,15 +646,3 @@ gen_kernel_library(
|
||||
"sin",
|
||||
]
|
||||
]
|
||||
|
||||
gen_kernel_library(
|
||||
name = "pow",
|
||||
tile_size = "256,1,1",
|
||||
types = [
|
||||
"f16",
|
||||
"f32",
|
||||
"f64",
|
||||
"i64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
@ -734,54 +734,5 @@ GENERATE_DEFAULT_TESTS(Sub,
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Int64, int64, int64, baseline_sub)
|
||||
|
||||
/// Test `tf.Pow`.
|
||||
|
||||
template <typename T>
|
||||
T baseline_pow(T lhs, T rhs) {
|
||||
return std::pow(lhs, rhs);
|
||||
}
|
||||
|
||||
template <typename T, std::enable_if_t<
|
||||
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> PowInput() {
|
||||
return test::InputAsVector<T, double>({0.0, 0.1, 0.2, 0.3, 1.0, 2.0, 3.0});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> PowInput() {
|
||||
return test::InputAsVector<T, double>({0, 1, 3});
|
||||
}
|
||||
|
||||
template <>
|
||||
Eigen::half baseline_pow(Eigen::half lhs, Eigen::half rhs) {
|
||||
return static_cast<Eigen::half>(
|
||||
std::pow(static_cast<float>(lhs), static_cast<float>(rhs)));
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(Pow,
|
||||
/*test_name=*/Half,
|
||||
Eigen::half, Eigen::half,
|
||||
PowInput<Eigen::half>(),
|
||||
PowInput<Eigen::half>(),
|
||||
baseline_pow)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(Pow,
|
||||
/*test_name=*/Float, float,
|
||||
float, PowInput<float>(),
|
||||
PowInput<float>(),
|
||||
baseline_pow)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(Pow,
|
||||
/*test_name=*/Double, double,
|
||||
double, PowInput<double>(),
|
||||
PowInput<double>(),
|
||||
baseline_pow)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(Pow,
|
||||
/*test_name=*/Int64, int64,
|
||||
int64, PowInput<int64>(),
|
||||
PowInput<int64>(),
|
||||
baseline_pow)
|
||||
|
||||
} // namespace
|
||||
} // end namespace tensorflow
|
||||
|
@ -1,25 +0,0 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
@ -1,6 +0,0 @@
|
||||
func @Pow_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
|
||||
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
|
||||
%0 = "tf.Pow"(%arg0, %arg1) {T = elem_type, device = ""}
|
||||
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
|
||||
return %0 : tensor<*xelem_type>
|
||||
}
|
Loading…
Reference in New Issue
Block a user