Support automatic differentiation in fill_lower_triangular when input has no batch dims.

Change: 137973576
This commit is contained in:
Joshua V. Dillon 2016-11-02 10:48:54 -08:00 committed by TensorFlower Gardener
parent 65b1677070
commit 60d5d28ae2
2 changed files with 13 additions and 5 deletions

View File

@ -325,7 +325,7 @@ class FillLowerTriangularTest(tf.test.TestCase):
def testCorrectlyMakesNoBatchLowerTril(self):
with self.test_session():
x = np.arange(9)
x = tf.convert_to_tensor(np.arange(9, dtype=np.float32))
expected = np.array(
[[0., 0., 0.],
[1., 2., 0.],
@ -333,6 +333,10 @@ class FillLowerTriangularTest(tf.test.TestCase):
actual = distribution_util.fill_lower_triangular(x)
self.assertAllEqual(expected.shape, actual.get_shape())
self.assertAllEqual(expected, actual.eval())
self.assertAllEqual(
np.concatenate([np.ones(6, dtype=np.float32),
np.zeros(3, dtype=np.float32)]),
tf.gradients(distribution_util.fill_lower_triangular(x), x)[0].eval())
def testCorrectlyMakesBatchLowerTril(self):
with self.test_session():

View File

@ -435,15 +435,14 @@ def fill_lower_triangular(x, name="fill_lower_triangular"):
"""
with ops.name_scope(name, values=(x,)):
x = ops.convert_to_tensor(x, name="x")
ndims = x.get_shape().ndims
if ndims is not None and x.get_shape()[-1].value is not None:
if (x.get_shape().ndims is not None and
x.get_shape()[-1].value is not None):
d = x.get_shape()[-1].value
# d = n^2/2 + n/2 implies n is:
n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
final_shape = x.get_shape()[:-1].concatenate(
tensor_shape.TensorShape([n, n]))
else:
ndims = array_ops.rank(x)
d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
# d = n^2/2 + n/2 implies n is:
n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
@ -494,7 +493,12 @@ def fill_lower_triangular(x, name="fill_lower_triangular"):
array_ops.tile([tril_ids], [m, 1])])
idx = array_ops.transpose(idx, [1, 2, 0])
y = array_ops.gather_nd(y, idx)
if x.get_shape().ndims == 1:
# Prefer using gather because it has a gradient.
# We wrap the result in a list so downstream logic "just works."
y = [array_ops.gather(y[0, :], tril_ids)]
else:
y = array_ops.gather_nd(y, idx)
y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))
y.set_shape(y.get_shape().merge_with(final_shape))