[MLIR][KernelGen] Add atan2 kernels

PiperOrigin-RevId: 351120975
Change-Id: I07bc4518bdaf66fcc4065c42a4cb393e742fba1e
This commit is contained in:
A. Unique TensorFlower 2021-01-11 03:41:19 -08:00 committed by TensorFlower Gardener
parent dff3c8a47a
commit 091856315b
8 changed files with 96 additions and 19 deletions

View File

@ -37,14 +37,4 @@ func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
return %1 : tensor<4x7xcomplex<f64>>
}
// CHECK-LABEL: atan2
func @atan2(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> tensor<4x4x4xf32> {
// NO_FALLBACK: tf.Atan2
// SUPPORTED_FALLBACK_DEVICE-NOT: tf.Atan2
// UNSPECIFIED_FALLBACK_DEVICE: tf.Atan2
// UNSUPPORTED_FALLBACK_DEVICE: tf.Atan2
%0 = "tf.Atan2"(%arg0, %arg1) : (tensor<4x1xf32>, tensor<4x1x4xf32>) -> tensor<4x4x4xf32>
return %0: tensor<4x4x4xf32>
}
}

View File

@ -91,6 +91,7 @@ class DirectBinaryPat<Op FromOp, Op ToOp>
(ToOp $l, $r, (BinBroadcastDimensions $l, $r))>;
foreach fromToBinPair = [[TF_AddV2Op, HLOClient_BroadcastAddOp],
[TF_Atan2Op, HLOClient_BroadcastAtan2Op],
[TF_DivOp, HLOClient_BroadcastDivOp],
[TF_LeftShiftOp, HLOClient_BroadcastShiftLeftOp],
[TF_MaximumOp, HLOClient_BroadcastMaxOp],

View File

@ -16,8 +16,14 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
REGISTER2(BinaryOp, CPU, "Atan2", functor::atan2, float, double);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER2(BinaryOp, GPU, "Atan2", functor::atan2, float, double);
#endif
#endif
} // namespace tensorflow

View File

@ -131,6 +131,7 @@ tf_kernel_library(
name = "cwise_binary_op",
srcs = [
"gpu_op_add.cc",
"gpu_op_atan2.cc",
"gpu_op_bitwise_and.cc",
"gpu_op_bitwise_or.cc",
"gpu_op_bitwise_xor.cc",
@ -154,6 +155,7 @@ tf_kernel_library(
],
deps = [
":add_v2_kernels",
":atan2_kernels",
":bitwise_and_kernels",
":bitwise_or_kernels",
":bitwise_xor_kernels",
@ -457,6 +459,17 @@ gen_kernel_library(
]
]
gen_kernel_library(
name = "atan2",
tile_size = "256,1,1",
types = [
"f32",
"f64",
],
# TODO(b/174543802): Enable once fusion heursitics is better.
# unroll_factors = "4",
)
# Logical operations.
[
gen_kernel_library(

View File

@ -192,7 +192,7 @@ def gen_kernel_library(name, types, tile_size, tags = [], unroll_factors = None,
"$(location //tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel)",
"$(location {name}_{type}.mlir)".format(name = name, type = type),
],
size = "small",
size = "medium",
data = [
":{name}_{type}.mlir".format(name = name, type = type),
"//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel",

View File

@ -373,15 +373,53 @@ T baseline_add(T lhs, T rhs) {
return lhs + rhs;
}
GENERATE_DEFAULT_TESTS(AddV2,
/*test_name=*/Half, Eigen::half, Eigen::half,
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Half, Eigen::half, Eigen::half,
baseline_add)
GENERATE_DEFAULT_TESTS(AddV2,
/*test_name=*/Float, float, float, baseline_add)
GENERATE_DEFAULT_TESTS(AddV2,
/*test_name=*/Double, double, double, baseline_add)
GENERATE_DEFAULT_TESTS(AddV2,
/*test_name=*/Int64, int64, int64, baseline_add)
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Float, float, float, baseline_add)
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Double, double, double,
baseline_add)
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Int64, int64, int64, baseline_add)
/// Test `tf.Atan2`.
// Prevent the undefined case (0, 0) with non-zero rhs values.
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
Atan2,
/*test_name=*/FloatRhsNonZero, float, float, test::DefaultInput<float>(),
test::DefaultInputNonZero<float>(), std::atan2);
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
Atan2,
/*test_name=*/DoubleRhsNonZero, double, double,
test::DefaultInput<double>(), test::DefaultInputNonZero<double>(),
std::atan2);
// Prevent the undefined case (0, 0) with non-zero lhs values.
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
Atan2,
/*test_name=*/FloatLhsNonZero, float, float,
test::DefaultInputNonZero<float>(), test::DefaultInput<float>(),
std::atan2);
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
Atan2,
/*test_name=*/DoubleLhsNonZero, double, double,
test::DefaultInputNonZero<double>(), test::DefaultInput<double>(),
std::atan2);
// Test some particularly interesting cases.
TEST_F(GpuBinaryOpTest, Atan2FloatSpecialCases) {
TestEqualShapes<float, float, float, float>(
"Atan2", /*shape=*/{20},
test::InputAsVector<float>({1, 1, 1, 0, -1, -1, -1, 0}),
test::InputAsVector<float>({1, 0, -1, -1, -1, 0, 1, 1}), std::atan2,
test::GpuOpsTestConfig().ExpectStrictlyEqual());
}
TEST_F(GpuBinaryOpTest, Atan2DoubleSpecialCases) {
TestEqualShapes<double, double, double, double>(
"Atan2", /*shape=*/{20},
test::InputAsVector<double>({1, 1, 1, 0, -1, -1, -1, 0}),
test::InputAsVector<double>({1, 0, -1, -1, -1, 0, 1, 1}), std::atan2,
test::GpuOpsTestConfig().ExpectStrictlyEqual());
}
/// Test `tf.BitwiseAnd`.

View File

@ -0,0 +1,23 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(Atan2, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(Atan2, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -0,0 +1,6 @@
func @Atan2_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Atan2"(%arg0, %arg1)
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>
}