Decrease tolerances on ScaleTriL tests to prevent flakiness.
PiperOrigin-RevId: 246044331
This commit is contained in:
parent
a198279b08
commit
b962b1f193
@ -41,10 +41,10 @@ class ScaleTriLBijectorTest(test.TestCase):
|
|||||||
diag_shift=shift)
|
diag_shift=shift)
|
||||||
|
|
||||||
y_ = self.evaluate(b.forward(x))
|
y_ = self.evaluate(b.forward(x))
|
||||||
self.assertAllClose(y, y_)
|
self.assertAllClose(y, y_, rtol=1e-4)
|
||||||
|
|
||||||
x_ = self.evaluate(b.inverse(y))
|
x_ = self.evaluate(b.inverse(y))
|
||||||
self.assertAllClose(x, x_)
|
self.assertAllClose(x, x_, rtol=1e-4)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testInvertible(self):
|
def testInvertible(self):
|
||||||
@ -52,18 +52,18 @@ class ScaleTriLBijectorTest(test.TestCase):
|
|||||||
# Generate random inputs from an unconstrained space, with
|
# Generate random inputs from an unconstrained space, with
|
||||||
# event size 6 to specify 3x3 triangular matrices.
|
# event size 6 to specify 3x3 triangular matrices.
|
||||||
batch_shape = [2, 1]
|
batch_shape = [2, 1]
|
||||||
x = np.float32(np.random.randn(*(batch_shape + [6])))
|
x = np.float32(self._rng.randn(*(batch_shape + [6])))
|
||||||
b = bijectors.ScaleTriL(diag_bijector=bijectors.Softplus(),
|
b = bijectors.ScaleTriL(diag_bijector=bijectors.Softplus(),
|
||||||
diag_shift=3.14159)
|
diag_shift=3.14159)
|
||||||
y = self.evaluate(b.forward(x))
|
y = self.evaluate(b.forward(x))
|
||||||
self.assertAllEqual(y.shape, batch_shape + [3, 3])
|
self.assertAllEqual(y.shape, batch_shape + [3, 3])
|
||||||
|
|
||||||
x_ = self.evaluate(b.inverse(y))
|
x_ = self.evaluate(b.inverse(y))
|
||||||
self.assertAllClose(x, x_)
|
self.assertAllClose(x, x_, rtol=1e-4)
|
||||||
|
|
||||||
fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1))
|
fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1))
|
||||||
ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2))
|
ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2))
|
||||||
self.assertAllClose(fldj, -ildj)
|
self.assertAllClose(fldj, -ildj, rtol=1e-4)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user