Simplify tf.contrib.distributions.fill_triangular indexing calculation and add more comments explaining some of the magic.
PiperOrigin-RevId: 169557899
This commit is contained in:
parent
5cc7b86a0d
commit
2679dcfbaa
@ -769,13 +769,45 @@ def fill_triangular(x, upper=False, name=None):
|
||||
dtype=dtypes.int32)
|
||||
static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate(
|
||||
[None, None])
|
||||
# We can't do: `x[..., -(n**2-m):]` because this doesn't correctly handle
|
||||
# `m == n == 1`. Hence, we do nonnegative indexing.
|
||||
x_tail = x[..., (m - (n**2 - m)):]
|
||||
# We now concatenate the "tail" of `x` to `x` (and reverse one of them).
|
||||
#
|
||||
# We do this based on the insight that the input `x` provides `ceil(n/2)`
|
||||
# rows of an `n x n` matrix, some of which will get zeroed out being on the
|
||||
# wrong side of the diagonal. The first row will not get zeroed out at all,
|
||||
# and we need `floor(n/2)` more rows, so the first is what we omit from
|
||||
# `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)`
|
||||
# rows provided by a reversed tail, it is exactly the other set of elements
|
||||
# of the reversed tail which will be zeroed out for being on the wrong side
|
||||
# of the diagonal further up/down the matrix. And, in doing-so, we've filled
|
||||
# the triangular matrix in a clock-wise spiral pattern. Neat!
|
||||
#
|
||||
# Try it out in numpy:
|
||||
# n = 3
|
||||
# x = np.arange(n * (n + 1) / 2)
|
||||
# m = x.shape[0]
|
||||
# n = np.int32(np.sqrt(.25 + 2 * m) - .5)
|
||||
# x_tail = x[(m - (n**2 - m)):]
|
||||
# np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower
|
||||
# # ==> array([[3, 4, 5],
|
||||
# [5, 4, 3],
|
||||
# [2, 1, 0]])
|
||||
# np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper
|
||||
# # ==> array([[0, 1, 2],
|
||||
# [3, 4, 5],
|
||||
# [5, 4, 3]])
|
||||
#
|
||||
# Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't
|
||||
# correctly handle `m == n == 1`. Hence, we do nonnegative indexing.
|
||||
# Furthermore observe that:
|
||||
# m - (n**2 - m)
|
||||
# = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2)
|
||||
# = 2 (n**2 / 2 + n / 2) - n**2
|
||||
# = n**2 + n - n**2
|
||||
# = n
|
||||
if upper:
|
||||
x_list = [x, array_ops.reverse(x_tail, axis=[-1])]
|
||||
x_list = [x, array_ops.reverse(x[..., n:], axis=[-1])]
|
||||
else:
|
||||
x_list = [x_tail, array_ops.reverse(x, axis=[-1])]
|
||||
x_list = [x[..., n:], array_ops.reverse(x, axis=[-1])]
|
||||
new_shape = (
|
||||
static_final_shape.as_list()
|
||||
if static_final_shape.is_fully_defined()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user