Change contrib/bayesflow/custom_grad tests to assertAllClose.
PiperOrigin-RevId: 162681035
This commit is contained in:
parent
d2f8e98650
commit
e4d6cb9744
@ -48,8 +48,8 @@ class CustomGradientTest(test.TestCase):
|
|||||||
gx = gradients_impl.gradients(fx, x)[0]
|
gx = gradients_impl.gradients(fx, x)[0]
|
||||||
[fx_, gx_] = sess.run([fx, gx])
|
[fx_, gx_] = sess.run([fx, gx])
|
||||||
|
|
||||||
self.assertAllEqual(f(x_), fx_)
|
self.assertAllClose(f(x_), fx_)
|
||||||
self.assertAllEqual(g(x_), gx_)
|
self.assertAllClose(g(x_), gx_)
|
||||||
|
|
||||||
def test_works_correctly_both_f_g_zero(self):
|
def test_works_correctly_both_f_g_zero(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -62,8 +62,8 @@ class CustomGradientTest(test.TestCase):
|
|||||||
gx = gradients_impl.gradients(fx, x)[0]
|
gx = gradients_impl.gradients(fx, x)[0]
|
||||||
[fx_, gx_] = sess.run([fx, gx])
|
[fx_, gx_] = sess.run([fx, gx])
|
||||||
|
|
||||||
self.assertAllEqual(f(x_), fx_)
|
self.assertAllClose(f(x_), fx_)
|
||||||
self.assertAllEqual(g(x_), gx_)
|
self.assertAllClose(g(x_), gx_)
|
||||||
|
|
||||||
def test_works_correctly_vector_of_vars(self):
|
def test_works_correctly_vector_of_vars(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
|
Loading…
Reference in New Issue
Block a user