Support automatic differentiation in fill_lower_triangular when input has no batch dims.
Change: 137973576
This commit is contained in:
parent
65b1677070
commit
60d5d28ae2
tensorflow/contrib/distributions/python
@ -325,7 +325,7 @@ class FillLowerTriangularTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testCorrectlyMakesNoBatchLowerTril(self):
|
def testCorrectlyMakesNoBatchLowerTril(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
x = np.arange(9)
|
x = tf.convert_to_tensor(np.arange(9, dtype=np.float32))
|
||||||
expected = np.array(
|
expected = np.array(
|
||||||
[[0., 0., 0.],
|
[[0., 0., 0.],
|
||||||
[1., 2., 0.],
|
[1., 2., 0.],
|
||||||
@ -333,6 +333,10 @@ class FillLowerTriangularTest(tf.test.TestCase):
|
|||||||
actual = distribution_util.fill_lower_triangular(x)
|
actual = distribution_util.fill_lower_triangular(x)
|
||||||
self.assertAllEqual(expected.shape, actual.get_shape())
|
self.assertAllEqual(expected.shape, actual.get_shape())
|
||||||
self.assertAllEqual(expected, actual.eval())
|
self.assertAllEqual(expected, actual.eval())
|
||||||
|
self.assertAllEqual(
|
||||||
|
np.concatenate([np.ones(6, dtype=np.float32),
|
||||||
|
np.zeros(3, dtype=np.float32)]),
|
||||||
|
tf.gradients(distribution_util.fill_lower_triangular(x), x)[0].eval())
|
||||||
|
|
||||||
def testCorrectlyMakesBatchLowerTril(self):
|
def testCorrectlyMakesBatchLowerTril(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
@ -435,15 +435,14 @@ def fill_lower_triangular(x, name="fill_lower_triangular"):
|
|||||||
"""
|
"""
|
||||||
with ops.name_scope(name, values=(x,)):
|
with ops.name_scope(name, values=(x,)):
|
||||||
x = ops.convert_to_tensor(x, name="x")
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
ndims = x.get_shape().ndims
|
if (x.get_shape().ndims is not None and
|
||||||
if ndims is not None and x.get_shape()[-1].value is not None:
|
x.get_shape()[-1].value is not None):
|
||||||
d = x.get_shape()[-1].value
|
d = x.get_shape()[-1].value
|
||||||
# d = n^2/2 + n/2 implies n is:
|
# d = n^2/2 + n/2 implies n is:
|
||||||
n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
|
n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
|
||||||
final_shape = x.get_shape()[:-1].concatenate(
|
final_shape = x.get_shape()[:-1].concatenate(
|
||||||
tensor_shape.TensorShape([n, n]))
|
tensor_shape.TensorShape([n, n]))
|
||||||
else:
|
else:
|
||||||
ndims = array_ops.rank(x)
|
|
||||||
d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
|
d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
|
||||||
# d = n^2/2 + n/2 implies n is:
|
# d = n^2/2 + n/2 implies n is:
|
||||||
n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
|
n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
|
||||||
@ -494,7 +493,12 @@ def fill_lower_triangular(x, name="fill_lower_triangular"):
|
|||||||
array_ops.tile([tril_ids], [m, 1])])
|
array_ops.tile([tril_ids], [m, 1])])
|
||||||
idx = array_ops.transpose(idx, [1, 2, 0])
|
idx = array_ops.transpose(idx, [1, 2, 0])
|
||||||
|
|
||||||
y = array_ops.gather_nd(y, idx)
|
if x.get_shape().ndims == 1:
|
||||||
|
# Prefer using gather because it has a gradient.
|
||||||
|
# We wrap the result in a list so downstream logic "just works."
|
||||||
|
y = [array_ops.gather(y[0, :], tril_ids)]
|
||||||
|
else:
|
||||||
|
y = array_ops.gather_nd(y, idx)
|
||||||
y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))
|
y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))
|
||||||
|
|
||||||
y.set_shape(y.get_shape().merge_with(final_shape))
|
y.set_shape(y.get_shape().merge_with(final_shape))
|
||||||
|
Loading…
Reference in New Issue
Block a user