[XLA] Fix bug in functions comparing complex values: avoid double decrement of the number of mismatches when only the real parts of the values are different.

PiperOrigin-RevId: 239736734
This commit is contained in:
A. Unique TensorFlower 2019-03-21 22:04:35 -07:00 committed by TensorFlower Gardener
parent 0ec5db4091
commit 1f6e793504
2 changed files with 160 additions and 30 deletions

View File

@ -459,48 +459,30 @@ class NearComparator {
// For complex types, we compare real and imaginary parts individually.
void CompareValues(complex64 expected, complex64 actual, int64 linear_index) {
bool mismatch = false;
const auto both_parts_mismatch = num_mismatches_ + 2;
CompareValues<float>(expected.real(), actual.real(), linear_index);
if (mismatches_.data<bool>()[linear_index] == true) {
mismatch = true;
// Delay the mismatch count increase for real part, instead increase
// mismatch by 1 for the entire complex number.
num_mismatches_--;
}
CompareValues<float>(expected.imag(), actual.imag(), linear_index);
if (mismatches_.data<bool>()[linear_index] == true) {
mismatch = true;
// Delay the mismatch count increase for imag part, instead increase
// mismatch by 1 for the entire complex number.
if (num_mismatches_ == both_parts_mismatch) {
// The mismatch counter had been incremented by each CompareValues() call,
// which means that both real and imaginary parts of the passed-in complex
// values are different. However, the counter should reflect a single
// mismatch between these complex values.
num_mismatches_--;
}
if (mismatch == true) {
num_mismatches_++;
}
mismatches_.data<bool>()[linear_index] = mismatch;
}
void CompareValues(complex128 expected, complex128 actual,
int64 linear_index) {
bool mismatch = false;
const auto both_parts_mismatch = num_mismatches_ + 2;
CompareValues<double>(expected.real(), actual.real(), linear_index);
if (mismatches_.data<bool>()[linear_index] == true) {
mismatch = true;
// Delay the mismatch count increase for real part, instead increase
// mismatch by 1 for the entire complex number.
num_mismatches_--;
}
CompareValues<double>(expected.imag(), actual.imag(), linear_index);
if (mismatches_.data<bool>()[linear_index] == true) {
mismatch = true;
// Delay the mismatch count increase for imag part, instead increase
// mismatch by 1 for the entire complex number.
if (num_mismatches_ == both_parts_mismatch) {
// The mismatch counter had been incremented by each CompareValues() call,
// which means that both real and imaginary parts of the passed-in complex
// values are different. However, the counter should reflect a single
// mismatch between these complex values.
num_mismatches_--;
}
if (mismatch == true) {
num_mismatches_++;
}
mismatches_.data<bool>()[linear_index] = mismatch;
}
// Compares the two literals elementwise.

View File

@ -38,6 +38,68 @@ TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
}
TEST(LiteralTestUtilTest, ComparesEqualComplex64TuplesEqual) {
Literal literal = LiteralUtil::MakeTupleFromSlices({
LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
});
EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
}
TEST(LiteralTestUtilTest, ComparesEqualComplex128TuplesEqual) {
Literal literal = LiteralUtil::MakeTupleFromSlices({
LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
});
EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
}
TEST(LiteralTestUtilTest, ComparesUnequalComplex64TuplesUnequal) {
Literal literal0 = LiteralUtil::MakeTupleFromSlices({
LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
});
Literal literal1 = LiteralUtil::MakeTupleFromSlices({
LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
});
Literal literal2 = LiteralUtil::MakeTupleFromSlices({
LiteralUtil::CreateR0<complex64>({42.42, 64.0}),
LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
});
Literal literal3 = LiteralUtil::MakeTupleFromSlices({
LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
LiteralUtil::CreateR0<complex64>({64.0, 42.42}),
});
EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal1));
EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal2));
EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal3));
EXPECT_FALSE(LiteralTestUtil::Equal(literal2, literal3));
}
TEST(LiteralTestUtilTest, ComparesUnequalComplex128TuplesUnequal) {
Literal literal0 = LiteralUtil::MakeTupleFromSlices({
LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
});
Literal literal1 = LiteralUtil::MakeTupleFromSlices({
LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
});
Literal literal2 = LiteralUtil::MakeTupleFromSlices({
LiteralUtil::CreateR0<complex128>({42.42, 64.0}),
LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
});
Literal literal3 = LiteralUtil::MakeTupleFromSlices({
LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
LiteralUtil::CreateR0<complex128>({64.0, 42.42}),
});
EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal1));
EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal2));
EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal3));
EXPECT_FALSE(LiteralTestUtil::Equal(literal2, literal3));
}
TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
// Implementation note: we have to use a death test here, because you can't
// un-fail an assertion failure. The CHECK-failure is death, so we can make a
@ -118,6 +180,92 @@ TEST(LiteralTestUtilTest, NearComparatorR1) {
EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtilTest, NearComparatorR1Complex64) {
auto a = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
{0.1, 1.1},
{0.2, 1.2},
{0.3, 1.3},
{0.4, 1.4},
{0.5, 1.5},
{0.6, 1.6},
{0.7, 1.7},
{0.8, 1.8}});
auto b = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
{0.1, 1.1},
{0.2, 1.2},
{0.3, 1.3},
{0.4, 1.4},
{0.5, 1.5},
{0.6, 1.6},
{0.7, 1.7},
{0.8, 1.8}});
auto c = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
{0.1, 1.1},
{0.2, 1.2},
{0.3, 1.3},
{0.4, 1.4},
{0.5, 1.5},
{0.6, 1.6},
{0.7, 1.7},
{0.9, 1.8}});
auto d = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
{0.1, 1.1},
{0.2, 1.2},
{0.3, 1.3},
{0.4, 1.4},
{0.5, 1.5},
{0.6, 1.6},
{0.7, 1.7},
{0.8, 1.9}});
EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
EXPECT_FALSE(LiteralTestUtil::Near(a, c, ErrorSpec{0.0001}));
EXPECT_FALSE(LiteralTestUtil::Near(a, d, ErrorSpec{0.0001}));
EXPECT_FALSE(LiteralTestUtil::Near(c, d, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtilTest, NearComparatorR1Complex128) {
auto a = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
{0.1, 1.1},
{0.2, 1.2},
{0.3, 1.3},
{0.4, 1.4},
{0.5, 1.5},
{0.6, 1.6},
{0.7, 1.7},
{0.8, 1.8}});
auto b = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
{0.1, 1.1},
{0.2, 1.2},
{0.3, 1.3},
{0.4, 1.4},
{0.5, 1.5},
{0.6, 1.6},
{0.7, 1.7},
{0.8, 1.8}});
auto c = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
{0.1, 1.1},
{0.2, 1.2},
{0.3, 1.3},
{0.4, 1.4},
{0.5, 1.5},
{0.6, 1.6},
{0.7, 1.7},
{0.9, 1.8}});
auto d = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
{0.1, 1.1},
{0.2, 1.2},
{0.3, 1.3},
{0.4, 1.4},
{0.5, 1.5},
{0.6, 1.6},
{0.7, 1.7},
{0.8, 1.9}});
EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
EXPECT_FALSE(LiteralTestUtil::Near(a, c, ErrorSpec{0.0001}));
EXPECT_FALSE(LiteralTestUtil::Near(a, d, ErrorSpec{0.0001}));
EXPECT_FALSE(LiteralTestUtil::Near(c, d, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
auto a = LiteralUtil::CreateR1<float>(
{0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});