Simplify tf.contrib.distributions.fill_triangular indexing calculation and add more comments explaining some of the magic.

PiperOrigin-RevId: 169557899
This commit is contained in:
Joshua V. Dillon 2017-09-21 10:00:54 -07:00 committed by TensorFlower Gardener
parent 5cc7b86a0d
commit 2679dcfbaa

View File

@ -769,13 +769,45 @@ def fill_triangular(x, upper=False, name=None):
dtype=dtypes.int32)
static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate(
[None, None])
# We can't do: `x[..., -(n**2-m):]` because this doesn't correctly handle
# `m == n == 1`. Hence, we do nonnegative indexing.
x_tail = x[..., (m - (n**2 - m)):]
# We now concatenate the "tail" of `x` to `x` (and reverse one of them).
#
# We do this based on the insight that the input `x` provides `ceil(n/2)`
# rows of an `n x n` matrix, some of which will get zeroed out being on the
# wrong side of the diagonal. The first row will not get zeroed out at all,
# and we need `floor(n/2)` more rows, so the first is what we omit from
# `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)`
# rows provided by a reversed tail, it is exactly the other set of elements
# of the reversed tail which will be zeroed out for being on the wrong side
# of the diagonal further up/down the matrix. And, in doing-so, we've filled
# the triangular matrix in a clock-wise spiral pattern. Neat!
#
# Try it out in numpy:
# n = 3
# x = np.arange(n * (n + 1) / 2)
# m = x.shape[0]
# n = np.int32(np.sqrt(.25 + 2 * m) - .5)
# x_tail = x[(m - (n**2 - m)):]
# np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower
# # ==> array([[3, 4, 5],
# [5, 4, 3],
# [2, 1, 0]])
# np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper
# # ==> array([[0, 1, 2],
# [3, 4, 5],
# [5, 4, 3]])
#
# Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't
# correctly handle `m == n == 1`. Hence, we do nonnegative indexing.
# Furthermore observe that:
# m - (n**2 - m)
# = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2)
# = 2 (n**2 / 2 + n / 2) - n**2
# = n**2 + n - n**2
# = n
if upper:
x_list = [x, array_ops.reverse(x_tail, axis=[-1])]
x_list = [x, array_ops.reverse(x[..., n:], axis=[-1])]
else:
x_list = [x_tail, array_ops.reverse(x, axis=[-1])]
x_list = [x[..., n:], array_ops.reverse(x, axis=[-1])]
new_shape = (
static_final_shape.as_list()
if static_final_shape.is_fully_defined()