Change contrib/bayesflow/custom_grad tests to assertAllClose.
PiperOrigin-RevId: 162681035
This commit is contained in:
parent
d2f8e98650
commit
e4d6cb9744
@ -48,8 +48,8 @@ class CustomGradientTest(test.TestCase):
|
||||
gx = gradients_impl.gradients(fx, x)[0]
|
||||
[fx_, gx_] = sess.run([fx, gx])
|
||||
|
||||
self.assertAllEqual(f(x_), fx_)
|
||||
self.assertAllEqual(g(x_), gx_)
|
||||
self.assertAllClose(f(x_), fx_)
|
||||
self.assertAllClose(g(x_), gx_)
|
||||
|
||||
def test_works_correctly_both_f_g_zero(self):
|
||||
with self.test_session() as sess:
|
||||
@ -62,8 +62,8 @@ class CustomGradientTest(test.TestCase):
|
||||
gx = gradients_impl.gradients(fx, x)[0]
|
||||
[fx_, gx_] = sess.run([fx, gx])
|
||||
|
||||
self.assertAllEqual(f(x_), fx_)
|
||||
self.assertAllEqual(g(x_), gx_)
|
||||
self.assertAllClose(f(x_), fx_)
|
||||
self.assertAllClose(g(x_), gx_)
|
||||
|
||||
def test_works_correctly_vector_of_vars(self):
|
||||
with self.test_session() as sess:
|
||||
|
Loading…
Reference in New Issue
Block a user