Add unranked kernel definition for bitwise or operation.
PiperOrigin-RevId: 346545955 Change-Id: I287e5dcc8bfe7ec7d1bc3e1fbf11816759fb5332
This commit is contained in:
parent
1a0dd99a6a
commit
84a2d80287
@ -21,8 +21,16 @@ REGISTER8(BinaryOp, CPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32,
|
|||||||
|
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||||
|
!defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
|
||||||
REGISTER8(BinaryOp, GPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32,
|
REGISTER8(BinaryOp, GPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32,
|
||||||
int64, uint8, uint16, uint32, uint64);
|
int64, uint8, uint16, uint32, uint64);
|
||||||
|
#else
|
||||||
|
// TODO(b/172804967): We do not generate unsigned kernels for GPU via mlir.
|
||||||
|
REGISTER4(BinaryOp, GPU, "BitwiseOr", functor::bitwise_or, uint8, uint16,
|
||||||
|
uint32, uint64);
|
||||||
|
#endif // !MLIR_GENERATED_GPU_KERNELS_ENABLED ||
|
||||||
|
// !MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -159,6 +159,7 @@ tf_kernel_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"unranked_op_gpu_add.cc",
|
"unranked_op_gpu_add.cc",
|
||||||
"unranked_op_gpu_bitwise_and.cc",
|
"unranked_op_gpu_bitwise_and.cc",
|
||||||
|
"unranked_op_gpu_bitwise_or.cc",
|
||||||
"unranked_op_gpu_equal.cc",
|
"unranked_op_gpu_equal.cc",
|
||||||
],
|
],
|
||||||
tags = [
|
tags = [
|
||||||
@ -167,6 +168,7 @@ tf_kernel_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":add_v2_unranked_kernels",
|
":add_v2_unranked_kernels",
|
||||||
":bitwise_and_unranked_kernels",
|
":bitwise_and_unranked_kernels",
|
||||||
|
":bitwise_or_unranked_kernels",
|
||||||
":equal_unranked_kernels",
|
":equal_unranked_kernels",
|
||||||
":greater_equal_unranked_kernels",
|
":greater_equal_unranked_kernels",
|
||||||
":greater_unranked_kernels",
|
":greater_unranked_kernels",
|
||||||
@ -375,25 +377,32 @@ gen_kernel_library(
|
|||||||
# unroll_factors = "4",
|
# unroll_factors = "4",
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_kernel_library(
|
# Bitwise operations.
|
||||||
name = "bitwise_and",
|
[
|
||||||
generate_ranked = False,
|
gen_kernel_library(
|
||||||
generate_unranked = True,
|
name = name,
|
||||||
tile_size = "256,1,1",
|
generate_ranked = False,
|
||||||
types = [
|
generate_unranked = True,
|
||||||
"i8",
|
tile_size = "256,1,1",
|
||||||
"i16",
|
types = [
|
||||||
"i32",
|
"i8",
|
||||||
"i64",
|
"i16",
|
||||||
# TODO(b/172804967): Enable once fixed.
|
"i32",
|
||||||
# "ui8",
|
"i64",
|
||||||
# "ui16",
|
# TODO(b/172804967): Enable once fixed.
|
||||||
# "ui32",
|
# "ui8",
|
||||||
# "ui64",
|
# "ui16",
|
||||||
],
|
# "ui32",
|
||||||
# TODO(b/174543802): Enable once fusion heursitics is better.
|
# "ui64",
|
||||||
# unroll_factors = "4",
|
],
|
||||||
)
|
# TODO(b/174543802): Enable once fusion heursitics is better.
|
||||||
|
# unroll_factors = "4",
|
||||||
|
)
|
||||||
|
for name in [
|
||||||
|
"bitwise_and",
|
||||||
|
"bitwise_or",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
[
|
[
|
||||||
gen_kernel_library(
|
gen_kernel_library(
|
||||||
|
@ -58,6 +58,15 @@ template <typename T,
|
|||||||
absl::optional<T> BitwiseAnd(T /*lhs*/, T /*rhs*/) {
|
absl::optional<T> BitwiseAnd(T /*lhs*/, T /*rhs*/) {
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
template <typename T, std::enable_if_t<std::is_integral<T>::value, bool> = true>
|
||||||
|
absl::optional<T> BitwiseOr(T lhs, T rhs) {
|
||||||
|
return lhs | rhs;
|
||||||
|
}
|
||||||
|
template <typename T,
|
||||||
|
std::enable_if_t<!std::is_integral<T>::value, bool> = true>
|
||||||
|
absl::optional<T> BitwiseOr(T /*lhs*/, T /*rhs*/) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
class ParametricGpuBinaryOpsTest
|
class ParametricGpuBinaryOpsTest
|
||||||
: public OpsTestBase,
|
: public OpsTestBase,
|
||||||
@ -316,6 +325,11 @@ class ParametricGpuBinaryOpsTest
|
|||||||
return static_cast<BaselineOutT>(val.value());
|
return static_cast<BaselineOutT>(val.value());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (GetParam().op_name == "BitwiseOr") {
|
||||||
|
if (auto val = BitwiseOr(lhs, rhs)) {
|
||||||
|
return static_cast<BaselineOutT>(val.value());
|
||||||
|
}
|
||||||
|
}
|
||||||
if (GetParam().op_name == "Equal") {
|
if (GetParam().op_name == "Equal") {
|
||||||
return static_cast<BaselineOutT>(lhs == rhs);
|
return static_cast<BaselineOutT>(lhs == rhs);
|
||||||
}
|
}
|
||||||
@ -339,6 +353,7 @@ std::vector<BinaryTestParam> GetBinaryTestParameters() {
|
|||||||
for (DataType dt :
|
for (DataType dt :
|
||||||
std::vector<DataType>{DT_INT8, DT_INT16, DT_INT32, DT_INT64}) {
|
std::vector<DataType>{DT_INT8, DT_INT16, DT_INT32, DT_INT64}) {
|
||||||
parameters.emplace_back("BitwiseAnd", dt, dt);
|
parameters.emplace_back("BitwiseAnd", dt, dt);
|
||||||
|
parameters.emplace_back("BitwiseOr", dt, dt);
|
||||||
}
|
}
|
||||||
for (DataType dt :
|
for (DataType dt :
|
||||||
std::vector<DataType>{DT_FLOAT, DT_DOUBLE, DT_HALF, DT_BOOL, DT_INT8,
|
std::vector<DataType>{DT_FLOAT, DT_DOUBLE, DT_HALF, DT_BOOL, DT_INT8,
|
||||||
|
@ -0,0 +1,6 @@
|
|||||||
|
func @BitwiseOr_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
|
||||||
|
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
|
||||||
|
%0 = "tf.BitwiseOr"(%arg0, %arg1)
|
||||||
|
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
|
||||||
|
return %0 : tensor<*xelem_type>
|
||||||
|
}
|
@ -0,0 +1,31 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i8, DT_INT8, int8);
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i16, DT_INT16, int16);
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i32, DT_INT32, int32);
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i64, DT_INT64, int64);
|
||||||
|
|
||||||
|
// TODO(b/172804967): Enable once fixed.
|
||||||
|
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui8, DT_UINT8, uint8);
|
||||||
|
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui16, DT_UINT16, uint16);
|
||||||
|
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui32, DT_UINT32, uint32);
|
||||||
|
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui64, DT_UINT64, uint64);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user