Modify assertAllClose to handle dicts.

Change: 153402607
This commit is contained in:
A. Unique TensorFlower 2017-04-17 15:16:31 -08:00 committed by TensorFlower Gardener
parent 4b046fb5aa
commit ae84106edc
2 changed files with 73 additions and 11 deletions

View File

@ -564,15 +564,7 @@ class TensorFlowTestCase(googletest.TestCase):
a = np.array(a)
return a
def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6):
"""Asserts that two numpy arrays have near values.
Args:
a: a numpy ndarray or anything can be converted to one.
b: a numpy ndarray or anything can be converted to one.
rtol: relative tolerance.
atol: absolute tolerance.
"""
def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
a = self._GetNdArray(a)
b = self._GetNdArray(b)
self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." %
@ -600,7 +592,37 @@ class TensorFlowTestCase(googletest.TestCase):
print("not close dif = ", np.abs(x - y))
print("not close tol = ", atol + rtol * np.abs(y))
print("dtype = %s, shape = %s" % (a.dtype, a.shape))
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol, err_msg=msg)
def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6):
"""Asserts that two numpy arrays, or dicts of same, have near values.
This does not support nested dicts.
Args:
a: A numpy ndarray (or anything can be converted to one), or dict of same.
Must be a dict iff `b` is a dict.
b: A numpy ndarray (or anything can be converted to one), or dict of same.
Must be a dict iff `a` is a dict.
rtol: relative tolerance.
atol: absolute tolerance.
Raises:
ValueError: if only one of `a` and `b` is a dict.
"""
is_a_dict = isinstance(a, dict)
if is_a_dict != isinstance(b, dict):
raise ValueError("Can't compare dict to non-dict, %s vs %s." % (a, b))
if is_a_dict:
self.assertItemsEqual(
a.keys(), b.keys(),
msg="mismatched keys, expected %s, got %s" % (a.keys(), b.keys()))
for k in a:
self._assertArrayLikeAllClose(
a[k], b[k], rtol=rtol, atol=atol,
msg="%s: expected %s, got %s." % (k, a, b))
else:
self._assertArrayLikeAllClose(a, b, rtol=rtol, atol=atol)
def assertAllCloseAccordingToType(self,
a,

View File

@ -196,7 +196,47 @@ class TestUtilTest(test_util.TensorFlowTestCase):
def testAllCloseScalars(self):
self.assertAllClose(7, 7 + 1e-8)
with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
self.assertAllClose(7, 8)
self.assertAllClose(7, 7 + 1e-5)
def testAllCloseDictToNonDict(self):
with self.assertRaisesRegexp(ValueError, r"Can't compare dict to non-dict"):
self.assertAllClose(1, {"a": 1})
with self.assertRaisesRegexp(ValueError, r"Can't compare dict to non-dict"):
self.assertAllClose({"a": 1}, 1)
def testAllCloseDicts(self):
a = 7
b = (2., 3.)
c = np.ones((3, 2, 4)) * 7.
expected = {"a": a, "b": b, "c": c}
# Identity.
self.assertAllClose(expected, expected)
self.assertAllClose(expected, dict(expected))
# With each item removed.
for k in expected:
actual = dict(expected)
del actual[k]
with self.assertRaisesRegexp(AssertionError, r"mismatched keys"):
self.assertAllClose(expected, actual)
# With each item changed.
with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
self.assertAllClose(expected, {"a": a + 1e-5, "b": b, "c": c})
with self.assertRaisesRegexp(AssertionError, r"Shape mismatch"):
self.assertAllClose(expected, {"a": a, "b": b + (4.,), "c": c})
c_copy = np.array(c)
c_copy[1, 1, 1] += 1e-5
with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
self.assertAllClose(expected, {"a": a, "b": b, "c": c_copy})
def testAllCloseNestedDicts(self):
a = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}}
with self.assertRaisesRegexp(
TypeError,
r"inputs could not be safely coerced to any supported types"):
self.assertAllClose(a, a)
def testArrayNear(self):
a = [1, 2]