Add fill_triangular_inverse
, which flattens a triangular matrix in a way such that:
# Lower triangular matrix x = tf.matrix_band_part(x, -1, 0) x == fill_triangular(fill_triangular_inverse(x)) Code by srvasude@ which I'm submitting on his behalf. PiperOrigin-RevId: 198623887
This commit is contained in:
parent
5c751fe8d7
commit
176754d6cc
@ -32,6 +32,7 @@ from tensorflow.contrib.distributions.python.ops.conditional_distribution import
|
|||||||
from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import *
|
from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import *
|
||||||
from tensorflow.contrib.distributions.python.ops.deterministic import *
|
from tensorflow.contrib.distributions.python.ops.deterministic import *
|
||||||
from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular
|
from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular
|
||||||
|
from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse
|
||||||
from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
|
from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
|
||||||
from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp
|
from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp
|
||||||
from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse
|
from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse
|
||||||
@ -156,6 +157,7 @@ _allowed_symbols = [
|
|||||||
'kl_divergence',
|
'kl_divergence',
|
||||||
'RegisterKL',
|
'RegisterKL',
|
||||||
'fill_triangular',
|
'fill_triangular',
|
||||||
|
'fill_triangular_inverse',
|
||||||
'matrix_diag_transform',
|
'matrix_diag_transform',
|
||||||
'reduce_weighted_logsumexp',
|
'reduce_weighted_logsumexp',
|
||||||
'softplus_inverse',
|
'softplus_inverse',
|
||||||
|
@ -814,6 +814,30 @@ class FillTriangularTest(test.TestCase):
|
|||||||
self._run_test(self._rng.randn(2, 3, int(7*8/2)), upper=True)
|
self._run_test(self._rng.randn(2, 3, int(7*8/2)), upper=True)
|
||||||
|
|
||||||
|
|
||||||
|
class FillTriangularInverseTest(FillTriangularTest):
|
||||||
|
|
||||||
|
def _run_test(self, x_, use_deferred_shape=False, **kwargs):
|
||||||
|
x_ = np.asarray(x_)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
static_shape = None if use_deferred_shape else x_.shape
|
||||||
|
x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
|
||||||
|
zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.)
|
||||||
|
- array_ops.stop_gradient(x_pl * (x_pl - 1.)))
|
||||||
|
x = x_pl + zeros_like_x_pl
|
||||||
|
actual = du.fill_triangular(x, **kwargs)
|
||||||
|
inverse_actual = du.fill_triangular_inverse(actual, **kwargs)
|
||||||
|
|
||||||
|
inverse_actual_ = sess.run(
|
||||||
|
inverse_actual,
|
||||||
|
feed_dict={x_pl: x_})
|
||||||
|
|
||||||
|
if use_deferred_shape:
|
||||||
|
self.assertEqual(None, inverse_actual.shape)
|
||||||
|
else:
|
||||||
|
self.assertAllEqual(x_.shape, inverse_actual.shape)
|
||||||
|
self.assertAllEqual(x_, inverse_actual_)
|
||||||
|
|
||||||
|
|
||||||
class ReduceWeightedLogSumExp(test.TestCase):
|
class ReduceWeightedLogSumExp(test.TestCase):
|
||||||
|
|
||||||
def _reduce_weighted_logsumexp(self, logx, w, axis, keep_dims=False):
|
def _reduce_weighted_logsumexp(self, logx, w, axis, keep_dims=False):
|
||||||
|
@ -824,8 +824,8 @@ def fill_triangular(x, upper=False, name=None):
|
|||||||
Triangular matrix elements are filled in a clockwise spiral. See example,
|
Triangular matrix elements are filled in a clockwise spiral. See example,
|
||||||
below.
|
below.
|
||||||
|
|
||||||
If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
|
If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is
|
||||||
b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
|
`[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
|
||||||
`n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
|
`n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -914,7 +914,7 @@ def fill_triangular(x, upper=False, name=None):
|
|||||||
# = 2 (n**2 / 2 + n / 2) - n**2
|
# = 2 (n**2 / 2 + n / 2) - n**2
|
||||||
# = n**2 + n - n**2
|
# = n**2 + n - n**2
|
||||||
# = n
|
# = n
|
||||||
ndims = array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims
|
ndims = prefer_static_rank(x)
|
||||||
if upper:
|
if upper:
|
||||||
x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])]
|
x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])]
|
||||||
else:
|
else:
|
||||||
@ -932,6 +932,74 @@ def fill_triangular(x, upper=False, name=None):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def fill_triangular_inverse(x, upper=False, name=None):
|
||||||
|
"""Creates a vector from a (batch of) triangular matrix.
|
||||||
|
|
||||||
|
The vector is created from the lower-triangular or upper-triangular portion
|
||||||
|
depending on the value of the parameter `upper`.
|
||||||
|
|
||||||
|
If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
|
||||||
|
`[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
fill_triangular_inverse(
|
||||||
|
[[4, 0, 0],
|
||||||
|
[6, 5, 0],
|
||||||
|
[3, 2, 1]])
|
||||||
|
|
||||||
|
# ==> [1, 2, 3, 4, 5, 6]
|
||||||
|
|
||||||
|
fill_triangular_inverse(
|
||||||
|
[[1, 2, 3],
|
||||||
|
[0, 5, 6],
|
||||||
|
[0, 0, 4]], upper=True)
|
||||||
|
|
||||||
|
# ==> [1, 2, 3, 4, 5, 6]
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor` representing lower (or upper) triangular elements.
|
||||||
|
upper: Python `bool` representing whether output matrix should be upper
|
||||||
|
triangular (`True`) or lower triangular (`False`, default).
|
||||||
|
name: Python `str`. The name to give this op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
|
||||||
|
(or upper) triangular elements from `x`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with ops.name_scope(name, "fill_triangular_inverse", values=[x]):
|
||||||
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
|
if x.shape.with_rank_at_least(2)[-1].value is not None:
|
||||||
|
n = np.int32(x.shape[-1].value)
|
||||||
|
m = np.int32((n * (n + 1)) // 2)
|
||||||
|
static_final_shape = x.shape[:-2].concatenate([m])
|
||||||
|
else:
|
||||||
|
n = array_ops.shape(x)[-1]
|
||||||
|
m = (n * (n + 1)) // 2
|
||||||
|
static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate(
|
||||||
|
[None])
|
||||||
|
ndims = prefer_static_rank(x)
|
||||||
|
if upper:
|
||||||
|
initial_elements = x[..., 0, :]
|
||||||
|
triangular_portion = x[..., 1:, :]
|
||||||
|
else:
|
||||||
|
initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2])
|
||||||
|
triangular_portion = x[..., :-1, :]
|
||||||
|
rotated_triangular_portion = array_ops.reverse(
|
||||||
|
array_ops.reverse(triangular_portion, axis=[ndims - 1]),
|
||||||
|
axis=[ndims - 2])
|
||||||
|
consolidated_matrix = triangular_portion + rotated_triangular_portion
|
||||||
|
end_sequence = array_ops.reshape(
|
||||||
|
consolidated_matrix,
|
||||||
|
array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0))
|
||||||
|
y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1)
|
||||||
|
y.set_shape(static_final_shape)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
def tridiag(below=None, diag=None, above=None, name=None):
|
def tridiag(below=None, diag=None, above=None, name=None):
|
||||||
"""Creates a matrix with values set above, below, and on the diagonal.
|
"""Creates a matrix with values set above, below, and on the diagonal.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user