kernel_gen for RightShift

PiperOrigin-RevId: 347001402
Change-Id: Ie2766f4c8402e842584a0f3d890aaf33f76a6204
This commit is contained in:
Tres Popp 2020-12-11 08:15:12 -08:00 committed by TensorFlower Gardener
parent a95d4c82a5
commit 241d3bd07b
7 changed files with 75 additions and 1 deletions

View File

@ -101,7 +101,7 @@ foreach fromToBinPair = [[TF_AddV2Op, HLOClient_BroadcastAddOp],
def : DirectBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
def LowerRightShiftSigned :
Pat<(TF_RightShiftOp AnyRankedTensor:$l, AnyRankedTensor:$r),
Pat<(TF_RightShiftOp AnyTensor:$l, AnyTensor:$r),
(HLOClient_BroadcastShiftRightArithmeticOp $l, $r,
(BinBroadcastDimensions $l, $r)),
[(SignedIntTensor $r)]>;

View File

@ -19,8 +19,13 @@ limitations under the License.
namespace tensorflow {
namespace functor {
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
DEFINE_BINARY8(right_shift, int8, int16, int32, int64, uint8, uint16, uint32,
uint64);
#else
DEFINE_BINARY4(right_shift, uint8, uint16, uint32, uint64);
#endif
} // namespace functor
} // namespace tensorflow

View File

@ -21,8 +21,14 @@ REGISTER8(BinaryOp, CPU, "RightShift", functor::right_shift, int8, int16, int32,
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
REGISTER8(BinaryOp, GPU, "RightShift", functor::right_shift, int8, int16, int32,
int64, uint8, uint16, uint32, uint64);
#else
REGISTER4(BinaryOp, GPU, "RightShift", functor::right_shift, uint8, uint16,
uint32, uint64);
#endif
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -172,6 +172,7 @@ tf_kernel_library(
"unranked_op_gpu_logical_and.cc",
"unranked_op_gpu_logical_or.cc",
"unranked_op_gpu_not_equal.cc",
"unranked_op_gpu_right_shift.cc",
],
tags = [
"manual",
@ -192,6 +193,7 @@ tf_kernel_library(
":maximum_unranked_kernels",
":minimum_unranked_kernels",
":not_equal_unranked_kernels",
":right_shift_unranked_kernels",
":unranked_op_gpu_base",
"//third_party/eigen3",
],
@ -407,6 +409,21 @@ gen_kernel_library(
# unroll_factors = "4",
)
gen_kernel_library(
name = "right_shift",
generate_ranked = False,
generate_unranked = True,
tile_size = "256,1,1",
types = [
"i8",
"i16",
"i32",
"i64",
],
# TODO(b/174543802): Enable once fusion heuristics is better.
# unroll_factors = "4",
)
# Bitwise operations.
[
gen_kernel_library(

View File

@ -75,6 +75,15 @@ absl::optional<T> BitwiseAnd(T /*lhs*/, T /*rhs*/) {
return absl::nullopt;
}
template <typename T, std::enable_if_t<std::is_integral<T>::value, bool> = true>
absl::optional<T> RightShift(T lhs, T rhs) {
return lhs >> rhs;
}
template <typename T,
std::enable_if_t<!std::is_integral<T>::value, bool> = true>
absl::optional<T> RightShift(T /*lhs*/, T /*rhs*/) {
return absl::nullopt;
}
template <typename T, std::enable_if_t<std::is_integral<T>::value, bool> = true>
absl::optional<T> BitwiseOr(T lhs, T rhs) {
return lhs | rhs;
}
@ -396,6 +405,11 @@ class ParametricGpuBinaryOpsTest
if (GetParam().op_name == "NotEqual") {
return static_cast<BaselineOutT>(lhs != rhs);
}
if (GetParam().op_name == "RightShift") {
if (auto val = RightShift(lhs, rhs)) {
return static_cast<BaselineOutT>(val.value());
}
}
// Add the logic for creating expected values for the kernel you want to
// test here.
// <PLACEHOLDER>
@ -417,6 +431,7 @@ std::vector<BinaryTestParam> GetBinaryTestParameters() {
parameters.emplace_back("BitwiseOr", dt, dt);
parameters.emplace_back("BitwiseXor", dt, dt);
parameters.emplace_back("LeftShift", dt, dt);
parameters.emplace_back("RightShift", dt, dt);
}
for (DataType dt :
{DT_FLOAT, DT_DOUBLE, DT_HALF, DT_BOOL, DT_INT8, DT_INT16, DT_INT64}) {

View File

@ -0,0 +1,6 @@
func @RightShift_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.RightShift"(%arg0, %arg1) {T = elem_type, device = ""}
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>
}

View File

@ -0,0 +1,25 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i32, DT_INT32, int32);
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i64, DT_INT64, int64);
} // namespace tensorflow