Merge pull request #38288 from yongtang:26069-equal-dtype
PiperOrigin-RevId: 307638517 Change-Id: I57e294157e5ead1a285009994ecdb90b7577a232
This commit is contained in:
commit
9c925a52e8
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double,
|
REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double,
|
||||||
uint8, int8, int16, bfloat16);
|
uint8, int8, int16, bfloat16);
|
||||||
|
REGISTER3(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64);
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint<float>("T"),
|
Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint<float>("T"),
|
||||||
ApproximateEqualOp<CPUDevice, float>);
|
ApproximateEqualOp<CPUDevice, float>);
|
||||||
|
|
|
@ -18,6 +18,8 @@ limitations under the License.
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
|
REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
|
||||||
double, uint8, int8, int16, bfloat16);
|
double, uint8, int8, int16, bfloat16);
|
||||||
|
REGISTER3(BinaryOp, CPU, "NotEqual", functor::not_equal_to, uint16, uint32,
|
||||||
|
uint64);
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
|
REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
|
||||||
double, uint8);
|
double, uint8);
|
||||||
|
|
|
@ -717,8 +717,8 @@ REGISTER_OP("GreaterEqual").COMPARISON();
|
||||||
.SetIsCommutative() \
|
.SetIsCommutative() \
|
||||||
.Attr( \
|
.Attr( \
|
||||||
"T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
|
"T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
|
||||||
"int64, complex64, quint8, qint8, qint32, string, bool, " \
|
"int64, uint16, uint32, uint64, complex64, " \
|
||||||
"complex128}") \
|
"quint8, qint8, qint32, string, bool, complex128}") \
|
||||||
.Attr("incompatible_shape_error: bool = true") \
|
.Attr("incompatible_shape_error: bool = true") \
|
||||||
.SetShapeFn([](InferenceContext* c) { \
|
.SetShapeFn([](InferenceContext* c) { \
|
||||||
ShapeHandle x = c->input(0); \
|
ShapeHandle x = c->input(0); \
|
||||||
|
|
|
@ -953,6 +953,44 @@ class ComparisonOpTest(test.TestCase):
|
||||||
"Incompatible shapes|Dimensions must be equal"):
|
"Incompatible shapes|Dimensions must be equal"):
|
||||||
f(x.astype(t), y.astype(t))
|
f(x.astype(t), y.astype(t))
|
||||||
|
|
||||||
|
def testEqualDType(self):
|
||||||
|
dtypes = [
|
||||||
|
np.float16,
|
||||||
|
np.float32,
|
||||||
|
np.float64,
|
||||||
|
np.int8,
|
||||||
|
np.int16,
|
||||||
|
np.int32,
|
||||||
|
np.int64,
|
||||||
|
np.uint8,
|
||||||
|
np.uint16,
|
||||||
|
np.uint32,
|
||||||
|
np.uint64,
|
||||||
|
np.bool,
|
||||||
|
]
|
||||||
|
x = np.asarray([0, 1, 2, 3, 4])
|
||||||
|
y = np.asarray([0, 1, 2, 3, 4])
|
||||||
|
for dtype in dtypes:
|
||||||
|
xt = x.astype(dtype)
|
||||||
|
yt = y.astype(dtype)
|
||||||
|
cmp_eq = math_ops.equal(xt, yt)
|
||||||
|
cmp_ne = math_ops.not_equal(xt, yt)
|
||||||
|
values = self.evaluate([cmp_eq, cmp_ne])
|
||||||
|
self.assertAllEqual(
|
||||||
|
[[True, True, True, True, True], [False, False, False, False, False]],
|
||||||
|
values)
|
||||||
|
for dtype in [np.complex64, np.complex128]:
|
||||||
|
xt = x.astype(dtype)
|
||||||
|
xt -= 1j * xt
|
||||||
|
yt = y.astype(dtype)
|
||||||
|
yt -= 1j * yt
|
||||||
|
cmp_eq = math_ops.equal(xt, yt)
|
||||||
|
cmp_ne = math_ops.not_equal(xt, yt)
|
||||||
|
values = self.evaluate([cmp_eq, cmp_ne])
|
||||||
|
self.assertAllEqual(
|
||||||
|
[[True, True, True, True, True], [False, False, False, False, False]],
|
||||||
|
values)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|
Loading…
Reference in New Issue