[MLIR][KernelGen] Add atan2 kernels
PiperOrigin-RevId: 351120975 Change-Id: I07bc4518bdaf66fcc4065c42a4cb393e742fba1e
This commit is contained in:
parent
dff3c8a47a
commit
091856315b
@ -37,14 +37,4 @@ func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
|
||||
return %1 : tensor<4x7xcomplex<f64>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: atan2
|
||||
func @atan2(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> tensor<4x4x4xf32> {
|
||||
// NO_FALLBACK: tf.Atan2
|
||||
// SUPPORTED_FALLBACK_DEVICE-NOT: tf.Atan2
|
||||
// UNSPECIFIED_FALLBACK_DEVICE: tf.Atan2
|
||||
// UNSUPPORTED_FALLBACK_DEVICE: tf.Atan2
|
||||
%0 = "tf.Atan2"(%arg0, %arg1) : (tensor<4x1xf32>, tensor<4x1x4xf32>) -> tensor<4x4x4xf32>
|
||||
return %0: tensor<4x4x4xf32>
|
||||
}
|
||||
|
||||
}
|
@ -91,6 +91,7 @@ class DirectBinaryPat<Op FromOp, Op ToOp>
|
||||
(ToOp $l, $r, (BinBroadcastDimensions $l, $r))>;
|
||||
|
||||
foreach fromToBinPair = [[TF_AddV2Op, HLOClient_BroadcastAddOp],
|
||||
[TF_Atan2Op, HLOClient_BroadcastAtan2Op],
|
||||
[TF_DivOp, HLOClient_BroadcastDivOp],
|
||||
[TF_LeftShiftOp, HLOClient_BroadcastShiftLeftOp],
|
||||
[TF_MaximumOp, HLOClient_BroadcastMaxOp],
|
||||
|
@ -16,8 +16,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER2(BinaryOp, CPU, "Atan2", functor::atan2, float, double);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
REGISTER2(BinaryOp, GPU, "Atan2", functor::atan2, float, double);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -131,6 +131,7 @@ tf_kernel_library(
|
||||
name = "cwise_binary_op",
|
||||
srcs = [
|
||||
"gpu_op_add.cc",
|
||||
"gpu_op_atan2.cc",
|
||||
"gpu_op_bitwise_and.cc",
|
||||
"gpu_op_bitwise_or.cc",
|
||||
"gpu_op_bitwise_xor.cc",
|
||||
@ -154,6 +155,7 @@ tf_kernel_library(
|
||||
],
|
||||
deps = [
|
||||
":add_v2_kernels",
|
||||
":atan2_kernels",
|
||||
":bitwise_and_kernels",
|
||||
":bitwise_or_kernels",
|
||||
":bitwise_xor_kernels",
|
||||
@ -457,6 +459,17 @@ gen_kernel_library(
|
||||
]
|
||||
]
|
||||
|
||||
gen_kernel_library(
|
||||
name = "atan2",
|
||||
tile_size = "256,1,1",
|
||||
types = [
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
# TODO(b/174543802): Enable once fusion heursitics is better.
|
||||
# unroll_factors = "4",
|
||||
)
|
||||
|
||||
# Logical operations.
|
||||
[
|
||||
gen_kernel_library(
|
||||
|
@ -192,7 +192,7 @@ def gen_kernel_library(name, types, tile_size, tags = [], unroll_factors = None,
|
||||
"$(location //tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel)",
|
||||
"$(location {name}_{type}.mlir)".format(name = name, type = type),
|
||||
],
|
||||
size = "small",
|
||||
size = "medium",
|
||||
data = [
|
||||
":{name}_{type}.mlir".format(name = name, type = type),
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel",
|
||||
|
@ -373,15 +373,53 @@ T baseline_add(T lhs, T rhs) {
|
||||
return lhs + rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(AddV2,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_add)
|
||||
GENERATE_DEFAULT_TESTS(AddV2,
|
||||
/*test_name=*/Float, float, float, baseline_add)
|
||||
GENERATE_DEFAULT_TESTS(AddV2,
|
||||
/*test_name=*/Double, double, double, baseline_add)
|
||||
GENERATE_DEFAULT_TESTS(AddV2,
|
||||
/*test_name=*/Int64, int64, int64, baseline_add)
|
||||
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Float, float, float, baseline_add)
|
||||
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Double, double, double,
|
||||
baseline_add)
|
||||
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Int64, int64, int64, baseline_add)
|
||||
|
||||
/// Test `tf.Atan2`.
|
||||
|
||||
// Prevent the undefined case (0, 0) with non-zero rhs values.
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Atan2,
|
||||
/*test_name=*/FloatRhsNonZero, float, float, test::DefaultInput<float>(),
|
||||
test::DefaultInputNonZero<float>(), std::atan2);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Atan2,
|
||||
/*test_name=*/DoubleRhsNonZero, double, double,
|
||||
test::DefaultInput<double>(), test::DefaultInputNonZero<double>(),
|
||||
std::atan2);
|
||||
|
||||
// Prevent the undefined case (0, 0) with non-zero lhs values.
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Atan2,
|
||||
/*test_name=*/FloatLhsNonZero, float, float,
|
||||
test::DefaultInputNonZero<float>(), test::DefaultInput<float>(),
|
||||
std::atan2);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Atan2,
|
||||
/*test_name=*/DoubleLhsNonZero, double, double,
|
||||
test::DefaultInputNonZero<double>(), test::DefaultInput<double>(),
|
||||
std::atan2);
|
||||
|
||||
// Test some particularly interesting cases.
|
||||
TEST_F(GpuBinaryOpTest, Atan2FloatSpecialCases) {
|
||||
TestEqualShapes<float, float, float, float>(
|
||||
"Atan2", /*shape=*/{20},
|
||||
test::InputAsVector<float>({1, 1, 1, 0, -1, -1, -1, 0}),
|
||||
test::InputAsVector<float>({1, 0, -1, -1, -1, 0, 1, 1}), std::atan2,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual());
|
||||
}
|
||||
TEST_F(GpuBinaryOpTest, Atan2DoubleSpecialCases) {
|
||||
TestEqualShapes<double, double, double, double>(
|
||||
"Atan2", /*shape=*/{20},
|
||||
test::InputAsVector<double>({1, 1, 1, 0, -1, -1, -1, 0}),
|
||||
test::InputAsVector<double>({1, 0, -1, -1, -1, 0, 1, 1}), std::atan2,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual());
|
||||
}
|
||||
|
||||
/// Test `tf.BitwiseAnd`.
|
||||
|
||||
|
23
tensorflow/core/kernels/mlir_generated/gpu_op_atan2.cc
Normal file
23
tensorflow/core/kernels/mlir_generated/gpu_op_atan2.cc
Normal file
@ -0,0 +1,23 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Atan2, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Atan2, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,6 @@
|
||||
func @Atan2_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
|
||||
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
|
||||
%0 = "tf.Atan2"(%arg0, %arg1)
|
||||
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
|
||||
return %0 : tensor<*xelem_type>
|
||||
}
|
Loading…
Reference in New Issue
Block a user