kernel_gen for RightShift
PiperOrigin-RevId: 347001402 Change-Id: Ie2766f4c8402e842584a0f3d890aaf33f76a6204
This commit is contained in:
parent
a95d4c82a5
commit
241d3bd07b
@ -101,7 +101,7 @@ foreach fromToBinPair = [[TF_AddV2Op, HLOClient_BroadcastAddOp],
|
||||
def : DirectBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
|
||||
|
||||
def LowerRightShiftSigned :
|
||||
Pat<(TF_RightShiftOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
Pat<(TF_RightShiftOp AnyTensor:$l, AnyTensor:$r),
|
||||
(HLOClient_BroadcastShiftRightArithmeticOp $l, $r,
|
||||
(BinBroadcastDimensions $l, $r)),
|
||||
[(SignedIntTensor $r)]>;
|
||||
|
@ -19,8 +19,13 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
|
||||
DEFINE_BINARY8(right_shift, int8, int16, int32, int64, uint8, uint16, uint32,
|
||||
uint64);
|
||||
#else
|
||||
DEFINE_BINARY4(right_shift, uint8, uint16, uint32, uint64);
|
||||
#endif
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -21,8 +21,14 @@ REGISTER8(BinaryOp, CPU, "RightShift", functor::right_shift, int8, int16, int32,
|
||||
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
|
||||
REGISTER8(BinaryOp, GPU, "RightShift", functor::right_shift, int8, int16, int32,
|
||||
int64, uint8, uint16, uint32, uint64);
|
||||
#else
|
||||
REGISTER4(BinaryOp, GPU, "RightShift", functor::right_shift, uint8, uint16,
|
||||
uint32, uint64);
|
||||
#endif
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -172,6 +172,7 @@ tf_kernel_library(
|
||||
"unranked_op_gpu_logical_and.cc",
|
||||
"unranked_op_gpu_logical_or.cc",
|
||||
"unranked_op_gpu_not_equal.cc",
|
||||
"unranked_op_gpu_right_shift.cc",
|
||||
],
|
||||
tags = [
|
||||
"manual",
|
||||
@ -192,6 +193,7 @@ tf_kernel_library(
|
||||
":maximum_unranked_kernels",
|
||||
":minimum_unranked_kernels",
|
||||
":not_equal_unranked_kernels",
|
||||
":right_shift_unranked_kernels",
|
||||
":unranked_op_gpu_base",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
@ -407,6 +409,21 @@ gen_kernel_library(
|
||||
# unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "right_shift",
|
||||
generate_ranked = False,
|
||||
generate_unranked = True,
|
||||
tile_size = "256,1,1",
|
||||
types = [
|
||||
"i8",
|
||||
"i16",
|
||||
"i32",
|
||||
"i64",
|
||||
],
|
||||
# TODO(b/174543802): Enable once fusion heuristics is better.
|
||||
# unroll_factors = "4",
|
||||
)
|
||||
|
||||
# Bitwise operations.
|
||||
[
|
||||
gen_kernel_library(
|
||||
|
@ -75,6 +75,15 @@ absl::optional<T> BitwiseAnd(T /*lhs*/, T /*rhs*/) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
template <typename T, std::enable_if_t<std::is_integral<T>::value, bool> = true>
|
||||
absl::optional<T> RightShift(T lhs, T rhs) {
|
||||
return lhs >> rhs;
|
||||
}
|
||||
template <typename T,
|
||||
std::enable_if_t<!std::is_integral<T>::value, bool> = true>
|
||||
absl::optional<T> RightShift(T /*lhs*/, T /*rhs*/) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
template <typename T, std::enable_if_t<std::is_integral<T>::value, bool> = true>
|
||||
absl::optional<T> BitwiseOr(T lhs, T rhs) {
|
||||
return lhs | rhs;
|
||||
}
|
||||
@ -396,6 +405,11 @@ class ParametricGpuBinaryOpsTest
|
||||
if (GetParam().op_name == "NotEqual") {
|
||||
return static_cast<BaselineOutT>(lhs != rhs);
|
||||
}
|
||||
if (GetParam().op_name == "RightShift") {
|
||||
if (auto val = RightShift(lhs, rhs)) {
|
||||
return static_cast<BaselineOutT>(val.value());
|
||||
}
|
||||
}
|
||||
// Add the logic for creating expected values for the kernel you want to
|
||||
// test here.
|
||||
// <PLACEHOLDER>
|
||||
@ -417,6 +431,7 @@ std::vector<BinaryTestParam> GetBinaryTestParameters() {
|
||||
parameters.emplace_back("BitwiseOr", dt, dt);
|
||||
parameters.emplace_back("BitwiseXor", dt, dt);
|
||||
parameters.emplace_back("LeftShift", dt, dt);
|
||||
parameters.emplace_back("RightShift", dt, dt);
|
||||
}
|
||||
for (DataType dt :
|
||||
{DT_FLOAT, DT_DOUBLE, DT_HALF, DT_BOOL, DT_INT8, DT_INT16, DT_INT64}) {
|
||||
|
@ -0,0 +1,6 @@
|
||||
func @RightShift_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
|
||||
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
|
||||
%0 = "tf.RightShift"(%arg0, %arg1) {T = elem_type, device = ""}
|
||||
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
|
||||
return %0 : tensor<*xelem_type>
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i32, DT_INT32, int32);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
Loading…
x
Reference in New Issue
Block a user