diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 9b3de75dd4e..758dba1262e 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -459,48 +459,30 @@ class NearComparator { // For complex types, we compare real and imaginary parts individually. void CompareValues(complex64 expected, complex64 actual, int64 linear_index) { - bool mismatch = false; + const auto both_parts_mismatch = num_mismatches_ + 2; CompareValues(expected.real(), actual.real(), linear_index); - if (mismatches_.data()[linear_index] == true) { - mismatch = true; - // Delay the mismatch count increase for real part, instead increase - // mismatch by 1 for the entire complex number. - num_mismatches_--; - } CompareValues(expected.imag(), actual.imag(), linear_index); - if (mismatches_.data()[linear_index] == true) { - mismatch = true; - // Delay the mismatch count increase for imag part, instead increase - // mismatch by 1 for the entire complex number. + if (num_mismatches_ == both_parts_mismatch) { + // The mismatch counter had been incremented by each CompareValues() call, + // which means that both real and imaginary parts of the passed-in complex + // values are different. However, the counter should reflect a single + // mismatch between these complex values. num_mismatches_--; } - if (mismatch == true) { - num_mismatches_++; - } - mismatches_.data()[linear_index] = mismatch; } void CompareValues(complex128 expected, complex128 actual, int64 linear_index) { - bool mismatch = false; + const auto both_parts_mismatch = num_mismatches_ + 2; CompareValues(expected.real(), actual.real(), linear_index); - if (mismatches_.data()[linear_index] == true) { - mismatch = true; - // Delay the mismatch count increase for real part, instead increase - // mismatch by 1 for the entire complex number. - num_mismatches_--; - } CompareValues(expected.imag(), actual.imag(), linear_index); - if (mismatches_.data()[linear_index] == true) { - mismatch = true; - // Delay the mismatch count increase for imag part, instead increase - // mismatch by 1 for the entire complex number. + if (num_mismatches_ == both_parts_mismatch) { + // The mismatch counter had been incremented by each CompareValues() call, + // which means that both real and imaginary parts of the passed-in complex + // values are different. However, the counter should reflect a single + // mismatch between these complex values. num_mismatches_--; } - if (mismatch == true) { - num_mismatches_++; - } - mismatches_.data()[linear_index] = mismatch; } // Compares the two literals elementwise. diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index ea9b3037cf4..cc55d1c7405 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -38,6 +38,68 @@ TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal)); } +TEST(LiteralTestUtilTest, ComparesEqualComplex64TuplesEqual) { + Literal literal = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.0, 64.0}), + LiteralUtil::CreateR0({64.0, 42.0}), + }); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal)); +} + +TEST(LiteralTestUtilTest, ComparesEqualComplex128TuplesEqual) { + Literal literal = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.0, 64.0}), + LiteralUtil::CreateR0({64.0, 42.0}), + }); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal)); +} + +TEST(LiteralTestUtilTest, ComparesUnequalComplex64TuplesUnequal) { + Literal literal0 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.0, 64.0}), + LiteralUtil::CreateR0({64.0, 42.0}), + }); + Literal literal1 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({64.0, 42.0}), + LiteralUtil::CreateR0({42.0, 64.0}), + }); + Literal literal2 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.42, 64.0}), + LiteralUtil::CreateR0({64.0, 42.0}), + }); + Literal literal3 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.0, 64.0}), + LiteralUtil::CreateR0({64.0, 42.42}), + }); + EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal1)); + EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal2)); + EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal3)); + EXPECT_FALSE(LiteralTestUtil::Equal(literal2, literal3)); +} + +TEST(LiteralTestUtilTest, ComparesUnequalComplex128TuplesUnequal) { + Literal literal0 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.0, 64.0}), + LiteralUtil::CreateR0({64.0, 42.0}), + }); + Literal literal1 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({64.0, 42.0}), + LiteralUtil::CreateR0({42.0, 64.0}), + }); + Literal literal2 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.42, 64.0}), + LiteralUtil::CreateR0({64.0, 42.0}), + }); + Literal literal3 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.0, 64.0}), + LiteralUtil::CreateR0({64.0, 42.42}), + }); + EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal1)); + EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal2)); + EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal3)); + EXPECT_FALSE(LiteralTestUtil::Equal(literal2, literal3)); +} + TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { // Implementation note: we have to use a death test here, because you can't // un-fail an assertion failure. The CHECK-failure is death, so we can make a @@ -118,6 +180,92 @@ TEST(LiteralTestUtilTest, NearComparatorR1) { EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); } +TEST(LiteralTestUtilTest, NearComparatorR1Complex64) { + auto a = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.8, 1.8}}); + auto b = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.8, 1.8}}); + auto c = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.9, 1.8}}); + auto d = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.8, 1.9}}); + EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(a, c, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(a, d, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(c, d, ErrorSpec{0.0001})); +} + +TEST(LiteralTestUtilTest, NearComparatorR1Complex128) { + auto a = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.8, 1.8}}); + auto b = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.8, 1.8}}); + auto c = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.9, 1.8}}); + auto d = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.8, 1.9}}); + EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(a, c, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(a, d, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(c, d, ErrorSpec{0.0001})); +} + TEST(LiteralTestUtilTest, NearComparatorR1Nan) { auto a = LiteralUtil::CreateR1( {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});