Change linear_operator_test_util.random_postive_definite_matrix to return nicer conditioned matrix by default (a Wishart matrix). Enable tests in oss again.
PiperOrigin-RevId: 294817440 Change-Id: Ib9a976e9c78fba41863639d7c158e2a3b25523e8
This commit is contained in:
parent
cec24afe2d
commit
4d5741f1d4
@ -260,7 +260,6 @@ cuda_py_test(
|
||||
srcs = ["linear_operator_full_matrix_test.py"],
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"no_oss", # b/149346156
|
||||
"noasan",
|
||||
"optonly",
|
||||
],
|
||||
|
@ -825,15 +825,23 @@ class NonSquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest):
|
||||
return 2
|
||||
|
||||
|
||||
def random_positive_definite_matrix(shape, dtype, force_well_conditioned=False):
|
||||
"""[batch] positive definite matrix.
|
||||
def random_positive_definite_matrix(shape,
|
||||
dtype,
|
||||
oversampling_ratio=4,
|
||||
force_well_conditioned=False):
|
||||
"""[batch] positive definite Wisart matrix.
|
||||
|
||||
A Wishart(N, S) matrix is the S sample covariance matrix of an N-variate
|
||||
(standard) Normal random variable.
|
||||
|
||||
Args:
|
||||
shape: `TensorShape` or Python list. Shape of the returned matrix.
|
||||
dtype: `TensorFlow` `dtype` or Python dtype.
|
||||
force_well_conditioned: Python bool. If `True`, returned matrix has
|
||||
eigenvalues with modulus in `(1, 4)`. Otherwise, eigenvalues are
|
||||
chi-squared random variables.
|
||||
oversampling_ratio: S / N in the above. If S < N, the matrix will be
|
||||
singular (unless `force_well_conditioned is True`).
|
||||
force_well_conditioned: Python bool. If `True`, add `1` to the diagonal
|
||||
of the Wishart matrix, then divide by 2, ensuring most eigenvalues are
|
||||
close to 1.
|
||||
|
||||
Returns:
|
||||
`Tensor` with desired shape and dtype.
|
||||
@ -843,11 +851,21 @@ def random_positive_definite_matrix(shape, dtype, force_well_conditioned=False):
|
||||
shape = tensor_shape.TensorShape(shape)
|
||||
# Matrix must be square.
|
||||
shape.dims[-1].assert_is_compatible_with(shape.dims[-2])
|
||||
shape = shape.as_list()
|
||||
n = shape[-2]
|
||||
s = oversampling_ratio * shape[-1]
|
||||
wigner_shape = shape[:-2] + [n, s]
|
||||
|
||||
with ops.name_scope("random_positive_definite_matrix"):
|
||||
tril = random_tril_matrix(
|
||||
shape, dtype, force_well_conditioned=force_well_conditioned)
|
||||
return math_ops.matmul(tril, tril, adjoint_b=True)
|
||||
wigner = random_normal(
|
||||
wigner_shape,
|
||||
dtype=dtype,
|
||||
stddev=math_ops.cast(1 / np.sqrt(s), dtype.real_dtype))
|
||||
wishart = math_ops.matmul(wigner, wigner, adjoint_b=True)
|
||||
if force_well_conditioned:
|
||||
wishart += linalg_ops.eye(n, dtype=dtype)
|
||||
wishart /= math_ops.cast(2, dtype)
|
||||
return wishart
|
||||
|
||||
|
||||
def random_tril_matrix(shape,
|
||||
|
Loading…
Reference in New Issue
Block a user