Support automatic differentiation in fill_lower_triangular when input has no batch dims.
Change: 137973576
This commit is contained in:
parent
65b1677070
commit
60d5d28ae2
@ -325,7 +325,7 @@ class FillLowerTriangularTest(tf.test.TestCase):
|
||||
|
||||
def testCorrectlyMakesNoBatchLowerTril(self):
|
||||
with self.test_session():
|
||||
x = np.arange(9)
|
||||
x = tf.convert_to_tensor(np.arange(9, dtype=np.float32))
|
||||
expected = np.array(
|
||||
[[0., 0., 0.],
|
||||
[1., 2., 0.],
|
||||
@ -333,6 +333,10 @@ class FillLowerTriangularTest(tf.test.TestCase):
|
||||
actual = distribution_util.fill_lower_triangular(x)
|
||||
self.assertAllEqual(expected.shape, actual.get_shape())
|
||||
self.assertAllEqual(expected, actual.eval())
|
||||
self.assertAllEqual(
|
||||
np.concatenate([np.ones(6, dtype=np.float32),
|
||||
np.zeros(3, dtype=np.float32)]),
|
||||
tf.gradients(distribution_util.fill_lower_triangular(x), x)[0].eval())
|
||||
|
||||
def testCorrectlyMakesBatchLowerTril(self):
|
||||
with self.test_session():
|
||||
|
@ -435,15 +435,14 @@ def fill_lower_triangular(x, name="fill_lower_triangular"):
|
||||
"""
|
||||
with ops.name_scope(name, values=(x,)):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
ndims = x.get_shape().ndims
|
||||
if ndims is not None and x.get_shape()[-1].value is not None:
|
||||
if (x.get_shape().ndims is not None and
|
||||
x.get_shape()[-1].value is not None):
|
||||
d = x.get_shape()[-1].value
|
||||
# d = n^2/2 + n/2 implies n is:
|
||||
n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
|
||||
final_shape = x.get_shape()[:-1].concatenate(
|
||||
tensor_shape.TensorShape([n, n]))
|
||||
else:
|
||||
ndims = array_ops.rank(x)
|
||||
d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
|
||||
# d = n^2/2 + n/2 implies n is:
|
||||
n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
|
||||
@ -494,7 +493,12 @@ def fill_lower_triangular(x, name="fill_lower_triangular"):
|
||||
array_ops.tile([tril_ids], [m, 1])])
|
||||
idx = array_ops.transpose(idx, [1, 2, 0])
|
||||
|
||||
y = array_ops.gather_nd(y, idx)
|
||||
if x.get_shape().ndims == 1:
|
||||
# Prefer using gather because it has a gradient.
|
||||
# We wrap the result in a list so downstream logic "just works."
|
||||
y = [array_ops.gather(y[0, :], tril_ids)]
|
||||
else:
|
||||
y = array_ops.gather_nd(y, idx)
|
||||
y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))
|
||||
|
||||
y.set_shape(y.get_shape().merge_with(final_shape))
|
||||
|
Loading…
Reference in New Issue
Block a user