Add fill_triangular_inverse, which flattens a triangular matrix in a way such that:

# Lower triangular matrix
x = tf.matrix_band_part(x, -1, 0)
x == fill_triangular(fill_triangular_inverse(x))
Code by srvasude@ which I'm submitting on his behalf.

PiperOrigin-RevId: 198623887
This commit is contained in:
Joshua V. Dillon 2018-05-30 14:52:57 -07:00 committed by TensorFlower Gardener
parent 5c751fe8d7
commit 176754d6cc
3 changed files with 97 additions and 3 deletions

View File

@ -32,6 +32,7 @@ from tensorflow.contrib.distributions.python.ops.conditional_distribution import
from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import *
from tensorflow.contrib.distributions.python.ops.deterministic import *
from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular
from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse
from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp
from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse
@ -156,6 +157,7 @@ _allowed_symbols = [
'kl_divergence',
'RegisterKL',
'fill_triangular',
'fill_triangular_inverse',
'matrix_diag_transform',
'reduce_weighted_logsumexp',
'softplus_inverse',

View File

@ -814,6 +814,30 @@ class FillTriangularTest(test.TestCase):
self._run_test(self._rng.randn(2, 3, int(7*8/2)), upper=True)
class FillTriangularInverseTest(FillTriangularTest):
def _run_test(self, x_, use_deferred_shape=False, **kwargs):
x_ = np.asarray(x_)
with self.test_session() as sess:
static_shape = None if use_deferred_shape else x_.shape
x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.)
- array_ops.stop_gradient(x_pl * (x_pl - 1.)))
x = x_pl + zeros_like_x_pl
actual = du.fill_triangular(x, **kwargs)
inverse_actual = du.fill_triangular_inverse(actual, **kwargs)
inverse_actual_ = sess.run(
inverse_actual,
feed_dict={x_pl: x_})
if use_deferred_shape:
self.assertEqual(None, inverse_actual.shape)
else:
self.assertAllEqual(x_.shape, inverse_actual.shape)
self.assertAllEqual(x_, inverse_actual_)
class ReduceWeightedLogSumExp(test.TestCase):
def _reduce_weighted_logsumexp(self, logx, w, axis, keep_dims=False):

View File

@ -824,8 +824,8 @@ def fill_triangular(x, upper=False, name=None):
Triangular matrix elements are filled in a clockwise spiral. See example,
below.
If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is
`[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
`n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
Example:
@ -914,7 +914,7 @@ def fill_triangular(x, upper=False, name=None):
# = 2 (n**2 / 2 + n / 2) - n**2
# = n**2 + n - n**2
# = n
ndims = array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims
ndims = prefer_static_rank(x)
if upper:
x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])]
else:
@ -932,6 +932,74 @@ def fill_triangular(x, upper=False, name=None):
return x
def fill_triangular_inverse(x, upper=False, name=None):
"""Creates a vector from a (batch of) triangular matrix.
The vector is created from the lower-triangular or upper-triangular portion
depending on the value of the parameter `upper`.
If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
`[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
Example:
```python
fill_triangular_inverse(
[[4, 0, 0],
[6, 5, 0],
[3, 2, 1]])
# ==> [1, 2, 3, 4, 5, 6]
fill_triangular_inverse(
[[1, 2, 3],
[0, 5, 6],
[0, 0, 4]], upper=True)
# ==> [1, 2, 3, 4, 5, 6]
```
Args:
x: `Tensor` representing lower (or upper) triangular elements.
upper: Python `bool` representing whether output matrix should be upper
triangular (`True`) or lower triangular (`False`, default).
name: Python `str`. The name to give this op.
Returns:
flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
(or upper) triangular elements from `x`.
"""
with ops.name_scope(name, "fill_triangular_inverse", values=[x]):
x = ops.convert_to_tensor(x, name="x")
if x.shape.with_rank_at_least(2)[-1].value is not None:
n = np.int32(x.shape[-1].value)
m = np.int32((n * (n + 1)) // 2)
static_final_shape = x.shape[:-2].concatenate([m])
else:
n = array_ops.shape(x)[-1]
m = (n * (n + 1)) // 2
static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate(
[None])
ndims = prefer_static_rank(x)
if upper:
initial_elements = x[..., 0, :]
triangular_portion = x[..., 1:, :]
else:
initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2])
triangular_portion = x[..., :-1, :]
rotated_triangular_portion = array_ops.reverse(
array_ops.reverse(triangular_portion, axis=[ndims - 1]),
axis=[ndims - 2])
consolidated_matrix = triangular_portion + rotated_triangular_portion
end_sequence = array_ops.reshape(
consolidated_matrix,
array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0))
y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1)
y.set_shape(static_final_shape)
return y
def tridiag(below=None, diag=None, above=None, name=None):
"""Creates a matrix with values set above, below, and on the diagonal.