[KERNEL_GEN] Add kernel generation for Sub.

PiperOrigin-RevId: 350111344
Change-Id: I86aba6f297bded0b69887ff1750c05036ee5e87c
This commit is contained in:
Alexander Belyaev 2021-01-05 03:51:52 -08:00 committed by TensorFlower Gardener
parent b08f45a554
commit 031c7398b1
5 changed files with 74 additions and 12 deletions
tensorflow/core/kernels

View File

@ -30,8 +30,13 @@ REGISTER(BinaryOp, CPU, "Sub", functor::sub, int32);
#endif // __ANDROID_TYPES_SLIM__ #endif // __ANDROID_TYPES_SLIM__
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER7(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64, REGISTER7(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
complex64, complex128, uint32); complex64, complex128, uint32);
#else
REGISTER3(BinaryOp, GPU, "Sub", functor::sub, complex64, complex128, uint32);
#endif
// A special GPU kernel for int32. // A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel // TODO(b/25387198): Also enable int32 in device memory. This kernel

View File

@ -146,6 +146,7 @@ tf_kernel_library(
"gpu_op_mul.cc", "gpu_op_mul.cc",
"gpu_op_not_equal.cc", "gpu_op_not_equal.cc",
"gpu_op_right_shift.cc", "gpu_op_right_shift.cc",
"gpu_op_sub.cc",
], ],
tags = [ tags = [
"manual", "manual",
@ -170,6 +171,7 @@ tf_kernel_library(
":mul_kernels", ":mul_kernels",
":not_equal_kernels", ":not_equal_kernels",
":right_shift_kernels", ":right_shift_kernels",
":sub_kernels",
"//third_party/eigen3", "//third_party/eigen3",
], ],
) )
@ -366,18 +368,24 @@ gen_kernel_library(
unroll_factors = "4", unroll_factors = "4",
) )
gen_kernel_library( [
name = "add_v2", gen_kernel_library(
tile_size = "256,1,1", name = name,
types = [ tile_size = "256,1,1",
"f16", types = [
"f32", "f16",
"f64", "f32",
"i64", "f64",
], "i64",
# TODO(b/174543802): Enable once fusion heuristics is better. ],
# unroll_factors = "4", # TODO(b/174543802): Enable once fusion heuristics is better.
) # unroll_factors = "4",
)
for name in [
"add_v2",
"sub",
]
]
gen_kernel_library( gen_kernel_library(
name = "complex", name = "complex",

View File

@ -385,6 +385,23 @@ GENERATE_DEFAULT_TESTS(AddV2,
GENERATE_DEFAULT_TESTS(AddV2, GENERATE_DEFAULT_TESTS(AddV2,
/*test_name=*/Int64, int64, int64, baseline_add) /*test_name=*/Int64, int64, int64, baseline_add)
/// Test `tf.Sub`.
template <typename T>
T baseline_sub(T lhs, T rhs) {
return lhs - rhs;
}
GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Half, Eigen::half, Eigen::half,
baseline_sub)
GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Float, float, float, baseline_sub)
GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Double, double, double, baseline_sub)
GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Int64, int64, int64, baseline_sub)
/// Test `tf.BitwiseAnd`. /// Test `tf.BitwiseAnd`.
template <typename T> template <typename T>

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -0,0 +1,6 @@
func @Sub_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Sub"(%arg0, %arg1) {T = elem_type, device = ""}
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>
}