Add unranked kernel definition for bitwise or operation.

PiperOrigin-RevId: 346545955
Change-Id: I287e5dcc8bfe7ec7d1bc3e1fbf11816759fb5332
This commit is contained in:
Stephan Herhut 2020-12-09 07:26:39 -08:00 committed by TensorFlower Gardener
parent 1a0dd99a6a
commit 84a2d80287
5 changed files with 88 additions and 19 deletions

View File

@ -21,8 +21,16 @@ REGISTER8(BinaryOp, CPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32,
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
REGISTER8(BinaryOp, GPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32, REGISTER8(BinaryOp, GPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32,
int64, uint8, uint16, uint32, uint64); int64, uint8, uint16, uint32, uint64);
#else
// TODO(b/172804967): We do not generate unsigned kernels for GPU via mlir.
REGISTER4(BinaryOp, GPU, "BitwiseOr", functor::bitwise_or, uint8, uint16,
uint32, uint64);
#endif // !MLIR_GENERATED_GPU_KERNELS_ENABLED ||
// !MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow } // namespace tensorflow

View File

@ -159,6 +159,7 @@ tf_kernel_library(
srcs = [ srcs = [
"unranked_op_gpu_add.cc", "unranked_op_gpu_add.cc",
"unranked_op_gpu_bitwise_and.cc", "unranked_op_gpu_bitwise_and.cc",
"unranked_op_gpu_bitwise_or.cc",
"unranked_op_gpu_equal.cc", "unranked_op_gpu_equal.cc",
], ],
tags = [ tags = [
@ -167,6 +168,7 @@ tf_kernel_library(
deps = [ deps = [
":add_v2_unranked_kernels", ":add_v2_unranked_kernels",
":bitwise_and_unranked_kernels", ":bitwise_and_unranked_kernels",
":bitwise_or_unranked_kernels",
":equal_unranked_kernels", ":equal_unranked_kernels",
":greater_equal_unranked_kernels", ":greater_equal_unranked_kernels",
":greater_unranked_kernels", ":greater_unranked_kernels",
@ -375,25 +377,32 @@ gen_kernel_library(
# unroll_factors = "4", # unroll_factors = "4",
) )
gen_kernel_library( # Bitwise operations.
name = "bitwise_and", [
generate_ranked = False, gen_kernel_library(
generate_unranked = True, name = name,
tile_size = "256,1,1", generate_ranked = False,
types = [ generate_unranked = True,
"i8", tile_size = "256,1,1",
"i16", types = [
"i32", "i8",
"i64", "i16",
# TODO(b/172804967): Enable once fixed. "i32",
# "ui8", "i64",
# "ui16", # TODO(b/172804967): Enable once fixed.
# "ui32", # "ui8",
# "ui64", # "ui16",
], # "ui32",
# TODO(b/174543802): Enable once fusion heursitics is better. # "ui64",
# unroll_factors = "4", ],
) # TODO(b/174543802): Enable once fusion heursitics is better.
# unroll_factors = "4",
)
for name in [
"bitwise_and",
"bitwise_or",
]
]
[ [
gen_kernel_library( gen_kernel_library(

View File

@ -58,6 +58,15 @@ template <typename T,
absl::optional<T> BitwiseAnd(T /*lhs*/, T /*rhs*/) { absl::optional<T> BitwiseAnd(T /*lhs*/, T /*rhs*/) {
return absl::nullopt; return absl::nullopt;
} }
template <typename T, std::enable_if_t<std::is_integral<T>::value, bool> = true>
absl::optional<T> BitwiseOr(T lhs, T rhs) {
return lhs | rhs;
}
template <typename T,
std::enable_if_t<!std::is_integral<T>::value, bool> = true>
absl::optional<T> BitwiseOr(T /*lhs*/, T /*rhs*/) {
return absl::nullopt;
}
class ParametricGpuBinaryOpsTest class ParametricGpuBinaryOpsTest
: public OpsTestBase, : public OpsTestBase,
@ -316,6 +325,11 @@ class ParametricGpuBinaryOpsTest
return static_cast<BaselineOutT>(val.value()); return static_cast<BaselineOutT>(val.value());
} }
} }
if (GetParam().op_name == "BitwiseOr") {
if (auto val = BitwiseOr(lhs, rhs)) {
return static_cast<BaselineOutT>(val.value());
}
}
if (GetParam().op_name == "Equal") { if (GetParam().op_name == "Equal") {
return static_cast<BaselineOutT>(lhs == rhs); return static_cast<BaselineOutT>(lhs == rhs);
} }
@ -339,6 +353,7 @@ std::vector<BinaryTestParam> GetBinaryTestParameters() {
for (DataType dt : for (DataType dt :
std::vector<DataType>{DT_INT8, DT_INT16, DT_INT32, DT_INT64}) { std::vector<DataType>{DT_INT8, DT_INT16, DT_INT32, DT_INT64}) {
parameters.emplace_back("BitwiseAnd", dt, dt); parameters.emplace_back("BitwiseAnd", dt, dt);
parameters.emplace_back("BitwiseOr", dt, dt);
} }
for (DataType dt : for (DataType dt :
std::vector<DataType>{DT_FLOAT, DT_DOUBLE, DT_HALF, DT_BOOL, DT_INT8, std::vector<DataType>{DT_FLOAT, DT_DOUBLE, DT_HALF, DT_BOOL, DT_INT8,

View File

@ -0,0 +1,6 @@
func @BitwiseOr_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.BitwiseOr"(%arg0, %arg1)
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>
}

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i32, DT_INT32, int32);
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i64, DT_INT64, int64);
// TODO(b/172804967): Enable once fixed.
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui8, DT_UINT8, uint8);
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui16, DT_UINT16, uint16);
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui32, DT_UINT32, uint32);
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui64, DT_UINT64, uint64);
} // namespace tensorflow