Add uint16, uint32, uint64 support for tf.math.equal
This PR tries to address the issue raised in 26069 where tf.math.equal does not suport basic data types such as uint16, uint32, and uint64. While there might be some restrictions on comparision (e.g. >, <, etc) for certain data types due to CPU or GPU, the comparision of basic data types such as uint16, uint32, uint64 are very much simple operation across the board. They are important in many ops as well. For that reason, it makes sense to make sure at least all basic data types support `equal`. This PR adds the missing uint16, uint32, uint64 support for tf.math.equal Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
7d0a91fc48
commit
928ff3e27b
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
namespace tensorflow {
|
||||
REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double,
|
||||
uint8, int8, int16, bfloat16);
|
||||
REGISTER3(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint<float>("T"),
|
||||
ApproximateEqualOp<CPUDevice, float>);
|
||||
|
|
|
@ -717,8 +717,8 @@ REGISTER_OP("GreaterEqual").COMPARISON();
|
|||
.SetIsCommutative() \
|
||||
.Attr( \
|
||||
"T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
|
||||
"int64, complex64, quint8, qint8, qint32, string, bool, " \
|
||||
"complex128}") \
|
||||
"int64, uint16, uint32, uint64, complex64, " \
|
||||
"quint8, qint8, qint32, string, bool, complex128}") \
|
||||
.Attr("incompatible_shape_error: bool = true") \
|
||||
.SetShapeFn([](InferenceContext* c) { \
|
||||
ShapeHandle x = c->input(0); \
|
||||
|
|
Loading…
Reference in New Issue