[XLA] Fix bug in functions comparing complex values: avoid double decrement of the number of mismatches when only the real parts of the values are different.
PiperOrigin-RevId: 239736734
This commit is contained in:
parent
0ec5db4091
commit
1f6e793504
@ -459,48 +459,30 @@ class NearComparator {
|
||||
|
||||
// For complex types, we compare real and imaginary parts individually.
|
||||
void CompareValues(complex64 expected, complex64 actual, int64 linear_index) {
|
||||
bool mismatch = false;
|
||||
const auto both_parts_mismatch = num_mismatches_ + 2;
|
||||
CompareValues<float>(expected.real(), actual.real(), linear_index);
|
||||
if (mismatches_.data<bool>()[linear_index] == true) {
|
||||
mismatch = true;
|
||||
// Delay the mismatch count increase for real part, instead increase
|
||||
// mismatch by 1 for the entire complex number.
|
||||
num_mismatches_--;
|
||||
}
|
||||
CompareValues<float>(expected.imag(), actual.imag(), linear_index);
|
||||
if (mismatches_.data<bool>()[linear_index] == true) {
|
||||
mismatch = true;
|
||||
// Delay the mismatch count increase for imag part, instead increase
|
||||
// mismatch by 1 for the entire complex number.
|
||||
if (num_mismatches_ == both_parts_mismatch) {
|
||||
// The mismatch counter had been incremented by each CompareValues() call,
|
||||
// which means that both real and imaginary parts of the passed-in complex
|
||||
// values are different. However, the counter should reflect a single
|
||||
// mismatch between these complex values.
|
||||
num_mismatches_--;
|
||||
}
|
||||
if (mismatch == true) {
|
||||
num_mismatches_++;
|
||||
}
|
||||
mismatches_.data<bool>()[linear_index] = mismatch;
|
||||
}
|
||||
|
||||
void CompareValues(complex128 expected, complex128 actual,
|
||||
int64 linear_index) {
|
||||
bool mismatch = false;
|
||||
const auto both_parts_mismatch = num_mismatches_ + 2;
|
||||
CompareValues<double>(expected.real(), actual.real(), linear_index);
|
||||
if (mismatches_.data<bool>()[linear_index] == true) {
|
||||
mismatch = true;
|
||||
// Delay the mismatch count increase for real part, instead increase
|
||||
// mismatch by 1 for the entire complex number.
|
||||
num_mismatches_--;
|
||||
}
|
||||
CompareValues<double>(expected.imag(), actual.imag(), linear_index);
|
||||
if (mismatches_.data<bool>()[linear_index] == true) {
|
||||
mismatch = true;
|
||||
// Delay the mismatch count increase for imag part, instead increase
|
||||
// mismatch by 1 for the entire complex number.
|
||||
if (num_mismatches_ == both_parts_mismatch) {
|
||||
// The mismatch counter had been incremented by each CompareValues() call,
|
||||
// which means that both real and imaginary parts of the passed-in complex
|
||||
// values are different. However, the counter should reflect a single
|
||||
// mismatch between these complex values.
|
||||
num_mismatches_--;
|
||||
}
|
||||
if (mismatch == true) {
|
||||
num_mismatches_++;
|
||||
}
|
||||
mismatches_.data<bool>()[linear_index] = mismatch;
|
||||
}
|
||||
|
||||
// Compares the two literals elementwise.
|
||||
|
@ -38,6 +38,68 @@ TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
|
||||
}
|
||||
|
||||
TEST(LiteralTestUtilTest, ComparesEqualComplex64TuplesEqual) {
|
||||
Literal literal = LiteralUtil::MakeTupleFromSlices({
|
||||
LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
|
||||
LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
|
||||
});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
|
||||
}
|
||||
|
||||
TEST(LiteralTestUtilTest, ComparesEqualComplex128TuplesEqual) {
|
||||
Literal literal = LiteralUtil::MakeTupleFromSlices({
|
||||
LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
|
||||
LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
|
||||
});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
|
||||
}
|
||||
|
||||
TEST(LiteralTestUtilTest, ComparesUnequalComplex64TuplesUnequal) {
|
||||
Literal literal0 = LiteralUtil::MakeTupleFromSlices({
|
||||
LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
|
||||
LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
|
||||
});
|
||||
Literal literal1 = LiteralUtil::MakeTupleFromSlices({
|
||||
LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
|
||||
LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
|
||||
});
|
||||
Literal literal2 = LiteralUtil::MakeTupleFromSlices({
|
||||
LiteralUtil::CreateR0<complex64>({42.42, 64.0}),
|
||||
LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
|
||||
});
|
||||
Literal literal3 = LiteralUtil::MakeTupleFromSlices({
|
||||
LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
|
||||
LiteralUtil::CreateR0<complex64>({64.0, 42.42}),
|
||||
});
|
||||
EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal1));
|
||||
EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal2));
|
||||
EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal3));
|
||||
EXPECT_FALSE(LiteralTestUtil::Equal(literal2, literal3));
|
||||
}
|
||||
|
||||
TEST(LiteralTestUtilTest, ComparesUnequalComplex128TuplesUnequal) {
|
||||
Literal literal0 = LiteralUtil::MakeTupleFromSlices({
|
||||
LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
|
||||
LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
|
||||
});
|
||||
Literal literal1 = LiteralUtil::MakeTupleFromSlices({
|
||||
LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
|
||||
LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
|
||||
});
|
||||
Literal literal2 = LiteralUtil::MakeTupleFromSlices({
|
||||
LiteralUtil::CreateR0<complex128>({42.42, 64.0}),
|
||||
LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
|
||||
});
|
||||
Literal literal3 = LiteralUtil::MakeTupleFromSlices({
|
||||
LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
|
||||
LiteralUtil::CreateR0<complex128>({64.0, 42.42}),
|
||||
});
|
||||
EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal1));
|
||||
EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal2));
|
||||
EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal3));
|
||||
EXPECT_FALSE(LiteralTestUtil::Equal(literal2, literal3));
|
||||
}
|
||||
|
||||
TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
|
||||
// Implementation note: we have to use a death test here, because you can't
|
||||
// un-fail an assertion failure. The CHECK-failure is death, so we can make a
|
||||
@ -118,6 +180,92 @@ TEST(LiteralTestUtilTest, NearComparatorR1) {
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
|
||||
}
|
||||
|
||||
TEST(LiteralTestUtilTest, NearComparatorR1Complex64) {
|
||||
auto a = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
|
||||
{0.1, 1.1},
|
||||
{0.2, 1.2},
|
||||
{0.3, 1.3},
|
||||
{0.4, 1.4},
|
||||
{0.5, 1.5},
|
||||
{0.6, 1.6},
|
||||
{0.7, 1.7},
|
||||
{0.8, 1.8}});
|
||||
auto b = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
|
||||
{0.1, 1.1},
|
||||
{0.2, 1.2},
|
||||
{0.3, 1.3},
|
||||
{0.4, 1.4},
|
||||
{0.5, 1.5},
|
||||
{0.6, 1.6},
|
||||
{0.7, 1.7},
|
||||
{0.8, 1.8}});
|
||||
auto c = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
|
||||
{0.1, 1.1},
|
||||
{0.2, 1.2},
|
||||
{0.3, 1.3},
|
||||
{0.4, 1.4},
|
||||
{0.5, 1.5},
|
||||
{0.6, 1.6},
|
||||
{0.7, 1.7},
|
||||
{0.9, 1.8}});
|
||||
auto d = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
|
||||
{0.1, 1.1},
|
||||
{0.2, 1.2},
|
||||
{0.3, 1.3},
|
||||
{0.4, 1.4},
|
||||
{0.5, 1.5},
|
||||
{0.6, 1.6},
|
||||
{0.7, 1.7},
|
||||
{0.8, 1.9}});
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
|
||||
EXPECT_FALSE(LiteralTestUtil::Near(a, c, ErrorSpec{0.0001}));
|
||||
EXPECT_FALSE(LiteralTestUtil::Near(a, d, ErrorSpec{0.0001}));
|
||||
EXPECT_FALSE(LiteralTestUtil::Near(c, d, ErrorSpec{0.0001}));
|
||||
}
|
||||
|
||||
TEST(LiteralTestUtilTest, NearComparatorR1Complex128) {
|
||||
auto a = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
|
||||
{0.1, 1.1},
|
||||
{0.2, 1.2},
|
||||
{0.3, 1.3},
|
||||
{0.4, 1.4},
|
||||
{0.5, 1.5},
|
||||
{0.6, 1.6},
|
||||
{0.7, 1.7},
|
||||
{0.8, 1.8}});
|
||||
auto b = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
|
||||
{0.1, 1.1},
|
||||
{0.2, 1.2},
|
||||
{0.3, 1.3},
|
||||
{0.4, 1.4},
|
||||
{0.5, 1.5},
|
||||
{0.6, 1.6},
|
||||
{0.7, 1.7},
|
||||
{0.8, 1.8}});
|
||||
auto c = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
|
||||
{0.1, 1.1},
|
||||
{0.2, 1.2},
|
||||
{0.3, 1.3},
|
||||
{0.4, 1.4},
|
||||
{0.5, 1.5},
|
||||
{0.6, 1.6},
|
||||
{0.7, 1.7},
|
||||
{0.9, 1.8}});
|
||||
auto d = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
|
||||
{0.1, 1.1},
|
||||
{0.2, 1.2},
|
||||
{0.3, 1.3},
|
||||
{0.4, 1.4},
|
||||
{0.5, 1.5},
|
||||
{0.6, 1.6},
|
||||
{0.7, 1.7},
|
||||
{0.8, 1.9}});
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
|
||||
EXPECT_FALSE(LiteralTestUtil::Near(a, c, ErrorSpec{0.0001}));
|
||||
EXPECT_FALSE(LiteralTestUtil::Near(a, d, ErrorSpec{0.0001}));
|
||||
EXPECT_FALSE(LiteralTestUtil::Near(c, d, ErrorSpec{0.0001}));
|
||||
}
|
||||
|
||||
TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
|
||||
auto a = LiteralUtil::CreateR1<float>(
|
||||
{0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
|
||||
|
Loading…
Reference in New Issue
Block a user