Add fill_triangular_inverse
, which flattens a triangular matrix in a way such that:
# Lower triangular matrix x = tf.matrix_band_part(x, -1, 0) x == fill_triangular(fill_triangular_inverse(x)) Code by srvasude@ which I'm submitting on his behalf. PiperOrigin-RevId: 198623887
This commit is contained in:
parent
5c751fe8d7
commit
176754d6cc
@ -32,6 +32,7 @@ from tensorflow.contrib.distributions.python.ops.conditional_distribution import
|
||||
from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import *
|
||||
from tensorflow.contrib.distributions.python.ops.deterministic import *
|
||||
from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular
|
||||
from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse
|
||||
from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
|
||||
from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp
|
||||
from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse
|
||||
@ -156,6 +157,7 @@ _allowed_symbols = [
|
||||
'kl_divergence',
|
||||
'RegisterKL',
|
||||
'fill_triangular',
|
||||
'fill_triangular_inverse',
|
||||
'matrix_diag_transform',
|
||||
'reduce_weighted_logsumexp',
|
||||
'softplus_inverse',
|
||||
|
@ -814,6 +814,30 @@ class FillTriangularTest(test.TestCase):
|
||||
self._run_test(self._rng.randn(2, 3, int(7*8/2)), upper=True)
|
||||
|
||||
|
||||
class FillTriangularInverseTest(FillTriangularTest):
|
||||
|
||||
def _run_test(self, x_, use_deferred_shape=False, **kwargs):
|
||||
x_ = np.asarray(x_)
|
||||
with self.test_session() as sess:
|
||||
static_shape = None if use_deferred_shape else x_.shape
|
||||
x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
|
||||
zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.)
|
||||
- array_ops.stop_gradient(x_pl * (x_pl - 1.)))
|
||||
x = x_pl + zeros_like_x_pl
|
||||
actual = du.fill_triangular(x, **kwargs)
|
||||
inverse_actual = du.fill_triangular_inverse(actual, **kwargs)
|
||||
|
||||
inverse_actual_ = sess.run(
|
||||
inverse_actual,
|
||||
feed_dict={x_pl: x_})
|
||||
|
||||
if use_deferred_shape:
|
||||
self.assertEqual(None, inverse_actual.shape)
|
||||
else:
|
||||
self.assertAllEqual(x_.shape, inverse_actual.shape)
|
||||
self.assertAllEqual(x_, inverse_actual_)
|
||||
|
||||
|
||||
class ReduceWeightedLogSumExp(test.TestCase):
|
||||
|
||||
def _reduce_weighted_logsumexp(self, logx, w, axis, keep_dims=False):
|
||||
|
@ -824,8 +824,8 @@ def fill_triangular(x, upper=False, name=None):
|
||||
Triangular matrix elements are filled in a clockwise spiral. See example,
|
||||
below.
|
||||
|
||||
If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
|
||||
b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
|
||||
If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is
|
||||
`[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
|
||||
`n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
|
||||
|
||||
Example:
|
||||
@ -914,7 +914,7 @@ def fill_triangular(x, upper=False, name=None):
|
||||
# = 2 (n**2 / 2 + n / 2) - n**2
|
||||
# = n**2 + n - n**2
|
||||
# = n
|
||||
ndims = array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims
|
||||
ndims = prefer_static_rank(x)
|
||||
if upper:
|
||||
x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])]
|
||||
else:
|
||||
@ -932,6 +932,74 @@ def fill_triangular(x, upper=False, name=None):
|
||||
return x
|
||||
|
||||
|
||||
def fill_triangular_inverse(x, upper=False, name=None):
|
||||
"""Creates a vector from a (batch of) triangular matrix.
|
||||
|
||||
The vector is created from the lower-triangular or upper-triangular portion
|
||||
depending on the value of the parameter `upper`.
|
||||
|
||||
If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
|
||||
`[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
fill_triangular_inverse(
|
||||
[[4, 0, 0],
|
||||
[6, 5, 0],
|
||||
[3, 2, 1]])
|
||||
|
||||
# ==> [1, 2, 3, 4, 5, 6]
|
||||
|
||||
fill_triangular_inverse(
|
||||
[[1, 2, 3],
|
||||
[0, 5, 6],
|
||||
[0, 0, 4]], upper=True)
|
||||
|
||||
# ==> [1, 2, 3, 4, 5, 6]
|
||||
```
|
||||
|
||||
Args:
|
||||
x: `Tensor` representing lower (or upper) triangular elements.
|
||||
upper: Python `bool` representing whether output matrix should be upper
|
||||
triangular (`True`) or lower triangular (`False`, default).
|
||||
name: Python `str`. The name to give this op.
|
||||
|
||||
Returns:
|
||||
flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
|
||||
(or upper) triangular elements from `x`.
|
||||
"""
|
||||
|
||||
with ops.name_scope(name, "fill_triangular_inverse", values=[x]):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
if x.shape.with_rank_at_least(2)[-1].value is not None:
|
||||
n = np.int32(x.shape[-1].value)
|
||||
m = np.int32((n * (n + 1)) // 2)
|
||||
static_final_shape = x.shape[:-2].concatenate([m])
|
||||
else:
|
||||
n = array_ops.shape(x)[-1]
|
||||
m = (n * (n + 1)) // 2
|
||||
static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate(
|
||||
[None])
|
||||
ndims = prefer_static_rank(x)
|
||||
if upper:
|
||||
initial_elements = x[..., 0, :]
|
||||
triangular_portion = x[..., 1:, :]
|
||||
else:
|
||||
initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2])
|
||||
triangular_portion = x[..., :-1, :]
|
||||
rotated_triangular_portion = array_ops.reverse(
|
||||
array_ops.reverse(triangular_portion, axis=[ndims - 1]),
|
||||
axis=[ndims - 2])
|
||||
consolidated_matrix = triangular_portion + rotated_triangular_portion
|
||||
end_sequence = array_ops.reshape(
|
||||
consolidated_matrix,
|
||||
array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0))
|
||||
y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1)
|
||||
y.set_shape(static_final_shape)
|
||||
return y
|
||||
|
||||
|
||||
def tridiag(below=None, diag=None, above=None, name=None):
|
||||
"""Creates a matrix with values set above, below, and on the diagonal.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user