Add uint16, uint32, uint64 support for tf.math.equal

This PR tries to address the issue raised in 26069 where
tf.math.equal does not suport basic data types such as
uint16, uint32, and uint64.

While there might be some restrictions on comparision (e.g. >, <, etc)
for certain data types due to CPU or GPU, the comparision
of basic data types such as uint16, uint32, uint64 are very much
simple operation across the board. They are important in many
ops as well.

For that reason, it makes sense to make sure at least all basic
data types support `equal`.

This PR adds the missing uint16, uint32, uint64 support for tf.math.equal

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-04-06 21:07:14 +00:00
parent 7d0a91fc48
commit 928ff3e27b
2 changed files with 3 additions and 2 deletions

View File

@ -18,6 +18,7 @@ limitations under the License.
namespace tensorflow {
REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double,
uint8, int8, int16, bfloat16);
REGISTER3(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64);
REGISTER_KERNEL_BUILDER(
Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint<float>("T"),
ApproximateEqualOp<CPUDevice, float>);

View File

@ -717,8 +717,8 @@ REGISTER_OP("GreaterEqual").COMPARISON();
.SetIsCommutative() \
.Attr( \
"T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
"int64, complex64, quint8, qint8, qint32, string, bool, " \
"complex128}") \
"int64, uint16, uint32, uint64, complex64, " \
"quint8, qint8, qint32, string, bool, complex128}") \
.Attr("incompatible_shape_error: bool = true") \
.SetShapeFn([](InferenceContext* c) { \
ShapeHandle x = c->input(0); \