Modify assertAllClose to handle dicts.
Change: 153402607
This commit is contained in:
parent
4b046fb5aa
commit
ae84106edc
@ -564,15 +564,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
a = np.array(a)
|
||||
return a
|
||||
|
||||
def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6):
|
||||
"""Asserts that two numpy arrays have near values.
|
||||
|
||||
Args:
|
||||
a: a numpy ndarray or anything can be converted to one.
|
||||
b: a numpy ndarray or anything can be converted to one.
|
||||
rtol: relative tolerance.
|
||||
atol: absolute tolerance.
|
||||
"""
|
||||
def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
|
||||
a = self._GetNdArray(a)
|
||||
b = self._GetNdArray(b)
|
||||
self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." %
|
||||
@ -600,7 +592,37 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
print("not close dif = ", np.abs(x - y))
|
||||
print("not close tol = ", atol + rtol * np.abs(y))
|
||||
print("dtype = %s, shape = %s" % (a.dtype, a.shape))
|
||||
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
|
||||
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol, err_msg=msg)
|
||||
|
||||
def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6):
|
||||
"""Asserts that two numpy arrays, or dicts of same, have near values.
|
||||
|
||||
This does not support nested dicts.
|
||||
|
||||
Args:
|
||||
a: A numpy ndarray (or anything can be converted to one), or dict of same.
|
||||
Must be a dict iff `b` is a dict.
|
||||
b: A numpy ndarray (or anything can be converted to one), or dict of same.
|
||||
Must be a dict iff `a` is a dict.
|
||||
rtol: relative tolerance.
|
||||
atol: absolute tolerance.
|
||||
|
||||
Raises:
|
||||
ValueError: if only one of `a` and `b` is a dict.
|
||||
"""
|
||||
is_a_dict = isinstance(a, dict)
|
||||
if is_a_dict != isinstance(b, dict):
|
||||
raise ValueError("Can't compare dict to non-dict, %s vs %s." % (a, b))
|
||||
if is_a_dict:
|
||||
self.assertItemsEqual(
|
||||
a.keys(), b.keys(),
|
||||
msg="mismatched keys, expected %s, got %s" % (a.keys(), b.keys()))
|
||||
for k in a:
|
||||
self._assertArrayLikeAllClose(
|
||||
a[k], b[k], rtol=rtol, atol=atol,
|
||||
msg="%s: expected %s, got %s." % (k, a, b))
|
||||
else:
|
||||
self._assertArrayLikeAllClose(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
def assertAllCloseAccordingToType(self,
|
||||
a,
|
||||
|
@ -196,7 +196,47 @@ class TestUtilTest(test_util.TensorFlowTestCase):
|
||||
def testAllCloseScalars(self):
|
||||
self.assertAllClose(7, 7 + 1e-8)
|
||||
with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
|
||||
self.assertAllClose(7, 8)
|
||||
self.assertAllClose(7, 7 + 1e-5)
|
||||
|
||||
def testAllCloseDictToNonDict(self):
|
||||
with self.assertRaisesRegexp(ValueError, r"Can't compare dict to non-dict"):
|
||||
self.assertAllClose(1, {"a": 1})
|
||||
with self.assertRaisesRegexp(ValueError, r"Can't compare dict to non-dict"):
|
||||
self.assertAllClose({"a": 1}, 1)
|
||||
|
||||
def testAllCloseDicts(self):
|
||||
a = 7
|
||||
b = (2., 3.)
|
||||
c = np.ones((3, 2, 4)) * 7.
|
||||
expected = {"a": a, "b": b, "c": c}
|
||||
|
||||
# Identity.
|
||||
self.assertAllClose(expected, expected)
|
||||
self.assertAllClose(expected, dict(expected))
|
||||
|
||||
# With each item removed.
|
||||
for k in expected:
|
||||
actual = dict(expected)
|
||||
del actual[k]
|
||||
with self.assertRaisesRegexp(AssertionError, r"mismatched keys"):
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
# With each item changed.
|
||||
with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
|
||||
self.assertAllClose(expected, {"a": a + 1e-5, "b": b, "c": c})
|
||||
with self.assertRaisesRegexp(AssertionError, r"Shape mismatch"):
|
||||
self.assertAllClose(expected, {"a": a, "b": b + (4.,), "c": c})
|
||||
c_copy = np.array(c)
|
||||
c_copy[1, 1, 1] += 1e-5
|
||||
with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
|
||||
self.assertAllClose(expected, {"a": a, "b": b, "c": c_copy})
|
||||
|
||||
def testAllCloseNestedDicts(self):
|
||||
a = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}}
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
r"inputs could not be safely coerced to any supported types"):
|
||||
self.assertAllClose(a, a)
|
||||
|
||||
def testArrayNear(self):
|
||||
a = [1, 2]
|
||||
|
Loading…
Reference in New Issue
Block a user