Change linear_operator_test_util.random_postive_definite_matrix to return nicer conditioned matrix by default (a Wishart matrix). Enable tests in oss again.
PiperOrigin-RevId: 294817440 Change-Id: Ib9a976e9c78fba41863639d7c158e2a3b25523e8
This commit is contained in:
parent
cec24afe2d
commit
4d5741f1d4
@ -260,7 +260,6 @@ cuda_py_test(
|
|||||||
srcs = ["linear_operator_full_matrix_test.py"],
|
srcs = ["linear_operator_full_matrix_test.py"],
|
||||||
shard_count = 5,
|
shard_count = 5,
|
||||||
tags = [
|
tags = [
|
||||||
"no_oss", # b/149346156
|
|
||||||
"noasan",
|
"noasan",
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
|
@ -825,15 +825,23 @@ class NonSquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest):
|
|||||||
return 2
|
return 2
|
||||||
|
|
||||||
|
|
||||||
def random_positive_definite_matrix(shape, dtype, force_well_conditioned=False):
|
def random_positive_definite_matrix(shape,
|
||||||
"""[batch] positive definite matrix.
|
dtype,
|
||||||
|
oversampling_ratio=4,
|
||||||
|
force_well_conditioned=False):
|
||||||
|
"""[batch] positive definite Wisart matrix.
|
||||||
|
|
||||||
|
A Wishart(N, S) matrix is the S sample covariance matrix of an N-variate
|
||||||
|
(standard) Normal random variable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape: `TensorShape` or Python list. Shape of the returned matrix.
|
shape: `TensorShape` or Python list. Shape of the returned matrix.
|
||||||
dtype: `TensorFlow` `dtype` or Python dtype.
|
dtype: `TensorFlow` `dtype` or Python dtype.
|
||||||
force_well_conditioned: Python bool. If `True`, returned matrix has
|
oversampling_ratio: S / N in the above. If S < N, the matrix will be
|
||||||
eigenvalues with modulus in `(1, 4)`. Otherwise, eigenvalues are
|
singular (unless `force_well_conditioned is True`).
|
||||||
chi-squared random variables.
|
force_well_conditioned: Python bool. If `True`, add `1` to the diagonal
|
||||||
|
of the Wishart matrix, then divide by 2, ensuring most eigenvalues are
|
||||||
|
close to 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`Tensor` with desired shape and dtype.
|
`Tensor` with desired shape and dtype.
|
||||||
@ -843,11 +851,21 @@ def random_positive_definite_matrix(shape, dtype, force_well_conditioned=False):
|
|||||||
shape = tensor_shape.TensorShape(shape)
|
shape = tensor_shape.TensorShape(shape)
|
||||||
# Matrix must be square.
|
# Matrix must be square.
|
||||||
shape.dims[-1].assert_is_compatible_with(shape.dims[-2])
|
shape.dims[-1].assert_is_compatible_with(shape.dims[-2])
|
||||||
|
shape = shape.as_list()
|
||||||
|
n = shape[-2]
|
||||||
|
s = oversampling_ratio * shape[-1]
|
||||||
|
wigner_shape = shape[:-2] + [n, s]
|
||||||
|
|
||||||
with ops.name_scope("random_positive_definite_matrix"):
|
with ops.name_scope("random_positive_definite_matrix"):
|
||||||
tril = random_tril_matrix(
|
wigner = random_normal(
|
||||||
shape, dtype, force_well_conditioned=force_well_conditioned)
|
wigner_shape,
|
||||||
return math_ops.matmul(tril, tril, adjoint_b=True)
|
dtype=dtype,
|
||||||
|
stddev=math_ops.cast(1 / np.sqrt(s), dtype.real_dtype))
|
||||||
|
wishart = math_ops.matmul(wigner, wigner, adjoint_b=True)
|
||||||
|
if force_well_conditioned:
|
||||||
|
wishart += linalg_ops.eye(n, dtype=dtype)
|
||||||
|
wishart /= math_ops.cast(2, dtype)
|
||||||
|
return wishart
|
||||||
|
|
||||||
|
|
||||||
def random_tril_matrix(shape,
|
def random_tril_matrix(shape,
|
||||||
|
Loading…
Reference in New Issue
Block a user