[KERNEL_GEN] Add kernel generation for Sub.
PiperOrigin-RevId: 350111344 Change-Id: I86aba6f297bded0b69887ff1750c05036ee5e87c
This commit is contained in:
parent
b08f45a554
commit
031c7398b1
tensorflow/core/kernels
@ -30,8 +30,13 @@ REGISTER(BinaryOp, CPU, "Sub", functor::sub, int32);
|
|||||||
#endif // __ANDROID_TYPES_SLIM__
|
#endif // __ANDROID_TYPES_SLIM__
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||||
|
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||||
REGISTER7(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
|
REGISTER7(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
|
||||||
complex64, complex128, uint32);
|
complex64, complex128, uint32);
|
||||||
|
#else
|
||||||
|
REGISTER3(BinaryOp, GPU, "Sub", functor::sub, complex64, complex128, uint32);
|
||||||
|
#endif
|
||||||
|
|
||||||
// A special GPU kernel for int32.
|
// A special GPU kernel for int32.
|
||||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||||
|
@ -146,6 +146,7 @@ tf_kernel_library(
|
|||||||
"gpu_op_mul.cc",
|
"gpu_op_mul.cc",
|
||||||
"gpu_op_not_equal.cc",
|
"gpu_op_not_equal.cc",
|
||||||
"gpu_op_right_shift.cc",
|
"gpu_op_right_shift.cc",
|
||||||
|
"gpu_op_sub.cc",
|
||||||
],
|
],
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
@ -170,6 +171,7 @@ tf_kernel_library(
|
|||||||
":mul_kernels",
|
":mul_kernels",
|
||||||
":not_equal_kernels",
|
":not_equal_kernels",
|
||||||
":right_shift_kernels",
|
":right_shift_kernels",
|
||||||
|
":sub_kernels",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -366,18 +368,24 @@ gen_kernel_library(
|
|||||||
unroll_factors = "4",
|
unroll_factors = "4",
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_kernel_library(
|
[
|
||||||
name = "add_v2",
|
gen_kernel_library(
|
||||||
tile_size = "256,1,1",
|
name = name,
|
||||||
types = [
|
tile_size = "256,1,1",
|
||||||
"f16",
|
types = [
|
||||||
"f32",
|
"f16",
|
||||||
"f64",
|
"f32",
|
||||||
"i64",
|
"f64",
|
||||||
],
|
"i64",
|
||||||
# TODO(b/174543802): Enable once fusion heuristics is better.
|
],
|
||||||
# unroll_factors = "4",
|
# TODO(b/174543802): Enable once fusion heuristics is better.
|
||||||
)
|
# unroll_factors = "4",
|
||||||
|
)
|
||||||
|
for name in [
|
||||||
|
"add_v2",
|
||||||
|
"sub",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
gen_kernel_library(
|
gen_kernel_library(
|
||||||
name = "complex",
|
name = "complex",
|
||||||
|
@ -385,6 +385,23 @@ GENERATE_DEFAULT_TESTS(AddV2,
|
|||||||
GENERATE_DEFAULT_TESTS(AddV2,
|
GENERATE_DEFAULT_TESTS(AddV2,
|
||||||
/*test_name=*/Int64, int64, int64, baseline_add)
|
/*test_name=*/Int64, int64, int64, baseline_add)
|
||||||
|
|
||||||
|
/// Test `tf.Sub`.
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T baseline_sub(T lhs, T rhs) {
|
||||||
|
return lhs - rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
GENERATE_DEFAULT_TESTS(Sub,
|
||||||
|
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||||
|
baseline_sub)
|
||||||
|
GENERATE_DEFAULT_TESTS(Sub,
|
||||||
|
/*test_name=*/Float, float, float, baseline_sub)
|
||||||
|
GENERATE_DEFAULT_TESTS(Sub,
|
||||||
|
/*test_name=*/Double, double, double, baseline_sub)
|
||||||
|
GENERATE_DEFAULT_TESTS(Sub,
|
||||||
|
/*test_name=*/Int64, int64, int64, baseline_sub)
|
||||||
|
|
||||||
/// Test `tf.BitwiseAnd`.
|
/// Test `tf.BitwiseAnd`.
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
26
tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc
Normal file
26
tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f16, DT_HALF, Eigen::half);
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f32, DT_FLOAT, float);
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f64, DT_DOUBLE, double);
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, i64, DT_INT64, int64);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,6 @@
|
|||||||
|
func @Sub_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
|
||||||
|
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
|
||||||
|
%0 = "tf.Sub"(%arg0, %arg1) {T = elem_type, device = ""}
|
||||||
|
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
|
||||||
|
return %0 : tensor<*xelem_type>
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user