diff --git a/tensorflow/core/kernels/cwise_op_gpu_pow.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_pow.cu.cc index 19f915d08da..1a4acb78f69 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_pow.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_pow.cu.cc @@ -19,8 +19,11 @@ limitations under the License. namespace tensorflow { namespace functor { +#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \ + !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED) DEFINE_BINARY3(pow, Eigen::half, float, double); DEFINE_BINARY1(safe_pow_ignore_error, int64); +#endif } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_pow.cc b/tensorflow/core/kernels/cwise_op_pow.cc index f601781a5e6..e03ce437363 100644 --- a/tensorflow/core/kernels/cwise_op_pow.cc +++ b/tensorflow/core/kernels/cwise_op_pow.cc @@ -21,7 +21,10 @@ REGISTER6(BinaryOp, CPU, "Pow", functor::pow, float, Eigen::half, bfloat16, REGISTER2(BinaryOp, CPU, "Pow", functor::safe_pow, int32, int64); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \ + !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED) REGISTER3(BinaryOp, GPU, "Pow", functor::pow, float, Eigen::half, double); REGISTER(BinaryOp, GPU, "Pow", functor::safe_pow_ignore_error, int64); #endif +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index dc7aec9bfc3..88732fd4885 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -179,6 +179,7 @@ tf_kernel_library( "gpu_op_logical_or.cc", "gpu_op_mul.cc", "gpu_op_not_equal.cc", + "gpu_op_pow.cc", "gpu_op_right_shift.cc", "gpu_op_sub.cc", ], @@ -206,6 +207,7 @@ tf_kernel_library( ":minimum_kernels", ":mul_kernels", ":not_equal_kernels", + ":pow_kernels", ":right_shift_kernels", ":sub_kernels", "//third_party/eigen3", @@ -254,7 +256,7 @@ tf_cuda_cc_test( tf_cuda_cc_test( name = "gpu_binary_ops_test", - size = "small", + size = "medium", srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_binary_ops_test.cc"]), tags = tf_cuda_tests_tags() + [ "no_cuda_asan", # b/173033461 @@ -745,3 +747,15 @@ gen_kernel_library( "tan", ] ] + +gen_kernel_library( + name = "pow", + tile_size = "256,1,1", + types = [ + "f16", + "f32", + "f64", + "i64", + ], + unroll_factors = "4", +) diff --git a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc index facc5fcd608..3109f042c2a 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc @@ -778,5 +778,54 @@ GENERATE_DEFAULT_TESTS(Sub, GENERATE_DEFAULT_TESTS(Sub, /*test_name=*/Int64, int64, int64, baseline_sub) +/// Test `tf.Pow`. + +template +T baseline_pow(T lhs, T rhs) { + return std::pow(lhs, rhs); +} + +template ::value, + bool> = true> +absl::InlinedVector PowInput() { + return test::InputAsVector({0.0, 0.1, 0.2, 0.3, 1.0, 2.0, 3.0}); +} + +template ::value, + bool> = true> +absl::InlinedVector PowInput() { + return test::InputAsVector({-2, -1, -1, 1, 1, 3}); +} + +template <> +Eigen::half baseline_pow(Eigen::half lhs, Eigen::half rhs) { + return static_cast( + std::pow(static_cast(lhs), static_cast(rhs))); +} + +GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(Pow, + /*test_name=*/Half, + Eigen::half, Eigen::half, + PowInput(), + PowInput(), + baseline_pow) +GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(Pow, + /*test_name=*/Float, float, + float, PowInput(), + PowInput(), + baseline_pow) +GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(Pow, + /*test_name=*/Double, double, + double, PowInput(), + PowInput(), + baseline_pow) +GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(Pow, + /*test_name=*/Int64, int64, + int64, PowInput(), + PowInput(), + baseline_pow) + } // namespace } // end namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_pow.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_pow.cc new file mode 100644 index 00000000000..2a167caa2f0 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_pow.cc @@ -0,0 +1,25 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h" + +namespace tensorflow { + +GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, f16, DT_HALF, Eigen::half); +GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, f32, DT_FLOAT, float); +GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, f64, DT_DOUBLE, double); +GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, i64, DT_INT64, int64); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/pow.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/pow.mlir.tmpl new file mode 100644 index 00000000000..1e1e9a9759b --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/pow.mlir.tmpl @@ -0,0 +1,6 @@ +func @Pow_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Pow"(%arg0, %arg1) {T = elem_type, device = ""} + : (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +}