Add ComplexAbs kernels and tests.

Also add the LowerComplexPass to KernelCreator. This is needed for
lowering ComplexAbs.

PiperOrigin-RevId: 352388803
Change-Id: I45ddefd239c99f36d62af2f6bcf62f1b062df511
This commit is contained in:
Adrian Kuegel 2021-01-18 04:31:57 -08:00 committed by TensorFlower Gardener
parent 92bc3920b9
commit e19e3d081a
6 changed files with 63 additions and 0 deletions
tensorflow
compiler/mlir/tools/kernel_gen
core/kernels

View File

@ -115,6 +115,7 @@ Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
/*allow_partial_conversion=*/false, /*legalize_chlo=*/false));
pm.addNestedPass<mlir::FuncOp>(mlir::createTransformUnrankedHloPass());
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createChloLegalizeToHloPass());
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLowerComplexPass());
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());

View File

@ -24,7 +24,10 @@ REGISTER2(UnaryOp, CPU, "ComplexAbs", functor::abs, complex64, complex128);
#ifndef MLIR_GENERATED_GPU_KERNELS_ENABLED
REGISTER4(UnaryOp, GPU, "Abs", functor::abs, Eigen::half, float, double, int64);
#endif
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER2(UnaryOp, GPU, "ComplexAbs", functor::abs, complex64, complex128);
#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel

View File

@ -52,6 +52,7 @@ filegroup(
"gpu_op_atan.cc",
"gpu_op_ceil.cc",
"gpu_op_complex.cc",
"gpu_op_complex_abs.cc",
"gpu_op_conj.cc",
"gpu_op_cos.cc",
"gpu_op_cosh.cc",
@ -118,6 +119,7 @@ tf_kernel_library(
":asinh_kernels",
":atan_kernels",
":ceil_kernels",
":complex_abs_kernels",
":complex_kernels",
":conj_kernels",
":cos_kernels",
@ -465,6 +467,15 @@ gen_kernel_library(
unroll_factors = "2",
)
gen_kernel_library(
name = "complex_abs",
tile_size = "256",
types = [
"f32",
"f64",
],
)
gen_kernel_library(
name = "div",
tile_size = "256,1,1",

View File

@ -0,0 +1,29 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
namespace tensorflow {
GENERATE_UNARY_KERNEL2(ComplexAbs, f32, DT_FLOAT, float, std::complex<float>);
REGISTER_COMPLEX_KERNEL(ComplexAbs, f32, float, std::complex<float>);
GENERATE_UNARY_KERNEL2(ComplexAbs, f64, DT_DOUBLE, double,
std::complex<double>);
REGISTER_COMPLEX_KERNEL(ComplexAbs, f64, double, std::complex<double>);
} // namespace tensorflow

View File

@ -218,6 +218,20 @@ GENERATE_DEFAULT_TEST(Ceil, DT_DOUBLE, DT_DOUBLE, std::ceil,
GENERATE_DEFAULT_TEST_2(Ceil, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::ceil,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
/// Test `tf.ComplexAbs`.
template <typename T>
typename T::value_type baseline_complex_abs(T x) {
return std::abs(x);
}
GENERATE_DEFAULT_TEST(ComplexAbs, DT_COMPLEX64, DT_FLOAT, baseline_complex_abs,
test::GpuOpsTestConfig().AddTout().NoBufferReuse())
GENERATE_DEFAULT_TEST(ComplexAbs, DT_COMPLEX128, DT_DOUBLE,
baseline_complex_abs,
test::GpuOpsTestConfig().AddTout().NoBufferReuse())
/// Test `tf.Conj`.
template <typename T>

View File

@ -0,0 +1,5 @@
func @ComplexAbs_elem_type(%arg0: tensor<*xcomplex<elem_type>>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.ComplexAbs"(%arg0) : (tensor<*xcomplex<elem_type>>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>
}