[KERNEL_GEN] Add kernel generation for Sub.
PiperOrigin-RevId: 350111344 Change-Id: I86aba6f297bded0b69887ff1750c05036ee5e87c
This commit is contained in:
parent
b08f45a554
commit
031c7398b1
tensorflow/core/kernels
@ -30,8 +30,13 @@ REGISTER(BinaryOp, CPU, "Sub", functor::sub, int32);
|
||||
#endif // __ANDROID_TYPES_SLIM__
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
REGISTER7(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
|
||||
complex64, complex128, uint32);
|
||||
#else
|
||||
REGISTER3(BinaryOp, GPU, "Sub", functor::sub, complex64, complex128, uint32);
|
||||
#endif
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
|
@ -146,6 +146,7 @@ tf_kernel_library(
|
||||
"gpu_op_mul.cc",
|
||||
"gpu_op_not_equal.cc",
|
||||
"gpu_op_right_shift.cc",
|
||||
"gpu_op_sub.cc",
|
||||
],
|
||||
tags = [
|
||||
"manual",
|
||||
@ -170,6 +171,7 @@ tf_kernel_library(
|
||||
":mul_kernels",
|
||||
":not_equal_kernels",
|
||||
":right_shift_kernels",
|
||||
":sub_kernels",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
@ -366,18 +368,24 @@ gen_kernel_library(
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "add_v2",
|
||||
tile_size = "256,1,1",
|
||||
types = [
|
||||
"f16",
|
||||
"f32",
|
||||
"f64",
|
||||
"i64",
|
||||
],
|
||||
# TODO(b/174543802): Enable once fusion heuristics is better.
|
||||
# unroll_factors = "4",
|
||||
)
|
||||
[
|
||||
gen_kernel_library(
|
||||
name = name,
|
||||
tile_size = "256,1,1",
|
||||
types = [
|
||||
"f16",
|
||||
"f32",
|
||||
"f64",
|
||||
"i64",
|
||||
],
|
||||
# TODO(b/174543802): Enable once fusion heuristics is better.
|
||||
# unroll_factors = "4",
|
||||
)
|
||||
for name in [
|
||||
"add_v2",
|
||||
"sub",
|
||||
]
|
||||
]
|
||||
|
||||
gen_kernel_library(
|
||||
name = "complex",
|
||||
|
@ -385,6 +385,23 @@ GENERATE_DEFAULT_TESTS(AddV2,
|
||||
GENERATE_DEFAULT_TESTS(AddV2,
|
||||
/*test_name=*/Int64, int64, int64, baseline_add)
|
||||
|
||||
/// Test `tf.Sub`.
|
||||
|
||||
template <typename T>
|
||||
T baseline_sub(T lhs, T rhs) {
|
||||
return lhs - rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_sub)
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Float, float, float, baseline_sub)
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Double, double, double, baseline_sub)
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Int64, int64, int64, baseline_sub)
|
||||
|
||||
/// Test `tf.BitwiseAnd`.
|
||||
|
||||
template <typename T>
|
||||
|
26
tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc
Normal file
26
tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,6 @@
|
||||
func @Sub_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
|
||||
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
|
||||
%0 = "tf.Sub"(%arg0, %arg1) {T = elem_type, device = ""}
|
||||
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
|
||||
return %0 : tensor<*xelem_type>
|
||||
}
|
Loading…
Reference in New Issue
Block a user